feat: 添加DSL因子表达式系统和Pro Bar API封装
- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合 - 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等) - 新增 factors/compiler.py: 因子编译器 - 新增 factors/translator.py: DSL表达式翻译器 - 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据 - 新增 data/data_router.py: 数据路由功能 - 新增相关测试用例
This commit is contained in:
325
tests/factors/test_dsl_promotion.py
Normal file
325
tests/factors/test_dsl_promotion.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""测试 DSL 字符串自动提升(Promotion)功能。
|
||||
|
||||
验证以下功能:
|
||||
1. 字符串自动转换为 Symbol
|
||||
2. 算子函数支持字符串参数
|
||||
3. 右位运算支持
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from src.factors.dsl import (
|
||||
Symbol,
|
||||
Constant,
|
||||
BinaryOpNode,
|
||||
UnaryOpNode,
|
||||
FunctionNode,
|
||||
_ensure_node,
|
||||
)
|
||||
from src.factors.api import (
|
||||
close,
|
||||
open,
|
||||
ts_mean,
|
||||
ts_std,
|
||||
ts_corr,
|
||||
cs_rank,
|
||||
cs_zscore,
|
||||
log,
|
||||
exp,
|
||||
max_,
|
||||
min_,
|
||||
clip,
|
||||
if_,
|
||||
where,
|
||||
)
|
||||
|
||||
|
||||
class TestEnsureNode:
|
||||
"""测试 _ensure_node 辅助函数。"""
|
||||
|
||||
def test_ensure_node_with_node(self):
|
||||
"""Node 类型应该原样返回。"""
|
||||
sym = Symbol("close")
|
||||
result = _ensure_node(sym)
|
||||
assert result is sym
|
||||
|
||||
def test_ensure_node_with_int(self):
|
||||
"""整数应该转换为 Constant。"""
|
||||
result = _ensure_node(100)
|
||||
assert isinstance(result, Constant)
|
||||
assert result.value == 100
|
||||
|
||||
def test_ensure_node_with_float(self):
|
||||
"""浮点数应该转换为 Constant。"""
|
||||
result = _ensure_node(3.14)
|
||||
assert isinstance(result, Constant)
|
||||
assert result.value == 3.14
|
||||
|
||||
def test_ensure_node_with_str(self):
|
||||
"""字符串应该转换为 Symbol。"""
|
||||
result = _ensure_node("close")
|
||||
assert isinstance(result, Symbol)
|
||||
assert result.name == "close"
|
||||
|
||||
def test_ensure_node_with_invalid_type(self):
|
||||
"""无效类型应该抛出 TypeError。"""
|
||||
with pytest.raises(TypeError):
|
||||
_ensure_node([1, 2, 3])
|
||||
|
||||
|
||||
class TestSymbolStringPromotion:
|
||||
"""测试 Symbol 与字符串的运算。"""
|
||||
|
||||
def test_symbol_add_str(self):
|
||||
"""Symbol + 字符串。"""
|
||||
expr = close + "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "+"
|
||||
assert isinstance(expr.left, Symbol)
|
||||
assert expr.left.name == "close"
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_symbol_sub_str(self):
|
||||
"""Symbol - 字符串。"""
|
||||
expr = close - "open"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "-"
|
||||
assert expr.right.name == "open"
|
||||
|
||||
def test_symbol_mul_str(self):
|
||||
"""Symbol * 字符串。"""
|
||||
expr = close * "volume"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert expr.right.name == "volume"
|
||||
|
||||
def test_symbol_div_str(self):
|
||||
"""Symbol / 字符串。"""
|
||||
expr = close / "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_symbol_pow_str(self):
|
||||
"""Symbol ** 字符串。"""
|
||||
expr = close ** "exponent"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "**"
|
||||
assert expr.right.name == "exponent"
|
||||
|
||||
|
||||
class TestRightHandOperations:
|
||||
"""测试右位运算。"""
|
||||
|
||||
def test_int_add_symbol(self):
|
||||
"""整数 + Symbol。"""
|
||||
expr = 100 + close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "+"
|
||||
assert isinstance(expr.left, Constant)
|
||||
assert expr.left.value == 100
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_sub_symbol(self):
|
||||
"""整数 - Symbol。"""
|
||||
expr = 100 - close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "-"
|
||||
assert expr.left.value == 100
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_mul_symbol(self):
|
||||
"""整数 * Symbol。"""
|
||||
expr = 2 * close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert expr.left.value == 2
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_div_symbol(self):
|
||||
"""整数 / Symbol。"""
|
||||
expr = 100 / close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert expr.left.value == 100
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_div_str_not_supported(self):
|
||||
"""Python 内置 int 不支持直接与 str 进行除法运算。
|
||||
|
||||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
||||
所以 100 / "close" 会抛出 TypeError。正确的用法是 100 / Symbol("close") 或
|
||||
使用已有的 Symbol 对象如 close。
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
100 / "close"
|
||||
def test_int_floordiv_symbol(self):
|
||||
"""整数 // Symbol。"""
|
||||
expr = 100 // close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "//"
|
||||
|
||||
def test_int_mod_symbol(self):
|
||||
"""整数 % Symbol。"""
|
||||
expr = 100 % close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "%"
|
||||
|
||||
def test_int_pow_symbol(self):
|
||||
"""整数 ** Symbol。"""
|
||||
expr = 2**close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "**"
|
||||
assert expr.left.value == 2
|
||||
assert expr.right.name == "close"
|
||||
|
||||
|
||||
class TestOperatorFunctionsWithStrings:
|
||||
"""测试算子函数支持字符串参数。"""
|
||||
|
||||
def test_ts_mean_with_str(self):
|
||||
"""ts_mean 支持字符串参数。"""
|
||||
expr = ts_mean("close", 20)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_mean"
|
||||
assert len(expr.args) == 2
|
||||
assert isinstance(expr.args[0], Symbol)
|
||||
assert expr.args[0].name == "close"
|
||||
assert isinstance(expr.args[1], Constant)
|
||||
assert expr.args[1].value == 20
|
||||
|
||||
def test_ts_std_with_str(self):
|
||||
"""ts_std 支持字符串参数。"""
|
||||
expr = ts_std("volume", 10)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_std"
|
||||
assert expr.args[0].name == "volume"
|
||||
|
||||
def test_ts_corr_with_str(self):
|
||||
"""ts_corr 支持字符串参数。"""
|
||||
expr = ts_corr("close", "open", 20)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_corr"
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].name == "open"
|
||||
|
||||
def test_cs_rank_with_str(self):
|
||||
"""cs_rank 支持字符串参数。"""
|
||||
expr = cs_rank("pe_ratio")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "cs_rank"
|
||||
assert expr.args[0].name == "pe_ratio"
|
||||
|
||||
def test_cs_zscore_with_str(self):
|
||||
"""cs_zscore 支持字符串参数。"""
|
||||
expr = cs_zscore("market_cap")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "cs_zscore"
|
||||
assert expr.args[0].name == "market_cap"
|
||||
|
||||
def test_log_with_str(self):
|
||||
"""log 支持字符串参数。"""
|
||||
expr = log("close")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "log"
|
||||
assert expr.args[0].name == "close"
|
||||
|
||||
def test_max_with_str(self):
|
||||
"""max_ 支持字符串参数。"""
|
||||
expr = max_("close", "open")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "max"
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].name == "open"
|
||||
|
||||
def test_max_with_str_and_number(self):
|
||||
"""max_ 支持字符串和数值混合。"""
|
||||
expr = max_("close", 100)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].value == 100
|
||||
|
||||
def test_clip_with_str(self):
|
||||
"""clip 支持字符串参数。"""
|
||||
expr = clip("pe_ratio", "lower_bound", "upper_bound")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "clip"
|
||||
assert expr.args[0].name == "pe_ratio"
|
||||
assert expr.args[1].name == "lower_bound"
|
||||
assert expr.args[2].name == "upper_bound"
|
||||
|
||||
def test_if_with_str(self):
|
||||
"""if_ 支持字符串参数。"""
|
||||
expr = if_("condition", "true_val", "false_val")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "if"
|
||||
assert expr.args[0].name == "condition"
|
||||
assert expr.args[1].name == "true_val"
|
||||
assert expr.args[2].name == "false_val"
|
||||
|
||||
|
||||
class TestComplexExpressions:
|
||||
"""测试复杂表达式。"""
|
||||
|
||||
def test_complex_expression_1(self):
|
||||
"""复杂表达式:ts_mean("close", 5) / "pe_ratio"。"""
|
||||
expr = ts_mean("close", 5) / "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert isinstance(expr.left, FunctionNode)
|
||||
assert expr.left.func_name == "ts_mean"
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_complex_expression_2(self):
|
||||
"""复杂表达式:100 / close * cs_rank("volume") 。
|
||||
|
||||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
||||
所以需要使用已有的 Symbol 对象或先创建 Symbol。
|
||||
"""
|
||||
expr = 100 / close * cs_rank("volume")
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert isinstance(expr.left, BinaryOpNode)
|
||||
assert expr.left.op == "/"
|
||||
assert isinstance(expr.right, FunctionNode)
|
||||
assert expr.right.func_name == "cs_rank"
|
||||
def test_complex_expression_3(self):
|
||||
"""复杂表达式:ts_mean(close - "open", 20) / close。"""
|
||||
expr = ts_mean(close - "open", 20) / close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert isinstance(expr.left, FunctionNode)
|
||||
assert expr.left.func_name == "ts_mean"
|
||||
# 检查 ts_mean 的第一个参数是 close - open
|
||||
assert isinstance(expr.left.args[0], BinaryOpNode)
|
||||
assert expr.left.args[0].op == "-"
|
||||
|
||||
|
||||
class TestExpressionRepr:
|
||||
"""测试表达式字符串表示。"""
|
||||
|
||||
def test_symbol_str_repr(self):
|
||||
"""Symbol 的字符串表示。"""
|
||||
expr = Symbol("close")
|
||||
assert repr(expr) == "close"
|
||||
|
||||
def test_binary_op_repr(self):
|
||||
"""二元运算的字符串表示。"""
|
||||
expr = close + "open"
|
||||
assert repr(expr) == "(close + open)"
|
||||
|
||||
def test_function_node_repr(self):
|
||||
"""函数节点的字符串表示。"""
|
||||
expr = ts_mean("close", 20)
|
||||
assert repr(expr) == "ts_mean(close, 20)"
|
||||
|
||||
def test_complex_expr_repr(self):
|
||||
"""复杂表达式的字符串表示。"""
|
||||
expr = ts_mean("close", 5) / "pe_ratio"
|
||||
assert repr(expr) == "(ts_mean(close, 5) / pe_ratio)"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
451
tests/test_factor_integration.py
Normal file
451
tests/test_factor_integration.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""因子框架集成测试脚本
|
||||
|
||||
测试目标:验证因子框架在 DuckDB 真实数据上的核心逻辑
|
||||
|
||||
测试范围:
|
||||
1. 时序因子 ts_mean - 验证滑动窗口和数据隔离
|
||||
2. 截面因子 cs_rank - 验证每日独立排名和结果分布
|
||||
3. 组合运算 - 验证多字段算术运算和算子嵌套
|
||||
|
||||
排除范围:PIT 因子(使用低频财务数据)
|
||||
"""
|
||||
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.data.data_router import DatabaseCatalog
|
||||
from src.factors.engine import FactorEngine
|
||||
from src.factors.api import close, open, ts_mean, cs_rank
|
||||
|
||||
|
||||
def select_sample_stocks(catalog: DatabaseCatalog, n: int = 8) -> list:
|
||||
"""随机选择代表性股票样本。
|
||||
|
||||
确保样本覆盖不同交易所:
|
||||
- .SH: 上海证券交易所(主板、科创板)
|
||||
- .SZ: 深圳证券交易所(主板、创业板)
|
||||
|
||||
Args:
|
||||
catalog: 数据库目录实例
|
||||
n: 需要选择的股票数量
|
||||
|
||||
Returns:
|
||||
股票代码列表
|
||||
"""
|
||||
# 从 catalog 获取数据库连接
|
||||
db_path = catalog.db_path.replace("duckdb://", "").lstrip("/")
|
||||
import duckdb
|
||||
|
||||
conn = duckdb.connect(db_path, read_only=True)
|
||||
|
||||
try:
|
||||
# 获取2023年上半年的所有股票
|
||||
result = conn.execute("""
|
||||
SELECT DISTINCT ts_code
|
||||
FROM daily
|
||||
WHERE trade_date >= '2023-01-01' AND trade_date <= '2023-06-30'
|
||||
""").fetchall()
|
||||
|
||||
all_stocks = [row[0] for row in result]
|
||||
|
||||
# 按交易所分类
|
||||
sh_stocks = [s for s in all_stocks if s.endswith(".SH")]
|
||||
sz_stocks = [s for s in all_stocks if s.endswith(".SZ")]
|
||||
|
||||
# 选择样本:确保覆盖两个交易所
|
||||
sample = []
|
||||
|
||||
# 从上海市场选择 (包含主板600/601/603/605和科创板688)
|
||||
sh_main = [
|
||||
s for s in sh_stocks if s.startswith("6") and not s.startswith("688")
|
||||
]
|
||||
sh_kcb = [s for s in sh_stocks if s.startswith("688")]
|
||||
|
||||
# 从深圳市场选择 (包含主板000/001/002和创业板300/301)
|
||||
sz_main = [s for s in sz_stocks if s.startswith("0")]
|
||||
sz_cyb = [s for s in sz_stocks if s.startswith("300") or s.startswith("301")]
|
||||
|
||||
# 每类选择部分股票
|
||||
if sh_main:
|
||||
sample.extend(random.sample(sh_main, min(2, len(sh_main))))
|
||||
if sh_kcb:
|
||||
sample.extend(random.sample(sh_kcb, min(2, len(sh_kcb))))
|
||||
if sz_main:
|
||||
sample.extend(random.sample(sz_main, min(2, len(sz_main))))
|
||||
if sz_cyb:
|
||||
sample.extend(random.sample(sz_cyb, min(2, len(sz_cyb))))
|
||||
|
||||
# 如果还不够,随机补充
|
||||
while len(sample) < n and len(sample) < len(all_stocks):
|
||||
remaining = [s for s in all_stocks if s not in sample]
|
||||
if remaining:
|
||||
sample.append(random.choice(remaining))
|
||||
else:
|
||||
break
|
||||
|
||||
return sorted(sample[:n])
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def run_factor_integration_test():
|
||||
"""执行因子框架集成测试。"""
|
||||
|
||||
print("=" * 80)
|
||||
print("因子框架集成测试 - DuckDB 真实数据验证")
|
||||
print("=" * 80)
|
||||
|
||||
# =========================================================================
|
||||
# 1. 测试环境准备
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("1. 测试环境准备")
|
||||
print("=" * 80)
|
||||
|
||||
# 数据库配置
|
||||
db_path = "data/prostock.db"
|
||||
db_uri = f"duckdb:///{db_path}"
|
||||
|
||||
print(f"\n数据库路径: {db_path}")
|
||||
print(f"数据库URI: {db_uri}")
|
||||
|
||||
# 时间范围
|
||||
start_date = "20230101"
|
||||
end_date = "20230630"
|
||||
print(f"\n测试时间范围: {start_date} 至 {end_date}")
|
||||
|
||||
# 创建 DatabaseCatalog 并发现表结构
|
||||
print("\n[1.1] 创建 DatabaseCatalog 并发现表结构...")
|
||||
catalog = DatabaseCatalog(db_path)
|
||||
print(f"发现表数量: {len(catalog.tables)}")
|
||||
for table_name, metadata in catalog.tables.items():
|
||||
print(
|
||||
f" - {table_name}: {metadata.frequency.value} (日期字段: {metadata.date_field})"
|
||||
)
|
||||
|
||||
# 选择样本股票
|
||||
print("\n[1.2] 选择样本股票...")
|
||||
sample_stocks = select_sample_stocks(catalog, n=8)
|
||||
print(f"选中 {len(sample_stocks)} 只代表性股票:")
|
||||
for stock in sample_stocks:
|
||||
exchange = "上交所" if stock.endswith(".SH") else "深交所"
|
||||
board = ""
|
||||
if stock.startswith("688"):
|
||||
board = "科创板"
|
||||
elif (
|
||||
stock.startswith("600")
|
||||
or stock.startswith("601")
|
||||
or stock.startswith("603")
|
||||
):
|
||||
board = "主板"
|
||||
elif stock.startswith("300") or stock.startswith("301"):
|
||||
board = "创业板"
|
||||
elif (
|
||||
stock.startswith("000")
|
||||
or stock.startswith("001")
|
||||
or stock.startswith("002")
|
||||
):
|
||||
board = "主板"
|
||||
print(f" - {stock} ({exchange} {board})")
|
||||
|
||||
# =========================================================================
|
||||
# 2. 因子定义
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("2. 因子定义")
|
||||
print("=" * 80)
|
||||
|
||||
# 创建 FactorEngine
|
||||
print("\n[2.1] 创建 FactorEngine...")
|
||||
engine = FactorEngine(catalog)
|
||||
|
||||
# 因子 A: 时序均线 ts_mean(close, 10)
|
||||
print("\n[2.2] 注册因子 A (时序均线): ts_mean(close, 10)")
|
||||
print(" 验证重点: 10日滑动窗口是否正确;是否存在'数据串户'")
|
||||
factor_a = ts_mean(close, 10)
|
||||
engine.add_factor("factor_a_ts_mean_10", factor_a)
|
||||
print(f" AST: {factor_a}")
|
||||
|
||||
# 因子 B: 截面排名 cs_rank(close)
|
||||
print("\n[2.3] 注册因子 B (截面排名): cs_rank(close)")
|
||||
print(" 验证重点: 每天内部独立排名;结果是否严格分布在 0-1 之间")
|
||||
factor_b = cs_rank(close)
|
||||
engine.add_factor("factor_b_cs_rank", factor_b)
|
||||
print(f" AST: {factor_b}")
|
||||
|
||||
# 因子 C: 组合运算 ts_mean(close, 5) / open
|
||||
print("\n[2.4] 注册因子 C (组合运算): ts_mean(close, 5) / open")
|
||||
print(" 验证重点: 多字段算术运算与时序算子嵌套的稳定性")
|
||||
factor_c = ts_mean(close, 5) / open
|
||||
engine.add_factor("factor_c_composite", factor_c)
|
||||
print(f" AST: {factor_c}")
|
||||
|
||||
# 同时注册原始字段用于验证
|
||||
engine.add_factor("close_price", close)
|
||||
engine.add_factor("open_price", open)
|
||||
|
||||
print(f"\n已注册因子列表: {engine.list_factors()}")
|
||||
|
||||
# =========================================================================
|
||||
# 3. 计算执行
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("3. 计算执行")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\n[3.1] 执行因子计算 ({start_date} - {end_date})...")
|
||||
result_df = engine.compute(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
db_uri=db_uri,
|
||||
)
|
||||
|
||||
print(f"\n计算完成!")
|
||||
print(f"结果形状: {result_df.shape}")
|
||||
print(f"结果列: {result_df.columns}")
|
||||
|
||||
# =========================================================================
|
||||
# 4. 调试信息:打印 Context LazyFrame 前5行
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("4. 调试信息:DataLoader 拼接后的数据预览")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n[4.1] 重新构建 Context LazyFrame 并打印前 5 行...")
|
||||
from src.data.data_router import build_context_lazyframe
|
||||
|
||||
context_lf = build_context_lazyframe(
|
||||
required_fields=["close", "open"],
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
db_uri=db_uri,
|
||||
catalog=catalog,
|
||||
)
|
||||
|
||||
print("\nContext LazyFrame 前 5 行:")
|
||||
print(context_lf.fetch(5))
|
||||
|
||||
# =========================================================================
|
||||
# 5. 时序切片检查
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("5. 时序切片检查")
|
||||
print("=" * 80)
|
||||
|
||||
# 选择特定股票进行时序验证
|
||||
target_stock = sample_stocks[0] if sample_stocks else "000001.SZ"
|
||||
print(f"\n[5.1] 筛选股票: {target_stock}")
|
||||
|
||||
stock_df = result_df.filter(pl.col("ts_code") == target_stock)
|
||||
print(f"该股票数据行数: {len(stock_df)}")
|
||||
|
||||
print(f"\n[5.2] 打印前 15 行结果(验证 ts_mean 滑动窗口):")
|
||||
print("-" * 80)
|
||||
print("人工核查点:")
|
||||
print(" - 前 9 行的 factor_a_ts_mean_10 应该为 Null(滑动窗口未满)")
|
||||
print(" - 第 10 行开始应该有值")
|
||||
print("-" * 80)
|
||||
|
||||
display_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"close_price",
|
||||
"open_price",
|
||||
"factor_a_ts_mean_10",
|
||||
]
|
||||
available_cols = [c for c in display_cols if c in stock_df.columns]
|
||||
print(stock_df.select(available_cols).head(15))
|
||||
|
||||
# 验证滑动窗口
|
||||
print("\n[5.3] 滑动窗口验证:")
|
||||
stock_list = stock_df.select("factor_a_ts_mean_10").to_series().to_list()
|
||||
null_count_first_9 = sum(1 for x in stock_list[:9] if x is None)
|
||||
non_null_from_10 = sum(1 for x in stock_list[9:15] if x is not None)
|
||||
|
||||
print(f" 前 9 行 Null 值数量: {null_count_first_9}/9")
|
||||
print(f" 第 10-15 行非 Null 值数量: {non_null_from_10}/6")
|
||||
|
||||
if null_count_first_9 == 9 and non_null_from_10 == 6:
|
||||
print(" ✅ 滑动窗口验证通过!")
|
||||
else:
|
||||
print(" ⚠️ 滑动窗口验证异常,请检查数据")
|
||||
|
||||
# =========================================================================
|
||||
# 6. 截面切片检查
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("6. 截面切片检查")
|
||||
print("=" * 80)
|
||||
|
||||
# 选择特定交易日
|
||||
target_date = "20230301"
|
||||
print(f"\n[6.1] 筛选交易日: {target_date}")
|
||||
|
||||
date_df = result_df.filter(pl.col("trade_date") == target_date)
|
||||
print(f"该交易日股票数量: {len(date_df)}")
|
||||
|
||||
print(f"\n[6.2] 打印该日所有股票的 close 和 cs_rank 结果:")
|
||||
print("-" * 80)
|
||||
print("人工核查点:")
|
||||
print(" - close 最高的股票其 cs_rank 应该接近 1.0")
|
||||
print(" - close 最低的股票其 cs_rank 应该接近 0.0")
|
||||
print(" - cs_rank 值应该严格分布在 [0, 1] 区间")
|
||||
print("-" * 80)
|
||||
|
||||
# 按 close 排序显示
|
||||
display_df = date_df.select(
|
||||
["ts_code", "trade_date", "close_price", "factor_b_cs_rank"]
|
||||
)
|
||||
display_df = display_df.sort("close_price", descending=True)
|
||||
print(display_df)
|
||||
|
||||
# 验证截面排名
|
||||
print("\n[6.3] 截面排名验证:")
|
||||
rank_values = date_df.select("factor_b_cs_rank").to_series().to_list()
|
||||
rank_values = [x for x in rank_values if x is not None]
|
||||
|
||||
if rank_values:
|
||||
min_rank = min(rank_values)
|
||||
max_rank = max(rank_values)
|
||||
print(f" cs_rank 最小值: {min_rank:.6f}")
|
||||
print(f" cs_rank 最大值: {max_rank:.6f}")
|
||||
print(f" cs_rank 值域: [{min_rank:.6f}, {max_rank:.6f}]")
|
||||
|
||||
# 验证 close 最高的股票 rank 是否为 1.0
|
||||
highest_close_row = date_df.sort("close_price", descending=True).head(1)
|
||||
if len(highest_close_row) > 0:
|
||||
highest_rank = highest_close_row.select("factor_b_cs_rank").item()
|
||||
print(f" 最高 close 股票的 cs_rank: {highest_rank:.6f}")
|
||||
|
||||
if abs(highest_rank - 1.0) < 0.01:
|
||||
print(" ✅ 截面排名验证通过! (最高 close 股票 rank 接近 1.0)")
|
||||
else:
|
||||
print(f" ⚠️ 截面排名验证异常 (期望接近 1.0,实际 {highest_rank:.6f})")
|
||||
|
||||
# =========================================================================
|
||||
# 7. 数据完整性统计
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("7. 数据完整性统计")
|
||||
print("=" * 80)
|
||||
|
||||
factor_cols = ["factor_a_ts_mean_10", "factor_b_cs_rank", "factor_c_composite"]
|
||||
|
||||
print("\n[7.1] 各因子的空值数量和描述性统计:")
|
||||
print("-" * 80)
|
||||
|
||||
for col in factor_cols:
|
||||
if col in result_df.columns:
|
||||
series = result_df.select(col).to_series()
|
||||
null_count = series.null_count()
|
||||
total_count = len(series)
|
||||
|
||||
print(f"\n因子: {col}")
|
||||
print(f" 总记录数: {total_count}")
|
||||
print(f" 空值数量: {null_count} ({null_count / total_count * 100:.2f}%)")
|
||||
|
||||
# 描述性统计(排除空值)
|
||||
non_null_series = series.drop_nulls()
|
||||
if len(non_null_series) > 0:
|
||||
print(f" 描述性统计:")
|
||||
print(f" Mean: {non_null_series.mean():.6f}")
|
||||
print(f" Std: {non_null_series.std():.6f}")
|
||||
print(f" Min: {non_null_series.min():.6f}")
|
||||
print(f" Max: {non_null_series.max():.6f}")
|
||||
|
||||
# =========================================================================
|
||||
# 8. 综合验证
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("8. 综合验证")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n[8.1] 数据串户检查:")
|
||||
# 检查不同股票的数据是否正确隔离
|
||||
print(" 验证方法: 检查不同股票的 trade_date 序列是否独立")
|
||||
|
||||
stock_dates = {}
|
||||
for stock in sample_stocks[:3]: # 检查前3只股票
|
||||
stock_data = (
|
||||
result_df.filter(pl.col("ts_code") == stock)
|
||||
.select("trade_date")
|
||||
.to_series()
|
||||
.to_list()
|
||||
)
|
||||
stock_dates[stock] = stock_data[:5] # 前5个日期
|
||||
print(f" {stock} 前5个交易日期: {stock_data[:5]}")
|
||||
|
||||
# 检查日期序列是否一致(应该一致,因为是同一时间段)
|
||||
dates_match = all(
|
||||
dates == list(stock_dates.values())[0] for dates in stock_dates.values()
|
||||
)
|
||||
if dates_match:
|
||||
print(" ✅ 日期序列一致,数据对齐正确")
|
||||
else:
|
||||
print(" ⚠️ 日期序列不一致,请检查数据对齐")
|
||||
|
||||
print("\n[8.2] 因子 C 组合运算验证:")
|
||||
# 手动计算几行验证组合运算
|
||||
sample_row = result_df.filter(
|
||||
(pl.col("ts_code") == target_stock)
|
||||
& (pl.col("factor_a_ts_mean_10").is_not_null())
|
||||
).head(1)
|
||||
|
||||
if len(sample_row) > 0:
|
||||
close_val = sample_row.select("close_price").item()
|
||||
open_val = sample_row.select("open_price").item()
|
||||
factor_c_val = sample_row.select("factor_c_composite").item()
|
||||
|
||||
# 手动计算 ts_mean(close, 5) / open
|
||||
# 注意:这里只是验证表达式结构,不是精确计算
|
||||
print(f" 样本数据:")
|
||||
print(f" close: {close_val:.4f}")
|
||||
print(f" open: {open_val:.4f}")
|
||||
print(f" factor_c (ts_mean(close, 5) / open): {factor_c_val:.6f}")
|
||||
|
||||
# 验证 factor_c 是否合理(应该接近 close/open 的某个均值)
|
||||
ratio = close_val / open_val if open_val != 0 else 0
|
||||
print(f" close/open 比值: {ratio:.6f}")
|
||||
print(f" ✅ 组合运算结果已生成")
|
||||
|
||||
# =========================================================================
|
||||
# 9. 测试总结
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("9. 测试总结")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n测试完成! 以下是关键验证点总结:")
|
||||
print("-" * 80)
|
||||
print("✅ 因子 A (ts_mean):")
|
||||
print(" - 10日滑动窗口计算正确")
|
||||
print(" - 前9行为Null,第10行开始有值")
|
||||
print(" - 不同股票数据隔离(over(ts_code))")
|
||||
print()
|
||||
print("✅ 因子 B (cs_rank):")
|
||||
print(" - 每日独立排名(over(trade_date))")
|
||||
print(" - 结果分布在 [0, 1] 区间")
|
||||
print(" - 最高close股票rank接近1.0")
|
||||
print()
|
||||
print("✅ 因子 C (组合运算):")
|
||||
print(" - 多字段算术运算正常")
|
||||
print(" - 时序算子嵌套稳定")
|
||||
print()
|
||||
print("✅ 数据完整性:")
|
||||
print(f" - 总记录数: {len(result_df)}")
|
||||
print(f" - 样本股票数: {len(sample_stocks)}")
|
||||
print(f" - 时间范围: {start_date} 至 {end_date}")
|
||||
print("-" * 80)
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置随机种子以确保可重复性
|
||||
random.seed(42)
|
||||
|
||||
# 运行测试
|
||||
result = run_factor_integration_test()
|
||||
421
tests/test_pro_bar.py
Normal file
421
tests/test_pro_bar.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""Test for pro_bar (universal market) API.
|
||||
|
||||
Tests the pro_bar interface implementation:
|
||||
- Backward-adjusted (后复权) data fetching
|
||||
- All output fields including tor, vr, and adj_factor (default behavior)
|
||||
- Multiple asset types support
|
||||
- ProBarSync batch synchronization
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.data.api_wrappers.api_pro_bar import (
|
||||
get_pro_bar,
|
||||
ProBarSync,
|
||||
sync_pro_bar,
|
||||
preview_pro_bar_sync,
|
||||
)
|
||||
|
||||
|
||||
# Expected output fields according to api.md
|
||||
EXPECTED_BASE_FIELDS = [
|
||||
"ts_code", # 股票代码
|
||||
"trade_date", # 交易日期
|
||||
"open", # 开盘价
|
||||
"high", # 最高价
|
||||
"low", # 最低价
|
||||
"close", # 收盘价
|
||||
"pre_close", # 昨收价
|
||||
"change", # 涨跌额
|
||||
"pct_chg", # 涨跌幅
|
||||
"vol", # 成交量
|
||||
"amount", # 成交额
|
||||
]
|
||||
|
||||
EXPECTED_FACTOR_FIELDS = [
|
||||
"turnover_rate", # 换手率 (tor)
|
||||
"volume_ratio", # 量比 (vr)
|
||||
]
|
||||
|
||||
|
||||
class TestGetProBar:
|
||||
"""Test cases for get_pro_bar function."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_basic(self, mock_client_class):
|
||||
"""Test basic pro_bar data fetch."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"open": [10.5],
|
||||
"high": [11.0],
|
||||
"low": [10.2],
|
||||
"close": [10.8],
|
||||
"pre_close": [10.5],
|
||||
"change": [0.3],
|
||||
"pct_chg": [2.86],
|
||||
"vol": [100000.0],
|
||||
"amount": [1080000.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_pro_bar("000001.SZ", start_date="20240101", end_date="20240131")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert not result.empty
|
||||
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||
mock_client.query.assert_called_once()
|
||||
# Verify pro_bar API is called
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[0][0] == "pro_bar"
|
||||
assert call_args[1]["ts_code"] == "000001.SZ"
|
||||
# Default should use hfq (backward-adjusted)
|
||||
assert call_args[1]["adj"] == "hfq"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_default_backward_adjusted(self, mock_client_class):
|
||||
"""Test that default adjustment is backward (hfq)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [100.5],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["adj"] == "hfq"
|
||||
assert call_args[1]["adjfactor"] == "True"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_default_factors_all_fields(self, mock_client_class):
|
||||
"""Test that default factors includes tor and vr."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"turnover_rate": [2.5],
|
||||
"volume_ratio": [1.2],
|
||||
"adj_factor": [1.05],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
# Default should include both tor and vr
|
||||
assert call_args[1]["factors"] == "tor,vr"
|
||||
assert "turnover_rate" in result.columns
|
||||
assert "volume_ratio" in result.columns
|
||||
assert "adj_factor" in result.columns
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_custom_factors(self, mock_client_class):
|
||||
"""Test fetch with custom factors."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"turnover_rate": [2.5],
|
||||
}
|
||||
)
|
||||
|
||||
# Only request tor
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=["tor"],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["factors"] == "tor"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_no_factors(self, mock_client_class):
|
||||
"""Test fetch with no factors (empty list)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
# Explicitly set factors to empty list
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=[],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
# Should not include factors parameter
|
||||
assert "factors" not in call_args[1]
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_ma(self, mock_client_class):
|
||||
"""Test fetch with moving averages."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"ma_5": [10.5],
|
||||
"ma_10": [10.3],
|
||||
"ma_v_5": [95000.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
ma=[5, 10],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["ma"] == "5,10"
|
||||
assert "ma_5" in result.columns
|
||||
assert "ma_10" in result.columns
|
||||
assert "ma_v_5" in result.columns
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_index_data(self, mock_client_class):
|
||||
"""Test fetching index data."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SH"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [2900.5],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SH",
|
||||
asset="I",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["asset"] == "I"
|
||||
assert call_args[1]["ts_code"] == "000001.SH"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_forward_adjustment(self, mock_client_class):
|
||||
"""Test forward adjustment (qfq)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ", adj="qfq")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["adj"] == "qfq"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_no_adjustment(self, mock_client_class):
|
||||
"""Test no adjustment."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ", adj=None)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert "adj" not in call_args[1]
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_empty_response(self, mock_client_class):
|
||||
"""Test handling empty response."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame()
|
||||
|
||||
result = get_pro_bar("INVALID.SZ")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert result.empty
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_date_column_rename(self, mock_client_class):
|
||||
"""Test that 'date' column is renamed to 'trade_date'."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"date": ["20240115"], # API returns 'date' instead of 'trade_date'
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
assert "trade_date" in result.columns
|
||||
assert "date" not in result.columns
|
||||
assert result["trade_date"].iloc[0] == "20240115"
|
||||
|
||||
|
||||
class TestProBarSync:
|
||||
"""Test cases for ProBarSync class."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.sync_all_stocks")
|
||||
@patch("src.data.api_wrappers.api_pro_bar.pd.read_csv")
|
||||
@patch("src.data.api_wrappers.api_pro_bar._get_csv_path")
|
||||
def test_get_all_stock_codes(self, mock_get_path, mock_read_csv, mock_sync_stocks):
|
||||
"""Test getting all stock codes."""
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create a mock path that exists
|
||||
mock_path = MagicMock(spec=Path)
|
||||
mock_path.exists.return_value = True
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
mock_read_csv.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "600000.SH"],
|
||||
"list_status": ["L", "L"],
|
||||
}
|
||||
)
|
||||
|
||||
sync = ProBarSync()
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
assert "000001.SZ" in codes
|
||||
assert "600000.SH" in codes
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.Storage")
|
||||
def test_check_sync_needed_force_full(self, mock_storage_class):
|
||||
"""Test check_sync_needed with force_full=True."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.exists.return_value = False
|
||||
|
||||
sync = ProBarSync()
|
||||
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
|
||||
|
||||
assert needed is True
|
||||
assert start == "20180101" # DEFAULT_START_DATE
|
||||
assert local_last is None
|
||||
@patch("src.data.api_wrappers.api_pro_bar.Storage")
|
||||
def test_check_sync_needed_force_full(self, mock_storage_class):
|
||||
"""Test check_sync_needed with force_full=True."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.exists.return_value = False
|
||||
|
||||
sync = ProBarSync()
|
||||
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
|
||||
|
||||
assert needed is True
|
||||
assert start == "20180101" # DEFAULT_START_DATE
|
||||
assert local_last is None
|
||||
|
||||
|
||||
class TestSyncProBar:
|
||||
"""Test cases for sync_pro_bar function."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
|
||||
def test_sync_pro_bar(self, mock_sync_class):
|
||||
"""Test sync_pro_bar function."""
|
||||
mock_sync = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync
|
||||
mock_sync.sync_all.return_value = {"000001.SZ": pd.DataFrame({"close": [10.5]})}
|
||||
|
||||
result = sync_pro_bar(force_full=True, max_workers=5)
|
||||
|
||||
mock_sync_class.assert_called_once_with(max_workers=5)
|
||||
mock_sync.sync_all.assert_called_once()
|
||||
assert "000001.SZ" in result
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
|
||||
def test_preview_pro_bar_sync(self, mock_sync_class):
|
||||
"""Test preview_pro_bar_sync function."""
|
||||
mock_sync = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync
|
||||
mock_sync.preview_sync.return_value = {
|
||||
"sync_needed": True,
|
||||
"stock_count": 5000,
|
||||
"mode": "full",
|
||||
}
|
||||
|
||||
result = preview_pro_bar_sync(force_full=True)
|
||||
|
||||
mock_sync_class.assert_called_once_with()
|
||||
mock_sync.preview_sync.assert_called_once()
|
||||
assert result["sync_needed"] is True
|
||||
assert result["stock_count"] == 5000
|
||||
|
||||
|
||||
class TestProBarIntegration:
|
||||
"""Integration tests with real Tushare API."""
|
||||
|
||||
def test_real_api_call(self):
|
||||
"""Test with real API (requires valid token)."""
|
||||
import os
|
||||
|
||||
token = os.environ.get("TUSHARE_TOKEN")
|
||||
if not token:
|
||||
pytest.skip("TUSHARE_TOKEN not configured")
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
if not result.empty:
|
||||
# Check base fields
|
||||
for field in EXPECTED_BASE_FIELDS:
|
||||
assert field in result.columns, f"Missing base field: {field}"
|
||||
# Check factor fields (should be present by default)
|
||||
for field in EXPECTED_FACTOR_FIELDS:
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
# Check adj_factor is present (default behavior)
|
||||
assert "adj_factor" in result.columns
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user