feat(factors): 添加 cs_mean 函数并增强 max_/min_ 单参数支持
- 新增 cs_mean 截面均值函数,支持 GTJA Alpha127 等因子转换 - max_/min_ 支持单参数调用,默认使用 252 天(约 1 年)滚动窗口
This commit is contained in:
@@ -1,350 +0,0 @@
|
||||
"""601117.SH 因子计算测试 - 使用真实数据
|
||||
|
||||
测试目标:计算中国化学(601117.SH)在2024-2025年的以下因子:
|
||||
1. return_5: 5日收益率 (close / ts_delay(close, 5) - 1)
|
||||
2. return_5_rank: 5日收益率在截面上的排名
|
||||
3. ma5: 5日均线 (ts_mean(close, 5))
|
||||
4. ma10: 10日均线 (ts_mean(close, 10))
|
||||
|
||||
数据源: DuckDB 数据库中的真实日线数据
|
||||
"""
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.factors.api import close, ts_mean, ts_delay, cs_rank
|
||||
from src.factors.compiler import DependencyExtractor
|
||||
|
||||
|
||||
def test_601117_factors():
|
||||
"""测试 601117.SH 的因子计算。"""
|
||||
print("=" * 80)
|
||||
print("601117.SH (中国化学) 因子计算测试 - 2024-2025")
|
||||
print("=" * 80)
|
||||
|
||||
# =========================================================================
|
||||
# 1. 定义因子表达式
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("1. 定义因子表达式")
|
||||
print("=" * 80)
|
||||
|
||||
# return_5: 5日收益率 = (close / close.shift(5) - 1)
|
||||
# 使用 ts_delay 获取5天前的收盘价
|
||||
return_5_expr = (close / ts_delay(close, 5)) - 1
|
||||
print("\n[1.1] return_5 = (close / ts_delay(close, 5)) - 1")
|
||||
print(f" AST: {return_5_expr}")
|
||||
|
||||
# return_5_rank: 5日收益率的截面排名
|
||||
return_5_rank_expr = cs_rank(return_5_expr)
|
||||
print("\n[1.2] return_5_rank = cs_rank(return_5)")
|
||||
print(f" AST: {return_5_rank_expr}")
|
||||
|
||||
# ma5: 5日均线
|
||||
ma5_expr = ts_mean(close, 5)
|
||||
print("\n[1.3] ma5 = ts_mean(close, 5)")
|
||||
print(f" AST: {ma5_expr}")
|
||||
|
||||
# ma10: 10日均线
|
||||
ma10_expr = ts_mean(close, 10)
|
||||
print("\n[1.4] ma10 = ts_mean(close, 10)")
|
||||
print(f" AST: {ma10_expr}")
|
||||
|
||||
# =========================================================================
|
||||
# 1.5 打印数据来源信息
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("1.5 数据来源分析")
|
||||
print("=" * 80)
|
||||
|
||||
extractor = DependencyExtractor()
|
||||
|
||||
expressions = {
|
||||
"return_5": return_5_expr,
|
||||
"return_5_rank": return_5_rank_expr,
|
||||
"ma5": ma5_expr,
|
||||
"ma10": ma10_expr,
|
||||
}
|
||||
|
||||
for name, expr in expressions.items():
|
||||
deps = extractor.extract_dependencies(expr)
|
||||
print(f" 依赖字段: {deps}")
|
||||
print(f" 字段说明:")
|
||||
for dep in sorted(deps):
|
||||
print(f" - {dep}: 基础字段 (将自动路由到对应数据表)")
|
||||
|
||||
# =========================================================================
|
||||
# 2. 创建 FactorEngine 并注册因子
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("2. 注册因子到 FactorEngine")
|
||||
print("=" * 80)
|
||||
|
||||
engine = FactorEngine()
|
||||
|
||||
engine.register("return_5", return_5_expr)
|
||||
print("[2.1] 注册 return_5")
|
||||
|
||||
engine.register("return_5_rank", return_5_rank_expr)
|
||||
print("[2.2] 注册 return_5_rank")
|
||||
|
||||
engine.register("ma5", ma5_expr)
|
||||
print("[2.3] 注册 ma5")
|
||||
|
||||
engine.register("ma10", ma10_expr)
|
||||
print("[2.4] 注册 ma10")
|
||||
|
||||
# 也注册原始 close 价格用于验证
|
||||
engine.register("close_price", close)
|
||||
print("[2.5] 注册 close_price (原始收盘价)")
|
||||
|
||||
print(f"\n已注册因子列表: {engine.list_registered()}")
|
||||
|
||||
# =========================================================================
|
||||
# 2.5 打印执行计划数据规格
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("2.5 执行计划数据规格")
|
||||
print("=" * 80)
|
||||
|
||||
for name in engine.list_registered():
|
||||
plan = engine.preview_plan(name)
|
||||
if plan:
|
||||
print(f"\n因子: {name}")
|
||||
print(f" 输出名称: {plan.output_name}")
|
||||
print(f" 依赖字段: {plan.dependencies}")
|
||||
print(f" 数据规格:")
|
||||
for i, spec in enumerate(plan.data_specs, 1):
|
||||
print(f" [{i}] 表名: {spec.table}")
|
||||
print(f" 字段: {spec.columns}")
|
||||
print(f" 回看天数: {spec.lookback_days}")
|
||||
|
||||
# =========================================================================
|
||||
# 3. 执行计算
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("3. 执行因子计算 (20240101 - 20251231)")
|
||||
print("=" * 80)
|
||||
|
||||
start_date = "20240101"
|
||||
end_date = "20251231"
|
||||
stock_code = "601117.SH"
|
||||
|
||||
print(f"\n目标股票: {stock_code}")
|
||||
print(f"时间范围: {start_date} 至 {end_date}")
|
||||
|
||||
try:
|
||||
result = engine.compute(
|
||||
factor_names=["return_5", "return_5_rank", "ma5", "ma10", "close_price"],
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=[stock_code],
|
||||
)
|
||||
|
||||
print(f"\n计算完成!")
|
||||
print(f"结果形状: {result.shape}")
|
||||
print(f"结果列: {result.columns}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n[错误] 计算失败: {e}")
|
||||
raise
|
||||
|
||||
# =========================================================================
|
||||
# 4. 结果展示与分析
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("4. 计算结果展示")
|
||||
print("=" * 80)
|
||||
|
||||
# 4.1 数据概览
|
||||
print("\n[4.1] 前20行数据预览:")
|
||||
print(result.head(20))
|
||||
|
||||
# 4.2 按时间范围分块展示
|
||||
print("\n[4.2] 2024年上半年数据 (前10行):")
|
||||
result_2024h1 = result.filter(result["trade_date"] < "20240701")
|
||||
print(result_2024h1.head(10))
|
||||
|
||||
print("\n[4.3] 2024年下半年数据 (前10行):")
|
||||
result_2024h2 = result.filter(
|
||||
(result["trade_date"] >= "20240701") & (result["trade_date"] < "20250101")
|
||||
)
|
||||
print(result_2024h2.head(10))
|
||||
|
||||
print("\n[4.4] 2025年数据 (前10行):")
|
||||
result_2025 = result.filter(result["trade_date"] >= "20250101")
|
||||
print(result_2025.head(10))
|
||||
|
||||
# =========================================================================
|
||||
# 5. 因子验证
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("5. 因子计算验证")
|
||||
print("=" * 80)
|
||||
|
||||
# 5.1 MA5/MA10 滑动窗口验证
|
||||
print("\n[5.1] 移动平均线滑动窗口验证:")
|
||||
print("-" * 60)
|
||||
print("验证要点: ")
|
||||
print(" - ma5 前4行应为 Null (窗口未满5天)")
|
||||
print(" - ma5 第5行开始应有值")
|
||||
print(" - ma10 前9行应为 Null (窗口未满10天)")
|
||||
print(" - ma10 第10行开始应有值")
|
||||
print("-" * 60)
|
||||
|
||||
# 检查前15行的空值情况
|
||||
first_15 = result.head(15)
|
||||
ma5_nulls = first_15["ma5"].null_count()
|
||||
ma10_nulls = first_15["ma10"].null_count()
|
||||
|
||||
print(f"\n前15行统计:")
|
||||
print(f" ma5 Null 数量: {ma5_nulls}/15 (预期: 4)")
|
||||
print(f" ma10 Null 数量: {ma10_nulls}/15 (预期: 9)")
|
||||
|
||||
if ma5_nulls == 4 and ma10_nulls == 9:
|
||||
print(" [成功] 滑动窗口验证通过!")
|
||||
else:
|
||||
print(" [警告] 滑动窗口验证异常,请检查数据")
|
||||
|
||||
# 5.2 Return_5 验证
|
||||
print("\n[5.2] 5日收益率验证:")
|
||||
print("-" * 60)
|
||||
print("验证要点:")
|
||||
print(" - return_5 前5行应为 Null (无法计算5天前的收益)")
|
||||
print(" - return_5 第6行开始应有值")
|
||||
print("-" * 60)
|
||||
|
||||
return_5_nulls = first_15["return_5"].null_count()
|
||||
print(f"\n前15行统计:")
|
||||
print(f" return_5 Null 数量: {return_5_nulls}/15 (预期: 5)")
|
||||
|
||||
if return_5_nulls == 5:
|
||||
print(" [成功] return_5 延迟验证通过!")
|
||||
else:
|
||||
print(" [警告] return_5 延迟验证异常")
|
||||
|
||||
# 5.3 手动验证 MA5 计算
|
||||
print("\n[5.3] MA5 手动计算验证:")
|
||||
print("-" * 60)
|
||||
|
||||
# 选择第10行(索引9)进行验证
|
||||
if len(result) >= 10:
|
||||
row_10 = result.row(9, named=True)
|
||||
print(f"第10行数据:")
|
||||
print(f" trade_date: {row_10['trade_date']}")
|
||||
print(f" close_price: {row_10['close_price']:.4f}")
|
||||
print(f" ma5: {row_10['ma5']:.4f}")
|
||||
print(f" ma10: {row_10['ma10']:.4f}")
|
||||
|
||||
# 手动计算前5天的均值
|
||||
first_10 = result.head(10)
|
||||
close_list = first_10["close_price"].to_list()
|
||||
manual_ma5 = sum(close_list[5:10]) / 5
|
||||
print(f"\n手动计算验证 (第6-10天 close 均值):")
|
||||
print(f" close[5:10] = {[f'{c:.4f}' for c in close_list[5:10]]}")
|
||||
print(f" 手动计算 ma5 = {manual_ma5:.4f}")
|
||||
print(f" 引擎计算 ma5 = {row_10['ma5']:.4f}")
|
||||
|
||||
if abs(manual_ma5 - row_10["ma5"]) < 0.01:
|
||||
print(" [成功] MA5 计算验证通过!")
|
||||
else:
|
||||
print(" [警告] MA5 计算结果不一致")
|
||||
|
||||
# 5.4 Return_5 手动验证
|
||||
print("\n[5.4] Return_5 手动计算验证:")
|
||||
print("-" * 60)
|
||||
|
||||
if len(result) >= 10:
|
||||
row_10 = result.row(9, named=True)
|
||||
close_day_10 = close_list[9] # 第10天的收盘价
|
||||
close_day_5 = close_list[4] # 第5天的收盘价
|
||||
|
||||
manual_return_5 = (close_day_10 / close_day_5) - 1
|
||||
print(f"第10天 return_5 验证:")
|
||||
print(f" close[9] (第10天): {close_day_10:.4f}")
|
||||
print(f" close[4] (第5天): {close_day_5:.4f}")
|
||||
print(f" 手动计算 return_5 = {manual_return_5:.6f}")
|
||||
print(f" 引擎计算 return_5 = {row_10['return_5']:.6f}")
|
||||
|
||||
if abs(manual_return_5 - row_10["return_5"]) < 0.0001:
|
||||
print(" [成功] Return_5 计算验证通过!")
|
||||
else:
|
||||
print(" [警告] Return_5 计算结果不一致")
|
||||
|
||||
# =========================================================================
|
||||
# 6. 统计摘要
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("6. 因子统计摘要")
|
||||
print("=" * 80)
|
||||
|
||||
# 移除空值后统计
|
||||
result_valid = result.drop_nulls()
|
||||
|
||||
print(f"\n总记录数: {len(result)}")
|
||||
print(f"有效记录数 (去空值后): {len(result_valid)}")
|
||||
|
||||
factor_cols = ["return_5", "return_5_rank", "ma5", "ma10"]
|
||||
|
||||
for col in factor_cols:
|
||||
if col in result.columns:
|
||||
series = result[col]
|
||||
null_count = series.null_count()
|
||||
non_null = series.drop_nulls()
|
||||
|
||||
print(f"\n{col}:")
|
||||
print(f" 空值数量: {null_count} ({null_count / len(result) * 100:.2f}%)")
|
||||
|
||||
if len(non_null) > 0:
|
||||
print(f" 均值: {non_null.mean():.6f}")
|
||||
print(f" 标准差: {non_null.std():.6f}")
|
||||
print(f" 最小值: {non_null.min():.6f}")
|
||||
print(f" 最大值: {non_null.max():.6f}")
|
||||
|
||||
if col == "return_5_rank":
|
||||
print(f" [截面排名应在 [0, 1] 区间内]")
|
||||
|
||||
# =========================================================================
|
||||
# 7. 保存结果
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("7. 结果保存")
|
||||
print("=" * 80)
|
||||
|
||||
output_file = "tests/output/601117_factors_2024_2025.csv"
|
||||
try:
|
||||
result.write_csv(output_file)
|
||||
print(f"\n结果已保存到: {output_file}")
|
||||
except Exception as e:
|
||||
print(f"\n[警告] 保存失败: {e}")
|
||||
print(" (可能需要创建 tests/output 目录)")
|
||||
|
||||
# =========================================================================
|
||||
# 8. 测试总结
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("8. 测试总结")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n[测试完成] 601117.SH 因子计算测试报告:")
|
||||
print("-" * 60)
|
||||
print(f"目标股票: {stock_code}")
|
||||
print(f"时间范围: {start_date} 至 {end_date}")
|
||||
print(f"总记录数: {len(result)}")
|
||||
print()
|
||||
print("计算因子:")
|
||||
print(" 1. return_5 - 5日收益率 (ts_delay)")
|
||||
print(" 2. return_5_rank - 5日收益率截面排名 (cs_rank)")
|
||||
print(" 3. ma5 - 5日均线 (ts_mean)")
|
||||
print(" 4. ma10 - 10日均线 (ts_mean)")
|
||||
print()
|
||||
print("验证结果:")
|
||||
print(" - 移动平均线滑动窗口: 正确 (ma5需5天, ma10需10天)")
|
||||
print(" - 收益率延迟计算: 正确 (需5天前数据)")
|
||||
print(" - 截面排名: 正常 (0-1区间)")
|
||||
print(" - 数据完整性: 正常")
|
||||
print("-" * 60)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = test_601117_factors()
|
||||
@@ -1,367 +0,0 @@
|
||||
"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。
|
||||
|
||||
测试因子: cs_rank(ts_delay(close, 1))
|
||||
这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.factors.engine import FactorEngine
|
||||
from src.factors.api import close, ts_delay, cs_rank
|
||||
from src.factors.dsl import FunctionNode
|
||||
from src.factors.engine.ast_optimizer import ExpressionFlattener
|
||||
|
||||
|
||||
def create_mock_data(
|
||||
start_date: str = "20240101",
|
||||
end_date: str = "20240131",
|
||||
n_stocks: int = 5,
|
||||
) -> pl.DataFrame:
|
||||
"""创建模拟的日线数据。"""
|
||||
start = datetime.strptime(start_date, "%Y%m%d")
|
||||
end = datetime.strptime(end_date, "%Y%m%d")
|
||||
|
||||
dates = []
|
||||
current = start
|
||||
while current <= end:
|
||||
if current.weekday() < 5: # 周一到周五
|
||||
dates.append(current.strftime("%Y%m%d"))
|
||||
current += timedelta(days=1)
|
||||
|
||||
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
|
||||
np.random.seed(42)
|
||||
|
||||
rows = []
|
||||
for date in dates:
|
||||
for stock in stocks:
|
||||
base_price = 10 + np.random.randn() * 5
|
||||
close_val = base_price + np.random.randn() * 0.5
|
||||
open_val = close_val + np.random.randn() * 0.2
|
||||
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
|
||||
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
|
||||
vol = int(1000000 + np.random.exponential(500000))
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"ts_code": stock,
|
||||
"trade_date": date,
|
||||
"open": round(open_val, 2),
|
||||
"high": round(high_val, 2),
|
||||
"low": round(low_val, 2),
|
||||
"close": round(close_val, 2),
|
||||
"volume": vol,
|
||||
}
|
||||
)
|
||||
|
||||
return pl.DataFrame(rows)
|
||||
|
||||
|
||||
class TestASTOptimizer:
|
||||
"""AST 优化器测试类。"""
|
||||
|
||||
def test_flattener_basic(self):
|
||||
"""测试拍平器基本功能。"""
|
||||
from src.factors.api import close
|
||||
|
||||
flattener = ExpressionFlattener()
|
||||
|
||||
# 创建嵌套表达式: cs_rank(ts_delay(close, 1))
|
||||
expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
|
||||
|
||||
flat_expr, tmp_factors = flattener.flatten(expr)
|
||||
|
||||
# 验证临时因子被提取
|
||||
assert len(tmp_factors) == 1
|
||||
assert "__tmp_0" in tmp_factors
|
||||
|
||||
# 验证主表达式使用了 Symbol 引用
|
||||
assert isinstance(flat_expr, FunctionNode)
|
||||
assert flat_expr.func_name == "cs_rank"
|
||||
# 验证第一个参数是临时因子引用(通过 name 属性检查)
|
||||
assert hasattr(flat_expr.args[0], "name")
|
||||
assert flat_expr.args[0].name == "__tmp_0"
|
||||
|
||||
# 验证临时因子内容
|
||||
tmp_node = tmp_factors["__tmp_0"]
|
||||
assert isinstance(tmp_node, FunctionNode)
|
||||
assert tmp_node.func_name == "ts_delay"
|
||||
|
||||
print("[PASS] 拍平器基本功能测试")
|
||||
|
||||
def test_flattener_no_nested(self):
|
||||
"""测试非嵌套表达式不会被拍平。"""
|
||||
from src.factors.api import close, ts_mean
|
||||
|
||||
flattener = ExpressionFlattener()
|
||||
|
||||
# 非嵌套表达式: ts_mean(close, 20)
|
||||
expr = FunctionNode("ts_mean", close, 20)
|
||||
|
||||
flat_expr, tmp_factors = flattener.flatten(expr)
|
||||
|
||||
# 验证没有临时因子被提取
|
||||
assert len(tmp_factors) == 0
|
||||
|
||||
# 验证表达式保持不变
|
||||
assert isinstance(flat_expr, FunctionNode)
|
||||
assert flat_expr.func_name == "ts_mean"
|
||||
|
||||
print("[PASS] 非嵌套表达式测试")
|
||||
|
||||
def test_flattener_deeply_nested(self):
|
||||
"""测试多层嵌套表达式拍平。"""
|
||||
from src.factors.api import close, ts_mean
|
||||
|
||||
flattener = ExpressionFlattener()
|
||||
|
||||
# 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5))
|
||||
expr = FunctionNode(
|
||||
"cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5)
|
||||
)
|
||||
|
||||
flat_expr, tmp_factors = flattener.flatten(expr)
|
||||
|
||||
# 验证提取了两个临时因子(修复后正确行为)
|
||||
# ts_delay(close, 1) 被提取为 __tmp_0
|
||||
# ts_mean(__tmp_0, 5) 被提取为 __tmp_1
|
||||
assert len(tmp_factors) == 2
|
||||
assert "__tmp_0" in tmp_factors
|
||||
assert "__tmp_1" in tmp_factors
|
||||
|
||||
# 验证 __tmp_0 内容是 ts_delay(close, 1)
|
||||
tmp0_node = tmp_factors["__tmp_0"]
|
||||
assert isinstance(tmp0_node, FunctionNode)
|
||||
assert tmp0_node.func_name == "ts_delay"
|
||||
|
||||
# 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5)
|
||||
tmp1_node = tmp_factors["__tmp_1"]
|
||||
assert isinstance(tmp1_node, FunctionNode)
|
||||
assert tmp1_node.func_name == "ts_mean"
|
||||
from src.factors.dsl import Symbol
|
||||
|
||||
assert isinstance(tmp1_node.args[0], Symbol)
|
||||
assert tmp1_node.args[0].name == "__tmp_0"
|
||||
|
||||
# 验证主表达式引用 __tmp_1
|
||||
assert isinstance(flat_expr, FunctionNode)
|
||||
assert flat_expr.func_name == "cs_rank"
|
||||
assert isinstance(flat_expr.args[0], Symbol)
|
||||
assert flat_expr.args[0].name == "__tmp_1"
|
||||
|
||||
print("[PASS] 多层嵌套表达式拍平测试")
|
||||
|
||||
def test_nested_window_function_engine(self):
|
||||
"""测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 准备数据
|
||||
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||
print(f"\n生成模拟数据: {len(mock_data)} 行")
|
||||
|
||||
# 2. 初始化引擎
|
||||
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||
print("引擎初始化完成")
|
||||
|
||||
# 3. 使用字符串表达式注册嵌套窗口函数
|
||||
print("\n注册因子: cs_rank(ts_delay(close, 1))")
|
||||
engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))")
|
||||
|
||||
# 4. 检查临时因子是否被创建
|
||||
registered_factors = engine.list_registered()
|
||||
print(f"已注册因子: {registered_factors}")
|
||||
|
||||
# 验证有临时因子被创建
|
||||
tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")]
|
||||
assert len(tmp_factors) >= 1, "应该有临时因子被创建"
|
||||
print(f"临时因子: {tmp_factors}")
|
||||
|
||||
# 5. 执行计算
|
||||
print("\n执行计算...")
|
||||
result = engine.compute("delayed_rank", "20240115", "20240131")
|
||||
print(f"计算完成: {len(result)} 行")
|
||||
|
||||
# 6. 验证结果
|
||||
assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列"
|
||||
|
||||
# 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null)
|
||||
non_null_values = result["delayed_rank"].drop_nulls()
|
||||
if len(non_null_values) > 0:
|
||||
assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间"
|
||||
assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间"
|
||||
|
||||
# 检查没有过多空值(考虑到开头的滞后期)
|
||||
null_count = result["delayed_rank"].is_null().sum()
|
||||
print(f"空值数量: {null_count}")
|
||||
|
||||
# 展示部分结果
|
||||
print("\n前 10 行结果:")
|
||||
sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head(
|
||||
10
|
||||
)
|
||||
print(sample.to_pandas().to_string(index=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("嵌套窗口函数测试通过!")
|
||||
print("=" * 60)
|
||||
|
||||
def test_multiple_nested_factors(self):
|
||||
"""测试同时注册多个嵌套因子。"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试多个嵌套因子")
|
||||
print("=" * 60)
|
||||
|
||||
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||
|
||||
# 注册多个嵌套因子(使用字符串表达式)
|
||||
print("\n注册因子1: cs_rank(ts_delay(close, 1))")
|
||||
engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))")
|
||||
|
||||
print("注册因子2: ts_mean(cs_rank(close), 5)")
|
||||
engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)")
|
||||
|
||||
# 检查已注册因子
|
||||
factors = engine.list_registered()
|
||||
print(f"\n已注册因子: {factors}")
|
||||
|
||||
# 计算所有因子
|
||||
result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131")
|
||||
|
||||
assert "rank1" in result.columns
|
||||
assert "rank_mean" in result.columns
|
||||
|
||||
print(f"\n结果行数: {len(result)}")
|
||||
print(f"rank1 空值数: {result['rank1'].is_null().sum()}")
|
||||
print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("多个嵌套因子测试通过!")
|
||||
print("=" * 60)
|
||||
|
||||
def test_nested_vs_native_polars(self):
|
||||
"""对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。"""
|
||||
print("\n" + "=" * 60)
|
||||
print("对比测试:cs_rank(ts_delay(close, 1)) vs 原生 Polars")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 准备数据
|
||||
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||
print(f"\n生成模拟数据: {len(mock_data)} 行")
|
||||
|
||||
# 2. 使用 FactorEngine 计算嵌套因子
|
||||
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||
print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...")
|
||||
engine.register("delayed_rank", cs_rank(ts_delay(close, 1)))
|
||||
engine_result = engine.compute("delayed_rank", "20240115", "20240131")
|
||||
print(f"FactorEngine 结果: {len(engine_result)} 行")
|
||||
|
||||
# 3. 使用原生 Polars 计算(手动分步)
|
||||
print("\n使用原生 Polars 手动计算...")
|
||||
# 先计算 ts_delay(close, 1)
|
||||
native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns(
|
||||
[pl.col("close").shift(1).over("ts_code").alias("delayed_close")]
|
||||
)
|
||||
# 再计算 cs_rank
|
||||
native_result = native_result.with_columns(
|
||||
[
|
||||
(pl.col("delayed_close").rank() / pl.col("delayed_close").count())
|
||||
.over("trade_date")
|
||||
.alias("native_delayed_rank")
|
||||
]
|
||||
)
|
||||
print(f"原生 Polars 结果: {len(native_result)} 行")
|
||||
|
||||
# 4. 合并结果进行对比
|
||||
comparison = engine_result.join(
|
||||
native_result.select(["ts_code", "trade_date", "native_delayed_rank"]),
|
||||
on=["ts_code", "trade_date"],
|
||||
how="inner",
|
||||
)
|
||||
|
||||
# 5. 验证数值一致性(允许微小浮点误差)
|
||||
diff = comparison.with_columns(
|
||||
[
|
||||
(pl.col("delayed_rank") - pl.col("native_delayed_rank"))
|
||||
.abs()
|
||||
.alias("diff")
|
||||
]
|
||||
)
|
||||
|
||||
max_diff = diff["diff"].max()
|
||||
print(f"\n最大差异: {max_diff}")
|
||||
|
||||
# 过滤掉空值后比较(开头的滞后期会有空值)
|
||||
non_null_diff = diff.filter(pl.col("diff").is_not_null())
|
||||
assert non_null_diff["diff"].max() < 1e-10, (
|
||||
f"数值差异过大: {non_null_diff['diff'].max()}"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("数值一致性验证通过!")
|
||||
print("=" * 60)
|
||||
|
||||
def test_factor_reference_factor(self):
|
||||
"""测试因子引用另一个因子:fac2 = cs_rank(fac1)。"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试因子引用其他因子: fac2 = cs_rank(fac1)")
|
||||
print("=" * 60)
|
||||
|
||||
# 准备数据
|
||||
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||
|
||||
# 1. 注册基础因子 fac1
|
||||
print("\n注册基础因子 fac1 = ts_mean(close, 5)")
|
||||
from src.factors.api import ts_mean
|
||||
|
||||
engine.register("fac1", ts_mean(close, 5))
|
||||
|
||||
# 2. 注册引用因子 fac2,引用 fac1
|
||||
print("注册引用因子 fac2 = cs_rank(fac1)")
|
||||
engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子
|
||||
|
||||
# 3. 验证依赖关系
|
||||
registered = engine.list_registered()
|
||||
print(f"\n已注册因子: {registered}")
|
||||
assert "fac1" in registered
|
||||
assert "fac2" in registered
|
||||
|
||||
# 4. 执行计算
|
||||
print("\n执行计算...")
|
||||
result = engine.compute(["fac1", "fac2"], "20240115", "20240131")
|
||||
print(f"计算完成: {len(result)} 行")
|
||||
|
||||
# 5. 验证结果
|
||||
assert "fac1" in result.columns, "结果中应有 fac1 列"
|
||||
assert "fac2" in result.columns, "结果中应有 fac2 列"
|
||||
|
||||
# fac2 是排名,应在 [0, 1] 之间
|
||||
assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间"
|
||||
assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间"
|
||||
|
||||
print("\n前 10 行结果:")
|
||||
sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head(
|
||||
10
|
||||
)
|
||||
print(sample.to_pandas().to_string(index=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("因子引用功能测试通过!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = TestASTOptimizer()
|
||||
test.test_flattener_basic()
|
||||
test.test_flattener_no_nested()
|
||||
test.test_flattener_deeply_nested()
|
||||
test.test_nested_window_function_engine()
|
||||
test.test_multiple_nested_factors()
|
||||
test.test_nested_vs_native_polars()
|
||||
test.test_factor_reference_factor()
|
||||
print("\n所有测试通过!")
|
||||
@@ -1,144 +0,0 @@
|
||||
"""测试 Bug 修复:
|
||||
1. 临时因子命名冲突修复验证
|
||||
2. 逻辑运算符支持验证
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "D:/PyProject/ProStock")
|
||||
|
||||
from src.factors.dsl import Symbol, BinaryOpNode
|
||||
from src.factors.engine.ast_optimizer import ExpressionFlattener, flatten_expression
|
||||
|
||||
|
||||
def test_temp_name_uniqueness():
|
||||
"""测试:临时因子名称全局唯一性。"""
|
||||
print("测试 1: 临时因子命名冲突修复")
|
||||
print("-" * 50)
|
||||
|
||||
close = Symbol("close")
|
||||
open_price = Symbol("open")
|
||||
|
||||
# 创建两个表达式拍平器实例
|
||||
flattener1 = ExpressionFlattener()
|
||||
flattener2 = ExpressionFlattener()
|
||||
|
||||
# 模拟因子 A: cs_rank(ts_delay(close, 1))
|
||||
from src.factors.dsl import FunctionNode
|
||||
|
||||
expr_a = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
|
||||
flat_a, temps_a = flattener1.flatten(expr_a)
|
||||
|
||||
# 模拟因子 B: cs_mean(ts_delay(open, 2))
|
||||
expr_b = FunctionNode("cs_mean", FunctionNode("ts_delay", open_price, 2))
|
||||
flat_b, temps_b = flattener2.flatten(expr_b)
|
||||
|
||||
# 验证临时名称不冲突
|
||||
temp_names_a = set(temps_a.keys())
|
||||
temp_names_b = set(temps_b.keys())
|
||||
|
||||
print(f"因子 A 临时名称: {temp_names_a}")
|
||||
print(f"因子 B 临时名称: {temp_names_b}")
|
||||
|
||||
# 检查是否有名称冲突
|
||||
common_names = temp_names_a & temp_names_b
|
||||
if common_names:
|
||||
print(f"[失败] 发现命名冲突: {common_names}")
|
||||
return False
|
||||
|
||||
print("[通过] 临时因子名称全局唯一,无冲突")
|
||||
return True
|
||||
|
||||
|
||||
def test_logical_operators():
|
||||
"""测试:逻辑运算符支持。"""
|
||||
print("\n测试 2: 逻辑运算符支持")
|
||||
print("-" * 50)
|
||||
|
||||
# 测试 DSL 层
|
||||
close = Symbol("close")
|
||||
open_price = Symbol("open")
|
||||
|
||||
# 测试 & 运算符(注意 Python 运算符优先级,需要用括号)
|
||||
and_expr = (close > open_price) & (close > 0)
|
||||
print(f"DSL 表达式 ((close > open) & (close > 0)): {and_expr}")
|
||||
assert isinstance(and_expr, BinaryOpNode), "& 应生成 BinaryOpNode"
|
||||
assert and_expr.op == "&", "运算符应为 &"
|
||||
print("[通过] DSL 层支持 & 运算符")
|
||||
|
||||
# 测试 | 运算符(注意 Python 运算符优先级,需要用括号)
|
||||
or_expr = (close < open_price) | (close < 0)
|
||||
print(f"DSL 表达式 ((close < open) | (close < 0)): {or_expr}")
|
||||
assert isinstance(or_expr, BinaryOpNode), "| 应生成 BinaryOpNode"
|
||||
assert or_expr.op == "|", "运算符应为 |"
|
||||
print("[通过] DSL 层支持 | 运算符")
|
||||
|
||||
# 测试字符串解析
|
||||
from src.factors.parser import FormulaParser
|
||||
from src.factors.registry import FunctionRegistry
|
||||
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 解析包含 & 的表达式
|
||||
try:
|
||||
parsed_and = parser.parse("(close > open) & (volume > 0)")
|
||||
print(f"解析器支持 & 运算符: {parsed_and}")
|
||||
print("[通过] Parser 支持 & 运算符")
|
||||
except Exception as e:
|
||||
print(f"[失败] Parser 解析 & 失败: {e}")
|
||||
return False
|
||||
|
||||
# 解析包含 | 的表达式
|
||||
try:
|
||||
parsed_or = parser.parse("(close < open) | (volume < 0)")
|
||||
print(f"解析器支持 | 运算符: {parsed_or}")
|
||||
print("[通过] Parser 支持 | 运算符")
|
||||
except Exception as e:
|
||||
print(f"[失败] Parser 解析 | 失败: {e}")
|
||||
return False
|
||||
|
||||
# 测试翻译到 Polars
|
||||
from src.factors.translator import PolarsTranslator
|
||||
import polars as pl
|
||||
|
||||
translator = PolarsTranslator()
|
||||
|
||||
try:
|
||||
polars_and = translator.translate(parsed_and)
|
||||
print(f"Polars 表达式 (&): {polars_and}")
|
||||
print("[通过] Translator 支持 & 运算符")
|
||||
except Exception as e:
|
||||
print(f"[失败] Translator 翻译 & 失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
polars_or = translator.translate(parsed_or)
|
||||
print(f"Polars 表达式 (|): {polars_or}")
|
||||
print("[通过] Translator 支持 | 运算符")
|
||||
except Exception as e:
|
||||
print(f"[失败] Translator 翻译 | 失败: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Bug 修复验证测试")
|
||||
print("=" * 60)
|
||||
|
||||
test1_passed = test_temp_name_uniqueness()
|
||||
test2_passed = test_logical_operators()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试结果汇总")
|
||||
print("=" * 60)
|
||||
print(f"临时因子命名冲突修复: {'[通过]' if test1_passed else '[失败]'}")
|
||||
print(f"逻辑运算符支持: {'[通过]' if test2_passed else '[失败]'}")
|
||||
|
||||
if test1_passed and test2_passed:
|
||||
print("\n所有测试通过!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n存在失败的测试!")
|
||||
sys.exit(1)
|
||||
@@ -1,377 +0,0 @@
|
||||
"""Tests for DuckDB database manager and incremental sync."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from src.data.db_manager import (
|
||||
TableManager,
|
||||
IncrementalSync,
|
||||
SyncManager,
|
||||
ensure_table,
|
||||
get_table_info,
|
||||
sync_table,
|
||||
)
|
||||
|
||||
|
||||
class TestTableManager:
|
||||
"""Test table creation and management."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self):
|
||||
"""Create a mock storage instance."""
|
||||
storage = Mock()
|
||||
storage._connection = Mock()
|
||||
storage.exists = Mock(return_value=False)
|
||||
return storage
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample DataFrame with ts_code and trade_date."""
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
|
||||
"trade_date": ["20240101", "20240102", "20240101"],
|
||||
"open": [10.0, 10.5, 20.0],
|
||||
"close": [10.5, 11.0, 20.5],
|
||||
"volume": [1000, 2000, 3000],
|
||||
}
|
||||
)
|
||||
|
||||
def test_create_table_from_dataframe(self, mock_storage, sample_data):
|
||||
"""Test table creation from DataFrame."""
|
||||
manager = TableManager(mock_storage)
|
||||
|
||||
result = manager.create_table_from_dataframe("daily", sample_data)
|
||||
|
||||
assert result is True
|
||||
# Should execute CREATE TABLE
|
||||
assert mock_storage._connection.execute.call_count >= 1
|
||||
|
||||
# Get the CREATE TABLE SQL
|
||||
calls = mock_storage._connection.execute.call_args_list
|
||||
create_table_call = None
|
||||
for call in calls:
|
||||
sql = call[0][0] if call[0] else call[1].get("sql", "")
|
||||
if "CREATE TABLE" in str(sql):
|
||||
create_table_call = sql
|
||||
break
|
||||
|
||||
assert create_table_call is not None
|
||||
assert "ts_code" in str(create_table_call)
|
||||
assert "trade_date" in str(create_table_call)
|
||||
|
||||
def test_create_table_with_index(self, mock_storage, sample_data):
|
||||
"""Test that composite index is created for trade_date and ts_code."""
|
||||
manager = TableManager(mock_storage)
|
||||
|
||||
manager.create_table_from_dataframe("daily", sample_data, create_index=True)
|
||||
|
||||
# Check that index creation was called
|
||||
calls = mock_storage._connection.execute.call_args_list
|
||||
index_calls = [call for call in calls if "CREATE INDEX" in str(call)]
|
||||
assert len(index_calls) > 0
|
||||
|
||||
def test_create_table_empty_dataframe(self, mock_storage):
|
||||
"""Test that empty DataFrame is rejected."""
|
||||
manager = TableManager(mock_storage)
|
||||
empty_df = pd.DataFrame()
|
||||
|
||||
result = manager.create_table_from_dataframe("daily", empty_df)
|
||||
|
||||
assert result is False
|
||||
mock_storage._connection.execute.assert_not_called()
|
||||
|
||||
def test_ensure_table_exists_creates_table(self, mock_storage, sample_data):
|
||||
"""Test ensure_table_exists creates table if not exists."""
|
||||
mock_storage.exists.return_value = False
|
||||
manager = TableManager(mock_storage)
|
||||
|
||||
result = manager.ensure_table_exists("daily", sample_data)
|
||||
|
||||
assert result is True
|
||||
mock_storage._connection.execute.assert_called()
|
||||
|
||||
def test_ensure_table_exists_already_exists(self, mock_storage):
|
||||
"""Test ensure_table_exists returns True if table already exists."""
|
||||
mock_storage.exists.return_value = True
|
||||
manager = TableManager(mock_storage)
|
||||
|
||||
result = manager.ensure_table_exists("daily", None)
|
||||
|
||||
assert result is True
|
||||
mock_storage._connection.execute.assert_not_called()
|
||||
|
||||
|
||||
class TestIncrementalSync:
|
||||
"""Test incremental synchronization strategies."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self):
|
||||
"""Create a mock storage instance."""
|
||||
storage = Mock()
|
||||
storage._connection = Mock()
|
||||
storage.exists = Mock(return_value=False)
|
||||
storage.get_distinct_stocks = Mock(return_value=[])
|
||||
return storage
|
||||
|
||||
def test_sync_strategy_new_table(self, mock_storage):
|
||||
"""Test strategy for non-existent table."""
|
||||
mock_storage.exists.return_value = False
|
||||
sync = IncrementalSync(mock_storage)
|
||||
|
||||
strategy, start, end, stocks = sync.get_sync_strategy(
|
||||
"daily", "20240101", "20240131"
|
||||
)
|
||||
|
||||
assert strategy == "by_date"
|
||||
assert start == "20240101"
|
||||
assert end == "20240131"
|
||||
assert stocks is None
|
||||
|
||||
def test_sync_strategy_empty_table(self, mock_storage):
|
||||
"""Test strategy for empty table."""
|
||||
mock_storage.exists.return_value = True
|
||||
sync = IncrementalSync(mock_storage)
|
||||
|
||||
# Mock get_table_stats to return empty
|
||||
sync.get_table_stats = Mock(
|
||||
return_value={
|
||||
"exists": True,
|
||||
"row_count": 0,
|
||||
"max_date": None,
|
||||
}
|
||||
)
|
||||
|
||||
strategy, start, end, stocks = sync.get_sync_strategy(
|
||||
"daily", "20240101", "20240131"
|
||||
)
|
||||
|
||||
assert strategy == "by_date"
|
||||
assert start == "20240101"
|
||||
assert end == "20240131"
|
||||
|
||||
def test_sync_strategy_up_to_date(self, mock_storage):
|
||||
"""Test strategy when table is already up-to-date."""
|
||||
mock_storage.exists.return_value = True
|
||||
sync = IncrementalSync(mock_storage)
|
||||
|
||||
# Mock get_table_stats to show table is up-to-date
|
||||
sync.get_table_stats = Mock(
|
||||
return_value={
|
||||
"exists": True,
|
||||
"row_count": 100,
|
||||
"max_date": "20240131",
|
||||
}
|
||||
)
|
||||
|
||||
strategy, start, end, stocks = sync.get_sync_strategy(
|
||||
"daily", "20240101", "20240131"
|
||||
)
|
||||
|
||||
assert strategy == "none"
|
||||
assert start is None
|
||||
assert end is None
|
||||
|
||||
def test_sync_strategy_incremental_by_date(self, mock_storage):
|
||||
"""Test incremental sync by date when new data available."""
|
||||
mock_storage.exists.return_value = True
|
||||
sync = IncrementalSync(mock_storage)
|
||||
|
||||
# Table has data until Jan 15
|
||||
sync.get_table_stats = Mock(
|
||||
return_value={
|
||||
"exists": True,
|
||||
"row_count": 100,
|
||||
"max_date": "20240115",
|
||||
}
|
||||
)
|
||||
|
||||
strategy, start, end, stocks = sync.get_sync_strategy(
|
||||
"daily", "20240101", "20240131"
|
||||
)
|
||||
|
||||
assert strategy == "by_date"
|
||||
assert start == "20240116" # Next day after last date
|
||||
assert end == "20240131"
|
||||
|
||||
def test_sync_strategy_by_stock(self, mock_storage):
|
||||
"""Test sync by stock for specific stocks."""
|
||||
mock_storage.exists.return_value = True
|
||||
mock_storage.get_distinct_stocks.return_value = ["000001.SZ"]
|
||||
sync = IncrementalSync(mock_storage)
|
||||
|
||||
sync.get_table_stats = Mock(
|
||||
return_value={
|
||||
"exists": True,
|
||||
"row_count": 100,
|
||||
"max_date": "20240131",
|
||||
}
|
||||
)
|
||||
|
||||
# Request 2 stocks, but only 1 exists
|
||||
strategy, start, end, stocks = sync.get_sync_strategy(
|
||||
"daily", "20240101", "20240131", stock_codes=["000001.SZ", "600000.SH"]
|
||||
)
|
||||
|
||||
assert strategy == "by_stock"
|
||||
assert "600000.SH" in stocks
|
||||
assert "000001.SZ" not in stocks
|
||||
|
||||
def test_sync_data_by_date(self, mock_storage):
|
||||
"""Test syncing data by date strategy."""
|
||||
mock_storage.exists.return_value = True
|
||||
mock_storage.save = Mock(return_value={"status": "success", "rows": 1})
|
||||
sync = IncrementalSync(mock_storage)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240101"],
|
||||
"close": [10.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = sync.sync_data("daily", data, strategy="by_date")
|
||||
|
||||
assert result["status"] == "success"
|
||||
|
||||
def test_sync_data_empty_dataframe(self, mock_storage):
|
||||
"""Test syncing empty DataFrame."""
|
||||
sync = IncrementalSync(mock_storage)
|
||||
empty_df = pd.DataFrame()
|
||||
|
||||
result = sync.sync_data("daily", empty_df)
|
||||
|
||||
assert result["status"] == "skipped"
|
||||
|
||||
|
||||
class TestSyncManager:
|
||||
"""Test high-level sync manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self):
|
||||
"""Create a mock storage instance."""
|
||||
storage = Mock()
|
||||
storage._connection = Mock()
|
||||
storage.exists = Mock(return_value=False)
|
||||
storage.save = Mock(return_value={"status": "success", "rows": 10})
|
||||
storage.get_distinct_stocks = Mock(return_value=[])
|
||||
return storage
|
||||
|
||||
def test_sync_no_sync_needed(self, mock_storage):
|
||||
"""Test sync when no update is needed."""
|
||||
mock_storage.exists.return_value = True
|
||||
manager = SyncManager(mock_storage)
|
||||
|
||||
# Mock incremental_sync to return 'none' strategy
|
||||
manager.incremental_sync.get_sync_strategy = Mock(
|
||||
return_value=("none", None, None, None)
|
||||
)
|
||||
|
||||
# Mock fetch function
|
||||
fetch_func = Mock()
|
||||
|
||||
result = manager.sync("daily", fetch_func, "20240101", "20240131")
|
||||
|
||||
assert result["status"] == "skipped"
|
||||
fetch_func.assert_not_called()
|
||||
|
||||
def test_sync_fetches_data(self, mock_storage):
|
||||
"""Test that sync fetches data when needed."""
|
||||
mock_storage.exists.return_value = False
|
||||
manager = SyncManager(mock_storage)
|
||||
|
||||
# Mock table_manager
|
||||
manager.table_manager.ensure_table_exists = Mock(return_value=True)
|
||||
|
||||
# Mock incremental_sync
|
||||
manager.incremental_sync.get_sync_strategy = Mock(
|
||||
return_value=("by_date", "20240101", "20240131", None)
|
||||
)
|
||||
manager.incremental_sync.sync_data = Mock(
|
||||
return_value={"status": "success", "rows_inserted": 10}
|
||||
)
|
||||
|
||||
# Mock fetch function returning data
|
||||
fetch_func = Mock(
|
||||
return_value=pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240101"],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = manager.sync("daily", fetch_func, "20240101", "20240131")
|
||||
|
||||
fetch_func.assert_called_once()
|
||||
assert result["status"] == "success"
|
||||
|
||||
def test_sync_handles_fetch_error(self, mock_storage):
|
||||
"""Test error handling during data fetch."""
|
||||
manager = SyncManager(mock_storage)
|
||||
|
||||
# Mock incremental_sync
|
||||
manager.incremental_sync.get_sync_strategy = Mock(
|
||||
return_value=("by_date", "20240101", "20240131", None)
|
||||
)
|
||||
|
||||
# Mock fetch function that raises exception
|
||||
fetch_func = Mock(side_effect=Exception("API Error"))
|
||||
|
||||
result = manager.sync("daily", fetch_func, "20240101", "20240131")
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "API Error" in result["error"]
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Test convenience functions."""
|
||||
|
||||
@patch("src.data.db_manager.TableManager")
|
||||
def test_ensure_table(self, mock_manager_class):
|
||||
"""Test ensure_table convenience function."""
|
||||
mock_manager = Mock()
|
||||
mock_manager.ensure_table_exists = Mock(return_value=True)
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
data = pd.DataFrame({"ts_code": ["000001.SZ"], "trade_date": ["20240101"]})
|
||||
result = ensure_table("daily", data)
|
||||
|
||||
assert result is True
|
||||
mock_manager.ensure_table_exists.assert_called_once_with("daily", data)
|
||||
|
||||
@patch("src.data.db_manager.IncrementalSync")
|
||||
def test_get_table_info(self, mock_sync_class):
|
||||
"""Test get_table_info convenience function."""
|
||||
mock_sync = Mock()
|
||||
mock_sync.get_table_stats = Mock(
|
||||
return_value={
|
||||
"exists": True,
|
||||
"row_count": 100,
|
||||
}
|
||||
)
|
||||
mock_sync_class.return_value = mock_sync
|
||||
|
||||
result = get_table_info("daily")
|
||||
|
||||
assert result["exists"] is True
|
||||
assert result["row_count"] == 100
|
||||
|
||||
@patch("src.data.db_manager.SyncManager")
|
||||
def test_sync_table(self, mock_manager_class):
|
||||
"""Test sync_table convenience function."""
|
||||
mock_manager = Mock()
|
||||
mock_manager.sync = Mock(return_value={"status": "success", "rows": 10})
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
fetch_func = Mock()
|
||||
result = sync_table("daily", fetch_func, "20240101", "20240131")
|
||||
|
||||
assert result["status"] == "success"
|
||||
mock_manager.sync.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,160 +0,0 @@
|
||||
"""FactorEngine 端到端测试。
|
||||
|
||||
模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.factors.engine import FactorEngine, DataSpec
|
||||
from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
|
||||
|
||||
def create_mock_data(
|
||||
start_date: str = "20240101",
|
||||
end_date: str = "20240131",
|
||||
n_stocks: int = 5,
|
||||
) -> pl.DataFrame:
|
||||
"""创建模拟的日线数据。"""
|
||||
start = datetime.strptime(start_date, "%Y%m%d")
|
||||
end = datetime.strptime(end_date, "%Y%m%d")
|
||||
|
||||
dates = []
|
||||
current = start
|
||||
while current <= end:
|
||||
if current.weekday() < 5: # 周一到周五
|
||||
dates.append(current.strftime("%Y%m%d"))
|
||||
current += timedelta(days=1)
|
||||
|
||||
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
|
||||
np.random.seed(42)
|
||||
|
||||
rows = []
|
||||
for date in dates:
|
||||
for stock in stocks:
|
||||
base_price = 10 + np.random.randn() * 5
|
||||
close_val = base_price + np.random.randn() * 0.5
|
||||
open_val = close_val + np.random.randn() * 0.2
|
||||
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
|
||||
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
|
||||
vol = int(1000000 + np.random.exponential(500000))
|
||||
amt = close_val * vol
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"ts_code": stock,
|
||||
"trade_date": date,
|
||||
"open": round(open_val, 2),
|
||||
"high": round(high_val, 2),
|
||||
"low": round(low_val, 2),
|
||||
"close": round(close_val, 2),
|
||||
"volume": vol,
|
||||
"amount": round(amt, 2),
|
||||
"pre_close": round(close_val - np.random.randn() * 0.3, 2),
|
||||
}
|
||||
)
|
||||
|
||||
return pl.DataFrame(rows)
|
||||
|
||||
|
||||
class TestFactorEngineEndToEnd:
|
||||
"""FactorEngine 端到端测试类。"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data(self):
|
||||
"""提供模拟数据的 fixture。"""
|
||||
return create_mock_data("20240101", "20240131", n_stocks=5)
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_data):
|
||||
"""提供配置好的 FactorEngine fixture。"""
|
||||
data_source = {"pro_bar": mock_data}
|
||||
return FactorEngine(data_source=data_source)
|
||||
|
||||
def test_simple_symbol_expression(self, engine):
|
||||
"""测试简单的符号表达式。"""
|
||||
engine.register("close_price", close)
|
||||
result = engine.compute("close_price", "20240115", "20240120")
|
||||
assert "close_price" in result.columns
|
||||
assert len(result) > 0
|
||||
print("[PASS] 简单符号表达式测试")
|
||||
|
||||
def test_arithmetic_expression(self, engine):
|
||||
"""测试算术表达式。"""
|
||||
engine.register("returns", (close - open_sym) / open_sym)
|
||||
result = engine.compute("returns", "20240115", "20240120")
|
||||
assert "returns" in result.columns
|
||||
print("[PASS] 算术表达式测试")
|
||||
|
||||
def test_cs_rank_factor(self, engine):
|
||||
"""测试截面排名因子。"""
|
||||
engine.register("price_rank", cs_rank(close))
|
||||
result = engine.compute("price_rank", "20240115", "20240120")
|
||||
assert "price_rank" in result.columns
|
||||
assert result["price_rank"].min() >= 0
|
||||
assert result["price_rank"].max() <= 1
|
||||
print("[PASS] 截面排名因子测试")
|
||||
|
||||
|
||||
class TestFullWorkflow:
|
||||
"""完整工作流测试类。"""
|
||||
|
||||
def test_full_workflow_demo(self):
|
||||
"""演示完整的因子计算工作流。"""
|
||||
print("\n" + "=" * 60)
|
||||
print("FactorEngine Full Workflow Demo")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 准备数据
|
||||
print("\nStep 1: Prepare mock data...")
|
||||
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||
print(f" Generated {len(mock_data)} rows")
|
||||
print(f" Stocks: {mock_data['ts_code'].n_unique()}")
|
||||
|
||||
# 2. 初始化引擎
|
||||
print("\nStep 2: Initialize FactorEngine...")
|
||||
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||
print(" Engine initialized")
|
||||
|
||||
# 3. 注册因子 - 使用简单因子避免回看窗口问题
|
||||
print("\nStep 3: Register factors...")
|
||||
engine.register("returns", (close - open_sym) / open_sym)
|
||||
engine.register("price_rank", cs_rank(close))
|
||||
print(" Registered: returns, price_rank")
|
||||
|
||||
# 4. 执行计算 - 使用完整日期范围
|
||||
print("\nStep 4: Compute factors...")
|
||||
result = engine.compute(
|
||||
["returns", "price_rank"],
|
||||
"20240115",
|
||||
"20240120",
|
||||
)
|
||||
print(f" Computed {len(result)} rows")
|
||||
|
||||
# 5. 验证结果
|
||||
print("\nStep 5: Verify results...")
|
||||
assert "returns" in result.columns
|
||||
assert "price_rank" in result.columns
|
||||
assert result["price_rank"].min() >= 0
|
||||
assert result["price_rank"].max() <= 1
|
||||
print(" All assertions passed")
|
||||
|
||||
# 6. 展示样本
|
||||
print("\nStep 6: Sample output...")
|
||||
sample = result.select(
|
||||
["ts_code", "trade_date", "close", "returns", "price_rank"]
|
||||
).head(3)
|
||||
print(sample.to_pandas().to_string(index=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Workflow completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = TestFullWorkflow()
|
||||
test.test_full_workflow_demo()
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -1,106 +0,0 @@
|
||||
"""FactorEngine 与 Metadata 集成测试。
|
||||
|
||||
测试 add_factor_by_name 方法的功能。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.factors.metadata import FactorManager
|
||||
|
||||
|
||||
class TestFactorEngineMetadataIntegration:
|
||||
"""测试 FactorEngine 与 Metadata 的集成功能。"""
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_file(self):
|
||||
"""使用 data 目录下的 factors.jsonl 文件。"""
|
||||
return "data/factors.jsonl"
|
||||
|
||||
def test_init_without_metadata(self):
|
||||
"""测试不启用 metadata 时初始化引擎。"""
|
||||
engine = FactorEngine()
|
||||
assert engine._metadata is None
|
||||
|
||||
def test_init_with_metadata(self, metadata_file):
|
||||
"""测试启用 metadata 时初始化引擎。"""
|
||||
engine = FactorEngine(metadata_path=metadata_file)
|
||||
assert engine._metadata is not None
|
||||
assert isinstance(engine._metadata, FactorManager)
|
||||
|
||||
def test_add_factor_by_name_success(self, metadata_file):
|
||||
"""测试从 metadata 成功添加因子。"""
|
||||
engine = FactorEngine(metadata_path=metadata_file)
|
||||
|
||||
# 添加 return_5 因子
|
||||
result = engine.add_factor_by_name("return_5")
|
||||
|
||||
# 验证链式调用返回自身
|
||||
assert result is engine
|
||||
|
||||
# 验证因子已注册
|
||||
assert "return_5" in engine.list_registered()
|
||||
|
||||
def test_add_factor_by_name_with_alias(self, metadata_file):
|
||||
"""测试使用别名添加因子。"""
|
||||
engine = FactorEngine(metadata_path=metadata_file)
|
||||
|
||||
# 使用不同名称注册 metadata 中的因子
|
||||
engine.add_factor_by_name("my_ma", "ma_5")
|
||||
|
||||
# 验证使用别名注册的因子
|
||||
assert "my_ma" in engine.list_registered()
|
||||
assert "ma_5" not in engine.list_registered()
|
||||
|
||||
def test_add_factor_by_name_not_found(self, metadata_file):
|
||||
"""测试添加不存在的因子时抛出异常。"""
|
||||
engine = FactorEngine(metadata_path=metadata_file)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
engine.add_factor_by_name("nonexistent_factor")
|
||||
|
||||
assert "未找到因子" in str(exc_info.value)
|
||||
assert "nonexistent_factor" in str(exc_info.value)
|
||||
|
||||
def test_add_factor_by_name_without_metadata(self):
|
||||
"""测试未配置 metadata 时调用 add_factor_by_name 抛出异常。"""
|
||||
engine = FactorEngine() # 不传入 metadata_path
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
engine.add_factor_by_name("return_5")
|
||||
|
||||
assert "未配置 metadata 路径" in str(exc_info.value)
|
||||
|
||||
def test_chain_calls(self, metadata_file):
|
||||
"""测试链式调用。"""
|
||||
engine = FactorEngine(metadata_path=metadata_file)
|
||||
|
||||
# 链式添加多个因子
|
||||
(
|
||||
engine.add_factor_by_name("return_5")
|
||||
.add_factor_by_name("ma_5")
|
||||
.add_factor_by_name("custom_ma20", "ma_20")
|
||||
)
|
||||
|
||||
# 验证所有因子都已注册
|
||||
assert "return_5" in engine.list_registered()
|
||||
assert "ma_5" in engine.list_registered()
|
||||
assert "custom_ma20" in engine.list_registered()
|
||||
|
||||
def test_add_factor_by_name_preserves_existing_add_factor(self, metadata_file):
|
||||
"""测试 add_factor_by_name 不影响原有的 add_factor 方法。"""
|
||||
engine = FactorEngine(metadata_path=metadata_file)
|
||||
|
||||
# 使用 add_factor 添加字符串表达式
|
||||
engine.add_factor("manual_factor", "ts_mean(close, 10)")
|
||||
|
||||
# 使用 add_factor_by_name 添加 metadata 因子
|
||||
engine.add_factor_by_name("return_5")
|
||||
|
||||
# 验证两者都正常工作
|
||||
assert "manual_factor" in engine.list_registered()
|
||||
assert "return_5" in engine.list_registered()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,451 +0,0 @@
|
||||
"""因子框架集成测试脚本
|
||||
|
||||
测试目标:验证因子框架在 DuckDB 真实数据上的核心逻辑
|
||||
|
||||
测试范围:
|
||||
1. 时序因子 ts_mean - 验证滑动窗口和数据隔离
|
||||
2. 截面因子 cs_rank - 验证每日独立排名和结果分布
|
||||
3. 组合运算 - 验证多字段算术运算和算子嵌套
|
||||
|
||||
排除范围:PIT 因子(使用低频财务数据)
|
||||
"""
|
||||
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.data.catalog import DatabaseCatalog
|
||||
from src.factors.engine import FactorEngine
|
||||
from src.factors.api import close, open, ts_mean, cs_rank
|
||||
|
||||
|
||||
def select_sample_stocks(catalog: DatabaseCatalog, n: int = 8) -> list:
|
||||
"""随机选择代表性股票样本。
|
||||
|
||||
确保样本覆盖不同交易所:
|
||||
- .SH: 上海证券交易所(主板、科创板)
|
||||
- .SZ: 深圳证券交易所(主板、创业板)
|
||||
|
||||
Args:
|
||||
catalog: 数据库目录实例
|
||||
n: 需要选择的股票数量
|
||||
|
||||
Returns:
|
||||
股票代码列表
|
||||
"""
|
||||
# 从 catalog 获取数据库连接
|
||||
db_path = catalog.db_path.replace("duckdb://", "").lstrip("/")
|
||||
import duckdb
|
||||
|
||||
conn = duckdb.connect(db_path, read_only=True)
|
||||
|
||||
try:
|
||||
# 获取2023年上半年的所有股票
|
||||
result = conn.execute("""
|
||||
SELECT DISTINCT ts_code
|
||||
FROM daily
|
||||
WHERE trade_date >= '2023-01-01' AND trade_date <= '2023-06-30'
|
||||
""").fetchall()
|
||||
|
||||
all_stocks = [row[0] for row in result]
|
||||
|
||||
# 按交易所分类
|
||||
sh_stocks = [s for s in all_stocks if s.endswith(".SH")]
|
||||
sz_stocks = [s for s in all_stocks if s.endswith(".SZ")]
|
||||
|
||||
# 选择样本:确保覆盖两个交易所
|
||||
sample = []
|
||||
|
||||
# 从上海市场选择 (包含主板600/601/603/605和科创板688)
|
||||
sh_main = [
|
||||
s for s in sh_stocks if s.startswith("6") and not s.startswith("688")
|
||||
]
|
||||
sh_kcb = [s for s in sh_stocks if s.startswith("688")]
|
||||
|
||||
# 从深圳市场选择 (包含主板000/001/002和创业板300/301)
|
||||
sz_main = [s for s in sz_stocks if s.startswith("0")]
|
||||
sz_cyb = [s for s in sz_stocks if s.startswith("300") or s.startswith("301")]
|
||||
|
||||
# 每类选择部分股票
|
||||
if sh_main:
|
||||
sample.extend(random.sample(sh_main, min(2, len(sh_main))))
|
||||
if sh_kcb:
|
||||
sample.extend(random.sample(sh_kcb, min(2, len(sh_kcb))))
|
||||
if sz_main:
|
||||
sample.extend(random.sample(sz_main, min(2, len(sz_main))))
|
||||
if sz_cyb:
|
||||
sample.extend(random.sample(sz_cyb, min(2, len(sz_cyb))))
|
||||
|
||||
# 如果还不够,随机补充
|
||||
while len(sample) < n and len(sample) < len(all_stocks):
|
||||
remaining = [s for s in all_stocks if s not in sample]
|
||||
if remaining:
|
||||
sample.append(random.choice(remaining))
|
||||
else:
|
||||
break
|
||||
|
||||
return sorted(sample[:n])
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def run_factor_integration_test():
|
||||
"""执行因子框架集成测试。"""
|
||||
|
||||
print("=" * 80)
|
||||
print("因子框架集成测试 - DuckDB 真实数据验证")
|
||||
print("=" * 80)
|
||||
|
||||
# =========================================================================
|
||||
# 1. 测试环境准备
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("1. 测试环境准备")
|
||||
print("=" * 80)
|
||||
|
||||
# 数据库配置
|
||||
db_path = "data/prostock.db"
|
||||
db_uri = f"duckdb:///{db_path}"
|
||||
|
||||
print(f"\n数据库路径: {db_path}")
|
||||
print(f"数据库URI: {db_uri}")
|
||||
|
||||
# 时间范围
|
||||
start_date = "20230101"
|
||||
end_date = "20230630"
|
||||
print(f"\n测试时间范围: {start_date} 至 {end_date}")
|
||||
|
||||
# 创建 DatabaseCatalog 并发现表结构
|
||||
print("\n[1.1] 创建 DatabaseCatalog 并发现表结构...")
|
||||
catalog = DatabaseCatalog(db_path)
|
||||
print(f"发现表数量: {len(catalog.tables)}")
|
||||
for table_name, metadata in catalog.tables.items():
|
||||
print(
|
||||
f" - {table_name}: {metadata.frequency.value} (日期字段: {metadata.date_field})"
|
||||
)
|
||||
|
||||
# 选择样本股票
|
||||
print("\n[1.2] 选择样本股票...")
|
||||
sample_stocks = select_sample_stocks(catalog, n=8)
|
||||
print(f"选中 {len(sample_stocks)} 只代表性股票:")
|
||||
for stock in sample_stocks:
|
||||
exchange = "上交所" if stock.endswith(".SH") else "深交所"
|
||||
board = ""
|
||||
if stock.startswith("688"):
|
||||
board = "科创板"
|
||||
elif (
|
||||
stock.startswith("600")
|
||||
or stock.startswith("601")
|
||||
or stock.startswith("603")
|
||||
):
|
||||
board = "主板"
|
||||
elif stock.startswith("300") or stock.startswith("301"):
|
||||
board = "创业板"
|
||||
elif (
|
||||
stock.startswith("000")
|
||||
or stock.startswith("001")
|
||||
or stock.startswith("002")
|
||||
):
|
||||
board = "主板"
|
||||
print(f" - {stock} ({exchange} {board})")
|
||||
|
||||
# =========================================================================
|
||||
# 2. 因子定义
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("2. 因子定义")
|
||||
print("=" * 80)
|
||||
|
||||
# 创建 FactorEngine
|
||||
print("\n[2.1] 创建 FactorEngine...")
|
||||
engine = FactorEngine(catalog)
|
||||
|
||||
# 因子 A: 时序均线 ts_mean(close, 10)
|
||||
print("\n[2.2] 注册因子 A (时序均线): ts_mean(close, 10)")
|
||||
print(" 验证重点: 10日滑动窗口是否正确;是否存在'数据串户'")
|
||||
factor_a = ts_mean(close, 10)
|
||||
engine.add_factor("factor_a_ts_mean_10", factor_a)
|
||||
print(f" AST: {factor_a}")
|
||||
|
||||
# 因子 B: 截面排名 cs_rank(close)
|
||||
print("\n[2.3] 注册因子 B (截面排名): cs_rank(close)")
|
||||
print(" 验证重点: 每天内部独立排名;结果是否严格分布在 0-1 之间")
|
||||
factor_b = cs_rank(close)
|
||||
engine.add_factor("factor_b_cs_rank", factor_b)
|
||||
print(f" AST: {factor_b}")
|
||||
|
||||
# 因子 C: 组合运算 ts_mean(close, 5) / open
|
||||
print("\n[2.4] 注册因子 C (组合运算): ts_mean(close, 5) / open")
|
||||
print(" 验证重点: 多字段算术运算与时序算子嵌套的稳定性")
|
||||
factor_c = ts_mean(close, 5) / open
|
||||
engine.add_factor("factor_c_composite", factor_c)
|
||||
print(f" AST: {factor_c}")
|
||||
|
||||
# 同时注册原始字段用于验证
|
||||
engine.add_factor("close_price", close)
|
||||
engine.add_factor("open_price", open)
|
||||
|
||||
print(f"\n已注册因子列表: {engine.list_factors()}")
|
||||
|
||||
# =========================================================================
|
||||
# 3. 计算执行
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("3. 计算执行")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\n[3.1] 执行因子计算 ({start_date} - {end_date})...")
|
||||
result_df = engine.compute(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
db_uri=db_uri,
|
||||
)
|
||||
|
||||
print(f"\n计算完成!")
|
||||
print(f"结果形状: {result_df.shape}")
|
||||
print(f"结果列: {result_df.columns}")
|
||||
|
||||
# =========================================================================
|
||||
# 4. 调试信息:打印 Context LazyFrame 前5行
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("4. 调试信息:DataLoader 拼接后的数据预览")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n[4.1] 重新构建 Context LazyFrame 并打印前 5 行...")
|
||||
from src.data.catalog import build_context_lazyframe
|
||||
|
||||
context_lf = build_context_lazyframe(
|
||||
required_fields=["close", "open"],
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
db_uri=db_uri,
|
||||
catalog=catalog,
|
||||
)
|
||||
|
||||
print("\nContext LazyFrame 前 5 行:")
|
||||
print(context_lf.fetch(5))
|
||||
|
||||
# =========================================================================
|
||||
# 5. 时序切片检查
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("5. 时序切片检查")
|
||||
print("=" * 80)
|
||||
|
||||
# 选择特定股票进行时序验证
|
||||
target_stock = sample_stocks[0] if sample_stocks else "000001.SZ"
|
||||
print(f"\n[5.1] 筛选股票: {target_stock}")
|
||||
|
||||
stock_df = result_df.filter(pl.col("ts_code") == target_stock)
|
||||
print(f"该股票数据行数: {len(stock_df)}")
|
||||
|
||||
print(f"\n[5.2] 打印前 15 行结果(验证 ts_mean 滑动窗口):")
|
||||
print("-" * 80)
|
||||
print("人工核查点:")
|
||||
print(" - 前 9 行的 factor_a_ts_mean_10 应该为 Null(滑动窗口未满)")
|
||||
print(" - 第 10 行开始应该有值")
|
||||
print("-" * 80)
|
||||
|
||||
display_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"close_price",
|
||||
"open_price",
|
||||
"factor_a_ts_mean_10",
|
||||
]
|
||||
available_cols = [c for c in display_cols if c in stock_df.columns]
|
||||
print(stock_df.select(available_cols).head(15))
|
||||
|
||||
# 验证滑动窗口
|
||||
print("\n[5.3] 滑动窗口验证:")
|
||||
stock_list = stock_df.select("factor_a_ts_mean_10").to_series().to_list()
|
||||
null_count_first_9 = sum(1 for x in stock_list[:9] if x is None)
|
||||
non_null_from_10 = sum(1 for x in stock_list[9:15] if x is not None)
|
||||
|
||||
print(f" 前 9 行 Null 值数量: {null_count_first_9}/9")
|
||||
print(f" 第 10-15 行非 Null 值数量: {non_null_from_10}/6")
|
||||
|
||||
if null_count_first_9 == 9 and non_null_from_10 == 6:
|
||||
print(" ✅ 滑动窗口验证通过!")
|
||||
else:
|
||||
print(" ⚠️ 滑动窗口验证异常,请检查数据")
|
||||
|
||||
# =========================================================================
|
||||
# 6. 截面切片检查
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("6. 截面切片检查")
|
||||
print("=" * 80)
|
||||
|
||||
# 选择特定交易日
|
||||
target_date = "20230301"
|
||||
print(f"\n[6.1] 筛选交易日: {target_date}")
|
||||
|
||||
date_df = result_df.filter(pl.col("trade_date") == target_date)
|
||||
print(f"该交易日股票数量: {len(date_df)}")
|
||||
|
||||
print(f"\n[6.2] 打印该日所有股票的 close 和 cs_rank 结果:")
|
||||
print("-" * 80)
|
||||
print("人工核查点:")
|
||||
print(" - close 最高的股票其 cs_rank 应该接近 1.0")
|
||||
print(" - close 最低的股票其 cs_rank 应该接近 0.0")
|
||||
print(" - cs_rank 值应该严格分布在 [0, 1] 区间")
|
||||
print("-" * 80)
|
||||
|
||||
# 按 close 排序显示
|
||||
display_df = date_df.select(
|
||||
["ts_code", "trade_date", "close_price", "factor_b_cs_rank"]
|
||||
)
|
||||
display_df = display_df.sort("close_price", descending=True)
|
||||
print(display_df)
|
||||
|
||||
# 验证截面排名
|
||||
print("\n[6.3] 截面排名验证:")
|
||||
rank_values = date_df.select("factor_b_cs_rank").to_series().to_list()
|
||||
rank_values = [x for x in rank_values if x is not None]
|
||||
|
||||
if rank_values:
|
||||
min_rank = min(rank_values)
|
||||
max_rank = max(rank_values)
|
||||
print(f" cs_rank 最小值: {min_rank:.6f}")
|
||||
print(f" cs_rank 最大值: {max_rank:.6f}")
|
||||
print(f" cs_rank 值域: [{min_rank:.6f}, {max_rank:.6f}]")
|
||||
|
||||
# 验证 close 最高的股票 rank 是否为 1.0
|
||||
highest_close_row = date_df.sort("close_price", descending=True).head(1)
|
||||
if len(highest_close_row) > 0:
|
||||
highest_rank = highest_close_row.select("factor_b_cs_rank").item()
|
||||
print(f" 最高 close 股票的 cs_rank: {highest_rank:.6f}")
|
||||
|
||||
if abs(highest_rank - 1.0) < 0.01:
|
||||
print(" ✅ 截面排名验证通过! (最高 close 股票 rank 接近 1.0)")
|
||||
else:
|
||||
print(f" ⚠️ 截面排名验证异常 (期望接近 1.0,实际 {highest_rank:.6f})")
|
||||
|
||||
# =========================================================================
|
||||
# 7. 数据完整性统计
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("7. 数据完整性统计")
|
||||
print("=" * 80)
|
||||
|
||||
factor_cols = ["factor_a_ts_mean_10", "factor_b_cs_rank", "factor_c_composite"]
|
||||
|
||||
print("\n[7.1] 各因子的空值数量和描述性统计:")
|
||||
print("-" * 80)
|
||||
|
||||
for col in factor_cols:
|
||||
if col in result_df.columns:
|
||||
series = result_df.select(col).to_series()
|
||||
null_count = series.null_count()
|
||||
total_count = len(series)
|
||||
|
||||
print(f"\n因子: {col}")
|
||||
print(f" 总记录数: {total_count}")
|
||||
print(f" 空值数量: {null_count} ({null_count / total_count * 100:.2f}%)")
|
||||
|
||||
# 描述性统计(排除空值)
|
||||
non_null_series = series.drop_nulls()
|
||||
if len(non_null_series) > 0:
|
||||
print(f" 描述性统计:")
|
||||
print(f" Mean: {non_null_series.mean():.6f}")
|
||||
print(f" Std: {non_null_series.std():.6f}")
|
||||
print(f" Min: {non_null_series.min():.6f}")
|
||||
print(f" Max: {non_null_series.max():.6f}")
|
||||
|
||||
# =========================================================================
|
||||
# 8. 综合验证
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("8. 综合验证")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n[8.1] 数据串户检查:")
|
||||
# 检查不同股票的数据是否正确隔离
|
||||
print(" 验证方法: 检查不同股票的 trade_date 序列是否独立")
|
||||
|
||||
stock_dates = {}
|
||||
for stock in sample_stocks[:3]: # 检查前3只股票
|
||||
stock_data = (
|
||||
result_df.filter(pl.col("ts_code") == stock)
|
||||
.select("trade_date")
|
||||
.to_series()
|
||||
.to_list()
|
||||
)
|
||||
stock_dates[stock] = stock_data[:5] # 前5个日期
|
||||
print(f" {stock} 前5个交易日期: {stock_data[:5]}")
|
||||
|
||||
# 检查日期序列是否一致(应该一致,因为是同一时间段)
|
||||
dates_match = all(
|
||||
dates == list(stock_dates.values())[0] for dates in stock_dates.values()
|
||||
)
|
||||
if dates_match:
|
||||
print(" ✅ 日期序列一致,数据对齐正确")
|
||||
else:
|
||||
print(" ⚠️ 日期序列不一致,请检查数据对齐")
|
||||
|
||||
print("\n[8.2] 因子 C 组合运算验证:")
|
||||
# 手动计算几行验证组合运算
|
||||
sample_row = result_df.filter(
|
||||
(pl.col("ts_code") == target_stock)
|
||||
& (pl.col("factor_a_ts_mean_10").is_not_null())
|
||||
).head(1)
|
||||
|
||||
if len(sample_row) > 0:
|
||||
close_val = sample_row.select("close_price").item()
|
||||
open_val = sample_row.select("open_price").item()
|
||||
factor_c_val = sample_row.select("factor_c_composite").item()
|
||||
|
||||
# 手动计算 ts_mean(close, 5) / open
|
||||
# 注意:这里只是验证表达式结构,不是精确计算
|
||||
print(f" 样本数据:")
|
||||
print(f" close: {close_val:.4f}")
|
||||
print(f" open: {open_val:.4f}")
|
||||
print(f" factor_c (ts_mean(close, 5) / open): {factor_c_val:.6f}")
|
||||
|
||||
# 验证 factor_c 是否合理(应该接近 close/open 的某个均值)
|
||||
ratio = close_val / open_val if open_val != 0 else 0
|
||||
print(f" close/open 比值: {ratio:.6f}")
|
||||
print(f" ✅ 组合运算结果已生成")
|
||||
|
||||
# =========================================================================
|
||||
# 9. 测试总结
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("9. 测试总结")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n测试完成! 以下是关键验证点总结:")
|
||||
print("-" * 80)
|
||||
print("✅ 因子 A (ts_mean):")
|
||||
print(" - 10日滑动窗口计算正确")
|
||||
print(" - 前9行为Null,第10行开始有值")
|
||||
print(" - 不同股票数据隔离(over(ts_code))")
|
||||
print()
|
||||
print("✅ 因子 B (cs_rank):")
|
||||
print(" - 每日独立排名(over(trade_date))")
|
||||
print(" - 结果分布在 [0, 1] 区间")
|
||||
print(" - 最高close股票rank接近1.0")
|
||||
print()
|
||||
print("✅ 因子 C (组合运算):")
|
||||
print(" - 多字段算术运算正常")
|
||||
print(" - 时序算子嵌套稳定")
|
||||
print()
|
||||
print("✅ 数据完整性:")
|
||||
print(f" - 总记录数: {len(result_df)}")
|
||||
print(f" - 样本股票数: {len(sample_stocks)}")
|
||||
print(f" - 时间范围: {start_date} 至 {end_date}")
|
||||
print("-" * 80)
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置随机种子以确保可重复性
|
||||
random.seed(42)
|
||||
|
||||
# 运行测试
|
||||
result = run_factor_integration_test()
|
||||
@@ -1,351 +0,0 @@
|
||||
"""财务数据与行情数据拼接测试。
|
||||
|
||||
测试场景:
|
||||
1. 普通财务数据:正常公告,之后无修改
|
||||
2. 隔日修改:公告后几天发布修正版
|
||||
3. 当日修改:同一天发布多版,取 update_flag=1 的
|
||||
4. 边界条件:财务数据缺失、行情数据早于最早财务数据
|
||||
"""
|
||||
|
||||
import polars as pl
|
||||
from datetime import date
|
||||
from src.data.financial_loader import FinancialLoader
|
||||
|
||||
|
||||
def create_mock_price_data() -> pl.DataFrame:
|
||||
"""创建模拟行情数据。"""
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"] * 12,
|
||||
"trade_date": [
|
||||
"20240101",
|
||||
"20240102",
|
||||
"20240103",
|
||||
"20240104",
|
||||
"20240105",
|
||||
"20240108",
|
||||
"20240109",
|
||||
"20240110",
|
||||
"20240111",
|
||||
"20240112",
|
||||
# 添加2024-04-30之后的日期,用于测试同日不同报告期场景
|
||||
"20240501",
|
||||
"20240502",
|
||||
],
|
||||
"close": [
|
||||
10.0,
|
||||
10.2,
|
||||
10.3,
|
||||
10.1,
|
||||
10.5,
|
||||
10.6,
|
||||
10.4,
|
||||
10.7,
|
||||
10.8,
|
||||
10.9,
|
||||
11.0,
|
||||
11.1,
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_mock_financial_data() -> pl.DataFrame:
|
||||
"""创建模拟财务数据(覆盖多种场景)。
|
||||
|
||||
场景说明:
|
||||
1. 2024-01-02 发布 2023Q3 报告(end_date=20230930)
|
||||
2. 2024-01-02 发布 2023Q3 更正版(update_flag=1)
|
||||
3. 2024-04-30 同时发布 2023年报(end_date=20231231)和 2024Q1季报(end_date=20240331)
|
||||
4. 2024-04-30 发布 2023年报更正版
|
||||
|
||||
预期结果:
|
||||
- 2024-01-02 保留 2023Q3 更正版
|
||||
- 2024-04-30 保留 2024Q1 季报(end_date 最新)
|
||||
|
||||
注意:f_ann_date 必须是 Date 类型(与数据库保持一致)。
|
||||
"""
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"ts_code": [
|
||||
"000001.SZ",
|
||||
"000001.SZ",
|
||||
"000001.SZ",
|
||||
"000001.SZ",
|
||||
"000001.SZ",
|
||||
],
|
||||
"f_ann_date": [
|
||||
date(2024, 1, 2),
|
||||
date(2024, 1, 2), # 同日多版
|
||||
date(2024, 4, 30),
|
||||
date(2024, 4, 30),
|
||||
date(2024, 4, 30), # 同日不同报告期
|
||||
],
|
||||
"end_date": [
|
||||
"20230930",
|
||||
"20230930", # 2023Q3
|
||||
"20231231",
|
||||
"20240331",
|
||||
"20231231", # 年报和季报同一天发布
|
||||
],
|
||||
"report_type": [1, 1, 1, 1, 1], # 整数类型(与数据库一致)
|
||||
"update_flag": [0, 1, 0, 0, 1], # 年报也有更正版
|
||||
"net_profit": [
|
||||
1000000.0,
|
||||
1100000.0, # 2023Q3
|
||||
5000000.0,
|
||||
1500000.0,
|
||||
5500000.0, # 年报更正后550万,季报150万
|
||||
],
|
||||
"revenue": [
|
||||
5000000.0,
|
||||
5200000.0, # 2023Q3
|
||||
20000000.0,
|
||||
8000000.0,
|
||||
22000000.0,
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_financial_data_cleaning():
|
||||
"""测试财务数据清洗逻辑 - 确保同日多报告期时选 end_date 最新的。"""
|
||||
print("=== 测试 1: 财务数据清洗 ===")
|
||||
|
||||
df_finance = create_mock_financial_data()
|
||||
print("原始财务数据:")
|
||||
print(df_finance)
|
||||
|
||||
loader = FinancialLoader()
|
||||
|
||||
# 手动执行新的清洗逻辑
|
||||
df = df_finance.filter(pl.col("report_type") == 1)
|
||||
|
||||
# 添加辅助列
|
||||
df = df.with_columns(
|
||||
[
|
||||
pl.col("end_date").cast(pl.Int32).alias("end_date_int"),
|
||||
pl.col("update_flag")
|
||||
.fill_null("0")
|
||||
.cast(pl.Int32, strict=False)
|
||||
.fill_null(0)
|
||||
.alias("update_flag_int"),
|
||||
]
|
||||
)
|
||||
|
||||
# 确定性排序
|
||||
df = df.sort(["ts_code", "f_ann_date", "end_date_int", "update_flag_int"])
|
||||
|
||||
# 累积最大报告期
|
||||
df = df.with_columns(
|
||||
pl.col("end_date_int").cum_max().over("ts_code").alias("max_end_date_seen")
|
||||
)
|
||||
|
||||
# 过滤历史包袱
|
||||
df = df.filter(pl.col("end_date_int") == pl.col("max_end_date_seen"))
|
||||
|
||||
# 去重保留最后一条(end_date 最大的)
|
||||
df = df.unique(subset=["ts_code", "f_ann_date"], keep="last")
|
||||
|
||||
# 清理辅助列
|
||||
df = df.drop(["end_date_int", "update_flag_int", "max_end_date_seen"])
|
||||
df = df.sort(["ts_code", "f_ann_date"])
|
||||
|
||||
print("\n清洗后的财务数据:")
|
||||
print(df)
|
||||
|
||||
# 验证:应该有2条记录(2024-01-02 和 2024-04-30)
|
||||
assert len(df) == 2, f"清洗后应该有2条记录,实际有 {len(df)} 条"
|
||||
|
||||
# 验证:2024-01-02 的 end_date 应该是 20230930
|
||||
row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2))
|
||||
assert len(row_jan02) == 1
|
||||
assert row_jan02["end_date"][0] == "20230930"
|
||||
assert row_jan02["update_flag"][0] == 1
|
||||
print("[验证 1] 2024-01-02 正确保留了 2023Q3 更正版")
|
||||
|
||||
# 验证:2024-04-30 应该保留 2024Q1(end_date=20240331),而不是年报
|
||||
row_apr30 = df.filter(pl.col("f_ann_date") == date(2024, 4, 30))
|
||||
assert len(row_apr30) == 1
|
||||
assert row_apr30["end_date"][0] == "20240331", (
|
||||
f"2024-04-30 应该保留 end_date 最新的 20240331,"
|
||||
f"实际为 {row_apr30['end_date'][0]}"
|
||||
)
|
||||
assert row_apr30["net_profit"][0] == 1500000.0
|
||||
print("[验证 2] 2024-04-30 正确保留了 2024Q1 季报(end_date 最新)")
|
||||
|
||||
print("\n[通过] 财务数据清洗测试通过!")
|
||||
return df
|
||||
|
||||
|
||||
def test_financial_price_merge():
|
||||
"""测试财务数据拼接逻辑(无未来函数验证)。"""
|
||||
print("\n=== 测试 2: 财务数据与行情数据拼接 ===")
|
||||
|
||||
df_price = create_mock_price_data()
|
||||
df_finance_raw = create_mock_financial_data()
|
||||
|
||||
loader = FinancialLoader()
|
||||
|
||||
# 步骤1: 清洗财务数据(手动执行新的清洗逻辑)
|
||||
# 注意:f_ann_date 已经是 Date 类型,不需要转换
|
||||
df_finance = df_finance_raw.filter(pl.col("report_type") == 1)
|
||||
|
||||
# 添加辅助列
|
||||
df_finance = df_finance.with_columns(
|
||||
[
|
||||
pl.col("end_date").cast(pl.Int32).alias("end_date_int"),
|
||||
pl.col("update_flag")
|
||||
.fill_null("0")
|
||||
.cast(pl.Int32, strict=False)
|
||||
.fill_null(0)
|
||||
.alias("update_flag_int"),
|
||||
]
|
||||
)
|
||||
|
||||
# 确定性排序
|
||||
df_finance = df_finance.sort(
|
||||
["ts_code", "f_ann_date", "end_date_int", "update_flag_int"]
|
||||
)
|
||||
|
||||
# 累积最大报告期
|
||||
df_finance = df_finance.with_columns(
|
||||
pl.col("end_date_int").cum_max().over("ts_code").alias("max_end_date_seen")
|
||||
)
|
||||
|
||||
# 过滤历史包袱
|
||||
df_finance = df_finance.filter(
|
||||
pl.col("end_date_int") == pl.col("max_end_date_seen")
|
||||
)
|
||||
|
||||
# 去重保留最后一条(end_date 最大的)
|
||||
df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="last")
|
||||
|
||||
# 清理辅助列
|
||||
df_finance = df_finance.drop(
|
||||
["end_date_int", "update_flag_int", "max_end_date_seen"]
|
||||
)
|
||||
df_finance = df_finance.sort(["ts_code", "f_ann_date"])
|
||||
|
||||
print("清洗后的财务数据:")
|
||||
print(df_finance)
|
||||
|
||||
# 步骤2: 转换行情数据日期为 Date 类型
|
||||
df_price = df_price.with_columns(
|
||||
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
df_price = df_price.sort(["ts_code", "trade_date"])
|
||||
|
||||
# 步骤3: 拼接
|
||||
financial_cols = ["net_profit", "revenue"]
|
||||
merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols)
|
||||
|
||||
# 步骤4: 转回字符串格式
|
||||
merged = merged.with_columns(
|
||||
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
|
||||
print("\n拼接结果:")
|
||||
print(merged)
|
||||
|
||||
# 验证无未来函数:
|
||||
# 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告)
|
||||
jan01 = merged.filter(pl.col("trade_date") == "20240101")
|
||||
assert jan01["net_profit"].is_null().all(), (
|
||||
"2024-01-01 不应有 2023Q3 数据(尚未公告)"
|
||||
)
|
||||
print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)")
|
||||
|
||||
# 20240102 及之后应该看到 net_profit=1100000(update_flag=1 的版本)
|
||||
jan02 = merged.filter(pl.col("trade_date") == "20240102")
|
||||
assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据"
|
||||
print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1)")
|
||||
|
||||
# 20240104 应延续使用 2023Q3 数据
|
||||
jan04 = merged.filter(pl.col("trade_date") == "20240104")
|
||||
assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据"
|
||||
print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)")
|
||||
|
||||
# 20240110 应延续使用 2023Q3 数据(2024-04-30 还未公告)
|
||||
jan10 = merged.filter(pl.col("trade_date") == "20240110")
|
||||
assert jan10["net_profit"][0] == 1100000.0, "2024-01-10 应延续使用 2023Q3 数据"
|
||||
print("[验证 4] 2024-01-10 net_profit=1100000 - 正确(延续使用 2023Q3)")
|
||||
|
||||
# 20240112 应继续延续使用 2023Q3 数据
|
||||
jan12 = merged.filter(pl.col("trade_date") == "20240112")
|
||||
assert jan12["net_profit"][0] == 1100000.0, "2024-01-12 应继续使用 2023Q3 数据"
|
||||
print("[验证 5] 2024-01-12 net_profit=1100000 - 正确(延续使用 2023Q3)")
|
||||
|
||||
# 20240501 应切换到 2024Q1 数据(2024-04-30 已公告,且选择 end_date 最新的)
|
||||
may01 = merged.filter(pl.col("trade_date") == "20240501")
|
||||
assert may01["net_profit"][0] == 1500000.0, "2024-05-01 应切换到 2024Q1 数据"
|
||||
print(
|
||||
"[验证 6] 2024-05-01 net_profit=1500000 - 正确(切换到 2024Q1,end_date 最新)"
|
||||
)
|
||||
|
||||
print("\n[通过] 所有验证通过,无未来函数!")
|
||||
return merged
|
||||
|
||||
|
||||
def test_empty_financial_data():
|
||||
"""测试财务数据为空的情况。"""
|
||||
print("\n=== 测试 3: 空财务数据场景 ===")
|
||||
|
||||
df_price = create_mock_price_data()
|
||||
df_empty = pl.DataFrame()
|
||||
|
||||
loader = FinancialLoader()
|
||||
|
||||
# 转换行情数据日期为 Date 类型
|
||||
df_price = df_price.with_columns(
|
||||
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
df_price = df_price.sort(["ts_code", "trade_date"])
|
||||
|
||||
# 拼接空财务数据
|
||||
merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"])
|
||||
|
||||
# 转回字符串格式
|
||||
merged = merged.with_columns(
|
||||
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
|
||||
# 验证财务列为空
|
||||
assert merged["net_profit"].is_null().all(), (
|
||||
"财务数据为空时,net_profit 应全为 null"
|
||||
)
|
||||
|
||||
print("空财务数据拼接结果:")
|
||||
print(merged)
|
||||
print("\n[通过] 空财务数据场景测试通过!")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""运行所有测试。"""
|
||||
print("开始运行财务数据拼接功能测试...\n")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 测试 1: 数据清洗
|
||||
test_financial_data_cleaning()
|
||||
|
||||
# 测试 2: 数据拼接
|
||||
test_financial_price_merge()
|
||||
|
||||
# 测试 3: 空数据场景
|
||||
test_empty_financial_data()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有测试通过!")
|
||||
print("=" * 60)
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n[失败] 测试断言失败: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"\n[错误] 测试执行出错: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
||||
@@ -1,130 +0,0 @@
|
||||
"""测试新增的时间序列函数和智能分发逻辑。"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
from src.factors.translator import PolarsTranslator
|
||||
|
||||
|
||||
def test_ts_sma_translate():
|
||||
"""测试 ts_sma 翻译正确。"""
|
||||
close = Symbol("close")
|
||||
expr = FunctionNode("ts_sma", close, 10, 5)
|
||||
translator = PolarsTranslator()
|
||||
result = translator.translate(expr)
|
||||
assert isinstance(result, pl.Expr)
|
||||
|
||||
|
||||
def test_ts_wma_translate():
|
||||
"""测试 ts_wma 翻译正确。"""
|
||||
close = Symbol("close")
|
||||
expr = FunctionNode("ts_wma", close, 20)
|
||||
translator = PolarsTranslator()
|
||||
result = translator.translate(expr)
|
||||
assert isinstance(result, pl.Expr)
|
||||
|
||||
|
||||
def test_ts_sumac_translate():
|
||||
"""测试 ts_sumac 翻译正确。"""
|
||||
close = Symbol("close")
|
||||
expr = FunctionNode("ts_sumac", close)
|
||||
translator = PolarsTranslator()
|
||||
result = translator.translate(expr)
|
||||
assert isinstance(result, pl.Expr)
|
||||
|
||||
|
||||
def test_max_intelligent_dispatch():
|
||||
"""测试 max_ 智能分发: int -> ts_max,其他 -> element-wise max。"""
|
||||
from src.factors.api import max_, close
|
||||
|
||||
# 正整数 -> ts_max
|
||||
result = max_(close, 20)
|
||||
assert result.func_name == "ts_max"
|
||||
|
||||
# 零或负数 -> element-wise max
|
||||
result = max_(close, 0)
|
||||
assert result.func_name == "max"
|
||||
|
||||
result = max_(close, -1)
|
||||
assert result.func_name == "max"
|
||||
|
||||
# 浮点数 -> element-wise max
|
||||
result = max_(close, 10.5)
|
||||
assert result.func_name == "max"
|
||||
|
||||
|
||||
def test_min_intelligent_dispatch():
|
||||
"""测试 min_ 智能分发: int -> ts_min,其他 -> element-wise min。"""
|
||||
from src.factors.api import min_, close
|
||||
|
||||
# 正整数 -> ts_min
|
||||
result = min_(close, 20)
|
||||
assert result.func_name == "ts_min"
|
||||
|
||||
# 零或负数 -> element-wise min
|
||||
result = min_(close, 0)
|
||||
assert result.func_name == "min"
|
||||
|
||||
|
||||
def create_test_data() -> pl.DataFrame:
|
||||
"""创建测试数据。"""
|
||||
np.random.seed(42)
|
||||
n = 100
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"] * n,
|
||||
"trade_date": list(range(20240101, 20240101 + n)),
|
||||
"close": np.random.randn(n).cumsum() + 100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_ts_sma_computation():
|
||||
"""测试 ts_sma 计算与原生 Polars 一致。"""
|
||||
df = create_test_data()
|
||||
translator = PolarsTranslator()
|
||||
|
||||
# 翻译因子
|
||||
close = Symbol("close")
|
||||
expr_node = FunctionNode("ts_sma", close, 10, 5)
|
||||
expr = translator.translate(expr_node)
|
||||
|
||||
# 使用翻译后的表达式计算
|
||||
result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")])
|
||||
|
||||
# 原生 Polars 计算
|
||||
native = df.with_columns(
|
||||
[pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")]
|
||||
)
|
||||
|
||||
# 对比结果
|
||||
assert np.allclose(
|
||||
result["ts_sma_result"].to_numpy()[9:],
|
||||
native["native_sma"].to_numpy()[9:],
|
||||
equal_nan=True,
|
||||
)
|
||||
|
||||
|
||||
def test_ts_sumac_computation():
|
||||
"""测试 ts_sumac 计算与原生 Polars 一致。"""
|
||||
df = create_test_data()
|
||||
translator = PolarsTranslator()
|
||||
|
||||
close = Symbol("close")
|
||||
expr_node = FunctionNode("ts_sumac", close)
|
||||
expr = translator.translate(expr_node)
|
||||
|
||||
result = df.select(
|
||||
["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")]
|
||||
)
|
||||
|
||||
native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")])
|
||||
|
||||
assert np.allclose(
|
||||
result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,541 +0,0 @@
|
||||
"""Phase 1-2 因子函数集成测试。
|
||||
|
||||
测试所有新实现的函数,使用字符串因子表达式形式计算因子,
|
||||
并与原始 Polars 计算结果进行对比。
|
||||
|
||||
测试范围:
|
||||
1. 数学函数:atan, log1p
|
||||
2. 统计函数:ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema
|
||||
3. TA-Lib 函数:ts_atr, ts_rsi, ts_obv
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.factors import FormulaParser, FunctionRegistry
|
||||
from src.factors.translator import PolarsTranslator, HAS_TALIB
|
||||
from src.factors.engine import FactorEngine
|
||||
from src.data.catalog import DatabaseCatalog
|
||||
|
||||
|
||||
# ============== 测试数据准备 ==============
|
||||
|
||||
|
||||
def create_test_data() -> pl.DataFrame:
|
||||
"""创建测试用的模拟数据。
|
||||
|
||||
创建一个包含多只股票、多个交易日的 DataFrame,
|
||||
用于测试因子函数的计算。
|
||||
"""
|
||||
np.random.seed(42)
|
||||
|
||||
dates = pl.date_range(
|
||||
start=pl.date(2024, 1, 1),
|
||||
end=pl.date(2024, 1, 31),
|
||||
interval="1d",
|
||||
eager=True,
|
||||
)
|
||||
|
||||
stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600001.SH"]
|
||||
|
||||
data = []
|
||||
for stock in stocks:
|
||||
base_price = 100 + np.random.randn() * 10
|
||||
for i, date in enumerate(dates):
|
||||
price = base_price + np.random.randn() * 5 + i * 0.1
|
||||
data.append(
|
||||
{
|
||||
"ts_code": stock,
|
||||
"trade_date": date,
|
||||
"close": price,
|
||||
"open": price * (1 + np.random.randn() * 0.01),
|
||||
"high": price * (1 + abs(np.random.randn()) * 0.02),
|
||||
"low": price * (1 - abs(np.random.randn()) * 0.02),
|
||||
"vol": int(1000000 + np.random.randn() * 500000),
|
||||
}
|
||||
)
|
||||
|
||||
return pl.DataFrame(data)
|
||||
|
||||
|
||||
# ============== 数学函数测试 ==============
|
||||
|
||||
|
||||
def test_atan_function():
|
||||
"""测试 atan 函数:计算反正切值。"""
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 5,
|
||||
"trade_date": pl.date_range(
|
||||
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
|
||||
),
|
||||
"value": [0.0, 1.0, -1.0, 0.5, -0.5],
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("atan(value)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 原始 Polars 计算
|
||||
result_pl = df.with_columns(pl_result=pl.col("value").arctan()).to_pandas()[
|
||||
"pl_result"
|
||||
]
|
||||
|
||||
# 对比结果
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, result_pl.values, decimal=10
|
||||
)
|
||||
|
||||
|
||||
def test_log1p_function():
|
||||
"""测试 log1p 函数:计算 log(1+x)。"""
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 5,
|
||||
"trade_date": pl.date_range(
|
||||
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
|
||||
),
|
||||
"value": [0.0, 0.1, -0.1, 1.0, -0.5],
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("log1p(value)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 原始 Polars 计算
|
||||
result_pl = df.with_columns(pl_result=pl.col("value").log1p()).to_pandas()[
|
||||
"pl_result"
|
||||
]
|
||||
|
||||
# 对比结果
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, result_pl.values, decimal=10
|
||||
)
|
||||
|
||||
|
||||
# ============== 统计函数测试 ==============
|
||||
|
||||
|
||||
def test_ts_var_function():
|
||||
"""测试 ts_var 函数:滚动方差。"""
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 10 + ["B"] * 10,
|
||||
"trade_date": pl.date_range(
|
||||
pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True
|
||||
).append(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
|
||||
),
|
||||
"close": list(range(1, 11)) + list(range(10, 20)),
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("ts_var(close, 5)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = (
|
||||
df.with_columns(dsl_result=polars_expr)
|
||||
.to_pandas()
|
||||
.groupby("ts_code")["dsl_result"]
|
||||
.apply(list)
|
||||
)
|
||||
|
||||
# 原始 Polars 计算
|
||||
result_pl = (
|
||||
df.with_columns(
|
||||
pl_result=pl.col("close").rolling_var(window_size=5).over("ts_code")
|
||||
)
|
||||
.to_pandas()
|
||||
.groupby("ts_code")["pl_result"]
|
||||
.apply(list)
|
||||
)
|
||||
|
||||
# 对比结果
|
||||
for stock in ["A", "B"]:
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl[stock], result_pl[stock], decimal=10
|
||||
)
|
||||
|
||||
|
||||
def test_ts_skew_function():
|
||||
"""测试 ts_skew 函数:滚动偏度。"""
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 20 + ["B"] * 20,
|
||||
"trade_date": list(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
|
||||
)
|
||||
* 2,
|
||||
"close": np.random.randn(40),
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("ts_skew(close, 10)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 原始 Polars 计算
|
||||
result_pl = df.with_columns(
|
||||
pl_result=pl.col("close").rolling_skew(window_size=10).over("ts_code")
|
||||
).to_pandas()["pl_result"]
|
||||
|
||||
# 对比结果
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, result_pl.values, decimal=10
|
||||
)
|
||||
|
||||
|
||||
def test_ts_kurt_function():
|
||||
"""测试 ts_kurt 函数:滚动峰度。"""
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 20 + ["B"] * 20,
|
||||
"trade_date": list(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
|
||||
)
|
||||
* 2,
|
||||
"close": np.random.randn(40),
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("ts_kurt(close, 10)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 原始 Polars 计算
|
||||
result_pl = df.with_columns(
|
||||
pl_result=pl.col("close")
|
||||
.rolling_map(
|
||||
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
|
||||
window_size=10,
|
||||
)
|
||||
.over("ts_code")
|
||||
).to_pandas()["pl_result"]
|
||||
|
||||
# 对比结果
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, result_pl.values, decimal=10
|
||||
)
|
||||
|
||||
|
||||
def test_ts_pct_change_function():
|
||||
"""测试 ts_pct_change 函数:百分比变化。"""
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 5 + ["B"] * 5,
|
||||
"trade_date": list(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True)
|
||||
)
|
||||
* 2,
|
||||
"close": [100, 105, 102, 108, 110, 50, 52, 48, 55, 60],
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("ts_pct_change(close, 1)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 原始 Polars 计算
|
||||
result_pl = df.with_columns(
|
||||
pl_result=(pl.col("close") - pl.col("close").shift(1))
|
||||
/ pl.col("close").shift(1).over("ts_code")
|
||||
).to_pandas()["pl_result"]
|
||||
|
||||
# 对比结果
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, result_pl.values, decimal=10
|
||||
)
|
||||
|
||||
|
||||
def test_ts_ema_function():
|
||||
"""测试 ts_ema 函数:指数移动平均。"""
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 10 + ["B"] * 10,
|
||||
"trade_date": list(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
|
||||
)
|
||||
* 2,
|
||||
"close": list(range(1, 11)) + list(range(10, 20)),
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("ts_ema(close, 5)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 原始 Polars 计算
|
||||
result_pl = df.with_columns(
|
||||
pl_result=pl.col("close").ewm_mean(span=5).over("ts_code")
|
||||
).to_pandas()["pl_result"]
|
||||
|
||||
# 对比结果
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, result_pl.values, decimal=10
|
||||
)
|
||||
|
||||
|
||||
# ============== TA-Lib 函数测试 ==============
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
|
||||
def test_ts_atr_function():
|
||||
"""测试 ts_atr 函数:平均真实波幅。"""
|
||||
import talib
|
||||
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 20 + ["B"] * 20,
|
||||
"trade_date": list(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
|
||||
)
|
||||
* 2,
|
||||
"high": 100 + np.random.randn(40) * 2,
|
||||
"low": 98 + np.random.randn(40) * 2,
|
||||
"close": 99 + np.random.randn(40) * 2,
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("ts_atr(high, low, close, 14)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 使用 talib 手动计算(分组计算)
|
||||
result_expected = []
|
||||
for stock in ["A", "B"]:
|
||||
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
|
||||
atr = talib.ATR(
|
||||
stock_df["high"].values,
|
||||
stock_df["low"].values,
|
||||
stock_df["close"].values,
|
||||
timeperiod=14,
|
||||
)
|
||||
result_expected.extend(atr)
|
||||
|
||||
# 对比结果(允许小误差)
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, np.array(result_expected), decimal=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
|
||||
def test_ts_rsi_function():
|
||||
"""测试 ts_rsi 函数:相对强弱指数。"""
|
||||
import talib
|
||||
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 30 + ["B"] * 30,
|
||||
"trade_date": list(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
|
||||
)
|
||||
* 2,
|
||||
"close": 100 + np.cumsum(np.random.randn(60)),
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("ts_rsi(close, 14)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 使用 talib 手动计算(分组计算)
|
||||
result_expected = []
|
||||
for stock in ["A", "B"]:
|
||||
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
|
||||
rsi = talib.RSI(stock_df["close"].values, timeperiod=14)
|
||||
result_expected.extend(rsi)
|
||||
|
||||
# 对比结果(允许小误差)
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, np.array(result_expected), decimal=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
|
||||
def test_ts_obv_function():
|
||||
"""测试 ts_obv 函数:能量潮指标。"""
|
||||
import talib
|
||||
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 20 + ["B"] * 20,
|
||||
"trade_date": list(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
|
||||
)
|
||||
* 2,
|
||||
"close": 100 + np.cumsum(np.random.randn(40)),
|
||||
"vol": np.random.randint(100000, 1000000, 40).astype(float),
|
||||
}
|
||||
)
|
||||
|
||||
# DSL 计算
|
||||
expr = parser.parse("ts_obv(close, vol)")
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||
|
||||
# 使用 talib 手动计算(分组计算)
|
||||
result_expected = []
|
||||
for stock in ["A", "B"]:
|
||||
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
|
||||
obv = talib.OBV(
|
||||
stock_df["close"].values,
|
||||
stock_df["vol"].values,
|
||||
)
|
||||
result_expected.extend(obv)
|
||||
|
||||
# 对比结果(允许小误差)
|
||||
np.testing.assert_array_almost_equal(
|
||||
result_dsl.values, np.array(result_expected), decimal=5
|
||||
)
|
||||
|
||||
|
||||
# ============== 综合测试 ==============
|
||||
|
||||
|
||||
def test_complex_factor_expressions():
|
||||
"""测试复杂因子表达式的计算。"""
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A"] * 30 + ["B"] * 30,
|
||||
"trade_date": list(
|
||||
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
|
||||
)
|
||||
* 2,
|
||||
"close": 100 + np.cumsum(np.random.randn(60)),
|
||||
}
|
||||
)
|
||||
|
||||
# 测试 act_factor1: atan((ts_ema(close,5)/ts_delay(ts_ema(close,5),1)-1)*100) * 57.3 / 50
|
||||
expr = parser.parse(
|
||||
"atan((ts_ema(close, 5) / ts_delay(ts_ema(close, 5), 1) - 1) * 100) * 57.3 / 50"
|
||||
)
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(expr)
|
||||
result = df.with_columns(factor=polars_expr)
|
||||
|
||||
# 验证结果不为空
|
||||
assert len(result) == 60
|
||||
assert "factor" in result.columns
|
||||
|
||||
print("复杂因子表达式测试通过")
|
||||
|
||||
|
||||
# ============== 主函数 ==============
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("运行 Phase 1-2 因子函数测试...")
|
||||
print("=" * 80)
|
||||
|
||||
# 运行数学函数测试
|
||||
print("\n[数学函数测试]")
|
||||
test_atan_function()
|
||||
print(" ✅ atan 测试通过")
|
||||
|
||||
test_log1p_function()
|
||||
print(" ✅ log1p 测试通过")
|
||||
|
||||
# 运行统计函数测试
|
||||
print("\n[统计函数测试]")
|
||||
test_ts_var_function()
|
||||
print(" ✅ ts_var 测试通过")
|
||||
|
||||
test_ts_skew_function()
|
||||
print(" ✅ ts_skew 测试通过")
|
||||
|
||||
test_ts_kurt_function()
|
||||
print(" ✅ ts_kurt 测试通过")
|
||||
|
||||
test_ts_pct_change_function()
|
||||
print(" ✅ ts_pct_change 测试通过")
|
||||
|
||||
test_ts_ema_function()
|
||||
print(" ✅ ts_ema 测试通过")
|
||||
|
||||
# 运行 TA-Lib 函数测试
|
||||
print("\n[TA-Lib 函数测试]")
|
||||
try:
|
||||
import talib
|
||||
|
||||
HAS_TALIB = True
|
||||
except ImportError:
|
||||
HAS_TALIB = False
|
||||
print(" ⚠️ TA-Lib 未安装,跳过 TA-Lib 测试")
|
||||
|
||||
if HAS_TALIB:
|
||||
test_ts_atr_function()
|
||||
print(" ✅ ts_atr 测试通过")
|
||||
|
||||
test_ts_rsi_function()
|
||||
print(" ✅ ts_rsi 测试通过")
|
||||
|
||||
test_ts_obv_function()
|
||||
print(" ✅ ts_obv 测试通过")
|
||||
|
||||
# 运行综合测试
|
||||
print("\n[综合测试]")
|
||||
test_complex_factor_expressions()
|
||||
print(" ✅ 复杂因子表达式测试通过")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("所有测试通过!")
|
||||
@@ -1,421 +0,0 @@
|
||||
"""Test for pro_bar (universal market) API.
|
||||
|
||||
Tests the pro_bar interface implementation:
|
||||
- Backward-adjusted (后复权) data fetching
|
||||
- All output fields including tor, vr, and adj_factor (default behavior)
|
||||
- Multiple asset types support
|
||||
- ProBarSync batch synchronization
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.data.api_wrappers.api_pro_bar import (
|
||||
get_pro_bar,
|
||||
ProBarSync,
|
||||
sync_pro_bar,
|
||||
preview_pro_bar_sync,
|
||||
)
|
||||
|
||||
|
||||
# Expected output fields according to api.md
|
||||
EXPECTED_BASE_FIELDS = [
|
||||
"ts_code", # 股票代码
|
||||
"trade_date", # 交易日期
|
||||
"open", # 开盘价
|
||||
"high", # 最高价
|
||||
"low", # 最低价
|
||||
"close", # 收盘价
|
||||
"pre_close", # 昨收价
|
||||
"change", # 涨跌额
|
||||
"pct_chg", # 涨跌幅
|
||||
"vol", # 成交量
|
||||
"amount", # 成交额
|
||||
]
|
||||
|
||||
EXPECTED_FACTOR_FIELDS = [
|
||||
"turnover_rate", # 换手率 (tor)
|
||||
"volume_ratio", # 量比 (vr)
|
||||
]
|
||||
|
||||
|
||||
class TestGetProBar:
|
||||
"""Test cases for get_pro_bar function."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_basic(self, mock_client_class):
|
||||
"""Test basic pro_bar data fetch."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"open": [10.5],
|
||||
"high": [11.0],
|
||||
"low": [10.2],
|
||||
"close": [10.8],
|
||||
"pre_close": [10.5],
|
||||
"change": [0.3],
|
||||
"pct_chg": [2.86],
|
||||
"vol": [100000.0],
|
||||
"amount": [1080000.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_pro_bar("000001.SZ", start_date="20240101", end_date="20240131")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert not result.empty
|
||||
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||
mock_client.query.assert_called_once()
|
||||
# Verify pro_bar API is called
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[0][0] == "pro_bar"
|
||||
assert call_args[1]["ts_code"] == "000001.SZ"
|
||||
# Default should use hfq (backward-adjusted)
|
||||
assert call_args[1]["adj"] == "hfq"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_default_backward_adjusted(self, mock_client_class):
|
||||
"""Test that default adjustment is backward (hfq)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [100.5],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["adj"] == "hfq"
|
||||
assert call_args[1]["adjfactor"] == "True"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_default_factors_all_fields(self, mock_client_class):
|
||||
"""Test that default factors includes tor and vr."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"turnover_rate": [2.5],
|
||||
"volume_ratio": [1.2],
|
||||
"adj_factor": [1.05],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
# Default should include both tor and vr
|
||||
assert call_args[1]["factors"] == "tor,vr"
|
||||
assert "turnover_rate" in result.columns
|
||||
assert "volume_ratio" in result.columns
|
||||
assert "adj_factor" in result.columns
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_custom_factors(self, mock_client_class):
|
||||
"""Test fetch with custom factors."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"turnover_rate": [2.5],
|
||||
}
|
||||
)
|
||||
|
||||
# Only request tor
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=["tor"],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["factors"] == "tor"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_no_factors(self, mock_client_class):
|
||||
"""Test fetch with no factors (empty list)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
# Explicitly set factors to empty list
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=[],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
# Should not include factors parameter
|
||||
assert "factors" not in call_args[1]
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_ma(self, mock_client_class):
|
||||
"""Test fetch with moving averages."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"ma_5": [10.5],
|
||||
"ma_10": [10.3],
|
||||
"ma_v_5": [95000.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
ma=[5, 10],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["ma"] == "5,10"
|
||||
assert "ma_5" in result.columns
|
||||
assert "ma_10" in result.columns
|
||||
assert "ma_v_5" in result.columns
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_index_data(self, mock_client_class):
|
||||
"""Test fetching index data."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SH"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [2900.5],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SH",
|
||||
asset="I",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["asset"] == "I"
|
||||
assert call_args[1]["ts_code"] == "000001.SH"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_forward_adjustment(self, mock_client_class):
|
||||
"""Test forward adjustment (qfq)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ", adj="qfq")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["adj"] == "qfq"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_no_adjustment(self, mock_client_class):
|
||||
"""Test no adjustment."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ", adj=None)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert "adj" not in call_args[1]
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_empty_response(self, mock_client_class):
|
||||
"""Test handling empty response."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame()
|
||||
|
||||
result = get_pro_bar("INVALID.SZ")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert result.empty
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_date_column_rename(self, mock_client_class):
|
||||
"""Test that 'date' column is renamed to 'trade_date'."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"date": ["20240115"], # API returns 'date' instead of 'trade_date'
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
assert "trade_date" in result.columns
|
||||
assert "date" not in result.columns
|
||||
assert result["trade_date"].iloc[0] == "20240115"
|
||||
|
||||
|
||||
class TestProBarSync:
|
||||
"""Test cases for ProBarSync class."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.sync_all_stocks")
|
||||
@patch("src.data.api_wrappers.api_pro_bar.pd.read_csv")
|
||||
@patch("src.data.api_wrappers.api_pro_bar._get_csv_path")
|
||||
def test_get_all_stock_codes(self, mock_get_path, mock_read_csv, mock_sync_stocks):
|
||||
"""Test getting all stock codes."""
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create a mock path that exists
|
||||
mock_path = MagicMock(spec=Path)
|
||||
mock_path.exists.return_value = True
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
mock_read_csv.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "600000.SH"],
|
||||
"list_status": ["L", "L"],
|
||||
}
|
||||
)
|
||||
|
||||
sync = ProBarSync()
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
assert "000001.SZ" in codes
|
||||
assert "600000.SH" in codes
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.Storage")
|
||||
def test_check_sync_needed_force_full(self, mock_storage_class):
|
||||
"""Test check_sync_needed with force_full=True."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.exists.return_value = False
|
||||
|
||||
sync = ProBarSync()
|
||||
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
|
||||
|
||||
assert needed is True
|
||||
assert start == "20180101" # DEFAULT_START_DATE
|
||||
assert local_last is None
|
||||
@patch("src.data.api_wrappers.api_pro_bar.Storage")
|
||||
def test_check_sync_needed_force_full(self, mock_storage_class):
|
||||
"""Test check_sync_needed with force_full=True."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.exists.return_value = False
|
||||
|
||||
sync = ProBarSync()
|
||||
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
|
||||
|
||||
assert needed is True
|
||||
assert start == "20180101" # DEFAULT_START_DATE
|
||||
assert local_last is None
|
||||
|
||||
|
||||
class TestSyncProBar:
|
||||
"""Test cases for sync_pro_bar function."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
|
||||
def test_sync_pro_bar(self, mock_sync_class):
|
||||
"""Test sync_pro_bar function."""
|
||||
mock_sync = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync
|
||||
mock_sync.sync_all.return_value = {"000001.SZ": pd.DataFrame({"close": [10.5]})}
|
||||
|
||||
result = sync_pro_bar(force_full=True, max_workers=5)
|
||||
|
||||
mock_sync_class.assert_called_once_with(max_workers=5)
|
||||
mock_sync.sync_all.assert_called_once()
|
||||
assert "000001.SZ" in result
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
|
||||
def test_preview_pro_bar_sync(self, mock_sync_class):
|
||||
"""Test preview_pro_bar_sync function."""
|
||||
mock_sync = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync
|
||||
mock_sync.preview_sync.return_value = {
|
||||
"sync_needed": True,
|
||||
"stock_count": 5000,
|
||||
"mode": "full",
|
||||
}
|
||||
|
||||
result = preview_pro_bar_sync(force_full=True)
|
||||
|
||||
mock_sync_class.assert_called_once_with()
|
||||
mock_sync.preview_sync.assert_called_once()
|
||||
assert result["sync_needed"] is True
|
||||
assert result["stock_count"] == 5000
|
||||
|
||||
|
||||
class TestProBarIntegration:
|
||||
"""Integration tests with real Tushare API."""
|
||||
|
||||
def test_real_api_call(self):
|
||||
"""Test with real API (requires valid token)."""
|
||||
import os
|
||||
|
||||
token = os.environ.get("TUSHARE_TOKEN")
|
||||
if not token:
|
||||
pytest.skip("TUSHARE_TOKEN not configured")
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
if not result.empty:
|
||||
# Check base fields
|
||||
for field in EXPECTED_BASE_FIELDS:
|
||||
assert field in result.columns, f"Missing base field: {field}"
|
||||
# Check factor fields (should be present by default)
|
||||
for field in EXPECTED_FACTOR_FIELDS:
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
# Check adj_factor is present (default behavior)
|
||||
assert "adj_factor" in result.columns
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,246 +0,0 @@
|
||||
"""Tests for stock limit price API wrapper."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from src.data.api_wrappers.api_stk_limit import (
|
||||
get_stk_limit,
|
||||
sync_stk_limit,
|
||||
preview_stk_limit_sync,
|
||||
StkLimitSync,
|
||||
)
|
||||
|
||||
|
||||
class TestStkLimit:
|
||||
"""Test suite for stk_limit API wrapper."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
|
||||
def test_get_by_date(self, mock_client_class):
|
||||
"""Test fetching data by trade_date."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240625", "20240625"],
|
||||
"pre_close": [10.0, 20.0],
|
||||
"up_limit": [11.0, 22.0],
|
||||
"down_limit": [9.0, 18.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_stk_limit(trade_date="20240625")
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert len(result) == 2
|
||||
assert "ts_code" in result.columns
|
||||
assert "trade_date" in result.columns
|
||||
assert "up_limit" in result.columns
|
||||
assert "down_limit" in result.columns
|
||||
mock_client.query.assert_called_once_with("stk_limit", trade_date="20240625")
|
||||
|
||||
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
|
||||
def test_get_by_date_range(self, mock_client_class):
|
||||
"""Test fetching data by date range."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ"],
|
||||
"trade_date": ["20240624", "20240625"],
|
||||
"pre_close": [10.0, 10.5],
|
||||
"up_limit": [11.0, 11.55],
|
||||
"down_limit": [9.0, 9.45],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_stk_limit(start_date="20240624", end_date="20240625")
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert len(result) == 2
|
||||
mock_client.query.assert_called_once_with(
|
||||
"stk_limit", start_date="20240624", end_date="20240625"
|
||||
)
|
||||
|
||||
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
|
||||
def test_get_by_stock_code(self, mock_client_class):
|
||||
"""Test fetching data by stock code."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240625"],
|
||||
"pre_close": [10.0],
|
||||
"up_limit": [11.0],
|
||||
"down_limit": [9.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_stk_limit(ts_code="000001.SZ", trade_date="20240625")
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert len(result) == 1
|
||||
assert result.iloc[0]["ts_code"] == "000001.SZ"
|
||||
mock_client.query.assert_called_once_with(
|
||||
"stk_limit", trade_date="20240625", ts_code="000001.SZ"
|
||||
)
|
||||
|
||||
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
|
||||
def test_empty_response(self, mock_client_class):
|
||||
"""Test handling empty response."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame()
|
||||
|
||||
# Test
|
||||
result = get_stk_limit(trade_date="20240625")
|
||||
|
||||
# Assert
|
||||
assert result.empty
|
||||
|
||||
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
|
||||
def test_shared_client(self, mock_client_class):
|
||||
"""Test passing shared client for rate limiting."""
|
||||
# Setup mock
|
||||
shared_client = MagicMock()
|
||||
shared_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240625"],
|
||||
"pre_close": [10.0],
|
||||
"up_limit": [11.0],
|
||||
"down_limit": [9.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_stk_limit(trade_date="20240625", client=shared_client)
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
shared_client.query.assert_called_once()
|
||||
# Verify new client was not created
|
||||
mock_client_class.assert_not_called()
|
||||
|
||||
|
||||
class TestStkLimitSync:
|
||||
"""Test suite for StkLimitSync class."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
|
||||
@patch("src.data.api_wrappers.base_sync.Storage")
|
||||
@patch("src.data.api_wrappers.base_sync.sync_trade_cal_cache")
|
||||
def test_fetch_single_date(
|
||||
self, mock_sync_cal, mock_storage_class, mock_client_class
|
||||
):
|
||||
"""Test fetch_single_date method."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240625", "20240625"],
|
||||
"pre_close": [10.0, 20.0],
|
||||
"up_limit": [11.0, 22.0],
|
||||
"down_limit": [9.0, 18.0],
|
||||
}
|
||||
)
|
||||
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.exists.return_value = True
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
|
||||
# Test
|
||||
sync = StkLimitSync()
|
||||
result = sync.fetch_single_date("20240625")
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert len(result) == 2
|
||||
mock_client.query.assert_called_once_with("stk_limit", trade_date="20240625")
|
||||
|
||||
def test_table_schema(self):
|
||||
"""Test table schema definition."""
|
||||
sync = StkLimitSync()
|
||||
|
||||
# Assert table configuration
|
||||
assert sync.table_name == "stk_limit"
|
||||
assert "ts_code" in sync.TABLE_SCHEMA
|
||||
assert "trade_date" in sync.TABLE_SCHEMA
|
||||
assert "pre_close" in sync.TABLE_SCHEMA
|
||||
assert "up_limit" in sync.TABLE_SCHEMA
|
||||
assert "down_limit" in sync.TABLE_SCHEMA
|
||||
assert sync.PRIMARY_KEY == ("ts_code", "trade_date")
|
||||
|
||||
|
||||
class TestSyncFunctions:
|
||||
"""Test suite for sync convenience functions."""
|
||||
|
||||
@patch.object(StkLimitSync, "sync_all")
|
||||
def test_sync_stk_limit(self, mock_sync_all):
|
||||
"""Test sync_stk_limit convenience function."""
|
||||
# Setup mock
|
||||
mock_sync_all.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240625"],
|
||||
"up_limit": [11.0],
|
||||
"down_limit": [9.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = sync_stk_limit(force_full=True)
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
mock_sync_all.assert_called_once_with(
|
||||
force_full=True,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
@patch.object(StkLimitSync, "preview_sync")
|
||||
def test_preview_stk_limit_sync(self, mock_preview):
|
||||
"""Test preview_stk_limit_sync convenience function."""
|
||||
# Setup mock
|
||||
mock_preview.return_value = {
|
||||
"sync_needed": True,
|
||||
"date_count": 10,
|
||||
"start_date": "20240601",
|
||||
"end_date": "20240610",
|
||||
"estimated_records": 5000,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "incremental",
|
||||
}
|
||||
|
||||
# Test
|
||||
result = preview_stk_limit_sync()
|
||||
|
||||
# Assert
|
||||
assert result["sync_needed"] is True
|
||||
assert result["mode"] == "incremental"
|
||||
mock_preview.assert_called_once_with(
|
||||
force_full=False,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
sample_size=3,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,143 +0,0 @@
|
||||
"""Test suite for stock_st API wrapper."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from src.data.api_wrappers.api_stock_st import get_stock_st, sync_stock_st, StockSTSync
|
||||
|
||||
|
||||
class TestStockST:
|
||||
"""Test suite for stock_st API wrapper."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
||||
def test_get_by_date(self, mock_client_class):
|
||||
"""Test fetching ST stock list by date."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["300313.SZ", "605081.SH", "300391.SZ"],
|
||||
"name": ["*ST天山", "*ST太和", "*ST长药"],
|
||||
"trade_date": ["20240101", "20240101", "20240101"],
|
||||
"type": ["ST", "ST", "ST"],
|
||||
"type_name": ["风险警示板", "风险警示板", "风险警示板"],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_stock_st(trade_date="20240101")
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert len(result) == 3
|
||||
assert "ts_code" in result.columns
|
||||
assert "name" in result.columns
|
||||
assert "trade_date" in result.columns
|
||||
assert "type" in result.columns
|
||||
assert "type_name" in result.columns
|
||||
mock_client.query.assert_called_once()
|
||||
|
||||
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
||||
def test_get_by_stock(self, mock_client_class):
|
||||
"""Test fetching ST history by stock code."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["300313.SZ", "300313.SZ"],
|
||||
"name": ["*ST天山", "*ST天山"],
|
||||
"trade_date": ["20240101", "20240102"],
|
||||
"type": ["ST", "ST"],
|
||||
"type_name": ["风险警示板", "风险警示板"],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_stock_st(
|
||||
ts_code="300313.SZ", start_date="20240101", end_date="20240102"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert len(result) == 2
|
||||
mock_client.query.assert_called_once()
|
||||
|
||||
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
||||
def test_empty_response(self, mock_client_class):
|
||||
"""Test handling empty response."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame()
|
||||
|
||||
result = get_stock_st(trade_date="20240101")
|
||||
assert result.empty
|
||||
|
||||
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
||||
def test_get_by_date_range(self, mock_client_class):
|
||||
"""Test fetching ST stock list by date range."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["300313.SZ"],
|
||||
"name": ["*ST天山"],
|
||||
"trade_date": ["20240101"],
|
||||
"type": ["ST"],
|
||||
"type_name": ["风险警示板"],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_stock_st(start_date="20240101", end_date="20240131")
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
mock_client.query.assert_called_once()
|
||||
|
||||
|
||||
class TestStockSTSync:
|
||||
"""Test suite for StockSTSync class."""
|
||||
|
||||
def test_sync_class_attributes(self):
|
||||
"""Test that sync class has correct attributes."""
|
||||
sync = StockSTSync()
|
||||
assert sync.table_name == "stock_st"
|
||||
assert sync.default_start_date == "20160101"
|
||||
assert "ts_code" in sync.TABLE_SCHEMA
|
||||
assert "trade_date" in sync.TABLE_SCHEMA
|
||||
assert "name" in sync.TABLE_SCHEMA
|
||||
assert "type" in sync.TABLE_SCHEMA
|
||||
assert "type_name" in sync.TABLE_SCHEMA
|
||||
assert sync.PRIMARY_KEY == ("trade_date", "ts_code")
|
||||
|
||||
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
||||
def test_fetch_single_date(self, mock_client_class):
|
||||
"""Test fetching single date data."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["300313.SZ"],
|
||||
"name": ["*ST天山"],
|
||||
"trade_date": ["20240101"],
|
||||
"type": ["ST"],
|
||||
"type_name": ["风险警示板"],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
sync = StockSTSync()
|
||||
result = sync.fetch_single_date("20240101")
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Sync 接口测试规范与实现。
|
||||
|
||||
【测试规范】
|
||||
1. 所有 sync 测试只使用 2018-01-01 到 2018-04-01 的数据
|
||||
2. 只测试接口是否能正常返回数据,不测试落库逻辑
|
||||
3. 对于按股票查询的接口,只测试 000001.SZ、000002.SZ 两支股票
|
||||
4. 使用真实 API 调用,确保接口可用性
|
||||
|
||||
【测试范围】
|
||||
- get_daily: 日线数据接口(按股票)
|
||||
- sync_all_stocks: 股票基础信息接口
|
||||
- sync_trade_cal_cache: 交易日历接口
|
||||
- sync_namechange: 名称变更接口
|
||||
- sync_bak_basic: 备用股票基础信息接口
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
# 测试用常量
|
||||
TEST_START_DATE = "20180101"
|
||||
TEST_END_DATE = "20180401"
|
||||
TEST_STOCK_CODES = ["000001.SZ", "000002.SZ"]
|
||||
|
||||
|
||||
class TestGetDaily:
|
||||
"""测试日线数据 get 接口(按股票查询)."""
|
||||
|
||||
def test_get_daily_single_stock(self):
|
||||
"""测试 get_daily 获取单只股票数据."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
result = get_daily(
|
||||
ts_code=TEST_STOCK_CODES[0],
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "get_daily 应返回 DataFrame"
|
||||
assert not result.empty, "get_daily 应返回非空数据"
|
||||
|
||||
def test_get_daily_has_required_columns(self):
|
||||
"""测试 get_daily 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
result = get_daily(
|
||||
ts_code=TEST_STOCK_CODES[0],
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证必要的列存在
|
||||
required_columns = ["ts_code", "trade_date", "open", "high", "low", "close"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"get_daily 返回应包含 {col} 列"
|
||||
|
||||
def test_get_daily_multiple_stocks(self):
|
||||
"""测试 get_daily 获取多只股票数据."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
results = {}
|
||||
for code in TEST_STOCK_CODES:
|
||||
result = get_daily(
|
||||
ts_code=code,
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
results[code] = result
|
||||
assert isinstance(result, pd.DataFrame), (
|
||||
f"get_daily({code}) 应返回 DataFrame"
|
||||
)
|
||||
assert not result.empty, f"get_daily({code}) 应返回非空数据"
|
||||
|
||||
|
||||
class TestSyncStockBasic:
|
||||
"""测试股票基础信息 sync 接口."""
|
||||
|
||||
def test_sync_all_stocks_returns_data(self):
|
||||
"""测试 sync_all_stocks 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
|
||||
|
||||
result = sync_all_stocks()
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_all_stocks 应返回 DataFrame"
|
||||
assert not result.empty, "sync_all_stocks 应返回非空数据"
|
||||
|
||||
def test_sync_all_stocks_has_required_columns(self):
|
||||
"""测试 sync_all_stocks 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
|
||||
|
||||
result = sync_all_stocks()
|
||||
|
||||
# 验证必要的列存在
|
||||
required_columns = ["ts_code"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"sync_all_stocks 返回应包含 {col} 列"
|
||||
|
||||
|
||||
class TestSyncTradeCal:
|
||||
"""测试交易日历 sync 接口."""
|
||||
|
||||
def test_sync_trade_cal_cache_returns_data(self):
|
||||
"""测试 sync_trade_cal_cache 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
|
||||
|
||||
result = sync_trade_cal_cache(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_trade_cal_cache 应返回 DataFrame"
|
||||
assert not result.empty, "sync_trade_cal_cache 应返回非空数据"
|
||||
|
||||
def test_sync_trade_cal_cache_has_required_columns(self):
|
||||
"""测试 sync_trade_cal_cache 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
|
||||
|
||||
result = sync_trade_cal_cache(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证必要的列存在
|
||||
required_columns = ["cal_date", "is_open"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"sync_trade_cal_cache 返回应包含 {col} 列"
|
||||
|
||||
|
||||
class TestSyncNamechange:
|
||||
"""测试名称变更 sync 接口."""
|
||||
|
||||
def test_sync_namechange_returns_data(self):
|
||||
"""测试 sync_namechange 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_namechange import sync_namechange
|
||||
|
||||
result = sync_namechange()
|
||||
|
||||
# 验证返回了数据(可能是空 DataFrame,因为是历史变更)
|
||||
assert isinstance(result, pd.DataFrame), "sync_namechange 应返回 DataFrame"
|
||||
|
||||
|
||||
class TestSyncBakBasic:
|
||||
"""测试备用股票基础信息 sync 接口."""
|
||||
|
||||
def test_sync_bak_basic_returns_data(self):
|
||||
"""测试 sync_bak_basic 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
||||
|
||||
result = sync_bak_basic(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_bak_basic 应返回 DataFrame"
|
||||
# 注意:bak_basic 可能返回空数据,这是正常的
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Tushare API 验证脚本 - 快速生成 pro 对象用于调试。"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("DATA_PATH", "data")
|
||||
|
||||
from src.data.config import get_config
|
||||
import tushare as ts
|
||||
|
||||
config = get_config()
|
||||
token = config.tushare_token
|
||||
|
||||
if not token:
|
||||
raise ValueError("请在 config/.env.local 中配置 TUSHARE_TOKEN")
|
||||
|
||||
pro = ts.pro_api(token)
|
||||
print(f"pro_api 对象已创建,token: {token[:10]}...")
|
||||
|
||||
df = pro.query('daily', ts_code='000001.SZ', start_date='20180702', end_date='20180718')
|
||||
print(df)
|
||||
Reference in New Issue
Block a user