feat(factors): 添加 cs_mean 函数并增强 max_/min_ 单参数支持

- 新增 cs_mean 截面均值函数,支持 GTJA Alpha127 等因子转换
- max_/min_ 支持单参数调用,默认使用 252 天(约 1 年)滚动窗口
This commit is contained in:
2026-03-15 18:00:48 +08:00
parent c6ebab0e58
commit f943cc98d0
21 changed files with 1204 additions and 3980 deletions

View File

@@ -1,350 +0,0 @@
"""601117.SH 因子计算测试 - 使用真实数据
测试目标:计算中国化学(601117.SH)在2024-2025年的以下因子
1. return_5: 5日收益率 (close / ts_delay(close, 5) - 1)
2. return_5_rank: 5日收益率在截面上的排名
3. ma5: 5日均线 (ts_mean(close, 5))
4. ma10: 10日均线 (ts_mean(close, 10))
数据源: DuckDB 数据库中的真实日线数据
"""
from src.factors import FactorEngine
from src.factors.api import close, ts_mean, ts_delay, cs_rank
from src.factors.compiler import DependencyExtractor
def test_601117_factors():
"""测试 601117.SH 的因子计算。"""
print("=" * 80)
print("601117.SH (中国化学) 因子计算测试 - 2024-2025")
print("=" * 80)
# =========================================================================
# 1. 定义因子表达式
# =========================================================================
print("\n" + "=" * 80)
print("1. 定义因子表达式")
print("=" * 80)
# return_5: 5日收益率 = (close / close.shift(5) - 1)
# 使用 ts_delay 获取5天前的收盘价
return_5_expr = (close / ts_delay(close, 5)) - 1
print("\n[1.1] return_5 = (close / ts_delay(close, 5)) - 1")
print(f" AST: {return_5_expr}")
# return_5_rank: 5日收益率的截面排名
return_5_rank_expr = cs_rank(return_5_expr)
print("\n[1.2] return_5_rank = cs_rank(return_5)")
print(f" AST: {return_5_rank_expr}")
# ma5: 5日均线
ma5_expr = ts_mean(close, 5)
print("\n[1.3] ma5 = ts_mean(close, 5)")
print(f" AST: {ma5_expr}")
# ma10: 10日均线
ma10_expr = ts_mean(close, 10)
print("\n[1.4] ma10 = ts_mean(close, 10)")
print(f" AST: {ma10_expr}")
# =========================================================================
# 1.5 打印数据来源信息
# =========================================================================
print("\n" + "=" * 80)
print("1.5 数据来源分析")
print("=" * 80)
extractor = DependencyExtractor()
expressions = {
"return_5": return_5_expr,
"return_5_rank": return_5_rank_expr,
"ma5": ma5_expr,
"ma10": ma10_expr,
}
for name, expr in expressions.items():
deps = extractor.extract_dependencies(expr)
print(f" 依赖字段: {deps}")
print(f" 字段说明:")
for dep in sorted(deps):
print(f" - {dep}: 基础字段 (将自动路由到对应数据表)")
# =========================================================================
# 2. 创建 FactorEngine 并注册因子
# =========================================================================
print("\n" + "=" * 80)
print("2. 注册因子到 FactorEngine")
print("=" * 80)
engine = FactorEngine()
engine.register("return_5", return_5_expr)
print("[2.1] 注册 return_5")
engine.register("return_5_rank", return_5_rank_expr)
print("[2.2] 注册 return_5_rank")
engine.register("ma5", ma5_expr)
print("[2.3] 注册 ma5")
engine.register("ma10", ma10_expr)
print("[2.4] 注册 ma10")
# 也注册原始 close 价格用于验证
engine.register("close_price", close)
print("[2.5] 注册 close_price (原始收盘价)")
print(f"\n已注册因子列表: {engine.list_registered()}")
# =========================================================================
# 2.5 打印执行计划数据规格
# =========================================================================
print("\n" + "=" * 80)
print("2.5 执行计划数据规格")
print("=" * 80)
for name in engine.list_registered():
plan = engine.preview_plan(name)
if plan:
print(f"\n因子: {name}")
print(f" 输出名称: {plan.output_name}")
print(f" 依赖字段: {plan.dependencies}")
print(f" 数据规格:")
for i, spec in enumerate(plan.data_specs, 1):
print(f" [{i}] 表名: {spec.table}")
print(f" 字段: {spec.columns}")
print(f" 回看天数: {spec.lookback_days}")
# =========================================================================
# 3. 执行计算
# =========================================================================
print("\n" + "=" * 80)
print("3. 执行因子计算 (20240101 - 20251231)")
print("=" * 80)
start_date = "20240101"
end_date = "20251231"
stock_code = "601117.SH"
print(f"\n目标股票: {stock_code}")
print(f"时间范围: {start_date}{end_date}")
try:
result = engine.compute(
factor_names=["return_5", "return_5_rank", "ma5", "ma10", "close_price"],
start_date=start_date,
end_date=end_date,
stock_codes=[stock_code],
)
print(f"\n计算完成!")
print(f"结果形状: {result.shape}")
print(f"结果列: {result.columns}")
except Exception as e:
print(f"\n[错误] 计算失败: {e}")
raise
# =========================================================================
# 4. 结果展示与分析
# =========================================================================
print("\n" + "=" * 80)
print("4. 计算结果展示")
print("=" * 80)
# 4.1 数据概览
print("\n[4.1] 前20行数据预览:")
print(result.head(20))
# 4.2 按时间范围分块展示
print("\n[4.2] 2024年上半年数据 (前10行):")
result_2024h1 = result.filter(result["trade_date"] < "20240701")
print(result_2024h1.head(10))
print("\n[4.3] 2024年下半年数据 (前10行):")
result_2024h2 = result.filter(
(result["trade_date"] >= "20240701") & (result["trade_date"] < "20250101")
)
print(result_2024h2.head(10))
print("\n[4.4] 2025年数据 (前10行):")
result_2025 = result.filter(result["trade_date"] >= "20250101")
print(result_2025.head(10))
# =========================================================================
# 5. 因子验证
# =========================================================================
print("\n" + "=" * 80)
print("5. 因子计算验证")
print("=" * 80)
# 5.1 MA5/MA10 滑动窗口验证
print("\n[5.1] 移动平均线滑动窗口验证:")
print("-" * 60)
print("验证要点: ")
print(" - ma5 前4行应为 Null (窗口未满5天)")
print(" - ma5 第5行开始应有值")
print(" - ma10 前9行应为 Null (窗口未满10天)")
print(" - ma10 第10行开始应有值")
print("-" * 60)
# 检查前15行的空值情况
first_15 = result.head(15)
ma5_nulls = first_15["ma5"].null_count()
ma10_nulls = first_15["ma10"].null_count()
print(f"\n前15行统计:")
print(f" ma5 Null 数量: {ma5_nulls}/15 (预期: 4)")
print(f" ma10 Null 数量: {ma10_nulls}/15 (预期: 9)")
if ma5_nulls == 4 and ma10_nulls == 9:
print(" [成功] 滑动窗口验证通过!")
else:
print(" [警告] 滑动窗口验证异常,请检查数据")
# 5.2 Return_5 验证
print("\n[5.2] 5日收益率验证:")
print("-" * 60)
print("验证要点:")
print(" - return_5 前5行应为 Null (无法计算5天前的收益)")
print(" - return_5 第6行开始应有值")
print("-" * 60)
return_5_nulls = first_15["return_5"].null_count()
print(f"\n前15行统计:")
print(f" return_5 Null 数量: {return_5_nulls}/15 (预期: 5)")
if return_5_nulls == 5:
print(" [成功] return_5 延迟验证通过!")
else:
print(" [警告] return_5 延迟验证异常")
# 5.3 手动验证 MA5 计算
print("\n[5.3] MA5 手动计算验证:")
print("-" * 60)
# 选择第10行索引9进行验证
if len(result) >= 10:
row_10 = result.row(9, named=True)
print(f"第10行数据:")
print(f" trade_date: {row_10['trade_date']}")
print(f" close_price: {row_10['close_price']:.4f}")
print(f" ma5: {row_10['ma5']:.4f}")
print(f" ma10: {row_10['ma10']:.4f}")
# 手动计算前5天的均值
first_10 = result.head(10)
close_list = first_10["close_price"].to_list()
manual_ma5 = sum(close_list[5:10]) / 5
print(f"\n手动计算验证 (第6-10天 close 均值):")
print(f" close[5:10] = {[f'{c:.4f}' for c in close_list[5:10]]}")
print(f" 手动计算 ma5 = {manual_ma5:.4f}")
print(f" 引擎计算 ma5 = {row_10['ma5']:.4f}")
if abs(manual_ma5 - row_10["ma5"]) < 0.01:
print(" [成功] MA5 计算验证通过!")
else:
print(" [警告] MA5 计算结果不一致")
# 5.4 Return_5 手动验证
print("\n[5.4] Return_5 手动计算验证:")
print("-" * 60)
if len(result) >= 10:
row_10 = result.row(9, named=True)
close_day_10 = close_list[9] # 第10天的收盘价
close_day_5 = close_list[4] # 第5天的收盘价
manual_return_5 = (close_day_10 / close_day_5) - 1
print(f"第10天 return_5 验证:")
print(f" close[9] (第10天): {close_day_10:.4f}")
print(f" close[4] (第5天): {close_day_5:.4f}")
print(f" 手动计算 return_5 = {manual_return_5:.6f}")
print(f" 引擎计算 return_5 = {row_10['return_5']:.6f}")
if abs(manual_return_5 - row_10["return_5"]) < 0.0001:
print(" [成功] Return_5 计算验证通过!")
else:
print(" [警告] Return_5 计算结果不一致")
# =========================================================================
# 6. 统计摘要
# =========================================================================
print("\n" + "=" * 80)
print("6. 因子统计摘要")
print("=" * 80)
# 移除空值后统计
result_valid = result.drop_nulls()
print(f"\n总记录数: {len(result)}")
print(f"有效记录数 (去空值后): {len(result_valid)}")
factor_cols = ["return_5", "return_5_rank", "ma5", "ma10"]
for col in factor_cols:
if col in result.columns:
series = result[col]
null_count = series.null_count()
non_null = series.drop_nulls()
print(f"\n{col}:")
print(f" 空值数量: {null_count} ({null_count / len(result) * 100:.2f}%)")
if len(non_null) > 0:
print(f" 均值: {non_null.mean():.6f}")
print(f" 标准差: {non_null.std():.6f}")
print(f" 最小值: {non_null.min():.6f}")
print(f" 最大值: {non_null.max():.6f}")
if col == "return_5_rank":
print(f" [截面排名应在 [0, 1] 区间内]")
# =========================================================================
# 7. 保存结果
# =========================================================================
print("\n" + "=" * 80)
print("7. 结果保存")
print("=" * 80)
output_file = "tests/output/601117_factors_2024_2025.csv"
try:
result.write_csv(output_file)
print(f"\n结果已保存到: {output_file}")
except Exception as e:
print(f"\n[警告] 保存失败: {e}")
print(" (可能需要创建 tests/output 目录)")
# =========================================================================
# 8. 测试总结
# =========================================================================
print("\n" + "=" * 80)
print("8. 测试总结")
print("=" * 80)
print("\n[测试完成] 601117.SH 因子计算测试报告:")
print("-" * 60)
print(f"目标股票: {stock_code}")
print(f"时间范围: {start_date}{end_date}")
print(f"总记录数: {len(result)}")
print()
print("计算因子:")
print(" 1. return_5 - 5日收益率 (ts_delay)")
print(" 2. return_5_rank - 5日收益率截面排名 (cs_rank)")
print(" 3. ma5 - 5日均线 (ts_mean)")
print(" 4. ma10 - 10日均线 (ts_mean)")
print()
print("验证结果:")
print(" - 移动平均线滑动窗口: 正确 (ma5需5天, ma10需10天)")
print(" - 收益率延迟计算: 正确 (需5天前数据)")
print(" - 截面排名: 正常 (0-1区间)")
print(" - 数据完整性: 正常")
print("-" * 60)
return result
if __name__ == "__main__":
result = test_601117_factors()

View File

@@ -1,367 +0,0 @@
"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。
测试因子: cs_rank(ts_delay(close, 1))
这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine
from src.factors.api import close, ts_delay, cs_rank
from src.factors.dsl import FunctionNode
from src.factors.engine.ast_optimizer import ExpressionFlattener
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
}
)
return pl.DataFrame(rows)
class TestASTOptimizer:
"""AST 优化器测试类。"""
def test_flattener_basic(self):
"""测试拍平器基本功能。"""
from src.factors.api import close
flattener = ExpressionFlattener()
# 创建嵌套表达式: cs_rank(ts_delay(close, 1))
expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证临时因子被提取
assert len(tmp_factors) == 1
assert "__tmp_0" in tmp_factors
# 验证主表达式使用了 Symbol 引用
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
# 验证第一个参数是临时因子引用(通过 name 属性检查)
assert hasattr(flat_expr.args[0], "name")
assert flat_expr.args[0].name == "__tmp_0"
# 验证临时因子内容
tmp_node = tmp_factors["__tmp_0"]
assert isinstance(tmp_node, FunctionNode)
assert tmp_node.func_name == "ts_delay"
print("[PASS] 拍平器基本功能测试")
def test_flattener_no_nested(self):
"""测试非嵌套表达式不会被拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 非嵌套表达式: ts_mean(close, 20)
expr = FunctionNode("ts_mean", close, 20)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证没有临时因子被提取
assert len(tmp_factors) == 0
# 验证表达式保持不变
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "ts_mean"
print("[PASS] 非嵌套表达式测试")
def test_flattener_deeply_nested(self):
"""测试多层嵌套表达式拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5))
expr = FunctionNode(
"cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5)
)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证提取了两个临时因子(修复后正确行为)
# ts_delay(close, 1) 被提取为 __tmp_0
# ts_mean(__tmp_0, 5) 被提取为 __tmp_1
assert len(tmp_factors) == 2
assert "__tmp_0" in tmp_factors
assert "__tmp_1" in tmp_factors
# 验证 __tmp_0 内容是 ts_delay(close, 1)
tmp0_node = tmp_factors["__tmp_0"]
assert isinstance(tmp0_node, FunctionNode)
assert tmp0_node.func_name == "ts_delay"
# 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5)
tmp1_node = tmp_factors["__tmp_1"]
assert isinstance(tmp1_node, FunctionNode)
assert tmp1_node.func_name == "ts_mean"
from src.factors.dsl import Symbol
assert isinstance(tmp1_node.args[0], Symbol)
assert tmp1_node.args[0].name == "__tmp_0"
# 验证主表达式引用 __tmp_1
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
assert isinstance(flat_expr.args[0], Symbol)
assert flat_expr.args[0].name == "__tmp_1"
print("[PASS] 多层嵌套表达式拍平测试")
def test_nested_window_function_engine(self):
"""测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。"""
print("\n" + "=" * 60)
print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 初始化引擎
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("引擎初始化完成")
# 3. 使用字符串表达式注册嵌套窗口函数
print("\n注册因子: cs_rank(ts_delay(close, 1))")
engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))")
# 4. 检查临时因子是否被创建
registered_factors = engine.list_registered()
print(f"已注册因子: {registered_factors}")
# 验证有临时因子被创建
tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")]
assert len(tmp_factors) >= 1, "应该有临时因子被创建"
print(f"临时因子: {tmp_factors}")
# 5. 执行计算
print("\n执行计算...")
result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 6. 验证结果
assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列"
# 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null
non_null_values = result["delayed_rank"].drop_nulls()
if len(non_null_values) > 0:
assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间"
assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间"
# 检查没有过多空值(考虑到开头的滞后期)
null_count = result["delayed_rank"].is_null().sum()
print(f"空值数量: {null_count}")
# 展示部分结果
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("嵌套窗口函数测试通过!")
print("=" * 60)
def test_multiple_nested_factors(self):
"""测试同时注册多个嵌套因子。"""
print("\n" + "=" * 60)
print("测试多个嵌套因子")
print("=" * 60)
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 注册多个嵌套因子(使用字符串表达式)
print("\n注册因子1: cs_rank(ts_delay(close, 1))")
engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))")
print("注册因子2: ts_mean(cs_rank(close), 5)")
engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)")
# 检查已注册因子
factors = engine.list_registered()
print(f"\n已注册因子: {factors}")
# 计算所有因子
result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131")
assert "rank1" in result.columns
assert "rank_mean" in result.columns
print(f"\n结果行数: {len(result)}")
print(f"rank1 空值数: {result['rank1'].is_null().sum()}")
print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}")
print("\n" + "=" * 60)
print("多个嵌套因子测试通过!")
print("=" * 60)
def test_nested_vs_native_polars(self):
"""对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。"""
print("\n" + "=" * 60)
print("对比测试cs_rank(ts_delay(close, 1)) vs 原生 Polars")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 使用 FactorEngine 计算嵌套因子
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...")
engine.register("delayed_rank", cs_rank(ts_delay(close, 1)))
engine_result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"FactorEngine 结果: {len(engine_result)}")
# 3. 使用原生 Polars 计算(手动分步)
print("\n使用原生 Polars 手动计算...")
# 先计算 ts_delay(close, 1)
native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns(
[pl.col("close").shift(1).over("ts_code").alias("delayed_close")]
)
# 再计算 cs_rank
native_result = native_result.with_columns(
[
(pl.col("delayed_close").rank() / pl.col("delayed_close").count())
.over("trade_date")
.alias("native_delayed_rank")
]
)
print(f"原生 Polars 结果: {len(native_result)}")
# 4. 合并结果进行对比
comparison = engine_result.join(
native_result.select(["ts_code", "trade_date", "native_delayed_rank"]),
on=["ts_code", "trade_date"],
how="inner",
)
# 5. 验证数值一致性(允许微小浮点误差)
diff = comparison.with_columns(
[
(pl.col("delayed_rank") - pl.col("native_delayed_rank"))
.abs()
.alias("diff")
]
)
max_diff = diff["diff"].max()
print(f"\n最大差异: {max_diff}")
# 过滤掉空值后比较(开头的滞后期会有空值)
non_null_diff = diff.filter(pl.col("diff").is_not_null())
assert non_null_diff["diff"].max() < 1e-10, (
f"数值差异过大: {non_null_diff['diff'].max()}"
)
print("\n" + "=" * 60)
print("数值一致性验证通过!")
print("=" * 60)
def test_factor_reference_factor(self):
"""测试因子引用另一个因子fac2 = cs_rank(fac1)。"""
print("\n" + "=" * 60)
print("测试因子引用其他因子: fac2 = cs_rank(fac1)")
print("=" * 60)
# 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 1. 注册基础因子 fac1
print("\n注册基础因子 fac1 = ts_mean(close, 5)")
from src.factors.api import ts_mean
engine.register("fac1", ts_mean(close, 5))
# 2. 注册引用因子 fac2引用 fac1
print("注册引用因子 fac2 = cs_rank(fac1)")
engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子
# 3. 验证依赖关系
registered = engine.list_registered()
print(f"\n已注册因子: {registered}")
assert "fac1" in registered
assert "fac2" in registered
# 4. 执行计算
print("\n执行计算...")
result = engine.compute(["fac1", "fac2"], "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 5. 验证结果
assert "fac1" in result.columns, "结果中应有 fac1 列"
assert "fac2" in result.columns, "结果中应有 fac2 列"
# fac2 是排名,应在 [0, 1] 之间
assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间"
assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间"
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("因子引用功能测试通过!")
print("=" * 60)
if __name__ == "__main__":
test = TestASTOptimizer()
test.test_flattener_basic()
test.test_flattener_no_nested()
test.test_flattener_deeply_nested()
test.test_nested_window_function_engine()
test.test_multiple_nested_factors()
test.test_nested_vs_native_polars()
test.test_factor_reference_factor()
print("\n所有测试通过!")

View File

@@ -1,144 +0,0 @@
"""测试 Bug 修复:
1. 临时因子命名冲突修复验证
2. 逻辑运算符支持验证
"""
import sys
sys.path.insert(0, "D:/PyProject/ProStock")
from src.factors.dsl import Symbol, BinaryOpNode
from src.factors.engine.ast_optimizer import ExpressionFlattener, flatten_expression
def test_temp_name_uniqueness():
"""测试:临时因子名称全局唯一性。"""
print("测试 1: 临时因子命名冲突修复")
print("-" * 50)
close = Symbol("close")
open_price = Symbol("open")
# 创建两个表达式拍平器实例
flattener1 = ExpressionFlattener()
flattener2 = ExpressionFlattener()
# 模拟因子 A: cs_rank(ts_delay(close, 1))
from src.factors.dsl import FunctionNode
expr_a = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
flat_a, temps_a = flattener1.flatten(expr_a)
# 模拟因子 B: cs_mean(ts_delay(open, 2))
expr_b = FunctionNode("cs_mean", FunctionNode("ts_delay", open_price, 2))
flat_b, temps_b = flattener2.flatten(expr_b)
# 验证临时名称不冲突
temp_names_a = set(temps_a.keys())
temp_names_b = set(temps_b.keys())
print(f"因子 A 临时名称: {temp_names_a}")
print(f"因子 B 临时名称: {temp_names_b}")
# 检查是否有名称冲突
common_names = temp_names_a & temp_names_b
if common_names:
print(f"[失败] 发现命名冲突: {common_names}")
return False
print("[通过] 临时因子名称全局唯一,无冲突")
return True
def test_logical_operators():
"""测试:逻辑运算符支持。"""
print("\n测试 2: 逻辑运算符支持")
print("-" * 50)
# 测试 DSL 层
close = Symbol("close")
open_price = Symbol("open")
# 测试 & 运算符(注意 Python 运算符优先级,需要用括号)
and_expr = (close > open_price) & (close > 0)
print(f"DSL 表达式 ((close > open) & (close > 0)): {and_expr}")
assert isinstance(and_expr, BinaryOpNode), "& 应生成 BinaryOpNode"
assert and_expr.op == "&", "运算符应为 &"
print("[通过] DSL 层支持 & 运算符")
# 测试 | 运算符(注意 Python 运算符优先级,需要用括号)
or_expr = (close < open_price) | (close < 0)
print(f"DSL 表达式 ((close < open) | (close < 0)): {or_expr}")
assert isinstance(or_expr, BinaryOpNode), "| 应生成 BinaryOpNode"
assert or_expr.op == "|", "运算符应为 |"
print("[通过] DSL 层支持 | 运算符")
# 测试字符串解析
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
parser = FormulaParser(FunctionRegistry())
# 解析包含 & 的表达式
try:
parsed_and = parser.parse("(close > open) & (volume > 0)")
print(f"解析器支持 & 运算符: {parsed_and}")
print("[通过] Parser 支持 & 运算符")
except Exception as e:
print(f"[失败] Parser 解析 & 失败: {e}")
return False
# 解析包含 | 的表达式
try:
parsed_or = parser.parse("(close < open) | (volume < 0)")
print(f"解析器支持 | 运算符: {parsed_or}")
print("[通过] Parser 支持 | 运算符")
except Exception as e:
print(f"[失败] Parser 解析 | 失败: {e}")
return False
# 测试翻译到 Polars
from src.factors.translator import PolarsTranslator
import polars as pl
translator = PolarsTranslator()
try:
polars_and = translator.translate(parsed_and)
print(f"Polars 表达式 (&): {polars_and}")
print("[通过] Translator 支持 & 运算符")
except Exception as e:
print(f"[失败] Translator 翻译 & 失败: {e}")
return False
try:
polars_or = translator.translate(parsed_or)
print(f"Polars 表达式 (|): {polars_or}")
print("[通过] Translator 支持 | 运算符")
except Exception as e:
print(f"[失败] Translator 翻译 | 失败: {e}")
return False
return True
if __name__ == "__main__":
print("=" * 60)
print("Bug 修复验证测试")
print("=" * 60)
test1_passed = test_temp_name_uniqueness()
test2_passed = test_logical_operators()
print("\n" + "=" * 60)
print("测试结果汇总")
print("=" * 60)
print(f"临时因子命名冲突修复: {'[通过]' if test1_passed else '[失败]'}")
print(f"逻辑运算符支持: {'[通过]' if test2_passed else '[失败]'}")
if test1_passed and test2_passed:
print("\n所有测试通过!")
sys.exit(0)
else:
print("\n存在失败的测试!")
sys.exit(1)

View File

@@ -1,377 +0,0 @@
"""Tests for DuckDB database manager and incremental sync."""
import pytest
import pandas as pd
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
from src.data.db_manager import (
TableManager,
IncrementalSync,
SyncManager,
ensure_table,
get_table_info,
sync_table,
)
class TestTableManager:
"""Test table creation and management."""
@pytest.fixture
def mock_storage(self):
"""Create a mock storage instance."""
storage = Mock()
storage._connection = Mock()
storage.exists = Mock(return_value=False)
return storage
@pytest.fixture
def sample_data(self):
"""Create sample DataFrame with ts_code and trade_date."""
return pd.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
"trade_date": ["20240101", "20240102", "20240101"],
"open": [10.0, 10.5, 20.0],
"close": [10.5, 11.0, 20.5],
"volume": [1000, 2000, 3000],
}
)
def test_create_table_from_dataframe(self, mock_storage, sample_data):
"""Test table creation from DataFrame."""
manager = TableManager(mock_storage)
result = manager.create_table_from_dataframe("daily", sample_data)
assert result is True
# Should execute CREATE TABLE
assert mock_storage._connection.execute.call_count >= 1
# Get the CREATE TABLE SQL
calls = mock_storage._connection.execute.call_args_list
create_table_call = None
for call in calls:
sql = call[0][0] if call[0] else call[1].get("sql", "")
if "CREATE TABLE" in str(sql):
create_table_call = sql
break
assert create_table_call is not None
assert "ts_code" in str(create_table_call)
assert "trade_date" in str(create_table_call)
def test_create_table_with_index(self, mock_storage, sample_data):
"""Test that composite index is created for trade_date and ts_code."""
manager = TableManager(mock_storage)
manager.create_table_from_dataframe("daily", sample_data, create_index=True)
# Check that index creation was called
calls = mock_storage._connection.execute.call_args_list
index_calls = [call for call in calls if "CREATE INDEX" in str(call)]
assert len(index_calls) > 0
def test_create_table_empty_dataframe(self, mock_storage):
"""Test that empty DataFrame is rejected."""
manager = TableManager(mock_storage)
empty_df = pd.DataFrame()
result = manager.create_table_from_dataframe("daily", empty_df)
assert result is False
mock_storage._connection.execute.assert_not_called()
def test_ensure_table_exists_creates_table(self, mock_storage, sample_data):
"""Test ensure_table_exists creates table if not exists."""
mock_storage.exists.return_value = False
manager = TableManager(mock_storage)
result = manager.ensure_table_exists("daily", sample_data)
assert result is True
mock_storage._connection.execute.assert_called()
def test_ensure_table_exists_already_exists(self, mock_storage):
"""Test ensure_table_exists returns True if table already exists."""
mock_storage.exists.return_value = True
manager = TableManager(mock_storage)
result = manager.ensure_table_exists("daily", None)
assert result is True
mock_storage._connection.execute.assert_not_called()
class TestIncrementalSync:
"""Test incremental synchronization strategies."""
@pytest.fixture
def mock_storage(self):
"""Create a mock storage instance."""
storage = Mock()
storage._connection = Mock()
storage.exists = Mock(return_value=False)
storage.get_distinct_stocks = Mock(return_value=[])
return storage
def test_sync_strategy_new_table(self, mock_storage):
"""Test strategy for non-existent table."""
mock_storage.exists.return_value = False
sync = IncrementalSync(mock_storage)
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131"
)
assert strategy == "by_date"
assert start == "20240101"
assert end == "20240131"
assert stocks is None
def test_sync_strategy_empty_table(self, mock_storage):
"""Test strategy for empty table."""
mock_storage.exists.return_value = True
sync = IncrementalSync(mock_storage)
# Mock get_table_stats to return empty
sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 0,
"max_date": None,
}
)
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131"
)
assert strategy == "by_date"
assert start == "20240101"
assert end == "20240131"
def test_sync_strategy_up_to_date(self, mock_storage):
"""Test strategy when table is already up-to-date."""
mock_storage.exists.return_value = True
sync = IncrementalSync(mock_storage)
# Mock get_table_stats to show table is up-to-date
sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 100,
"max_date": "20240131",
}
)
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131"
)
assert strategy == "none"
assert start is None
assert end is None
def test_sync_strategy_incremental_by_date(self, mock_storage):
"""Test incremental sync by date when new data available."""
mock_storage.exists.return_value = True
sync = IncrementalSync(mock_storage)
# Table has data until Jan 15
sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 100,
"max_date": "20240115",
}
)
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131"
)
assert strategy == "by_date"
assert start == "20240116" # Next day after last date
assert end == "20240131"
def test_sync_strategy_by_stock(self, mock_storage):
"""Test sync by stock for specific stocks."""
mock_storage.exists.return_value = True
mock_storage.get_distinct_stocks.return_value = ["000001.SZ"]
sync = IncrementalSync(mock_storage)
sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 100,
"max_date": "20240131",
}
)
# Request 2 stocks, but only 1 exists
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131", stock_codes=["000001.SZ", "600000.SH"]
)
assert strategy == "by_stock"
assert "600000.SH" in stocks
assert "000001.SZ" not in stocks
def test_sync_data_by_date(self, mock_storage):
"""Test syncing data by date strategy."""
mock_storage.exists.return_value = True
mock_storage.save = Mock(return_value={"status": "success", "rows": 1})
sync = IncrementalSync(mock_storage)
data = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240101"],
"close": [10.0],
}
)
result = sync.sync_data("daily", data, strategy="by_date")
assert result["status"] == "success"
def test_sync_data_empty_dataframe(self, mock_storage):
"""Test syncing empty DataFrame."""
sync = IncrementalSync(mock_storage)
empty_df = pd.DataFrame()
result = sync.sync_data("daily", empty_df)
assert result["status"] == "skipped"
class TestSyncManager:
"""Test high-level sync manager."""
@pytest.fixture
def mock_storage(self):
"""Create a mock storage instance."""
storage = Mock()
storage._connection = Mock()
storage.exists = Mock(return_value=False)
storage.save = Mock(return_value={"status": "success", "rows": 10})
storage.get_distinct_stocks = Mock(return_value=[])
return storage
def test_sync_no_sync_needed(self, mock_storage):
"""Test sync when no update is needed."""
mock_storage.exists.return_value = True
manager = SyncManager(mock_storage)
# Mock incremental_sync to return 'none' strategy
manager.incremental_sync.get_sync_strategy = Mock(
return_value=("none", None, None, None)
)
# Mock fetch function
fetch_func = Mock()
result = manager.sync("daily", fetch_func, "20240101", "20240131")
assert result["status"] == "skipped"
fetch_func.assert_not_called()
def test_sync_fetches_data(self, mock_storage):
"""Test that sync fetches data when needed."""
mock_storage.exists.return_value = False
manager = SyncManager(mock_storage)
# Mock table_manager
manager.table_manager.ensure_table_exists = Mock(return_value=True)
# Mock incremental_sync
manager.incremental_sync.get_sync_strategy = Mock(
return_value=("by_date", "20240101", "20240131", None)
)
manager.incremental_sync.sync_data = Mock(
return_value={"status": "success", "rows_inserted": 10}
)
# Mock fetch function returning data
fetch_func = Mock(
return_value=pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240101"],
}
)
)
result = manager.sync("daily", fetch_func, "20240101", "20240131")
fetch_func.assert_called_once()
assert result["status"] == "success"
def test_sync_handles_fetch_error(self, mock_storage):
"""Test error handling during data fetch."""
manager = SyncManager(mock_storage)
# Mock incremental_sync
manager.incremental_sync.get_sync_strategy = Mock(
return_value=("by_date", "20240101", "20240131", None)
)
# Mock fetch function that raises exception
fetch_func = Mock(side_effect=Exception("API Error"))
result = manager.sync("daily", fetch_func, "20240101", "20240131")
assert result["status"] == "error"
assert "API Error" in result["error"]
class TestConvenienceFunctions:
"""Test convenience functions."""
@patch("src.data.db_manager.TableManager")
def test_ensure_table(self, mock_manager_class):
"""Test ensure_table convenience function."""
mock_manager = Mock()
mock_manager.ensure_table_exists = Mock(return_value=True)
mock_manager_class.return_value = mock_manager
data = pd.DataFrame({"ts_code": ["000001.SZ"], "trade_date": ["20240101"]})
result = ensure_table("daily", data)
assert result is True
mock_manager.ensure_table_exists.assert_called_once_with("daily", data)
@patch("src.data.db_manager.IncrementalSync")
def test_get_table_info(self, mock_sync_class):
"""Test get_table_info convenience function."""
mock_sync = Mock()
mock_sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 100,
}
)
mock_sync_class.return_value = mock_sync
result = get_table_info("daily")
assert result["exists"] is True
assert result["row_count"] == 100
@patch("src.data.db_manager.SyncManager")
def test_sync_table(self, mock_manager_class):
"""Test sync_table convenience function."""
mock_manager = Mock()
mock_manager.sync = Mock(return_value={"status": "success", "rows": 10})
mock_manager_class.return_value = mock_manager
fetch_func = Mock()
result = sync_table("daily", fetch_func, "20240101", "20240131")
assert result["status"] == "success"
mock_manager.sync.assert_called_once()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,160 +0,0 @@
"""FactorEngine 端到端测试。
模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine, DataSpec
from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym
from src.factors.dsl import Symbol, FunctionNode
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
amt = close_val * vol
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
"amount": round(amt, 2),
"pre_close": round(close_val - np.random.randn() * 0.3, 2),
}
)
return pl.DataFrame(rows)
class TestFactorEngineEndToEnd:
"""FactorEngine 端到端测试类。"""
@pytest.fixture
def mock_data(self):
"""提供模拟数据的 fixture。"""
return create_mock_data("20240101", "20240131", n_stocks=5)
@pytest.fixture
def engine(self, mock_data):
"""提供配置好的 FactorEngine fixture。"""
data_source = {"pro_bar": mock_data}
return FactorEngine(data_source=data_source)
def test_simple_symbol_expression(self, engine):
"""测试简单的符号表达式。"""
engine.register("close_price", close)
result = engine.compute("close_price", "20240115", "20240120")
assert "close_price" in result.columns
assert len(result) > 0
print("[PASS] 简单符号表达式测试")
def test_arithmetic_expression(self, engine):
"""测试算术表达式。"""
engine.register("returns", (close - open_sym) / open_sym)
result = engine.compute("returns", "20240115", "20240120")
assert "returns" in result.columns
print("[PASS] 算术表达式测试")
def test_cs_rank_factor(self, engine):
"""测试截面排名因子。"""
engine.register("price_rank", cs_rank(close))
result = engine.compute("price_rank", "20240115", "20240120")
assert "price_rank" in result.columns
assert result["price_rank"].min() >= 0
assert result["price_rank"].max() <= 1
print("[PASS] 截面排名因子测试")
class TestFullWorkflow:
"""完整工作流测试类。"""
def test_full_workflow_demo(self):
"""演示完整的因子计算工作流。"""
print("\n" + "=" * 60)
print("FactorEngine Full Workflow Demo")
print("=" * 60)
# 1. 准备数据
print("\nStep 1: Prepare mock data...")
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f" Generated {len(mock_data)} rows")
print(f" Stocks: {mock_data['ts_code'].n_unique()}")
# 2. 初始化引擎
print("\nStep 2: Initialize FactorEngine...")
engine = FactorEngine(data_source={"pro_bar": mock_data})
print(" Engine initialized")
# 3. 注册因子 - 使用简单因子避免回看窗口问题
print("\nStep 3: Register factors...")
engine.register("returns", (close - open_sym) / open_sym)
engine.register("price_rank", cs_rank(close))
print(" Registered: returns, price_rank")
# 4. 执行计算 - 使用完整日期范围
print("\nStep 4: Compute factors...")
result = engine.compute(
["returns", "price_rank"],
"20240115",
"20240120",
)
print(f" Computed {len(result)} rows")
# 5. 验证结果
print("\nStep 5: Verify results...")
assert "returns" in result.columns
assert "price_rank" in result.columns
assert result["price_rank"].min() >= 0
assert result["price_rank"].max() <= 1
print(" All assertions passed")
# 6. 展示样本
print("\nStep 6: Sample output...")
sample = result.select(
["ts_code", "trade_date", "close", "returns", "price_rank"]
).head(3)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("Workflow completed successfully!")
print("=" * 60)
if __name__ == "__main__":
test = TestFullWorkflow()
test.test_full_workflow_demo()
pytest.main([__file__, "-v", "--tb=short"])

View File

@@ -1,106 +0,0 @@
"""FactorEngine 与 Metadata 集成测试。
测试 add_factor_by_name 方法的功能。
"""
import pytest
from src.factors import FactorEngine
from src.factors.metadata import FactorManager
class TestFactorEngineMetadataIntegration:
"""测试 FactorEngine 与 Metadata 的集成功能。"""
@pytest.fixture
def metadata_file(self):
"""使用 data 目录下的 factors.jsonl 文件。"""
return "data/factors.jsonl"
def test_init_without_metadata(self):
"""测试不启用 metadata 时初始化引擎。"""
engine = FactorEngine()
assert engine._metadata is None
def test_init_with_metadata(self, metadata_file):
"""测试启用 metadata 时初始化引擎。"""
engine = FactorEngine(metadata_path=metadata_file)
assert engine._metadata is not None
assert isinstance(engine._metadata, FactorManager)
def test_add_factor_by_name_success(self, metadata_file):
"""测试从 metadata 成功添加因子。"""
engine = FactorEngine(metadata_path=metadata_file)
# 添加 return_5 因子
result = engine.add_factor_by_name("return_5")
# 验证链式调用返回自身
assert result is engine
# 验证因子已注册
assert "return_5" in engine.list_registered()
def test_add_factor_by_name_with_alias(self, metadata_file):
"""测试使用别名添加因子。"""
engine = FactorEngine(metadata_path=metadata_file)
# 使用不同名称注册 metadata 中的因子
engine.add_factor_by_name("my_ma", "ma_5")
# 验证使用别名注册的因子
assert "my_ma" in engine.list_registered()
assert "ma_5" not in engine.list_registered()
def test_add_factor_by_name_not_found(self, metadata_file):
"""测试添加不存在的因子时抛出异常。"""
engine = FactorEngine(metadata_path=metadata_file)
with pytest.raises(ValueError) as exc_info:
engine.add_factor_by_name("nonexistent_factor")
assert "未找到因子" in str(exc_info.value)
assert "nonexistent_factor" in str(exc_info.value)
def test_add_factor_by_name_without_metadata(self):
"""测试未配置 metadata 时调用 add_factor_by_name 抛出异常。"""
engine = FactorEngine() # 不传入 metadata_path
with pytest.raises(RuntimeError) as exc_info:
engine.add_factor_by_name("return_5")
assert "未配置 metadata 路径" in str(exc_info.value)
def test_chain_calls(self, metadata_file):
"""测试链式调用。"""
engine = FactorEngine(metadata_path=metadata_file)
# 链式添加多个因子
(
engine.add_factor_by_name("return_5")
.add_factor_by_name("ma_5")
.add_factor_by_name("custom_ma20", "ma_20")
)
# 验证所有因子都已注册
assert "return_5" in engine.list_registered()
assert "ma_5" in engine.list_registered()
assert "custom_ma20" in engine.list_registered()
def test_add_factor_by_name_preserves_existing_add_factor(self, metadata_file):
"""测试 add_factor_by_name 不影响原有的 add_factor 方法。"""
engine = FactorEngine(metadata_path=metadata_file)
# 使用 add_factor 添加字符串表达式
engine.add_factor("manual_factor", "ts_mean(close, 10)")
# 使用 add_factor_by_name 添加 metadata 因子
engine.add_factor_by_name("return_5")
# 验证两者都正常工作
assert "manual_factor" in engine.list_registered()
assert "return_5" in engine.list_registered()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,451 +0,0 @@
"""因子框架集成测试脚本
测试目标:验证因子框架在 DuckDB 真实数据上的核心逻辑
测试范围:
1. 时序因子 ts_mean - 验证滑动窗口和数据隔离
2. 截面因子 cs_rank - 验证每日独立排名和结果分布
3. 组合运算 - 验证多字段算术运算和算子嵌套
排除范围PIT 因子(使用低频财务数据)
"""
import random
from datetime import datetime
import polars as pl
from src.data.catalog import DatabaseCatalog
from src.factors.engine import FactorEngine
from src.factors.api import close, open, ts_mean, cs_rank
def select_sample_stocks(catalog: DatabaseCatalog, n: int = 8) -> list:
"""随机选择代表性股票样本。
确保样本覆盖不同交易所:
- .SH: 上海证券交易所(主板、科创板)
- .SZ: 深圳证券交易所(主板、创业板)
Args:
catalog: 数据库目录实例
n: 需要选择的股票数量
Returns:
股票代码列表
"""
# 从 catalog 获取数据库连接
db_path = catalog.db_path.replace("duckdb://", "").lstrip("/")
import duckdb
conn = duckdb.connect(db_path, read_only=True)
try:
# 获取2023年上半年的所有股票
result = conn.execute("""
SELECT DISTINCT ts_code
FROM daily
WHERE trade_date >= '2023-01-01' AND trade_date <= '2023-06-30'
""").fetchall()
all_stocks = [row[0] for row in result]
# 按交易所分类
sh_stocks = [s for s in all_stocks if s.endswith(".SH")]
sz_stocks = [s for s in all_stocks if s.endswith(".SZ")]
# 选择样本:确保覆盖两个交易所
sample = []
# 从上海市场选择 (包含主板600/601/603/605和科创板688)
sh_main = [
s for s in sh_stocks if s.startswith("6") and not s.startswith("688")
]
sh_kcb = [s for s in sh_stocks if s.startswith("688")]
# 从深圳市场选择 (包含主板000/001/002和创业板300/301)
sz_main = [s for s in sz_stocks if s.startswith("0")]
sz_cyb = [s for s in sz_stocks if s.startswith("300") or s.startswith("301")]
# 每类选择部分股票
if sh_main:
sample.extend(random.sample(sh_main, min(2, len(sh_main))))
if sh_kcb:
sample.extend(random.sample(sh_kcb, min(2, len(sh_kcb))))
if sz_main:
sample.extend(random.sample(sz_main, min(2, len(sz_main))))
if sz_cyb:
sample.extend(random.sample(sz_cyb, min(2, len(sz_cyb))))
# 如果还不够,随机补充
while len(sample) < n and len(sample) < len(all_stocks):
remaining = [s for s in all_stocks if s not in sample]
if remaining:
sample.append(random.choice(remaining))
else:
break
return sorted(sample[:n])
finally:
conn.close()
def run_factor_integration_test():
"""执行因子框架集成测试。"""
print("=" * 80)
print("因子框架集成测试 - DuckDB 真实数据验证")
print("=" * 80)
# =========================================================================
# 1. 测试环境准备
# =========================================================================
print("\n" + "=" * 80)
print("1. 测试环境准备")
print("=" * 80)
# 数据库配置
db_path = "data/prostock.db"
db_uri = f"duckdb:///{db_path}"
print(f"\n数据库路径: {db_path}")
print(f"数据库URI: {db_uri}")
# 时间范围
start_date = "20230101"
end_date = "20230630"
print(f"\n测试时间范围: {start_date}{end_date}")
# 创建 DatabaseCatalog 并发现表结构
print("\n[1.1] 创建 DatabaseCatalog 并发现表结构...")
catalog = DatabaseCatalog(db_path)
print(f"发现表数量: {len(catalog.tables)}")
for table_name, metadata in catalog.tables.items():
print(
f" - {table_name}: {metadata.frequency.value} (日期字段: {metadata.date_field})"
)
# 选择样本股票
print("\n[1.2] 选择样本股票...")
sample_stocks = select_sample_stocks(catalog, n=8)
print(f"选中 {len(sample_stocks)} 只代表性股票:")
for stock in sample_stocks:
exchange = "上交所" if stock.endswith(".SH") else "深交所"
board = ""
if stock.startswith("688"):
board = "科创板"
elif (
stock.startswith("600")
or stock.startswith("601")
or stock.startswith("603")
):
board = "主板"
elif stock.startswith("300") or stock.startswith("301"):
board = "创业板"
elif (
stock.startswith("000")
or stock.startswith("001")
or stock.startswith("002")
):
board = "主板"
print(f" - {stock} ({exchange} {board})")
# =========================================================================
# 2. 因子定义
# =========================================================================
print("\n" + "=" * 80)
print("2. 因子定义")
print("=" * 80)
# 创建 FactorEngine
print("\n[2.1] 创建 FactorEngine...")
engine = FactorEngine(catalog)
# 因子 A: 时序均线 ts_mean(close, 10)
print("\n[2.2] 注册因子 A (时序均线): ts_mean(close, 10)")
print(" 验证重点: 10日滑动窗口是否正确是否存在'数据串户'")
factor_a = ts_mean(close, 10)
engine.add_factor("factor_a_ts_mean_10", factor_a)
print(f" AST: {factor_a}")
# 因子 B: 截面排名 cs_rank(close)
print("\n[2.3] 注册因子 B (截面排名): cs_rank(close)")
print(" 验证重点: 每天内部独立排名;结果是否严格分布在 0-1 之间")
factor_b = cs_rank(close)
engine.add_factor("factor_b_cs_rank", factor_b)
print(f" AST: {factor_b}")
# 因子 C: 组合运算 ts_mean(close, 5) / open
print("\n[2.4] 注册因子 C (组合运算): ts_mean(close, 5) / open")
print(" 验证重点: 多字段算术运算与时序算子嵌套的稳定性")
factor_c = ts_mean(close, 5) / open
engine.add_factor("factor_c_composite", factor_c)
print(f" AST: {factor_c}")
# 同时注册原始字段用于验证
engine.add_factor("close_price", close)
engine.add_factor("open_price", open)
print(f"\n已注册因子列表: {engine.list_factors()}")
# =========================================================================
# 3. 计算执行
# =========================================================================
print("\n" + "=" * 80)
print("3. 计算执行")
print("=" * 80)
print(f"\n[3.1] 执行因子计算 ({start_date} - {end_date})...")
result_df = engine.compute(
start_date=start_date,
end_date=end_date,
db_uri=db_uri,
)
print(f"\n计算完成!")
print(f"结果形状: {result_df.shape}")
print(f"结果列: {result_df.columns}")
# =========================================================================
# 4. 调试信息:打印 Context LazyFrame 前5行
# =========================================================================
print("\n" + "=" * 80)
print("4. 调试信息DataLoader 拼接后的数据预览")
print("=" * 80)
print("\n[4.1] 重新构建 Context LazyFrame 并打印前 5 行...")
from src.data.catalog import build_context_lazyframe
context_lf = build_context_lazyframe(
required_fields=["close", "open"],
start_date=start_date,
end_date=end_date,
db_uri=db_uri,
catalog=catalog,
)
print("\nContext LazyFrame 前 5 行:")
print(context_lf.fetch(5))
# =========================================================================
# 5. 时序切片检查
# =========================================================================
print("\n" + "=" * 80)
print("5. 时序切片检查")
print("=" * 80)
# 选择特定股票进行时序验证
target_stock = sample_stocks[0] if sample_stocks else "000001.SZ"
print(f"\n[5.1] 筛选股票: {target_stock}")
stock_df = result_df.filter(pl.col("ts_code") == target_stock)
print(f"该股票数据行数: {len(stock_df)}")
print(f"\n[5.2] 打印前 15 行结果(验证 ts_mean 滑动窗口):")
print("-" * 80)
print("人工核查点:")
print(" - 前 9 行的 factor_a_ts_mean_10 应该为 Null滑动窗口未满")
print(" - 第 10 行开始应该有值")
print("-" * 80)
display_cols = [
"ts_code",
"trade_date",
"close_price",
"open_price",
"factor_a_ts_mean_10",
]
available_cols = [c for c in display_cols if c in stock_df.columns]
print(stock_df.select(available_cols).head(15))
# 验证滑动窗口
print("\n[5.3] 滑动窗口验证:")
stock_list = stock_df.select("factor_a_ts_mean_10").to_series().to_list()
null_count_first_9 = sum(1 for x in stock_list[:9] if x is None)
non_null_from_10 = sum(1 for x in stock_list[9:15] if x is not None)
print(f" 前 9 行 Null 值数量: {null_count_first_9}/9")
print(f" 第 10-15 行非 Null 值数量: {non_null_from_10}/6")
if null_count_first_9 == 9 and non_null_from_10 == 6:
print(" ✅ 滑动窗口验证通过!")
else:
print(" ⚠️ 滑动窗口验证异常,请检查数据")
# =========================================================================
# 6. 截面切片检查
# =========================================================================
print("\n" + "=" * 80)
print("6. 截面切片检查")
print("=" * 80)
# 选择特定交易日
target_date = "20230301"
print(f"\n[6.1] 筛选交易日: {target_date}")
date_df = result_df.filter(pl.col("trade_date") == target_date)
print(f"该交易日股票数量: {len(date_df)}")
print(f"\n[6.2] 打印该日所有股票的 close 和 cs_rank 结果:")
print("-" * 80)
print("人工核查点:")
print(" - close 最高的股票其 cs_rank 应该接近 1.0")
print(" - close 最低的股票其 cs_rank 应该接近 0.0")
print(" - cs_rank 值应该严格分布在 [0, 1] 区间")
print("-" * 80)
# 按 close 排序显示
display_df = date_df.select(
["ts_code", "trade_date", "close_price", "factor_b_cs_rank"]
)
display_df = display_df.sort("close_price", descending=True)
print(display_df)
# 验证截面排名
print("\n[6.3] 截面排名验证:")
rank_values = date_df.select("factor_b_cs_rank").to_series().to_list()
rank_values = [x for x in rank_values if x is not None]
if rank_values:
min_rank = min(rank_values)
max_rank = max(rank_values)
print(f" cs_rank 最小值: {min_rank:.6f}")
print(f" cs_rank 最大值: {max_rank:.6f}")
print(f" cs_rank 值域: [{min_rank:.6f}, {max_rank:.6f}]")
# 验证 close 最高的股票 rank 是否为 1.0
highest_close_row = date_df.sort("close_price", descending=True).head(1)
if len(highest_close_row) > 0:
highest_rank = highest_close_row.select("factor_b_cs_rank").item()
print(f" 最高 close 股票的 cs_rank: {highest_rank:.6f}")
if abs(highest_rank - 1.0) < 0.01:
print(" ✅ 截面排名验证通过! (最高 close 股票 rank 接近 1.0)")
else:
print(f" ⚠️ 截面排名验证异常 (期望接近 1.0,实际 {highest_rank:.6f})")
# =========================================================================
# 7. 数据完整性统计
# =========================================================================
print("\n" + "=" * 80)
print("7. 数据完整性统计")
print("=" * 80)
factor_cols = ["factor_a_ts_mean_10", "factor_b_cs_rank", "factor_c_composite"]
print("\n[7.1] 各因子的空值数量和描述性统计:")
print("-" * 80)
for col in factor_cols:
if col in result_df.columns:
series = result_df.select(col).to_series()
null_count = series.null_count()
total_count = len(series)
print(f"\n因子: {col}")
print(f" 总记录数: {total_count}")
print(f" 空值数量: {null_count} ({null_count / total_count * 100:.2f}%)")
# 描述性统计(排除空值)
non_null_series = series.drop_nulls()
if len(non_null_series) > 0:
print(f" 描述性统计:")
print(f" Mean: {non_null_series.mean():.6f}")
print(f" Std: {non_null_series.std():.6f}")
print(f" Min: {non_null_series.min():.6f}")
print(f" Max: {non_null_series.max():.6f}")
# =========================================================================
# 8. 综合验证
# =========================================================================
print("\n" + "=" * 80)
print("8. 综合验证")
print("=" * 80)
print("\n[8.1] 数据串户检查:")
# 检查不同股票的数据是否正确隔离
print(" 验证方法: 检查不同股票的 trade_date 序列是否独立")
stock_dates = {}
for stock in sample_stocks[:3]: # 检查前3只股票
stock_data = (
result_df.filter(pl.col("ts_code") == stock)
.select("trade_date")
.to_series()
.to_list()
)
stock_dates[stock] = stock_data[:5] # 前5个日期
print(f" {stock} 前5个交易日期: {stock_data[:5]}")
# 检查日期序列是否一致(应该一致,因为是同一时间段)
dates_match = all(
dates == list(stock_dates.values())[0] for dates in stock_dates.values()
)
if dates_match:
print(" ✅ 日期序列一致,数据对齐正确")
else:
print(" ⚠️ 日期序列不一致,请检查数据对齐")
print("\n[8.2] 因子 C 组合运算验证:")
# 手动计算几行验证组合运算
sample_row = result_df.filter(
(pl.col("ts_code") == target_stock)
& (pl.col("factor_a_ts_mean_10").is_not_null())
).head(1)
if len(sample_row) > 0:
close_val = sample_row.select("close_price").item()
open_val = sample_row.select("open_price").item()
factor_c_val = sample_row.select("factor_c_composite").item()
# 手动计算 ts_mean(close, 5) / open
# 注意:这里只是验证表达式结构,不是精确计算
print(f" 样本数据:")
print(f" close: {close_val:.4f}")
print(f" open: {open_val:.4f}")
print(f" factor_c (ts_mean(close, 5) / open): {factor_c_val:.6f}")
# 验证 factor_c 是否合理(应该接近 close/open 的某个均值)
ratio = close_val / open_val if open_val != 0 else 0
print(f" close/open 比值: {ratio:.6f}")
print(f" ✅ 组合运算结果已生成")
# =========================================================================
# 9. 测试总结
# =========================================================================
print("\n" + "=" * 80)
print("9. 测试总结")
print("=" * 80)
print("\n测试完成! 以下是关键验证点总结:")
print("-" * 80)
print("✅ 因子 A (ts_mean):")
print(" - 10日滑动窗口计算正确")
print(" - 前9行为Null第10行开始有值")
print(" - 不同股票数据隔离over(ts_code)")
print()
print("✅ 因子 B (cs_rank):")
print(" - 每日独立排名over(trade_date)")
print(" - 结果分布在 [0, 1] 区间")
print(" - 最高close股票rank接近1.0")
print()
print("✅ 因子 C (组合运算):")
print(" - 多字段算术运算正常")
print(" - 时序算子嵌套稳定")
print()
print("✅ 数据完整性:")
print(f" - 总记录数: {len(result_df)}")
print(f" - 样本股票数: {len(sample_stocks)}")
print(f" - 时间范围: {start_date}{end_date}")
print("-" * 80)
return result_df
if __name__ == "__main__":
# 设置随机种子以确保可重复性
random.seed(42)
# 运行测试
result = run_factor_integration_test()

View File

@@ -1,351 +0,0 @@
"""财务数据与行情数据拼接测试。
测试场景:
1. 普通财务数据:正常公告,之后无修改
2. 隔日修改:公告后几天发布修正版
3. 当日修改:同一天发布多版,取 update_flag=1 的
4. 边界条件:财务数据缺失、行情数据早于最早财务数据
"""
import polars as pl
from datetime import date
from src.data.financial_loader import FinancialLoader
def create_mock_price_data() -> pl.DataFrame:
"""创建模拟行情数据。"""
return pl.DataFrame(
{
"ts_code": ["000001.SZ"] * 12,
"trade_date": [
"20240101",
"20240102",
"20240103",
"20240104",
"20240105",
"20240108",
"20240109",
"20240110",
"20240111",
"20240112",
# 添加2024-04-30之后的日期用于测试同日不同报告期场景
"20240501",
"20240502",
],
"close": [
10.0,
10.2,
10.3,
10.1,
10.5,
10.6,
10.4,
10.7,
10.8,
10.9,
11.0,
11.1,
],
}
)
def create_mock_financial_data() -> pl.DataFrame:
"""创建模拟财务数据(覆盖多种场景)。
场景说明:
1. 2024-01-02 发布 2023Q3 报告end_date=20230930
2. 2024-01-02 发布 2023Q3 更正版update_flag=1
3. 2024-04-30 同时发布 2023年报end_date=20231231和 2024Q1季报end_date=20240331
4. 2024-04-30 发布 2023年报更正版
预期结果:
- 2024-01-02 保留 2023Q3 更正版
- 2024-04-30 保留 2024Q1 季报end_date 最新)
注意f_ann_date 必须是 Date 类型(与数据库保持一致)。
"""
return pl.DataFrame(
{
"ts_code": [
"000001.SZ",
"000001.SZ",
"000001.SZ",
"000001.SZ",
"000001.SZ",
],
"f_ann_date": [
date(2024, 1, 2),
date(2024, 1, 2), # 同日多版
date(2024, 4, 30),
date(2024, 4, 30),
date(2024, 4, 30), # 同日不同报告期
],
"end_date": [
"20230930",
"20230930", # 2023Q3
"20231231",
"20240331",
"20231231", # 年报和季报同一天发布
],
"report_type": [1, 1, 1, 1, 1], # 整数类型(与数据库一致)
"update_flag": [0, 1, 0, 0, 1], # 年报也有更正版
"net_profit": [
1000000.0,
1100000.0, # 2023Q3
5000000.0,
1500000.0,
5500000.0, # 年报更正后550万季报150万
],
"revenue": [
5000000.0,
5200000.0, # 2023Q3
20000000.0,
8000000.0,
22000000.0,
],
}
)
def test_financial_data_cleaning():
"""测试财务数据清洗逻辑 - 确保同日多报告期时选 end_date 最新的。"""
print("=== 测试 1: 财务数据清洗 ===")
df_finance = create_mock_financial_data()
print("原始财务数据:")
print(df_finance)
loader = FinancialLoader()
# 手动执行新的清洗逻辑
df = df_finance.filter(pl.col("report_type") == 1)
# 添加辅助列
df = df.with_columns(
[
pl.col("end_date").cast(pl.Int32).alias("end_date_int"),
pl.col("update_flag")
.fill_null("0")
.cast(pl.Int32, strict=False)
.fill_null(0)
.alias("update_flag_int"),
]
)
# 确定性排序
df = df.sort(["ts_code", "f_ann_date", "end_date_int", "update_flag_int"])
# 累积最大报告期
df = df.with_columns(
pl.col("end_date_int").cum_max().over("ts_code").alias("max_end_date_seen")
)
# 过滤历史包袱
df = df.filter(pl.col("end_date_int") == pl.col("max_end_date_seen"))
# 去重保留最后一条end_date 最大的)
df = df.unique(subset=["ts_code", "f_ann_date"], keep="last")
# 清理辅助列
df = df.drop(["end_date_int", "update_flag_int", "max_end_date_seen"])
df = df.sort(["ts_code", "f_ann_date"])
print("\n清洗后的财务数据:")
print(df)
# 验证应该有2条记录2024-01-02 和 2024-04-30
assert len(df) == 2, f"清洗后应该有2条记录实际有 {len(df)}"
# 验证2024-01-02 的 end_date 应该是 20230930
row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2))
assert len(row_jan02) == 1
assert row_jan02["end_date"][0] == "20230930"
assert row_jan02["update_flag"][0] == 1
print("[验证 1] 2024-01-02 正确保留了 2023Q3 更正版")
# 验证2024-04-30 应该保留 2024Q1end_date=20240331而不是年报
row_apr30 = df.filter(pl.col("f_ann_date") == date(2024, 4, 30))
assert len(row_apr30) == 1
assert row_apr30["end_date"][0] == "20240331", (
f"2024-04-30 应该保留 end_date 最新的 20240331"
f"实际为 {row_apr30['end_date'][0]}"
)
assert row_apr30["net_profit"][0] == 1500000.0
print("[验证 2] 2024-04-30 正确保留了 2024Q1 季报end_date 最新)")
print("\n[通过] 财务数据清洗测试通过!")
return df
def test_financial_price_merge():
"""测试财务数据拼接逻辑(无未来函数验证)。"""
print("\n=== 测试 2: 财务数据与行情数据拼接 ===")
df_price = create_mock_price_data()
df_finance_raw = create_mock_financial_data()
loader = FinancialLoader()
# 步骤1: 清洗财务数据(手动执行新的清洗逻辑)
# 注意f_ann_date 已经是 Date 类型,不需要转换
df_finance = df_finance_raw.filter(pl.col("report_type") == 1)
# 添加辅助列
df_finance = df_finance.with_columns(
[
pl.col("end_date").cast(pl.Int32).alias("end_date_int"),
pl.col("update_flag")
.fill_null("0")
.cast(pl.Int32, strict=False)
.fill_null(0)
.alias("update_flag_int"),
]
)
# 确定性排序
df_finance = df_finance.sort(
["ts_code", "f_ann_date", "end_date_int", "update_flag_int"]
)
# 累积最大报告期
df_finance = df_finance.with_columns(
pl.col("end_date_int").cum_max().over("ts_code").alias("max_end_date_seen")
)
# 过滤历史包袱
df_finance = df_finance.filter(
pl.col("end_date_int") == pl.col("max_end_date_seen")
)
# 去重保留最后一条end_date 最大的)
df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="last")
# 清理辅助列
df_finance = df_finance.drop(
["end_date_int", "update_flag_int", "max_end_date_seen"]
)
df_finance = df_finance.sort(["ts_code", "f_ann_date"])
print("清洗后的财务数据:")
print(df_finance)
# 步骤2: 转换行情数据日期为 Date 类型
df_price = df_price.with_columns(
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
)
df_price = df_price.sort(["ts_code", "trade_date"])
# 步骤3: 拼接
financial_cols = ["net_profit", "revenue"]
merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols)
# 步骤4: 转回字符串格式
merged = merged.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
print("\n拼接结果:")
print(merged)
# 验证无未来函数:
# 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告)
jan01 = merged.filter(pl.col("trade_date") == "20240101")
assert jan01["net_profit"].is_null().all(), (
"2024-01-01 不应有 2023Q3 数据(尚未公告)"
)
print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)")
# 20240102 及之后应该看到 net_profit=1100000update_flag=1 的版本)
jan02 = merged.filter(pl.col("trade_date") == "20240102")
assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据"
print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1")
# 20240104 应延续使用 2023Q3 数据
jan04 = merged.filter(pl.col("trade_date") == "20240104")
assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据"
print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)")
# 20240110 应延续使用 2023Q3 数据2024-04-30 还未公告)
jan10 = merged.filter(pl.col("trade_date") == "20240110")
assert jan10["net_profit"][0] == 1100000.0, "2024-01-10 应延续使用 2023Q3 数据"
print("[验证 4] 2024-01-10 net_profit=1100000 - 正确(延续使用 2023Q3")
# 20240112 应继续延续使用 2023Q3 数据
jan12 = merged.filter(pl.col("trade_date") == "20240112")
assert jan12["net_profit"][0] == 1100000.0, "2024-01-12 应继续使用 2023Q3 数据"
print("[验证 5] 2024-01-12 net_profit=1100000 - 正确(延续使用 2023Q3")
# 20240501 应切换到 2024Q1 数据2024-04-30 已公告,且选择 end_date 最新的)
may01 = merged.filter(pl.col("trade_date") == "20240501")
assert may01["net_profit"][0] == 1500000.0, "2024-05-01 应切换到 2024Q1 数据"
print(
"[验证 6] 2024-05-01 net_profit=1500000 - 正确(切换到 2024Q1end_date 最新)"
)
print("\n[通过] 所有验证通过,无未来函数!")
return merged
def test_empty_financial_data():
"""测试财务数据为空的情况。"""
print("\n=== 测试 3: 空财务数据场景 ===")
df_price = create_mock_price_data()
df_empty = pl.DataFrame()
loader = FinancialLoader()
# 转换行情数据日期为 Date 类型
df_price = df_price.with_columns(
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
)
df_price = df_price.sort(["ts_code", "trade_date"])
# 拼接空财务数据
merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"])
# 转回字符串格式
merged = merged.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
# 验证财务列为空
assert merged["net_profit"].is_null().all(), (
"财务数据为空时net_profit 应全为 null"
)
print("空财务数据拼接结果:")
print(merged)
print("\n[通过] 空财务数据场景测试通过!")
def run_all_tests():
"""运行所有测试。"""
print("开始运行财务数据拼接功能测试...\n")
print("=" * 60)
try:
# 测试 1: 数据清洗
test_financial_data_cleaning()
# 测试 2: 数据拼接
test_financial_price_merge()
# 测试 3: 空数据场景
test_empty_financial_data()
print("\n" + "=" * 60)
print("所有测试通过!")
print("=" * 60)
except AssertionError as e:
print(f"\n[失败] 测试断言失败: {e}")
raise
except Exception as e:
print(f"\n[错误] 测试执行出错: {e}")
raise
if __name__ == "__main__":
run_all_tests()

View File

@@ -1,130 +0,0 @@
"""测试新增的时间序列函数和智能分发逻辑。"""
import pytest
import polars as pl
import numpy as np
from src.factors.dsl import Symbol, FunctionNode
from src.factors.translator import PolarsTranslator
def test_ts_sma_translate():
"""测试 ts_sma 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_sma", close, 10, 5)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_ts_wma_translate():
"""测试 ts_wma 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_wma", close, 20)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_ts_sumac_translate():
"""测试 ts_sumac 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_sumac", close)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_max_intelligent_dispatch():
"""测试 max_ 智能分发: int -> ts_max其他 -> element-wise max。"""
from src.factors.api import max_, close
# 正整数 -> ts_max
result = max_(close, 20)
assert result.func_name == "ts_max"
# 零或负数 -> element-wise max
result = max_(close, 0)
assert result.func_name == "max"
result = max_(close, -1)
assert result.func_name == "max"
# 浮点数 -> element-wise max
result = max_(close, 10.5)
assert result.func_name == "max"
def test_min_intelligent_dispatch():
"""测试 min_ 智能分发: int -> ts_min其他 -> element-wise min。"""
from src.factors.api import min_, close
# 正整数 -> ts_min
result = min_(close, 20)
assert result.func_name == "ts_min"
# 零或负数 -> element-wise min
result = min_(close, 0)
assert result.func_name == "min"
def create_test_data() -> pl.DataFrame:
"""创建测试数据。"""
np.random.seed(42)
n = 100
return pl.DataFrame(
{
"ts_code": ["000001.SZ"] * n,
"trade_date": list(range(20240101, 20240101 + n)),
"close": np.random.randn(n).cumsum() + 100,
}
)
def test_ts_sma_computation():
"""测试 ts_sma 计算与原生 Polars 一致。"""
df = create_test_data()
translator = PolarsTranslator()
# 翻译因子
close = Symbol("close")
expr_node = FunctionNode("ts_sma", close, 10, 5)
expr = translator.translate(expr_node)
# 使用翻译后的表达式计算
result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")])
# 原生 Polars 计算
native = df.with_columns(
[pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")]
)
# 对比结果
assert np.allclose(
result["ts_sma_result"].to_numpy()[9:],
native["native_sma"].to_numpy()[9:],
equal_nan=True,
)
def test_ts_sumac_computation():
"""测试 ts_sumac 计算与原生 Polars 一致。"""
df = create_test_data()
translator = PolarsTranslator()
close = Symbol("close")
expr_node = FunctionNode("ts_sumac", close)
expr = translator.translate(expr_node)
result = df.select(
["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")]
)
native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")])
assert np.allclose(
result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy()
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,541 +0,0 @@
"""Phase 1-2 因子函数集成测试。
测试所有新实现的函数,使用字符串因子表达式形式计算因子,
并与原始 Polars 计算结果进行对比。
测试范围:
1. 数学函数atan, log1p
2. 统计函数ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema
3. TA-Lib 函数ts_atr, ts_rsi, ts_obv
"""
import numpy as np
import polars as pl
import pytest
from src.factors import FormulaParser, FunctionRegistry
from src.factors.translator import PolarsTranslator, HAS_TALIB
from src.factors.engine import FactorEngine
from src.data.catalog import DatabaseCatalog
# ============== 测试数据准备 ==============
def create_test_data() -> pl.DataFrame:
"""创建测试用的模拟数据。
创建一个包含多只股票、多个交易日的 DataFrame
用于测试因子函数的计算。
"""
np.random.seed(42)
dates = pl.date_range(
start=pl.date(2024, 1, 1),
end=pl.date(2024, 1, 31),
interval="1d",
eager=True,
)
stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600001.SH"]
data = []
for stock in stocks:
base_price = 100 + np.random.randn() * 10
for i, date in enumerate(dates):
price = base_price + np.random.randn() * 5 + i * 0.1
data.append(
{
"ts_code": stock,
"trade_date": date,
"close": price,
"open": price * (1 + np.random.randn() * 0.01),
"high": price * (1 + abs(np.random.randn()) * 0.02),
"low": price * (1 - abs(np.random.randn()) * 0.02),
"vol": int(1000000 + np.random.randn() * 500000),
}
)
return pl.DataFrame(data)
# ============== 数学函数测试 ==============
def test_atan_function():
"""测试 atan 函数:计算反正切值。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
),
"value": [0.0, 1.0, -1.0, 0.5, -0.5],
}
)
# DSL 计算
expr = parser.parse("atan(value)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(pl_result=pl.col("value").arctan()).to_pandas()[
"pl_result"
]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_log1p_function():
"""测试 log1p 函数:计算 log(1+x)。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
),
"value": [0.0, 0.1, -0.1, 1.0, -0.5],
}
)
# DSL 计算
expr = parser.parse("log1p(value)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(pl_result=pl.col("value").log1p()).to_pandas()[
"pl_result"
]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
# ============== 统计函数测试 ==============
def test_ts_var_function():
"""测试 ts_var 函数:滚动方差。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 10 + ["B"] * 10,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True
).append(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
),
"close": list(range(1, 11)) + list(range(10, 20)),
}
)
# DSL 计算
expr = parser.parse("ts_var(close, 5)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = (
df.with_columns(dsl_result=polars_expr)
.to_pandas()
.groupby("ts_code")["dsl_result"]
.apply(list)
)
# 原始 Polars 计算
result_pl = (
df.with_columns(
pl_result=pl.col("close").rolling_var(window_size=5).over("ts_code")
)
.to_pandas()
.groupby("ts_code")["pl_result"]
.apply(list)
)
# 对比结果
for stock in ["A", "B"]:
np.testing.assert_array_almost_equal(
result_dsl[stock], result_pl[stock], decimal=10
)
def test_ts_skew_function():
"""测试 ts_skew 函数:滚动偏度。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": np.random.randn(40),
}
)
# DSL 计算
expr = parser.parse("ts_skew(close, 10)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close").rolling_skew(window_size=10).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_kurt_function():
"""测试 ts_kurt 函数:滚动峰度。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": np.random.randn(40),
}
)
# DSL 计算
expr = parser.parse("ts_kurt(close, 10)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close")
.rolling_map(
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
window_size=10,
)
.over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_pct_change_function():
"""测试 ts_pct_change 函数:百分比变化。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5 + ["B"] * 5,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True)
)
* 2,
"close": [100, 105, 102, 108, 110, 50, 52, 48, 55, 60],
}
)
# DSL 计算
expr = parser.parse("ts_pct_change(close, 1)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=(pl.col("close") - pl.col("close").shift(1))
/ pl.col("close").shift(1).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_ema_function():
"""测试 ts_ema 函数:指数移动平均。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 10 + ["B"] * 10,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
)
* 2,
"close": list(range(1, 11)) + list(range(10, 20)),
}
)
# DSL 计算
expr = parser.parse("ts_ema(close, 5)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close").ewm_mean(span=5).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
# ============== TA-Lib 函数测试 ==============
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_atr_function():
"""测试 ts_atr 函数:平均真实波幅。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"high": 100 + np.random.randn(40) * 2,
"low": 98 + np.random.randn(40) * 2,
"close": 99 + np.random.randn(40) * 2,
}
)
# DSL 计算
expr = parser.parse("ts_atr(high, low, close, 14)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
atr = talib.ATR(
stock_df["high"].values,
stock_df["low"].values,
stock_df["close"].values,
timeperiod=14,
)
result_expected.extend(atr)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_rsi_function():
"""测试 ts_rsi 函数:相对强弱指数。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 30 + ["B"] * 30,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(60)),
}
)
# DSL 计算
expr = parser.parse("ts_rsi(close, 14)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
rsi = talib.RSI(stock_df["close"].values, timeperiod=14)
result_expected.extend(rsi)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_obv_function():
"""测试 ts_obv 函数:能量潮指标。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(40)),
"vol": np.random.randint(100000, 1000000, 40).astype(float),
}
)
# DSL 计算
expr = parser.parse("ts_obv(close, vol)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
obv = talib.OBV(
stock_df["close"].values,
stock_df["vol"].values,
)
result_expected.extend(obv)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
# ============== 综合测试 ==============
def test_complex_factor_expressions():
"""测试复杂因子表达式的计算。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 30 + ["B"] * 30,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(60)),
}
)
# 测试 act_factor1: atan((ts_ema(close,5)/ts_delay(ts_ema(close,5),1)-1)*100) * 57.3 / 50
expr = parser.parse(
"atan((ts_ema(close, 5) / ts_delay(ts_ema(close, 5), 1) - 1) * 100) * 57.3 / 50"
)
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result = df.with_columns(factor=polars_expr)
# 验证结果不为空
assert len(result) == 60
assert "factor" in result.columns
print("复杂因子表达式测试通过")
# ============== 主函数 ==============
if __name__ == "__main__":
print("运行 Phase 1-2 因子函数测试...")
print("=" * 80)
# 运行数学函数测试
print("\n[数学函数测试]")
test_atan_function()
print(" ✅ atan 测试通过")
test_log1p_function()
print(" ✅ log1p 测试通过")
# 运行统计函数测试
print("\n[统计函数测试]")
test_ts_var_function()
print(" ✅ ts_var 测试通过")
test_ts_skew_function()
print(" ✅ ts_skew 测试通过")
test_ts_kurt_function()
print(" ✅ ts_kurt 测试通过")
test_ts_pct_change_function()
print(" ✅ ts_pct_change 测试通过")
test_ts_ema_function()
print(" ✅ ts_ema 测试通过")
# 运行 TA-Lib 函数测试
print("\n[TA-Lib 函数测试]")
try:
import talib
HAS_TALIB = True
except ImportError:
HAS_TALIB = False
print(" ⚠️ TA-Lib 未安装,跳过 TA-Lib 测试")
if HAS_TALIB:
test_ts_atr_function()
print(" ✅ ts_atr 测试通过")
test_ts_rsi_function()
print(" ✅ ts_rsi 测试通过")
test_ts_obv_function()
print(" ✅ ts_obv 测试通过")
# 运行综合测试
print("\n[综合测试]")
test_complex_factor_expressions()
print(" ✅ 复杂因子表达式测试通过")
print("\n" + "=" * 80)
print("所有测试通过!")

View File

@@ -1,421 +0,0 @@
"""Test for pro_bar (universal market) API.
Tests the pro_bar interface implementation:
- Backward-adjusted (后复权) data fetching
- All output fields including tor, vr, and adj_factor (default behavior)
- Multiple asset types support
- ProBarSync batch synchronization
"""
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
from src.data.api_wrappers.api_pro_bar import (
get_pro_bar,
ProBarSync,
sync_pro_bar,
preview_pro_bar_sync,
)
# Expected output fields according to api.md
EXPECTED_BASE_FIELDS = [
"ts_code", # 股票代码
"trade_date", # 交易日期
"open", # 开盘价
"high", # 最高价
"low", # 最低价
"close", # 收盘价
"pre_close", # 昨收价
"change", # 涨跌额
"pct_chg", # 涨跌幅
"vol", # 成交量
"amount", # 成交额
]
EXPECTED_FACTOR_FIELDS = [
"turnover_rate", # 换手率 (tor)
"volume_ratio", # 量比 (vr)
]
class TestGetProBar:
"""Test cases for get_pro_bar function."""
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_basic(self, mock_client_class):
"""Test basic pro_bar data fetch."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"open": [10.5],
"high": [11.0],
"low": [10.2],
"close": [10.8],
"pre_close": [10.5],
"change": [0.3],
"pct_chg": [2.86],
"vol": [100000.0],
"amount": [1080000.0],
}
)
# Test
result = get_pro_bar("000001.SZ", start_date="20240101", end_date="20240131")
# Assert
assert isinstance(result, pd.DataFrame)
assert not result.empty
assert result["ts_code"].iloc[0] == "000001.SZ"
mock_client.query.assert_called_once()
# Verify pro_bar API is called
call_args = mock_client.query.call_args
assert call_args[0][0] == "pro_bar"
assert call_args[1]["ts_code"] == "000001.SZ"
# Default should use hfq (backward-adjusted)
assert call_args[1]["adj"] == "hfq"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_default_backward_adjusted(self, mock_client_class):
"""Test that default adjustment is backward (hfq)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [100.5],
}
)
result = get_pro_bar("000001.SZ")
call_args = mock_client.query.call_args
assert call_args[1]["adj"] == "hfq"
assert call_args[1]["adjfactor"] == "True"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_default_factors_all_fields(self, mock_client_class):
"""Test that default factors includes tor and vr."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
"turnover_rate": [2.5],
"volume_ratio": [1.2],
"adj_factor": [1.05],
}
)
result = get_pro_bar("000001.SZ")
call_args = mock_client.query.call_args
# Default should include both tor and vr
assert call_args[1]["factors"] == "tor,vr"
assert "turnover_rate" in result.columns
assert "volume_ratio" in result.columns
assert "adj_factor" in result.columns
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_with_custom_factors(self, mock_client_class):
"""Test fetch with custom factors."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
"turnover_rate": [2.5],
}
)
# Only request tor
result = get_pro_bar(
"000001.SZ",
start_date="20240101",
end_date="20240131",
factors=["tor"],
)
call_args = mock_client.query.call_args
assert call_args[1]["factors"] == "tor"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_with_no_factors(self, mock_client_class):
"""Test fetch with no factors (empty list)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
}
)
# Explicitly set factors to empty list
result = get_pro_bar(
"000001.SZ",
start_date="20240101",
end_date="20240131",
factors=[],
)
call_args = mock_client.query.call_args
# Should not include factors parameter
assert "factors" not in call_args[1]
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_with_ma(self, mock_client_class):
"""Test fetch with moving averages."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
"ma_5": [10.5],
"ma_10": [10.3],
"ma_v_5": [95000.0],
}
)
result = get_pro_bar(
"000001.SZ",
start_date="20240101",
end_date="20240131",
ma=[5, 10],
)
call_args = mock_client.query.call_args
assert call_args[1]["ma"] == "5,10"
assert "ma_5" in result.columns
assert "ma_10" in result.columns
assert "ma_v_5" in result.columns
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_index_data(self, mock_client_class):
"""Test fetching index data."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SH"],
"trade_date": ["20240115"],
"close": [2900.5],
}
)
result = get_pro_bar(
"000001.SH",
asset="I",
start_date="20240101",
end_date="20240131",
)
call_args = mock_client.query.call_args
assert call_args[1]["asset"] == "I"
assert call_args[1]["ts_code"] == "000001.SH"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_forward_adjustment(self, mock_client_class):
"""Test forward adjustment (qfq)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
}
)
result = get_pro_bar("000001.SZ", adj="qfq")
call_args = mock_client.query.call_args
assert call_args[1]["adj"] == "qfq"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_no_adjustment(self, mock_client_class):
"""Test no adjustment."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
}
)
result = get_pro_bar("000001.SZ", adj=None)
call_args = mock_client.query.call_args
assert "adj" not in call_args[1]
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_empty_response(self, mock_client_class):
"""Test handling empty response."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame()
result = get_pro_bar("INVALID.SZ")
assert isinstance(result, pd.DataFrame)
assert result.empty
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_date_column_rename(self, mock_client_class):
"""Test that 'date' column is renamed to 'trade_date'."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"date": ["20240115"], # API returns 'date' instead of 'trade_date'
"close": [10.8],
}
)
result = get_pro_bar("000001.SZ")
assert "trade_date" in result.columns
assert "date" not in result.columns
assert result["trade_date"].iloc[0] == "20240115"
class TestProBarSync:
"""Test cases for ProBarSync class."""
@patch("src.data.api_wrappers.api_pro_bar.sync_all_stocks")
@patch("src.data.api_wrappers.api_pro_bar.pd.read_csv")
@patch("src.data.api_wrappers.api_pro_bar._get_csv_path")
def test_get_all_stock_codes(self, mock_get_path, mock_read_csv, mock_sync_stocks):
"""Test getting all stock codes."""
from pathlib import Path
from unittest.mock import MagicMock
# Create a mock path that exists
mock_path = MagicMock(spec=Path)
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
mock_read_csv.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "600000.SH"],
"list_status": ["L", "L"],
}
)
sync = ProBarSync()
codes = sync.get_all_stock_codes()
assert len(codes) == 2
assert "000001.SZ" in codes
assert "600000.SH" in codes
@patch("src.data.api_wrappers.api_pro_bar.Storage")
def test_check_sync_needed_force_full(self, mock_storage_class):
"""Test check_sync_needed with force_full=True."""
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.exists.return_value = False
sync = ProBarSync()
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
assert needed is True
assert start == "20180101" # DEFAULT_START_DATE
assert local_last is None
@patch("src.data.api_wrappers.api_pro_bar.Storage")
def test_check_sync_needed_force_full(self, mock_storage_class):
"""Test check_sync_needed with force_full=True."""
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.exists.return_value = False
sync = ProBarSync()
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
assert needed is True
assert start == "20180101" # DEFAULT_START_DATE
assert local_last is None
class TestSyncProBar:
"""Test cases for sync_pro_bar function."""
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
def test_sync_pro_bar(self, mock_sync_class):
"""Test sync_pro_bar function."""
mock_sync = MagicMock()
mock_sync_class.return_value = mock_sync
mock_sync.sync_all.return_value = {"000001.SZ": pd.DataFrame({"close": [10.5]})}
result = sync_pro_bar(force_full=True, max_workers=5)
mock_sync_class.assert_called_once_with(max_workers=5)
mock_sync.sync_all.assert_called_once()
assert "000001.SZ" in result
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
def test_preview_pro_bar_sync(self, mock_sync_class):
"""Test preview_pro_bar_sync function."""
mock_sync = MagicMock()
mock_sync_class.return_value = mock_sync
mock_sync.preview_sync.return_value = {
"sync_needed": True,
"stock_count": 5000,
"mode": "full",
}
result = preview_pro_bar_sync(force_full=True)
mock_sync_class.assert_called_once_with()
mock_sync.preview_sync.assert_called_once()
assert result["sync_needed"] is True
assert result["stock_count"] == 5000
class TestProBarIntegration:
"""Integration tests with real Tushare API."""
def test_real_api_call(self):
"""Test with real API (requires valid token)."""
import os
token = os.environ.get("TUSHARE_TOKEN")
if not token:
pytest.skip("TUSHARE_TOKEN not configured")
result = get_pro_bar(
"000001.SZ",
start_date="20240101",
end_date="20240131",
)
# Verify structure
assert isinstance(result, pd.DataFrame)
if not result.empty:
# Check base fields
for field in EXPECTED_BASE_FIELDS:
assert field in result.columns, f"Missing base field: {field}"
# Check factor fields (should be present by default)
for field in EXPECTED_FACTOR_FIELDS:
assert field in result.columns, f"Missing factor field: {field}"
# Check adj_factor is present (default behavior)
assert "adj_factor" in result.columns
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,246 +0,0 @@
"""Tests for stock limit price API wrapper."""
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
from src.data.api_wrappers.api_stk_limit import (
get_stk_limit,
sync_stk_limit,
preview_stk_limit_sync,
StkLimitSync,
)
class TestStkLimit:
"""Test suite for stk_limit API wrapper."""
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
def test_get_by_date(self, mock_client_class):
"""Test fetching data by trade_date."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240625", "20240625"],
"pre_close": [10.0, 20.0],
"up_limit": [11.0, 22.0],
"down_limit": [9.0, 18.0],
}
)
# Test
result = get_stk_limit(trade_date="20240625")
# Assert
assert not result.empty
assert len(result) == 2
assert "ts_code" in result.columns
assert "trade_date" in result.columns
assert "up_limit" in result.columns
assert "down_limit" in result.columns
mock_client.query.assert_called_once_with("stk_limit", trade_date="20240625")
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
def test_get_by_date_range(self, mock_client_class):
"""Test fetching data by date range."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ"],
"trade_date": ["20240624", "20240625"],
"pre_close": [10.0, 10.5],
"up_limit": [11.0, 11.55],
"down_limit": [9.0, 9.45],
}
)
# Test
result = get_stk_limit(start_date="20240624", end_date="20240625")
# Assert
assert not result.empty
assert len(result) == 2
mock_client.query.assert_called_once_with(
"stk_limit", start_date="20240624", end_date="20240625"
)
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
def test_get_by_stock_code(self, mock_client_class):
"""Test fetching data by stock code."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240625"],
"pre_close": [10.0],
"up_limit": [11.0],
"down_limit": [9.0],
}
)
# Test
result = get_stk_limit(ts_code="000001.SZ", trade_date="20240625")
# Assert
assert not result.empty
assert len(result) == 1
assert result.iloc[0]["ts_code"] == "000001.SZ"
mock_client.query.assert_called_once_with(
"stk_limit", trade_date="20240625", ts_code="000001.SZ"
)
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
def test_empty_response(self, mock_client_class):
"""Test handling empty response."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame()
# Test
result = get_stk_limit(trade_date="20240625")
# Assert
assert result.empty
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
def test_shared_client(self, mock_client_class):
"""Test passing shared client for rate limiting."""
# Setup mock
shared_client = MagicMock()
shared_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240625"],
"pre_close": [10.0],
"up_limit": [11.0],
"down_limit": [9.0],
}
)
# Test
result = get_stk_limit(trade_date="20240625", client=shared_client)
# Assert
assert not result.empty
shared_client.query.assert_called_once()
# Verify new client was not created
mock_client_class.assert_not_called()
class TestStkLimitSync:
"""Test suite for StkLimitSync class."""
@patch("src.data.api_wrappers.api_stk_limit.TushareClient")
@patch("src.data.api_wrappers.base_sync.Storage")
@patch("src.data.api_wrappers.base_sync.sync_trade_cal_cache")
def test_fetch_single_date(
self, mock_sync_cal, mock_storage_class, mock_client_class
):
"""Test fetch_single_date method."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240625", "20240625"],
"pre_close": [10.0, 20.0],
"up_limit": [11.0, 22.0],
"down_limit": [9.0, 18.0],
}
)
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.exists.return_value = True
mock_storage.load.return_value = pd.DataFrame()
# Test
sync = StkLimitSync()
result = sync.fetch_single_date("20240625")
# Assert
assert not result.empty
assert len(result) == 2
mock_client.query.assert_called_once_with("stk_limit", trade_date="20240625")
def test_table_schema(self):
"""Test table schema definition."""
sync = StkLimitSync()
# Assert table configuration
assert sync.table_name == "stk_limit"
assert "ts_code" in sync.TABLE_SCHEMA
assert "trade_date" in sync.TABLE_SCHEMA
assert "pre_close" in sync.TABLE_SCHEMA
assert "up_limit" in sync.TABLE_SCHEMA
assert "down_limit" in sync.TABLE_SCHEMA
assert sync.PRIMARY_KEY == ("ts_code", "trade_date")
class TestSyncFunctions:
"""Test suite for sync convenience functions."""
@patch.object(StkLimitSync, "sync_all")
def test_sync_stk_limit(self, mock_sync_all):
"""Test sync_stk_limit convenience function."""
# Setup mock
mock_sync_all.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240625"],
"up_limit": [11.0],
"down_limit": [9.0],
}
)
# Test
result = sync_stk_limit(force_full=True)
# Assert
assert not result.empty
mock_sync_all.assert_called_once_with(
force_full=True,
start_date=None,
end_date=None,
dry_run=False,
)
@patch.object(StkLimitSync, "preview_sync")
def test_preview_stk_limit_sync(self, mock_preview):
"""Test preview_stk_limit_sync convenience function."""
# Setup mock
mock_preview.return_value = {
"sync_needed": True,
"date_count": 10,
"start_date": "20240601",
"end_date": "20240610",
"estimated_records": 5000,
"sample_data": pd.DataFrame(),
"mode": "incremental",
}
# Test
result = preview_stk_limit_sync()
# Assert
assert result["sync_needed"] is True
assert result["mode"] == "incremental"
mock_preview.assert_called_once_with(
force_full=False,
start_date=None,
end_date=None,
sample_size=3,
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,143 +0,0 @@
"""Test suite for stock_st API wrapper."""
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
from src.data.api_wrappers.api_stock_st import get_stock_st, sync_stock_st, StockSTSync
class TestStockST:
"""Test suite for stock_st API wrapper."""
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_get_by_date(self, mock_client_class):
"""Test fetching ST stock list by date."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["300313.SZ", "605081.SH", "300391.SZ"],
"name": ["*ST天山", "*ST太和", "*ST长药"],
"trade_date": ["20240101", "20240101", "20240101"],
"type": ["ST", "ST", "ST"],
"type_name": ["风险警示板", "风险警示板", "风险警示板"],
}
)
# Test
result = get_stock_st(trade_date="20240101")
# Assert
assert not result.empty
assert len(result) == 3
assert "ts_code" in result.columns
assert "name" in result.columns
assert "trade_date" in result.columns
assert "type" in result.columns
assert "type_name" in result.columns
mock_client.query.assert_called_once()
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_get_by_stock(self, mock_client_class):
"""Test fetching ST history by stock code."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["300313.SZ", "300313.SZ"],
"name": ["*ST天山", "*ST天山"],
"trade_date": ["20240101", "20240102"],
"type": ["ST", "ST"],
"type_name": ["风险警示板", "风险警示板"],
}
)
# Test
result = get_stock_st(
ts_code="300313.SZ", start_date="20240101", end_date="20240102"
)
# Assert
assert not result.empty
assert len(result) == 2
mock_client.query.assert_called_once()
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_empty_response(self, mock_client_class):
"""Test handling empty response."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame()
result = get_stock_st(trade_date="20240101")
assert result.empty
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_get_by_date_range(self, mock_client_class):
"""Test fetching ST stock list by date range."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["300313.SZ"],
"name": ["*ST天山"],
"trade_date": ["20240101"],
"type": ["ST"],
"type_name": ["风险警示板"],
}
)
# Test
result = get_stock_st(start_date="20240101", end_date="20240131")
# Assert
assert not result.empty
mock_client.query.assert_called_once()
class TestStockSTSync:
"""Test suite for StockSTSync class."""
def test_sync_class_attributes(self):
"""Test that sync class has correct attributes."""
sync = StockSTSync()
assert sync.table_name == "stock_st"
assert sync.default_start_date == "20160101"
assert "ts_code" in sync.TABLE_SCHEMA
assert "trade_date" in sync.TABLE_SCHEMA
assert "name" in sync.TABLE_SCHEMA
assert "type" in sync.TABLE_SCHEMA
assert "type_name" in sync.TABLE_SCHEMA
assert sync.PRIMARY_KEY == ("trade_date", "ts_code")
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_fetch_single_date(self, mock_client_class):
"""Test fetching single date data."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["300313.SZ"],
"name": ["*ST天山"],
"trade_date": ["20240101"],
"type": ["ST"],
"type_name": ["风险警示板"],
}
)
# Test
sync = StockSTSync()
result = sync.fetch_single_date("20240101")
# Assert
assert not result.empty
assert len(result) == 1
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,164 +0,0 @@
"""Sync 接口测试规范与实现。
【测试规范】
1. 所有 sync 测试只使用 2018-01-01 到 2018-04-01 的数据
2. 只测试接口是否能正常返回数据,不测试落库逻辑
3. 对于按股票查询的接口,只测试 000001.SZ、000002.SZ 两支股票
4. 使用真实 API 调用,确保接口可用性
【测试范围】
- get_daily: 日线数据接口(按股票)
- sync_all_stocks: 股票基础信息接口
- sync_trade_cal_cache: 交易日历接口
- sync_namechange: 名称变更接口
- sync_bak_basic: 备用股票基础信息接口
"""
import pytest
import pandas as pd
from datetime import datetime
# 测试用常量
TEST_START_DATE = "20180101"
TEST_END_DATE = "20180401"
TEST_STOCK_CODES = ["000001.SZ", "000002.SZ"]
class TestGetDaily:
"""测试日线数据 get 接口(按股票查询)."""
def test_get_daily_single_stock(self):
"""测试 get_daily 获取单只股票数据."""
from src.data.api_wrappers.api_daily import get_daily
result = get_daily(
ts_code=TEST_STOCK_CODES[0],
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "get_daily 应返回 DataFrame"
assert not result.empty, "get_daily 应返回非空数据"
def test_get_daily_has_required_columns(self):
"""测试 get_daily 返回的数据包含必要字段."""
from src.data.api_wrappers.api_daily import get_daily
result = get_daily(
ts_code=TEST_STOCK_CODES[0],
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
# 验证必要的列存在
required_columns = ["ts_code", "trade_date", "open", "high", "low", "close"]
for col in required_columns:
assert col in result.columns, f"get_daily 返回应包含 {col}"
def test_get_daily_multiple_stocks(self):
"""测试 get_daily 获取多只股票数据."""
from src.data.api_wrappers.api_daily import get_daily
results = {}
for code in TEST_STOCK_CODES:
result = get_daily(
ts_code=code,
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
results[code] = result
assert isinstance(result, pd.DataFrame), (
f"get_daily({code}) 应返回 DataFrame"
)
assert not result.empty, f"get_daily({code}) 应返回非空数据"
class TestSyncStockBasic:
"""测试股票基础信息 sync 接口."""
def test_sync_all_stocks_returns_data(self):
"""测试 sync_all_stocks 是否能正常返回数据."""
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
result = sync_all_stocks()
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_all_stocks 应返回 DataFrame"
assert not result.empty, "sync_all_stocks 应返回非空数据"
def test_sync_all_stocks_has_required_columns(self):
"""测试 sync_all_stocks 返回的数据包含必要字段."""
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
result = sync_all_stocks()
# 验证必要的列存在
required_columns = ["ts_code"]
for col in required_columns:
assert col in result.columns, f"sync_all_stocks 返回应包含 {col}"
class TestSyncTradeCal:
"""测试交易日历 sync 接口."""
def test_sync_trade_cal_cache_returns_data(self):
"""测试 sync_trade_cal_cache 是否能正常返回数据."""
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
result = sync_trade_cal_cache(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_trade_cal_cache 应返回 DataFrame"
assert not result.empty, "sync_trade_cal_cache 应返回非空数据"
def test_sync_trade_cal_cache_has_required_columns(self):
"""测试 sync_trade_cal_cache 返回的数据包含必要字段."""
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
result = sync_trade_cal_cache(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
# 验证必要的列存在
required_columns = ["cal_date", "is_open"]
for col in required_columns:
assert col in result.columns, f"sync_trade_cal_cache 返回应包含 {col}"
class TestSyncNamechange:
"""测试名称变更 sync 接口."""
def test_sync_namechange_returns_data(self):
"""测试 sync_namechange 是否能正常返回数据."""
from src.data.api_wrappers.api_namechange import sync_namechange
result = sync_namechange()
# 验证返回了数据(可能是空 DataFrame因为是历史变更
assert isinstance(result, pd.DataFrame), "sync_namechange 应返回 DataFrame"
class TestSyncBakBasic:
"""测试备用股票基础信息 sync 接口."""
def test_sync_bak_basic_returns_data(self):
"""测试 sync_bak_basic 是否能正常返回数据."""
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
result = sync_bak_basic(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_bak_basic 应返回 DataFrame"
# 注意bak_basic 可能返回空数据,这是正常的
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,20 +0,0 @@
"""Tushare API 验证脚本 - 快速生成 pro 对象用于调试。"""
import os
os.environ.setdefault("DATA_PATH", "data")
from src.data.config import get_config
import tushare as ts
config = get_config()
token = config.tushare_token
if not token:
raise ValueError("请在 config/.env.local 中配置 TUSHARE_TOKEN")
pro = ts.pro_api(token)
print(f"pro_api 对象已创建token: {token[:10]}...")
df = pro.query('daily', ts_code='000001.SZ', start_date='20180702', end_date='20180718')
print(df)