From f943cc98d0fc1ec2698eddbacb8057c7f759c529 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Sun, 15 Mar 2026 18:00:48 +0800 Subject: [PATCH] =?UTF-8?q?feat(factors):=20=E6=B7=BB=E5=8A=A0=20cs=5Fmean?= =?UTF-8?q?=20=E5=87=BD=E6=95=B0=E5=B9=B6=E5=A2=9E=E5=BC=BA=20max=5F/min?= =?UTF-8?q?=5F=20=E5=8D=95=E5=8F=82=E6=95=B0=E6=94=AF=E6=8C=81=20-=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=20cs=5Fmean=20=E6=88=AA=E9=9D=A2=E5=9D=87?= =?UTF-8?q?=E5=80=BC=E5=87=BD=E6=95=B0=EF=BC=8C=E6=94=AF=E6=8C=81=20GTJA?= =?UTF-8?q?=20Alpha127=20=E7=AD=89=E5=9B=A0=E5=AD=90=E8=BD=AC=E6=8D=A2=20-?= =?UTF-8?q?=20max=5F/min=5F=20=E6=94=AF=E6=8C=81=E5=8D=95=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E8=B0=83=E7=94=A8=EF=BC=8C=E9=BB=98=E8=AE=A4=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=20252=20=E5=A4=A9=EF=BC=88=E7=BA=A6=201=20=E5=B9=B4?= =?UTF-8?q?=EF=BC=89=E6=BB=9A=E5=8A=A8=E7=AA=97=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/experiment/common.py | 4 +- src/factors/api.py | 48 +- src/factors/translator.py | 13 + src/scripts/GtjaConvertor/__init__.py | 26 + src/scripts/GtjaConvertor/converter.py | 849 ++++++++++++++++++++++ src/scripts/GtjaConvertor/preprocessor.py | 273 +++++++ tests/test_601117_factors.py | 350 --------- tests/test_ast_optimizer.py | 367 ---------- tests/test_bugfixes.py | 144 ---- tests/test_db_manager.py | 377 ---------- tests/test_factor_engine.py | 160 ---- tests/test_factor_engine_metadata.py | 106 --- tests/test_factor_integration.py | 451 ------------ tests/test_financial_price_merge.py | 351 --------- tests/test_new_ts_functions.py | 130 ---- tests/test_phase1_2_factors.py | 541 -------------- tests/test_pro_bar.py | 421 ----------- tests/test_stk_limit.py | 246 ------- tests/test_stock_st.py | 143 ---- tests/test_sync.py | 164 ----- tests/test_tushare_api.py | 20 - 21 files changed, 1204 insertions(+), 3980 deletions(-) create mode 100644 src/scripts/GtjaConvertor/__init__.py create mode 100644 src/scripts/GtjaConvertor/converter.py create mode 100644 src/scripts/GtjaConvertor/preprocessor.py delete mode 100644 tests/test_601117_factors.py delete mode 100644 tests/test_ast_optimizer.py delete mode 100644 tests/test_bugfixes.py delete mode 100644 tests/test_db_manager.py delete mode 100644 tests/test_factor_engine.py delete mode 100644 tests/test_factor_engine_metadata.py delete mode 100644 tests/test_factor_integration.py delete mode 100644 tests/test_financial_price_merge.py delete mode 100644 tests/test_new_ts_functions.py delete mode 100644 tests/test_phase1_2_factors.py delete mode 100644 tests/test_pro_bar.py delete mode 100644 tests/test_stk_limit.py delete mode 100644 tests/test_stock_st.py delete mode 100644 tests/test_sync.py delete mode 100644 tests/test_tushare_api.py diff --git a/src/experiment/common.py b/src/experiment/common.py index 621a16f..39b8114 100644 --- a/src/experiment/common.py +++ b/src/experiment/common.py @@ -86,7 +86,9 @@ SELECTED_FACTORS = [ ] # 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子) -FACTOR_DEFINITIONS = {} +FACTOR_DEFINITIONS = { + 'test': '[([(col("close")) - (col("close").shift([dyn int: 5]).over([col("ts_code")]))]) / (col("close").shift([dyn int: 5]).over([col("ts_code")]))]' +} def get_label_factor(label_name: str) -> dict: diff --git a/src/factors/api.py b/src/factors/api.py index 99bb229..eed5dd5 100644 --- a/src/factors/api.py +++ b/src/factors/api.py @@ -418,6 +418,26 @@ def cs_demean(x: Union[Node, str]) -> FunctionNode: return FunctionNode("cs_demean", x) +def cs_mean(x: Union[Node, str]) -> FunctionNode: + """截面均值。 + + 计算因子在横截面上的平均值。 + + Args: + x: 输入因子表达式或字段名字符串 + + Returns: + FunctionNode: 函数调用节点 + + Example: + >>> from src.factors.api import close, cs_mean + >>> expr = cs_mean((close - 100) ** 2) + >>> print(expr) + cs_mean(((close - 100) ** 2)) + """ + return FunctionNode("cs_mean", x) + + # ==================== 数学函数 ==================== @@ -507,41 +527,53 @@ def abs(x: Union[Node, str]) -> FunctionNode: return FunctionNode("abs", x) -def max_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode: - """逐元素最大值。 +def max_( + x: Union[Node, str], y: Union[Node, str, int, float, None] = None +) -> FunctionNode: + """最大值。 智能分发逻辑: + - 单参数:调用 ts_max(x, 252) 计算滚动窗口最大值(默认 252 天≈1年) - 如果 y 是正整数 (y > 0),调用 ts_max(x, y) 滚动窗口最大值 - 否则,调用逐元素 max(x, y) 注意:避免 MAX(CLOSE - DELAY(CLOSE, 1), 0) 这类场景被错误路由到 ts_max Args: - x: 第一个因子表达式或字段名字符串 - y: 第二个因子表达式、字段名字符串或正整数(窗口大小) + x: 第一个因子表达式或字段名字符串,或单参数时的输入序列 + y: 可选,第二个因子表达式、字段名字符串或正整数(窗口大小) Returns: FunctionNode: 函数调用节点 """ + if y is None: + # 单参数:默认使用 252 天(约 1 年交易日)窗口 + return ts_max(x, 252) if isinstance(y, int) and y > 0: return ts_max(x, y) return FunctionNode("max", x, _ensure_node(y)) -def min_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode: - """逐元素最小值。 +def min_( + x: Union[Node, str], y: Union[Node, str, int, float, None] = None +) -> FunctionNode: + """最小值。 智能分发逻辑: + - 单参数:调用 ts_min(x, 252) 计算滚动窗口最小值(默认 252 天≈1年) - 如果 y 是正整数 (y > 0),调用 ts_min(x, y) 滚动窗口最小值 - 否则,调用逐元素 min(x, y) Args: - x: 第一个因子表达式或字段名字符串 - y: 第二个因子表达式、字段名字符串或正整数(窗口大小) + x: 第一个因子表达式或字段名字符串,或单参数时的输入序列 + y: 可选,第二个因子表达式、字段名字符串或正整数(窗口大小) Returns: FunctionNode: 函数调用节点 """ + if y is None: + # 单参数:默认使用 252 天(约 1 年交易日)窗口 + return ts_min(x, 252) if isinstance(y, int) and y > 0: return ts_min(x, y) return FunctionNode("min", x, _ensure_node(y)) diff --git a/src/factors/translator.py b/src/factors/translator.py index 1270f5b..e41fbc6 100644 --- a/src/factors/translator.py +++ b/src/factors/translator.py @@ -88,6 +88,7 @@ class PolarsTranslator: self.register_handler("cs_rank", self._handle_cs_rank) self.register_handler("cs_zscore", self._handle_cs_zscore) self.register_handler("cs_neutral", self._handle_cs_neutral) + self.register_handler("cs_mean", self._handle_cs_mean) # 元素级数学函数 (element_wise) self.register_handler("abs", self._handle_abs) @@ -681,6 +682,18 @@ class PolarsTranslator: # 简单实现:减去截面均值(可在未来扩展为分组中性化) return expr - expr.mean() + @cross_section + def _handle_cs_mean(self, node: FunctionNode) -> pl.Expr: + """处理 cs_mean(expr) -> 截面均值。 + + 计算因子在横截面上的平均值,常用于 Alpha127 等因子。 + 例如:MEAN((100*(CLOSE-MAX(CLOSE,12))/(MAX(CLOSE,12)))^2) 中的 MEAN + """ + if len(node.args) != 1: + raise ValueError("cs_mean 需要 1 个参数: (expr)") + expr = self.translate(node.args[0]) + return expr.mean() + # ==================== 元素级数学函数 (element_wise) ==================== # 这些函数对每个元素独立计算,不添加 over diff --git a/src/scripts/GtjaConvertor/__init__.py b/src/scripts/GtjaConvertor/__init__.py new file mode 100644 index 0000000..8e2a235 --- /dev/null +++ b/src/scripts/GtjaConvertor/__init__.py @@ -0,0 +1,26 @@ +"""GTJA Alpha191 因子转换器。 + +将国泰君安的 Alpha191 因子公式转换为框架可识别的 DSL 字符串表达式。 + +模块结构: + - preprocessor: GTJA 语法清洗工具 + - converter: 转换主程序 + +使用示例: + >>> from src.scripts.GtjaConvertor import GtjaConverter + >>> converter = GtjaConverter() + >>> # 输入 GTJA 原始表达式 + >>> dsl_str = converter.convert("(-1 * CORR(RANK(DELTA(LOG(VOLUME), 1)), RANK(((CLOSE - OPEN) / OPEN)), 6))") + >>> print(dsl_str) + (-1 * ts_corr(cs_rank(ts_delta(log(vol), 1)), cs_rank(((close - open) / open)), 6)) +""" + +from .preprocessor import clean_gtja_formula +from .converter import convert_to_dsl, GtjaConverter, parse_multiline_formulas + +__all__ = [ + "clean_gtja_formula", + "convert_to_dsl", + "GtjaConverter", + "parse_multiline_formulas", +] diff --git a/src/scripts/GtjaConvertor/converter.py b/src/scripts/GtjaConvertor/converter.py new file mode 100644 index 0000000..5f09318 --- /dev/null +++ b/src/scripts/GtjaConvertor/converter.py @@ -0,0 +1,849 @@ +"""GTJA 公式转换器。 + +将 GTJA 原始公式转换为框架可识别的 DSL 字符串表达式。 +转换过程中会验证公式是否能被正确解析为 DSL 节点。 +""" + +import re +from pathlib import Path +from typing import Any +from src.factors.dsl import Node, FunctionNode +from src.factors.api import ( + close, + open, + high, + low, + vol, + amount, + pre_close, + change, + pct_chg, + ts_mean, + ts_std, + ts_max, + ts_min, + ts_sum, + ts_delay, + ts_delta, + ts_corr, + ts_cov, + ts_var, + ts_skew, + ts_kurt, + ts_pct_change, + ts_ema, + ts_atr, + ts_rsi, + ts_obv, + ts_rank, + ts_sma, + ts_wma, + ts_decay_linear, + ts_argmax, + ts_argmin, + ts_count, + ts_prod, + ts_sumac, + cs_rank, + cs_zscore, + cs_neutralize, + cs_winsorize, + cs_demean, + log, + exp, + sqrt, + sign, + cos, + sin, + abs, + max_, + min_, + clip, + atan, + log1p, + if_, + where, +) + +# 动态补充缺失的 cs_mean +try: + from src.factors.api import cs_mean +except ImportError: + + def cs_mean(x): + return FunctionNode("cs_mean", x) + + +try: + from .preprocessor import clean_gtja_formula, filter_unsupported_formulas +except ImportError: + from preprocessor import clean_gtja_formula, filter_unsupported_formulas + + +class GtjaConverter: + # 安全的函数命名空间,用于验证公式语法的合理性 + SAFE_NAMESPACE: dict[str, Any] = { + "close": close, + "open": open, + "high": high, + "low": low, + "vol": vol, + "volume": vol, + "amount": amount, + "pre_close": pre_close, + "change": change, + "pct_chg": pct_chg, + "ts_mean": ts_mean, + "ts_std": ts_std, + "ts_max": ts_max, + "ts_min": ts_min, + "ts_sum": ts_sum, + "ts_delay": ts_delay, + "ts_delta": ts_delta, + "ts_corr": ts_corr, + "ts_cov": ts_cov, + "ts_var": ts_var, + "ts_skew": ts_skew, + "ts_kurt": ts_kurt, + "ts_pct_change": ts_pct_change, + "ts_ema": ts_ema, + "ts_atr": ts_atr, + "ts_rsi": ts_rsi, + "ts_obv": ts_obv, + "ts_rank": ts_rank, + "ts_sma": ts_sma, + "ts_wma": ts_wma, + "ts_decay_linear": ts_decay_linear, + "ts_argmax": ts_argmax, + "ts_argmin": ts_argmin, + "ts_count": ts_count, + "ts_prod": ts_prod, + "ts_sumac": ts_sumac, + "cs_rank": cs_rank, + "cs_zscore": cs_zscore, + "cs_neutralize": cs_neutralize, + "cs_winsorize": cs_winsorize, + "cs_demean": cs_demean, + "cs_mean": cs_mean, + "log": log, + "exp": exp, + "sqrt": sqrt, + "sign": sign, + "cos": cos, + "sin": sin, + "abs": abs, + "max_": max_, + "min_": min_, + "clip": clip, + "atan": atan, + "log1p": log1p, + "if_": if_, + "where": where, + } + + def __init__(self): + self.errors: list[str] = [] + self.warnings: list[str] = [] + self._registration_results: list[dict[str, Any]] = [] + + def convert(self, formula: str) -> str | None: + if not filter_unsupported_formulas(formula): + self.warnings.append( + f"包含暂不支持的算子/循环依赖,已跳过: {formula[:50]}..." + ) + return None + + clean_formula = clean_gtja_formula(formula) + + try: + # 使用 AST api 执行,验证所有函数的输入/参数类型是否有效 + self._validate_formula(clean_formula) + return clean_formula + except Exception as e: + self.errors.append( + f"语法节点构建失败: {formula[:50]}... \n\t-> 解析所得: {clean_formula}\n\t-> 错误: {e}" + ) + return None + + def _validate_formula(self, formula: str) -> Node: + # __builtins__: {} 禁止任何外部危险执行,彻底保证安全 + return eval(formula, {"__builtins__": {}}, self.SAFE_NAMESPACE) + + def convert_batch( + self, + formulas: dict[str, str], + auto_register: bool = False, + output_path: Path | None = None, + ) -> dict[str, str | None]: + """批量转换公式。 + + Args: + formulas: 公式字典,key 为因子名(如 "Alpha1"),value 为公式字符串 + auto_register: 是否自动注册成功的因子到因子库 + output_path: 因子库文件路径,默认使用 data/factors.jsonl + + Returns: + 转换结果字典,key 为因子名,value 为 DSL 表达式或 None + """ + results = {} + self._registration_results = [] + + for name, formula in formulas.items(): + result = self.convert(formula) + results[name] = result + + # 自动注册成功的因子 + if auto_register and result is not None: + reg_result = register_gtja_factor(name, result, output_path) + self._registration_results.append({"alpha_name": name, **reg_result}) + + return results + + def get_registration_report(self) -> dict[str, Any]: + """获取注册报告。 + + Returns: + 包含注册统计信息的字典 + """ + if not hasattr(self, "_registration_results"): + return { + "total": 0, + "success": 0, + "skipped": 0, + "failed": 0, + "details": [], + } + + success = sum(1 for r in self._registration_results if r["status"] == "success") + skipped = sum(1 for r in self._registration_results if r["status"] == "skipped") + failed = sum(1 for r in self._registration_results if r["status"] == "failed") + + return { + "total": len(self._registration_results), + "success": success, + "skipped": skipped, + "failed": failed, + "details": self._registration_results, + } + + def get_stats(self) -> dict[str, Any]: + return { + "errors": len(self.errors), + "warnings": len(self.warnings), + "error_details": self.errors, + "warning_details": self.warnings, + } + + +def convert_to_dsl(formula_str: str) -> str | None: + converter = GtjaConverter() + return converter.convert(formula_str) + + +def parse_multiline_formulas(text: str) -> dict[str, str]: + formulas = {} + for line in text.strip().split("\n"): + line = line.strip() + if not line: + continue + if ":" in line: + name, expr = line.split(":", 1) + if name.strip() and expr.strip(): + formulas[name.strip()] = expr.strip() + return formulas + + +def get_next_factor_id(filepath: Path) -> str: + """生成下一个 factor_id。 + + 从现有文件中提取最大序号,生成新的 F_XXX 格式 ID。 + + Args: + filepath: JSONL 文件路径 + + Returns: + 新的 factor_id,如 "F_001" + """ + import builtins + import json + + if not filepath.exists(): + return "F_001" + + try: + with builtins.open(filepath, "r", encoding="utf-8") as f: + lines = f.readlines() + except Exception: + return "F_001" + + max_num = 0 + pattern = re.compile(r"^F_(\d+)$") + + for line in lines: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + factor_id = data.get("factor_id", "") + match = pattern.match(factor_id) + if match: + num = int(match.group(1)) + max_num = max(max_num, num) + except (json.JSONDecodeError, ValueError): + continue + + return f"F_{max_num + 1:03d}" + + +def extract_alpha_number(alpha_name: str) -> int | None: + """从 Alpha 名称中提取数字。 + + Args: + alpha_name: 如 "Alpha1", "Alpha123" + + Returns: + 数字部分,如 1, 123;如果无法解析返回 None + """ + match = re.match(r"[Aa]lpha(\d+)", alpha_name) + if match: + return int(match.group(1)) + return None + + +def register_gtja_factor( + alpha_name: str, + dsl_expr: str, + output_path: Path | None = None, +) -> dict[str, Any]: + """注册单个 GTJA 因子到因子库。 + + Args: + alpha_name: 原始 Alpha 名称,如 "Alpha1" + dsl_expr: DSL 表达式字符串 + output_path: 因子库文件路径,默认使用 data/factors.jsonl + + Returns: + 注册结果字典,包含 status 和 message + """ + from src.factors.metadata import FactorManager + from src.factors.metadata.exceptions import DuplicateFactorError, ValidationError + from src.config.settings import settings + + # 提取数字并构建标准化名称 + alpha_num = extract_alpha_number(alpha_name) + if alpha_num is None: + return { + "status": "failed", + "message": f"无法从 '{alpha_name}' 提取数字编号", + } + + # 标准化名称: GTJA_alpha001, GTJA_alpha123 + factor_name = f"GTJA_alpha{alpha_num:03d}" + + # 确定输出路径 + if output_path is None: + output_path = settings.data_path_resolved / "factors.jsonl" + + # 初始化 FactorManager + manager = FactorManager(str(output_path)) + + try: + # 检查是否已存在(处理空文件的情况) + try: + existing = manager.get_factors_by_name(factor_name) + if len(existing) > 0: + return { + "status": "skipped", + "message": f"因子 '{factor_name}' 已存在", + } + except Exception: + # 如果查询失败(如文件为空),继续尝试注册 + pass + + # 生成 factor_id + factor_id = get_next_factor_id(output_path) + + # 构建因子记录 + factor_record = { + "factor_id": factor_id, + "name": factor_name, + "desc": f"GTJA {alpha_name} 因子", + "dsl": dsl_expr, + "category": "gtja_alpha", + "source": "GTJA191", + } + + # 注册因子 + manager.add_factor(factor_record) + + return { + "status": "success", + "message": f"{factor_id}: {factor_name}", + "factor_id": factor_id, + "factor_name": factor_name, + } + + except DuplicateFactorError as e: + return { + "status": "failed", + "message": f"因子 ID 重复: {e}", + } + except ValidationError as e: + return { + "status": "failed", + "message": f"验证失败: {e}", + } + except Exception as e: + return { + "status": "failed", + "message": f"注册失败: {e}", + } + + +if __name__ == "__main__": + # 使用示例:多行字符串输入 + converter = GtjaConverter() + + # 多行字符串,格式为 "因子名: 表达式",支持空行 + test_input = """ +Alpha1: (-1 * CORR(RANK(DELTA(LOG(VOLUME), 1)), RANK(((CLOSE -OPEN) / OPEN)), 6)) + +Alpha2: (-1 * DELTA((((CLOSE -LOW) -(HIGH -CLOSE)) / (HIGH -LOW)), 1)) + +Alpha3: SUM((CLOSE=DELAY(CLOSE,1)?0:CLOSE-(CLOSE>DELAY(CLOSE,1)?MIN(LOW,DELAY(CLOSE,1)):MAX(HIGH,DELAY(CLOSE,1)))),6) + +Alpha4: ((((SUM(CLOSE, 8) / 8) + STD(CLOSE, 8)) < (SUM(CLOSE, 2) / 2)) ? (-1 * 1) : (((SUM(CLOSE, 2) / 2) <((SUM(CLOSE, 8) / 8) -STD(CLOSE, 8))) ? 1 : (((1 < (VOLUME / MEAN(VOLUME,20))) || ((VOLUME /MEAN(VOLUME,20)) == 1)) ? 1 : (-1 * 1)))) + +Alpha5: (-1 * TSMAX(CORR(TSRANK(VOLUME, 5), TSRANK(HIGH, 5), 5), 3)) + +Alpha6: (RANK(SIGN(DELTA((((OPEN * 0.85) + (HIGH * 0.15))), 4)))* -1) + +Alpha7: ((RANK(MAX((VWAP -CLOSE), 3)) + RANK(MIN((VWAP -CLOSE), 3))) * RANK(DELTA(VOLUME, 3))) + +Alpha8: RANK(DELTA(((((HIGH + LOW) / 2) * 0.2) + (VWAP * 0.8)), 4) * -1) + +Alpha9: SMA(((HIGH+LOW)/2-(DELAY(HIGH,1)+DELAY(LOW,1))/2)*(HIGH-LOW)/VOLUME,7,2) + +Alpha10: (RANK(MAX(((RET < 0) ? STD(RET, 20) : CLOSE)^2),5)) + +Alpha11: SUM(((CLOSE-LOW)-(HIGH-CLOSE))./(HIGH-LOW).*VOLUME,6) + +Alpha12: (RANK((OPEN -(SUM(VWAP, 10) / 10)))) * (-1 * (RANK(ABS((CLOSE -VWAP))))) + +Alpha13: (((HIGH * LOW)^0.5) -VWAP) + +Alpha14: CLOSE-DELAY(CLOSE,5) + +Alpha15: OPEN/DELAY(CLOSE,1)-1 + +Alpha16: (-1 * TSMAX(RANK(CORR(RANK(VOLUME), RANK(VWAP), 5)), 5)) + +Alpha17: RANK((VWAP -MAX(VWAP, 15)))^DELTA(CLOSE, 5) + +Alpha18: CLOSE/DELAY(CLOSE,5) + +Alpha19: (CLOSEDELAY(CLOSE,1)?STD(CLOSE:20),0),20,1)/(SMA((CLOSE>DELAY(CLOSE,1)?STD(CLOSE,20):0),20,1)+SMA((CLOSE<=DELAY(CLOSE,1)?STD(CLOSE,20):0),20,1))*100 + +Alpha24: SMA(CLOSE-DELAY(CLOSE,5),5,1) + +Alpha25: ((-1 * RANK((DELTA(CLOSE, 7) * (1 -RANK(DECAYLINEAR((VOLUME/MEAN(VOLUME,20)), 9)))))) * (1 +RANK(SUM(RET, 250)))) + +Alpha26: ((((SUM(CLOSE, 7) / 7) -CLOSE)) + ((CORR(VWAP, DELAY(CLOSE, 5), 230)))) + +Alpha27: WMA((CLOSE-DELAY(CLOSE,3))/DELAY(CLOSE,3)*100+(CLOSE-DELAY(CLOSE,6))/DELAY(CLOSE,6)*100,12) + +Alpha28: 3*SMA((CLOSE-TSMIN(LOW,9))/(TSMAX(HIGH,9)-TSMIN(LOW,9))*100,3,1)-2*SMA(SMA((CLOSE-TSMIN(LOW,9))/(MAX(HIGH,9)-TSMAX(LOW,9))*100,3,1),3,1) + +Alpha29: (CLOSE-DELAY(CLOSE,6))/DELAY(CLOSE,6)*VOLUME + +Alpha30: WMA((REGRESI(CLOSE/DELAY(CLOSE)-1,MKT,SMB,HML,60))^2,20) + +Alpha31: (CLOSE-MEAN(CLOSE,12))/MEAN(CLOSE,12)*100 + +Alpha32: (-1 * SUM(RANK(CORR(RANK(HIGH), RANK(VOLUME), 3)), 3)) + +Alpha33: ((((-1 * TSMIN(LOW, 5)) + DELAY(TSMIN(LOW, 5), 5)) * RANK(((SUM(RET, 240) -SUM(RET, 20)) / 220))) *TSRANK(VOLUME, 5)) + +Alpha34: MEAN(CLOSE,12)/CLOSE + +Alpha35: (MIN(RANK(DECAYLINEAR(DELTA(OPEN, 1), 15)), RANK(DECAYLINEAR(CORR((VOLUME), ((OPEN * 0.65) +(OPEN *0.35)), 17),7))) * -1) + +Alpha36: RANK(SUM(CORR(RANK(VOLUME), RANK(VWAP)), 6), 2) + +Alpha37: (-1 * RANK(((SUM(OPEN, 5) * SUM(RET, 5)) -DELAY((SUM(OPEN,5) * SUM(RET, 5)), 10)))) + +Alpha38: (((SUM(HIGH, 20) / 20) < HIGH) ? (-1 * DELTA(HIGH, 2)) : 0) + +Alpha39: ((RANK(DECAYLINEAR(DELTA((CLOSE), 2),8)) -RANK(DECAYLINEAR(CORR(((VWAP * 0.3) + (OPEN * 0.7)),SUM(MEAN(VOLUME,180), 37), 14), 12))) * -1) + +Alpha40: SUM((CLOSE>DELAY(CLOSE,1)?VOLUME:0),26)/SUM((CLOSE<=DELAY(CLOSE,1)?VOLUME:0),26)*100 + +Alpha41: (RANK(MAX(DELTA((VWAP), 3), 5))* -1) + +Alpha42: ((-1 * RANK(STD(HIGH, 10))) * CORR(HIGH, VOLUME, 10)) + +Alpha43: SUM((CLOSE>DELAY(CLOSE,1)?VOLUME:(CLOSE=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)/(SUM(((HIGH+LOW)>=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)+SUM(((HIGH+LOW)<=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)) + +Alpha50: SUM(((HIGH+LOW)<=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)/(SUM(((HIGH+LOW)<=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)+SUM(((HIGH+LOW)>=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12))-SUM(((HIGH+LOW)>=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)/(SUM(((HIGH+LOW)>=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)+SUM(((HIGH+LOW)<=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)) + +Alpha51: SUM(((HIGH+LOW)<=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)/(SUM(((HIGH+LOW)<=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)+SUM(((HIGH+LOW)>=(DELAY(HIGH,1)+DELAY(LOW,1))?0:MAX(ABS(HIGH-DELAY(HIGH,1)),ABS(LOW-DELAY(LOW,1)))),12)) + +Alpha52: SUM(MAX(0,HIGH-DELAY((HIGH+LOW+CLOSE)/3,1)),26)/SUM(MAX(0,DELAY((HIGH+LOW+CLOSE)/3,1)-L),26)*100 + +Alpha53: COUNT(CLOSE>DELAY(CLOSE,1),12)/12*100 + +Alpha54: (-1 * RANK((STD(ABS(CLOSE -OPEN)) + (CLOSE -OPEN)) + CORR(CLOSE, OPEN,10))) + +Alpha55: SUM(16*(CLOSE-DELAY(CLOSE,1)+(CLOSE-OPEN)/2+DELAY(CLOSE,1)-DELAY(OPEN,1))/((ABS(HIGH-DELAY(CLOSE,1))>ABS(LOW-DELAY(CLOSE,1))&ABS(HIGH-DELAY(CLOSE,1))>ABS(HIGH-DELAY(LOW,1))?ABS(HIGH-DELAY(CLOSE,1))+ABS(LOW-DELAY(CLOSE,1))/2+ABS(DELAY(CLOSE,1)-DELAY(OPEN,1))/4:(ABS(LOW-DELAY(CLOSE,1))>ABS(HIGH-DELAY(LOW,1))&ABS(LOW-DELAY(CLOSE,1))>ABS(HIGH-DELAY(CLOSE,1))?ABS(LOW-DELAY(CLOSE,1))+ABS(HIGH-DELAY(CLOSE,1))/2+ABS(DELAY(CLOSE,1)-DELAY(OPEN,1))/4:ABS(HIGH-DELAY(LOW,1))+ABS(DELAY(CLOSE,1)-DELAY(OPEN,1))/4)))*MAX(ABS(HIGH-DELAY(CLOSE,1)),ABS(LOW-DELAY(CLOSE,1))),20) + +Alpha56: (RANK((OPEN -TSMIN(OPEN, 12))) < RANK((RANK(CORR(SUM(((HIGH + LOW) / 2), 19), SUM(MEAN(VOLUME,40), 19), 13))^5))) + +Alpha57: SMA((CLOSE-TSMIN(LOW,9))/(TSMAX(HIGH,9)-TSMIN(LOW,9))*100,3,1) + +Alpha58: COUNT(CLOSE>DELAY(CLOSE,1),20)/20*100 + +Alpha59: SUM((CLOSE=DELAY(CLOSE,1)?0:CLOSE-(CLOSE>DELAY(CLOSE,1)?MIN(LOW,DELAY(CLOSE,1)):MAX(HIGH,DELAY(CLOSE,1)))),20) + +Alpha60: SUM(((CLOSE-LOW)-(HIGH-CLOSE))./(HIGH-LOW).*VOLUME,20) + +Alpha61: (MAX(RANK(DECAYLINEAR(DELTA(VWAP, 1), 12)),RANK(DECAYLINEAR(RANK(CORR((LOW),MEAN(VOLUME,80), 8)), 17))) * -1) + +Alpha62: (-1 * CORR(HIGH, RANK(VOLUME), 5)) + +Alpha63: SMA(MAX(CLOSE-DELAY(CLOSE,1),0),6,1)/SMA(ABS(CLOSE-DELAY(CLOSE,1)),6,1)*100 + +Alpha64: (MAX(RANK(DECAYLINEAR(CORR(RANK(VWAP), RANK(VOLUME), 4), 4)),RANK(DECAYLINEAR(MAX(CORR(RANK(CLOSE), RANK(MEAN(VOLUME,60)), 4), 13), 14))) * -1) + +Alpha65: MEAN(CLOSE,6)/CLOSE + +Alpha66: (CLOSE-MEAN(CLOSE,6))/MEAN(CLOSE,6)*100 + +Alpha67: SMA(MAX(CLOSE-DELAY(CLOSE,1),0),24,1)/SMA(ABS(CLOSE-DELAY(CLOSE,1)),24,1)*100 + +Alpha68: SMA(((HIGH+LOW)/2-(DELAY(HIGH,1)+DELAY(LOW,1))/2)*(HIGH-LOW)/VOLUME,15,2) + +Alpha69: (SUM(DTM,20)>SUM(DBM,20)?(SUM(DTM,20)-SUM(DBM,20))/SUM(DTM,20): (SUM(DTM,20)=SUM(DBM,20)? 0: (SUM(DTM,20)-SUM(DBM,20))/SUM(DBM,20))) + +Alpha70: STD(AMOUNT,6) + +Alpha71: (CLOSE-MEAN(CLOSE,24))/MEAN(CLOSE,24)*100 + +Alpha72: SMA((TSMAX(HIGH,6)-CLOSE)/(TSMAX(HIGH,6)-TSMIN(LOW,6))*100,15,1) + +Alpha73: ((TSRANK(DECAYLINEAR(DECAYLINEAR(CORR((CLOSE), VOLUME, 10), 16), 4), 5) - RANK(DECAYLINEAR(CORR(VWAP, MEAN(VOLUME,30), 4),3))) * -1) + +Alpha74: (RANK(CORR(SUM(((LOW * 0.35) + (VWAP * 0.65)), 20), SUM(MEAN(VOLUME,40), 20), 7)) + RANK(CORR(RANK(VWAP), RANK(VOLUME), 6))) + +Alpha75: COUNT(CLOSE>OPEN &BANCHMARKINDEXCLOSEDELAY(CLOSE,1)?VOLUME:(CLOSE=DELAY(OPEN,1)?0:MAX((OPEN-LOW),(OPEN-DELAY(OPEN,1)))),20) + +Alpha94: SUM((CLOSE>DELAY(CLOSE,1)?VOLUME:(CLOSE0?CLOSE-DELAY(CLOSE,1):0),12)-SUM((CLOSE-DELAY(CLOSE,1)<0?ABS(CLOSE-DELAY(CLOSE,1)):0),12))/(SUM((CLOSE-DELAY(CLOSE,1)>0?CLOSE-DELAY(CLOSE,1):0),12)+SUM((CLOSE-DELAY(CLOSE,1)<0?ABS(CLOSE-DELAY(CLOSE,1)):0),12))*100 + +Alpha113: (-1 * ((RANK((SUM(DELAY(CLOSE, 5), 20) / 20)) * CORR(CLOSE, VOLUME, 2)) *RANK(CORR(SUM(CLOSE, 5),SUM(CLOSE, 20), 2)))) + +Alpha114: ((RANK(DELAY(((HIGH -LOW) / (SUM(CLOSE, 5) / 5)), 2)) * RANK(RANK(VOLUME))) / (((HIGH -LOW) /(SUM(CLOSE, 5) / 5)) / (VWAP -CLOSE))) + +Alpha115: (RANK(CORR(((HIGH * 0.9) + (CLOSE * 0.1)), MEAN(VOLUME,30), 10))^RANK(CORR(TSRANK(((HIGH + LOW) /2), 4), TSRANK(VOLUME, 10), 7))) + +Alpha116: REGBETA(CLOSE,SEQUENCE,20) + +Alpha117: ((TSRANK(VOLUME, 32) * (1 -TSRANK(((CLOSE + HIGH) -LOW), 16))) * (1 -TSRANK(RET, 32))) + +Alpha118: SUM(HIGH-OPEN,20)/SUM(OPEN-LOW,20)*100 + +Alpha119: (RANK(DECAYLINEAR(CORR(VWAP, SUM(MEAN(VOLUME,5), 26), 5), 7)) -RANK(DECAYLINEAR(TSRANK(MIN(CORR(RANK(OPEN), RANK(MEAN(VOLUME,15)), 21), 9), 7), 8))) + +Alpha120: (RANK((VWAP -CLOSE)) / RANK((VWAP + CLOSE))) + +Alpha121: ((RANK((VWAP -MIN(VWAP, 12)))^TSRANK(CORR(TSRANK(VWAP, 20), TSRANK(MEAN(VOLUME,60), 2), 18), 3)) * -1) + +Alpha122: (SMA(SMA(SMA(LOG(CLOSE),13,2),13,2),13,2)-DELAY(SMA(SMA(SMA(LOG(CLOSE),13,2),13,2),13,2),1))/DELAY(SMA(SMA(SMA(LOG(CLOSE),13,2),13,2),13,2),1) + +Alpha123: ((RANK(CORR(SUM(((HIGH + LOW) / 2), 20), SUM(MEAN(VOLUME,60), 20), 9))DELAY((HIGH+LOW+CLOSE)/3,1)?(HIGH+LOW+CLOSE)/3*VOLUME:0),14)/SUM(((HIGH+LOW+CLOSE)/3ABS(LOW-DELAY(CLOSE,1)) &ABS(HIGH-DELAY(CLOSE,1))>ABS(HIGH-DELAY(LOW,1))?ABS(HIGH-DELAY(CLOSE,1))+ABS(LOW-DELAY(CLOSE,1))/2+ABS(DELAY(CLOSE,1)-DELAY(OPEN,1))/4:(ABS(LOW-DELAY(CLOSE,1))>ABS(HIGH-DELAY(LOW,1)) &ABS(LOW-DELAY(CLOSE,1))>ABS(HIGH-DELAY(CLOSE,1))?ABS(LOW-DELAY(CLOSE,1))+ABS(HIGH-DELAY(CLOSE,1))/2+ABS(DELAY(CLOSE,1)-DELAY(OPEN,1))/4:ABS(HIGH-DELAY(LOW,1))+ABS(DELAY(CLOSE,1)-DELAY(OPEN,1))/4)))*MAX(ABS(HIGH-DELAY(CLOSE,1)),ABS(LOW-DELAY(CLOSE,1))) + +Alpha138: ((RANK(DECAYLINEAR(DELTA((((LOW * 0.7) + (VWAP *0.3))), 3), 20)) -TSRANK(DECAYLINEAR(TSRANK(CORR(TSRANK(LOW, 8), TSRANK(MEAN(VOLUME,60), 17), 5), 19), 16), 7)) * -1) + +Alpha139: (-1 * CORR(OPEN, VOLUME, 10)) + +Alpha140: MIN(RANK(DECAYLINEAR(((RANK(OPEN) + RANK(LOW)) -(RANK(HIGH) + RANK(CLOSE))), 8)),TSRANK(DECAYLINEAR(CORR(TSRANK(CLOSE, 8), TSRANK(MEAN(VOLUME,60), 20), 8), 7), 3)) + +Alpha141: (RANK(CORR(RANK(HIGH), RANK(MEAN(VOLUME,15)), 9))* -1) + +Alpha142: (((-1 * RANK(TSRANK(CLOSE, 10))) * RANK(DELTA(DELTA(CLOSE, 1), 1))) *RANK(TSRANK((VOLUME/MEAN(VOLUME,20)), 5))) + +Alpha143: CLOSE>DELAY(CLOSE,1)?(CLOSE-DELAY(CLOSE,1))/DELAY(CLOSE,1)*SELF:SELF + +Alpha144: SUMIF(ABS(CLOSE/DELAY(CLOSE,1)-1)/AMOUNT,20,CLOSEDELAY(CLOSE,1))?1/(CLOSE-DELAY(CLOSE,1)):1)-MIN(((CLOSE>DELAY(CLOSE,1))?1/(CLOSE-DELAY(CLOSE,1)):1),12))/(HIGH-LOW)*100,13,2) + +Alpha165: MAX(SUMAC(CLOSE-MEAN(CLOSE,48)))-MIN(SUMAC(CLOSE-MEAN(CLOSE,48)))/STD(CLOSE,48) + +Alpha166: -20*(20-1 )^1.5*SUM(CLOSE/DELAY(CLOSE,1)-1-MEAN(CLOSE/DELAY(CLOSE,1)-1,20),20)/((20-1)*(20-2)(SUM((CLOSE/DELAY(CLOSE,1),20)^2,20))^1.5) + +Alpha167: SUM((CLOSE-DELAY(CLOSE,1)>0?CLOSE-DELAY(CLOSE,1):0),12) + +Alpha168: (-1*VOLUME/MEAN(VOLUME,20)) + +Alpha169: SMA(MEAN(DELAY(SMA(CLOSE-DELAY(CLOSE,1),9,1),1),12)-MEAN(DELAY(SMA(CLOSE-DELAY(CLOSE,1),9,1),1),26),10,1) + +Alpha170: ((((RANK((1 / CLOSE)) * VOLUME) / MEAN(VOLUME,20)) * ((HIGH * RANK((HIGH -CLOSE))) / (SUM(HIGH, 5) /5))) -RANK((VWAP -DELAY(VWAP, 5)))) + +Alpha171: ((-1 * ((LOW -CLOSE) * (OPEN^5))) / ((CLOSE -HIGH) * (CLOSE^5))) + +Alpha172: MEAN(ABS(SUM((LD>0 & LD>HD)?LD:0,14)*100/SUM(TR,14)-SUM((HD>0 &HD>LD)?HD:0,14)*100/SUM(TR,14))/(SUM((LD>0 & LD>HD)?LD:0,14)*100/SUM(TR,14)+SUM((HD>0 &HD>LD)?HD:0,14)*100/SUM(TR,14))*100,6) + +Alpha173: 3*SMA(CLOSE,13,2)-2*SMA(SMA(CLOSE,13,2),13,2)+SMA(SMA(SMA(LOG(CLOSE),13,2),13,2),13,2) + +Alpha174: SMA((CLOSE>DELAY(CLOSE,1)?STD(CLOSE,20):0),20,1) + +Alpha175: MEAN(MAX(MAX((HIGH-LOW),ABS(DELAY(CLOSE,1)-HIGH)),ABS(DELAY(CLOSE,1)-LOW)),6) + +Alpha176: CORR(RANK(((CLOSE -TSMIN(LOW, 12)) / (TSMAX(HIGH, 12) -TSMIN(LOW,12)))),RANK(VOLUME), 6) + +Alpha177: ((20-HIGHDAY(HIGH,20))/20)*100 + +Alpha178: (CLOSE-DELAY(CLOSE,1))/DELAY(CLOSE,1)*VOLUME + +Alpha179: (RANK(CORR(VWAP, VOLUME, 4)) *RANK(CORR(RANK(LOW),RANK(MEAN(VOLUME,50)), 12))) + +Alpha180: ((MEAN(VOLUME,20)OPEN & BANCHMARKINDEXCLOSE>BANCHMARKINDEXOPEN)OR(CLOSE0 & LD>HD)?LD:0,14)*100/SUM(TR,14)-SUM((HD>0 &HD>LD)?HD:0,14)*100/SUM(TR,14))/(SUM((LD>0 &LD>HD)?LD:0,14)*100/SUM(TR,14)+SUM((HD>0 &HD>LD)?HD:0,14)*100/SUM(TR,14))*100,6)+DELAY(MEAN(ABS(SUM((LD>0 &LD>HD)?LD:0,14)*100/SUM(TR,14)-SUM((HD>0 & HD>LD)?HD:0,14)*100/SUM(TR,14))/(SUM((LD>0&LD>HD)?LD:0,14)*100/SUM(TR,14)+SUM((HD>0 & HD>LD)?HD:0,14)*100/SUM(TR,14))*100,6),6))/2 + +Alpha187: SUM((OPEN<=DELAY(OPEN,1)?0:MAX((HIGH-OPEN),(OPEN-DELAY(OPEN,1)))),20) + +Alpha188: ((HIGH-LOW–SMA(HIGH-LOW,11,2))/SMA(HIGH-LOW,11,2))*100 + +Alpha189: MEAN(ABS(CLOSE-MEAN(CLOSE,6)),6) + +Alpha190: LOG((COUNT(CLOSE/DELAY(CLOSE)-1>((CLOSE/DELAY(CLOSE,19))^(1/20)-1),20)-1)*(SUMIF(((CLOSE/DELAY(CLOSE)-1-(CLOSE/DELAY(CLOSE,19))^(1/20)-1))^2,20,CLOSE/DELAY(CLOSE)-1<(CLOSE/DELAY(CLOSE,19))^(1/20)-1))/((COUNT((CLOSE/DELAY(CLOSE)-1<(CLOSE/DELAY(CLOSE,19))^(1/20)-1),20))*(SUMIF((CLOSE/DELAY(CLOSE)-1-((CLOSE/DELAY(CLOSE,19))^(1/20)-1))^2,20,CLOSE/DELAY(CLOSE)-1>(CLOSE/DELAY(CLOSE,19))^(1/20)-1)))) + +Alpha191: ((CORR(MEAN(VOLUME,20), LOW, 5) + ((HIGH + LOW) / 2)) -CLOSE) + """ + + print("=" * 60) + print("GTJA Alpha191 因子转换测试(带自动注册)") + print("=" * 60) + + # 解析多行字符串 + formulas = parse_multiline_formulas(test_input) + print(f"\n共解析到 {len(formulas)} 个因子\n") + + # 使用批量转换并自动注册 + # auto_register=True 会自动将转换成功的因子注册到因子库 + results = converter.convert_batch( + formulas, + auto_register=True, # 启用自动注册 + ) + + # 显示每个因子的转换和注册结果 + for name, dsl_str in results.items(): + print(f"因子名称: {name}") + if dsl_str: + print(f"DSL 表达式: {dsl_str}") + else: + print("转换失败或包含不支持的算子") + print() + + # 打印转换统计 + stats = converter.get_stats() + print("\n" + "=" * 60) + print("转换统计:") + print(f" 错误: {stats['errors']}") + print(f" 警告: {stats['warnings']}(暂不支持的因子)") + + if stats["errors"] > 0: + print("\n错误详情:") + for error in stats["error_details"]: + print(f" - {error}") + + if stats["warnings"] > 0: + print("\n警告详情(这些因子不会被注册):") + for warning in stats["warning_details"]: + print(f" - {warning}") + + # 打印注册报告 + reg_report = converter.get_registration_report() + if reg_report["total"] > 0: + print("\n" + "=" * 60) + print("因子注册报告:") + print(f" 总计尝试: {reg_report['total']}") + print(f" 成功注册: {reg_report['success']}") + print(f" 已存在跳过: {reg_report['skipped']}") + print(f" 注册失败: {reg_report['failed']}") + + # 打印成功的因子 + success_items = [d for d in reg_report["details"] if d["status"] == "success"] + if success_items: + print("\n成功注册的因子:") + for item in success_items: + print(f" - {item['message']}") diff --git a/src/scripts/GtjaConvertor/preprocessor.py b/src/scripts/GtjaConvertor/preprocessor.py new file mode 100644 index 0000000..64b7673 --- /dev/null +++ b/src/scripts/GtjaConvertor/preprocessor.py @@ -0,0 +1,273 @@ +"""GTJA 公式预处理器。 + +将 GTJA 原始语法清洗为框架可识别的 DSL 语法。 +修复了原版公式中的拼写错误、歧义重载、嵌套三元运算符等问题。 +""" + +import re + +def clean_gtja_formula(formula: str) -> str: + """将 GTJA 原始语法清洗为框架可识别的 DSL 语法。""" + + formula = formula.strip() + + # 0. 清洗中文标点符号和空格 + formula = formula.replace("(", "(").replace(")", ")") + formula = formula.replace(",", ",").replace("–", "-") + formula = formula.replace("【", "[").replace("】", "]") + + # 1. 替换基础算术运算符和逻辑运算符 + formula = formula.replace("./", "/").replace(".*", "*").replace("^", "**") + formula = formula.replace("||", "|").replace("&&", "&") + + # 2. 宏替换 (基础衍生宏) + replacements = { + r"\bRET\b": "(CLOSE / DELAY(CLOSE, 1) - 1)", + r"\bVWAP\b": "(AMOUNT / VOLUME)", + } + for old, new in replacements.items(): + formula = re.sub(old, new, formula, flags=re.IGNORECASE) + + # 3. 修复原版 GTJA 公式库中的各处天坑笔误 (Typo) + typo_mapping = { + r"\bHGIH\b": "HIGH", # Alpha 159 拼写错误 + r"\bDELAT\b": "DELTA", # Alpha 131 拼写错误 + r"\?STD\(CLOSE\s*:\s*20\)\s*,\s*0": "? STD(CLOSE, 20) : 0", # Alpha 23 冒号与逗号打反 + r"CLOSE\s*:\s*20": "CLOSE, 20", # 其他可能存在的冒号误触 + r"(?<=-)L\b": "LOW", # Alpha 52 极简缩写: -L) + r"\)\(": ")*(", # Alpha 166 缺乘号: (20-2)(SUM... + r"\(CLOSE/DELAY\(CLOSE,1\),20\)": "(CLOSE/DELAY(CLOSE,1)-1)",# Alpha 166 多余参数与格式错乱 + r"\*SIGN\(DELTA\(CLOSE,\s*7\)\)\s*:\s*\(-1\s*\*VOLUME\)\)\)": "*SIGN(DELTA(CLOSE, 7))) : (-1 * VOLUME))", # Alpha 180 括号位置打错 + r"\bOR\b": "|", # Alpha 182 异常逻辑符 + r"\bAND\b": "&", # 异常逻辑符 + } + for bad, good in typo_mapping.items(): + formula = re.sub(bad, good, formula, flags=re.IGNORECASE) + + # 4. 修复条件表达式中的赋值符为比较符 + # 把 = 变成 ==,但避开 <=, >=, !=, == + formula = re.sub(r"(?!])=(?![=])", "==", formula) + + # 5. 智能解析多态重载函数 (RANK, MEAN) -> (cs_/ts_) + def resolve_overloaded_funcs(f: str) -> str: + for target in ["RANK", "MEAN"]: + while True: + match = re.search(rf"(? 1 else "cs_" + f = f[:start_idx] + prefix + target.lower() + f[paren_start:] + return f + + formula = resolve_overloaded_funcs(formula) + + # 6. 三元运算符安全转换 (Condition) ? True : False -> if_(Condition, True, False) + # 自右向左匹配,完美解决复杂嵌套 + def ternary_to_if(f: str) -> str: + max_iterations = 100 + iteration = 0 + + while "?" in f and iteration < max_iterations: + q_idx = f.rfind("?") + if q_idx == -1: break + + depth = 0 + c_idx = -1 + for i in range(q_idx + 1, len(f)): + if f[i] == '(': depth += 1 + elif f[i] == ')': depth -= 1 + elif f[i] == ':' and depth == 0: + c_idx = i + break + + if c_idx == -1: + f = f[:q_idx] + "_" + f[q_idx+1:] + continue + + depth = 0 + a_start = 0 + for i in range(q_idx - 1, -1, -1): + if f[i] == ')': depth += 1 + elif f[i] == '(': + depth -= 1 + if depth < 0: + a_start = i + 1 + break + elif f[i] == ',' and depth == 0: + a_start = i + 1 + break + + depth = 0 + c_end = len(f) + for i in range(c_idx + 1, len(f)): + if f[i] == '(': depth += 1 + elif f[i] == ')': + depth -= 1 + if depth < 0: + c_end = i + break + elif f[i] == ',' and depth == 0: + c_end = i + break + + A_str = f[a_start:q_idx].strip() + B_str = f[q_idx+1:c_idx].strip() + C_str = f[c_idx+1:c_end].strip() + + replacement = f"if_({A_str}, {B_str}, {C_str})" + f = f[:a_start] + replacement + f[c_end:] + iteration += 1 + + return f + + formula = ternary_to_if(formula) + + # 7. 函数名直接映射 (GTJA -> DSL) + function_mapping = { + r"\bDELAY\s*\(": "ts_delay(", + r"\bDELTA\s*\(": "ts_delta(", + r"\bSTD\s*\(": "ts_std(", + r"\bMAX\s*\(": "max_(", + r"\bMIN\s*\(": "min_(", + r"\bSUM\s*\(": "ts_sum(", + r"\bVAR\s*\(": "ts_var(", + r"\bCOV\s*\(": "ts_cov(", + r"\bCOVIANCE\s*\(": "ts_cov(", + r"\bCORR\s*\(": "ts_corr(", + r"\bSMA\s*\(": "ts_sma(", + r"\bSMEAN\s*\(": "ts_sma(", + r"\bMA\s*\(": "ts_mean(", + r"\bWMA\s*\(": "ts_wma(", + r"\bDECAYLINEAR\s*\(": "ts_decay_linear(", + r"\bHIGHDAY\s*\(": "ts_argmax(", + r"\bLOWDAY\s*\(": "ts_argmin(", + r"\bCOUNT\s*\(": "ts_count(", + r"\bPROD\s*\(": "ts_prod(", + r"\bSUMAC\s*\(": "ts_sumac(", + r"\bTSRANK\s*\(": "ts_rank(", + r"\bTSMAX\s*\(": "ts_max(", + r"\bTSMIN\s*\(": "ts_min(", + r"\bLOG\s*\(": "log(", + r"\bEXP\s*\(": "exp(", + r"\bSQRT\s*\(": "sqrt(", + r"\bSIGN\s*\(": "sign(", + r"\bABS\s*\(": "abs(", + r"\bATAN\s*\(": "atan(", + } + for gtja_func, dsl_func in function_mapping.items(): + formula = re.sub(gtja_func, dsl_func, formula, flags=re.IGNORECASE) + + # 8. 字段名映射 + field_mapping = { + r"\bCLOSE\b": "close", + r"\bOPEN\b": "open", + r"\bHIGH\b": "high", + r"\bLOW\b": "low", + r"\bVOLUME\b": "vol", + r"\bVOL\b": "vol", + r"\bAMOUNT\b": "amount", + r"\bPRE_CLOSE\b": "pre_close", + r"\bCHANGE\b": "change", + r"\bPCT_CHG\b": "pct_chg", + } + for gtja_field, dsl_field in field_mapping.items(): + formula = re.sub(gtja_field, dsl_field, formula, flags=re.IGNORECASE) + + # 9. 智能补全默认缺省参数 + def add_default_args(f: str, func_name: str, default_val: str, required_args: int) -> str: + pattern = f"{func_name}(" + result =[] + i = 0 + while i < len(f): + if f[i:i+len(pattern)] == pattern: + paren_start = i + len(pattern) - 1 + + depth = 1 + paren_end = -1 + for j in range(paren_start + 1, len(f)): + if f[j] == '(': depth += 1 + elif f[j] == ')': + depth -= 1 + if depth == 0: + paren_end = j + break + + if paren_end == -1: + result.append(f[i]) + i += 1 + continue + + args_content = f[paren_start+1:paren_end] + + # 正确统计顶层逗号数量 (修复嵌套逗号被计入的 Bug) + depth_comma = 0 + comma_count = 0 + for ch in args_content: + if ch == '(': depth_comma += 1 + elif ch == ')': depth_comma -= 1 + elif ch == ',' and depth_comma == 0: + comma_count += 1 + + arg_count = comma_count + 1 if args_content.strip() else 0 + + if arg_count < required_args: + result.append(f"{func_name}({args_content}, {default_val})") + else: + result.append(f[i:paren_end+1]) + i = paren_end + 1 + else: + result.append(f[i]) + i += 1 + return "".join(result) + + formula = add_default_args(formula, "ts_delay", "1", 2) + formula = add_default_args(formula, "ts_delta", "1", 2) + formula = add_default_args(formula, "ts_std", "20", 2) + formula = add_default_args(formula, "ts_corr", "5", 3) + formula = add_default_args(formula, "ts_sma", "1", 3) + + return formula + + +def filter_unsupported_formulas(formula: str) -> bool: + """检查公式是否包含不支持的函数/算子。""" + unsupported_patterns =[ + r"\bREGBETA\b", # OLS Beta + r"\bREGRESI\b", # OLS 残差 + r"\bSEQUENCE\b", # 生成时间序列(作自变量) + r"\bSELF\b", # 循环递归引用 + r"\bBANCHMARK\w*\b", # 基准指数(修正匹配 BANCHMARKINDEXCLOSE 等连写) + r"\bINDEX\b", # 宏观变量引入 + r"\bMKT\b", r"\bSMB\b", r"\bHML\b", # Fama-French 因子 + r"\bDTM\b", r"\bDBM\b", r"\bTR\b", r"\bHD\b", r"\bLD\b", # 复杂的外部黑盒宏 + r"\bFILTER\b", # 条件屏蔽函数 + r"\bSUMIF\b", # 条件求和函数 + ] + for pattern in unsupported_patterns: + if re.search(pattern, formula, re.IGNORECASE): + return False + return True + diff --git a/tests/test_601117_factors.py b/tests/test_601117_factors.py deleted file mode 100644 index b163f18..0000000 --- a/tests/test_601117_factors.py +++ /dev/null @@ -1,350 +0,0 @@ -"""601117.SH 因子计算测试 - 使用真实数据 - -测试目标:计算中国化学(601117.SH)在2024-2025年的以下因子: -1. return_5: 5日收益率 (close / ts_delay(close, 5) - 1) -2. return_5_rank: 5日收益率在截面上的排名 -3. ma5: 5日均线 (ts_mean(close, 5)) -4. ma10: 10日均线 (ts_mean(close, 10)) - -数据源: DuckDB 数据库中的真实日线数据 -""" - -from src.factors import FactorEngine -from src.factors.api import close, ts_mean, ts_delay, cs_rank -from src.factors.compiler import DependencyExtractor - - -def test_601117_factors(): - """测试 601117.SH 的因子计算。""" - print("=" * 80) - print("601117.SH (中国化学) 因子计算测试 - 2024-2025") - print("=" * 80) - - # ========================================================================= - # 1. 定义因子表达式 - # ========================================================================= - print("\n" + "=" * 80) - print("1. 定义因子表达式") - print("=" * 80) - - # return_5: 5日收益率 = (close / close.shift(5) - 1) - # 使用 ts_delay 获取5天前的收盘价 - return_5_expr = (close / ts_delay(close, 5)) - 1 - print("\n[1.1] return_5 = (close / ts_delay(close, 5)) - 1") - print(f" AST: {return_5_expr}") - - # return_5_rank: 5日收益率的截面排名 - return_5_rank_expr = cs_rank(return_5_expr) - print("\n[1.2] return_5_rank = cs_rank(return_5)") - print(f" AST: {return_5_rank_expr}") - - # ma5: 5日均线 - ma5_expr = ts_mean(close, 5) - print("\n[1.3] ma5 = ts_mean(close, 5)") - print(f" AST: {ma5_expr}") - - # ma10: 10日均线 - ma10_expr = ts_mean(close, 10) - print("\n[1.4] ma10 = ts_mean(close, 10)") - print(f" AST: {ma10_expr}") - - # ========================================================================= - # 1.5 打印数据来源信息 - # ========================================================================= - print("\n" + "=" * 80) - print("1.5 数据来源分析") - print("=" * 80) - - extractor = DependencyExtractor() - - expressions = { - "return_5": return_5_expr, - "return_5_rank": return_5_rank_expr, - "ma5": ma5_expr, - "ma10": ma10_expr, - } - - for name, expr in expressions.items(): - deps = extractor.extract_dependencies(expr) - print(f" 依赖字段: {deps}") - print(f" 字段说明:") - for dep in sorted(deps): - print(f" - {dep}: 基础字段 (将自动路由到对应数据表)") - - # ========================================================================= - # 2. 创建 FactorEngine 并注册因子 - # ========================================================================= - print("\n" + "=" * 80) - print("2. 注册因子到 FactorEngine") - print("=" * 80) - - engine = FactorEngine() - - engine.register("return_5", return_5_expr) - print("[2.1] 注册 return_5") - - engine.register("return_5_rank", return_5_rank_expr) - print("[2.2] 注册 return_5_rank") - - engine.register("ma5", ma5_expr) - print("[2.3] 注册 ma5") - - engine.register("ma10", ma10_expr) - print("[2.4] 注册 ma10") - - # 也注册原始 close 价格用于验证 - engine.register("close_price", close) - print("[2.5] 注册 close_price (原始收盘价)") - - print(f"\n已注册因子列表: {engine.list_registered()}") - - # ========================================================================= - # 2.5 打印执行计划数据规格 - # ========================================================================= - print("\n" + "=" * 80) - print("2.5 执行计划数据规格") - print("=" * 80) - - for name in engine.list_registered(): - plan = engine.preview_plan(name) - if plan: - print(f"\n因子: {name}") - print(f" 输出名称: {plan.output_name}") - print(f" 依赖字段: {plan.dependencies}") - print(f" 数据规格:") - for i, spec in enumerate(plan.data_specs, 1): - print(f" [{i}] 表名: {spec.table}") - print(f" 字段: {spec.columns}") - print(f" 回看天数: {spec.lookback_days}") - - # ========================================================================= - # 3. 执行计算 - # ========================================================================= - print("\n" + "=" * 80) - print("3. 执行因子计算 (20240101 - 20251231)") - print("=" * 80) - - start_date = "20240101" - end_date = "20251231" - stock_code = "601117.SH" - - print(f"\n目标股票: {stock_code}") - print(f"时间范围: {start_date} 至 {end_date}") - - try: - result = engine.compute( - factor_names=["return_5", "return_5_rank", "ma5", "ma10", "close_price"], - start_date=start_date, - end_date=end_date, - stock_codes=[stock_code], - ) - - print(f"\n计算完成!") - print(f"结果形状: {result.shape}") - print(f"结果列: {result.columns}") - - except Exception as e: - print(f"\n[错误] 计算失败: {e}") - raise - - # ========================================================================= - # 4. 结果展示与分析 - # ========================================================================= - print("\n" + "=" * 80) - print("4. 计算结果展示") - print("=" * 80) - - # 4.1 数据概览 - print("\n[4.1] 前20行数据预览:") - print(result.head(20)) - - # 4.2 按时间范围分块展示 - print("\n[4.2] 2024年上半年数据 (前10行):") - result_2024h1 = result.filter(result["trade_date"] < "20240701") - print(result_2024h1.head(10)) - - print("\n[4.3] 2024年下半年数据 (前10行):") - result_2024h2 = result.filter( - (result["trade_date"] >= "20240701") & (result["trade_date"] < "20250101") - ) - print(result_2024h2.head(10)) - - print("\n[4.4] 2025年数据 (前10行):") - result_2025 = result.filter(result["trade_date"] >= "20250101") - print(result_2025.head(10)) - - # ========================================================================= - # 5. 因子验证 - # ========================================================================= - print("\n" + "=" * 80) - print("5. 因子计算验证") - print("=" * 80) - - # 5.1 MA5/MA10 滑动窗口验证 - print("\n[5.1] 移动平均线滑动窗口验证:") - print("-" * 60) - print("验证要点: ") - print(" - ma5 前4行应为 Null (窗口未满5天)") - print(" - ma5 第5行开始应有值") - print(" - ma10 前9行应为 Null (窗口未满10天)") - print(" - ma10 第10行开始应有值") - print("-" * 60) - - # 检查前15行的空值情况 - first_15 = result.head(15) - ma5_nulls = first_15["ma5"].null_count() - ma10_nulls = first_15["ma10"].null_count() - - print(f"\n前15行统计:") - print(f" ma5 Null 数量: {ma5_nulls}/15 (预期: 4)") - print(f" ma10 Null 数量: {ma10_nulls}/15 (预期: 9)") - - if ma5_nulls == 4 and ma10_nulls == 9: - print(" [成功] 滑动窗口验证通过!") - else: - print(" [警告] 滑动窗口验证异常,请检查数据") - - # 5.2 Return_5 验证 - print("\n[5.2] 5日收益率验证:") - print("-" * 60) - print("验证要点:") - print(" - return_5 前5行应为 Null (无法计算5天前的收益)") - print(" - return_5 第6行开始应有值") - print("-" * 60) - - return_5_nulls = first_15["return_5"].null_count() - print(f"\n前15行统计:") - print(f" return_5 Null 数量: {return_5_nulls}/15 (预期: 5)") - - if return_5_nulls == 5: - print(" [成功] return_5 延迟验证通过!") - else: - print(" [警告] return_5 延迟验证异常") - - # 5.3 手动验证 MA5 计算 - print("\n[5.3] MA5 手动计算验证:") - print("-" * 60) - - # 选择第10行(索引9)进行验证 - if len(result) >= 10: - row_10 = result.row(9, named=True) - print(f"第10行数据:") - print(f" trade_date: {row_10['trade_date']}") - print(f" close_price: {row_10['close_price']:.4f}") - print(f" ma5: {row_10['ma5']:.4f}") - print(f" ma10: {row_10['ma10']:.4f}") - - # 手动计算前5天的均值 - first_10 = result.head(10) - close_list = first_10["close_price"].to_list() - manual_ma5 = sum(close_list[5:10]) / 5 - print(f"\n手动计算验证 (第6-10天 close 均值):") - print(f" close[5:10] = {[f'{c:.4f}' for c in close_list[5:10]]}") - print(f" 手动计算 ma5 = {manual_ma5:.4f}") - print(f" 引擎计算 ma5 = {row_10['ma5']:.4f}") - - if abs(manual_ma5 - row_10["ma5"]) < 0.01: - print(" [成功] MA5 计算验证通过!") - else: - print(" [警告] MA5 计算结果不一致") - - # 5.4 Return_5 手动验证 - print("\n[5.4] Return_5 手动计算验证:") - print("-" * 60) - - if len(result) >= 10: - row_10 = result.row(9, named=True) - close_day_10 = close_list[9] # 第10天的收盘价 - close_day_5 = close_list[4] # 第5天的收盘价 - - manual_return_5 = (close_day_10 / close_day_5) - 1 - print(f"第10天 return_5 验证:") - print(f" close[9] (第10天): {close_day_10:.4f}") - print(f" close[4] (第5天): {close_day_5:.4f}") - print(f" 手动计算 return_5 = {manual_return_5:.6f}") - print(f" 引擎计算 return_5 = {row_10['return_5']:.6f}") - - if abs(manual_return_5 - row_10["return_5"]) < 0.0001: - print(" [成功] Return_5 计算验证通过!") - else: - print(" [警告] Return_5 计算结果不一致") - - # ========================================================================= - # 6. 统计摘要 - # ========================================================================= - print("\n" + "=" * 80) - print("6. 因子统计摘要") - print("=" * 80) - - # 移除空值后统计 - result_valid = result.drop_nulls() - - print(f"\n总记录数: {len(result)}") - print(f"有效记录数 (去空值后): {len(result_valid)}") - - factor_cols = ["return_5", "return_5_rank", "ma5", "ma10"] - - for col in factor_cols: - if col in result.columns: - series = result[col] - null_count = series.null_count() - non_null = series.drop_nulls() - - print(f"\n{col}:") - print(f" 空值数量: {null_count} ({null_count / len(result) * 100:.2f}%)") - - if len(non_null) > 0: - print(f" 均值: {non_null.mean():.6f}") - print(f" 标准差: {non_null.std():.6f}") - print(f" 最小值: {non_null.min():.6f}") - print(f" 最大值: {non_null.max():.6f}") - - if col == "return_5_rank": - print(f" [截面排名应在 [0, 1] 区间内]") - - # ========================================================================= - # 7. 保存结果 - # ========================================================================= - print("\n" + "=" * 80) - print("7. 结果保存") - print("=" * 80) - - output_file = "tests/output/601117_factors_2024_2025.csv" - try: - result.write_csv(output_file) - print(f"\n结果已保存到: {output_file}") - except Exception as e: - print(f"\n[警告] 保存失败: {e}") - print(" (可能需要创建 tests/output 目录)") - - # ========================================================================= - # 8. 测试总结 - # ========================================================================= - print("\n" + "=" * 80) - print("8. 测试总结") - print("=" * 80) - - print("\n[测试完成] 601117.SH 因子计算测试报告:") - print("-" * 60) - print(f"目标股票: {stock_code}") - print(f"时间范围: {start_date} 至 {end_date}") - print(f"总记录数: {len(result)}") - print() - print("计算因子:") - print(" 1. return_5 - 5日收益率 (ts_delay)") - print(" 2. return_5_rank - 5日收益率截面排名 (cs_rank)") - print(" 3. ma5 - 5日均线 (ts_mean)") - print(" 4. ma10 - 10日均线 (ts_mean)") - print() - print("验证结果:") - print(" - 移动平均线滑动窗口: 正确 (ma5需5天, ma10需10天)") - print(" - 收益率延迟计算: 正确 (需5天前数据)") - print(" - 截面排名: 正常 (0-1区间)") - print(" - 数据完整性: 正常") - print("-" * 60) - - return result - - -if __name__ == "__main__": - result = test_601117_factors() diff --git a/tests/test_ast_optimizer.py b/tests/test_ast_optimizer.py deleted file mode 100644 index fe54169..0000000 --- a/tests/test_ast_optimizer.py +++ /dev/null @@ -1,367 +0,0 @@ -"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。 - -测试因子: cs_rank(ts_delay(close, 1)) -这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。 -""" - -import pytest -import polars as pl -import numpy as np -from datetime import datetime, timedelta - -from src.factors.engine import FactorEngine -from src.factors.api import close, ts_delay, cs_rank -from src.factors.dsl import FunctionNode -from src.factors.engine.ast_optimizer import ExpressionFlattener - - -def create_mock_data( - start_date: str = "20240101", - end_date: str = "20240131", - n_stocks: int = 5, -) -> pl.DataFrame: - """创建模拟的日线数据。""" - start = datetime.strptime(start_date, "%Y%m%d") - end = datetime.strptime(end_date, "%Y%m%d") - - dates = [] - current = start - while current <= end: - if current.weekday() < 5: # 周一到周五 - dates.append(current.strftime("%Y%m%d")) - current += timedelta(days=1) - - stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)] - np.random.seed(42) - - rows = [] - for date in dates: - for stock in stocks: - base_price = 10 + np.random.randn() * 5 - close_val = base_price + np.random.randn() * 0.5 - open_val = close_val + np.random.randn() * 0.2 - high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3 - low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3 - vol = int(1000000 + np.random.exponential(500000)) - - rows.append( - { - "ts_code": stock, - "trade_date": date, - "open": round(open_val, 2), - "high": round(high_val, 2), - "low": round(low_val, 2), - "close": round(close_val, 2), - "volume": vol, - } - ) - - return pl.DataFrame(rows) - - -class TestASTOptimizer: - """AST 优化器测试类。""" - - def test_flattener_basic(self): - """测试拍平器基本功能。""" - from src.factors.api import close - - flattener = ExpressionFlattener() - - # 创建嵌套表达式: cs_rank(ts_delay(close, 1)) - expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1)) - - flat_expr, tmp_factors = flattener.flatten(expr) - - # 验证临时因子被提取 - assert len(tmp_factors) == 1 - assert "__tmp_0" in tmp_factors - - # 验证主表达式使用了 Symbol 引用 - assert isinstance(flat_expr, FunctionNode) - assert flat_expr.func_name == "cs_rank" - # 验证第一个参数是临时因子引用(通过 name 属性检查) - assert hasattr(flat_expr.args[0], "name") - assert flat_expr.args[0].name == "__tmp_0" - - # 验证临时因子内容 - tmp_node = tmp_factors["__tmp_0"] - assert isinstance(tmp_node, FunctionNode) - assert tmp_node.func_name == "ts_delay" - - print("[PASS] 拍平器基本功能测试") - - def test_flattener_no_nested(self): - """测试非嵌套表达式不会被拍平。""" - from src.factors.api import close, ts_mean - - flattener = ExpressionFlattener() - - # 非嵌套表达式: ts_mean(close, 20) - expr = FunctionNode("ts_mean", close, 20) - - flat_expr, tmp_factors = flattener.flatten(expr) - - # 验证没有临时因子被提取 - assert len(tmp_factors) == 0 - - # 验证表达式保持不变 - assert isinstance(flat_expr, FunctionNode) - assert flat_expr.func_name == "ts_mean" - - print("[PASS] 非嵌套表达式测试") - - def test_flattener_deeply_nested(self): - """测试多层嵌套表达式拍平。""" - from src.factors.api import close, ts_mean - - flattener = ExpressionFlattener() - - # 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5)) - expr = FunctionNode( - "cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5) - ) - - flat_expr, tmp_factors = flattener.flatten(expr) - - # 验证提取了两个临时因子(修复后正确行为) - # ts_delay(close, 1) 被提取为 __tmp_0 - # ts_mean(__tmp_0, 5) 被提取为 __tmp_1 - assert len(tmp_factors) == 2 - assert "__tmp_0" in tmp_factors - assert "__tmp_1" in tmp_factors - - # 验证 __tmp_0 内容是 ts_delay(close, 1) - tmp0_node = tmp_factors["__tmp_0"] - assert isinstance(tmp0_node, FunctionNode) - assert tmp0_node.func_name == "ts_delay" - - # 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5) - tmp1_node = tmp_factors["__tmp_1"] - assert isinstance(tmp1_node, FunctionNode) - assert tmp1_node.func_name == "ts_mean" - from src.factors.dsl import Symbol - - assert isinstance(tmp1_node.args[0], Symbol) - assert tmp1_node.args[0].name == "__tmp_0" - - # 验证主表达式引用 __tmp_1 - assert isinstance(flat_expr, FunctionNode) - assert flat_expr.func_name == "cs_rank" - assert isinstance(flat_expr.args[0], Symbol) - assert flat_expr.args[0].name == "__tmp_1" - - print("[PASS] 多层嵌套表达式拍平测试") - - def test_nested_window_function_engine(self): - """测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。""" - print("\n" + "=" * 60) - print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))") - print("=" * 60) - - # 1. 准备数据 - mock_data = create_mock_data("20240101", "20240131", n_stocks=5) - print(f"\n生成模拟数据: {len(mock_data)} 行") - - # 2. 初始化引擎 - engine = FactorEngine(data_source={"pro_bar": mock_data}) - print("引擎初始化完成") - - # 3. 使用字符串表达式注册嵌套窗口函数 - print("\n注册因子: cs_rank(ts_delay(close, 1))") - engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))") - - # 4. 检查临时因子是否被创建 - registered_factors = engine.list_registered() - print(f"已注册因子: {registered_factors}") - - # 验证有临时因子被创建 - tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")] - assert len(tmp_factors) >= 1, "应该有临时因子被创建" - print(f"临时因子: {tmp_factors}") - - # 5. 执行计算 - print("\n执行计算...") - result = engine.compute("delayed_rank", "20240115", "20240131") - print(f"计算完成: {len(result)} 行") - - # 6. 验证结果 - assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列" - - # 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null) - non_null_values = result["delayed_rank"].drop_nulls() - if len(non_null_values) > 0: - assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间" - assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间" - - # 检查没有过多空值(考虑到开头的滞后期) - null_count = result["delayed_rank"].is_null().sum() - print(f"空值数量: {null_count}") - - # 展示部分结果 - print("\n前 10 行结果:") - sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head( - 10 - ) - print(sample.to_pandas().to_string(index=False)) - - print("\n" + "=" * 60) - print("嵌套窗口函数测试通过!") - print("=" * 60) - - def test_multiple_nested_factors(self): - """测试同时注册多个嵌套因子。""" - print("\n" + "=" * 60) - print("测试多个嵌套因子") - print("=" * 60) - - mock_data = create_mock_data("20240101", "20240131", n_stocks=5) - engine = FactorEngine(data_source={"pro_bar": mock_data}) - - # 注册多个嵌套因子(使用字符串表达式) - print("\n注册因子1: cs_rank(ts_delay(close, 1))") - engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))") - - print("注册因子2: ts_mean(cs_rank(close), 5)") - engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)") - - # 检查已注册因子 - factors = engine.list_registered() - print(f"\n已注册因子: {factors}") - - # 计算所有因子 - result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131") - - assert "rank1" in result.columns - assert "rank_mean" in result.columns - - print(f"\n结果行数: {len(result)}") - print(f"rank1 空值数: {result['rank1'].is_null().sum()}") - print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}") - - print("\n" + "=" * 60) - print("多个嵌套因子测试通过!") - print("=" * 60) - - def test_nested_vs_native_polars(self): - """对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。""" - print("\n" + "=" * 60) - print("对比测试:cs_rank(ts_delay(close, 1)) vs 原生 Polars") - print("=" * 60) - - # 1. 准备数据 - mock_data = create_mock_data("20240101", "20240131", n_stocks=5) - print(f"\n生成模拟数据: {len(mock_data)} 行") - - # 2. 使用 FactorEngine 计算嵌套因子 - engine = FactorEngine(data_source={"pro_bar": mock_data}) - print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...") - engine.register("delayed_rank", cs_rank(ts_delay(close, 1))) - engine_result = engine.compute("delayed_rank", "20240115", "20240131") - print(f"FactorEngine 结果: {len(engine_result)} 行") - - # 3. 使用原生 Polars 计算(手动分步) - print("\n使用原生 Polars 手动计算...") - # 先计算 ts_delay(close, 1) - native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns( - [pl.col("close").shift(1).over("ts_code").alias("delayed_close")] - ) - # 再计算 cs_rank - native_result = native_result.with_columns( - [ - (pl.col("delayed_close").rank() / pl.col("delayed_close").count()) - .over("trade_date") - .alias("native_delayed_rank") - ] - ) - print(f"原生 Polars 结果: {len(native_result)} 行") - - # 4. 合并结果进行对比 - comparison = engine_result.join( - native_result.select(["ts_code", "trade_date", "native_delayed_rank"]), - on=["ts_code", "trade_date"], - how="inner", - ) - - # 5. 验证数值一致性(允许微小浮点误差) - diff = comparison.with_columns( - [ - (pl.col("delayed_rank") - pl.col("native_delayed_rank")) - .abs() - .alias("diff") - ] - ) - - max_diff = diff["diff"].max() - print(f"\n最大差异: {max_diff}") - - # 过滤掉空值后比较(开头的滞后期会有空值) - non_null_diff = diff.filter(pl.col("diff").is_not_null()) - assert non_null_diff["diff"].max() < 1e-10, ( - f"数值差异过大: {non_null_diff['diff'].max()}" - ) - - print("\n" + "=" * 60) - print("数值一致性验证通过!") - print("=" * 60) - - def test_factor_reference_factor(self): - """测试因子引用另一个因子:fac2 = cs_rank(fac1)。""" - print("\n" + "=" * 60) - print("测试因子引用其他因子: fac2 = cs_rank(fac1)") - print("=" * 60) - - # 准备数据 - mock_data = create_mock_data("20240101", "20240131", n_stocks=5) - engine = FactorEngine(data_source={"pro_bar": mock_data}) - - # 1. 注册基础因子 fac1 - print("\n注册基础因子 fac1 = ts_mean(close, 5)") - from src.factors.api import ts_mean - - engine.register("fac1", ts_mean(close, 5)) - - # 2. 注册引用因子 fac2,引用 fac1 - print("注册引用因子 fac2 = cs_rank(fac1)") - engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子 - - # 3. 验证依赖关系 - registered = engine.list_registered() - print(f"\n已注册因子: {registered}") - assert "fac1" in registered - assert "fac2" in registered - - # 4. 执行计算 - print("\n执行计算...") - result = engine.compute(["fac1", "fac2"], "20240115", "20240131") - print(f"计算完成: {len(result)} 行") - - # 5. 验证结果 - assert "fac1" in result.columns, "结果中应有 fac1 列" - assert "fac2" in result.columns, "结果中应有 fac2 列" - - # fac2 是排名,应在 [0, 1] 之间 - assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间" - assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间" - - print("\n前 10 行结果:") - sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head( - 10 - ) - print(sample.to_pandas().to_string(index=False)) - - print("\n" + "=" * 60) - print("因子引用功能测试通过!") - print("=" * 60) - - -if __name__ == "__main__": - test = TestASTOptimizer() - test.test_flattener_basic() - test.test_flattener_no_nested() - test.test_flattener_deeply_nested() - test.test_nested_window_function_engine() - test.test_multiple_nested_factors() - test.test_nested_vs_native_polars() - test.test_factor_reference_factor() - print("\n所有测试通过!") diff --git a/tests/test_bugfixes.py b/tests/test_bugfixes.py deleted file mode 100644 index f94597e..0000000 --- a/tests/test_bugfixes.py +++ /dev/null @@ -1,144 +0,0 @@ -"""测试 Bug 修复: -1. 临时因子命名冲突修复验证 -2. 逻辑运算符支持验证 -""" - -import sys - -sys.path.insert(0, "D:/PyProject/ProStock") - -from src.factors.dsl import Symbol, BinaryOpNode -from src.factors.engine.ast_optimizer import ExpressionFlattener, flatten_expression - - -def test_temp_name_uniqueness(): - """测试:临时因子名称全局唯一性。""" - print("测试 1: 临时因子命名冲突修复") - print("-" * 50) - - close = Symbol("close") - open_price = Symbol("open") - - # 创建两个表达式拍平器实例 - flattener1 = ExpressionFlattener() - flattener2 = ExpressionFlattener() - - # 模拟因子 A: cs_rank(ts_delay(close, 1)) - from src.factors.dsl import FunctionNode - - expr_a = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1)) - flat_a, temps_a = flattener1.flatten(expr_a) - - # 模拟因子 B: cs_mean(ts_delay(open, 2)) - expr_b = FunctionNode("cs_mean", FunctionNode("ts_delay", open_price, 2)) - flat_b, temps_b = flattener2.flatten(expr_b) - - # 验证临时名称不冲突 - temp_names_a = set(temps_a.keys()) - temp_names_b = set(temps_b.keys()) - - print(f"因子 A 临时名称: {temp_names_a}") - print(f"因子 B 临时名称: {temp_names_b}") - - # 检查是否有名称冲突 - common_names = temp_names_a & temp_names_b - if common_names: - print(f"[失败] 发现命名冲突: {common_names}") - return False - - print("[通过] 临时因子名称全局唯一,无冲突") - return True - - -def test_logical_operators(): - """测试:逻辑运算符支持。""" - print("\n测试 2: 逻辑运算符支持") - print("-" * 50) - - # 测试 DSL 层 - close = Symbol("close") - open_price = Symbol("open") - - # 测试 & 运算符(注意 Python 运算符优先级,需要用括号) - and_expr = (close > open_price) & (close > 0) - print(f"DSL 表达式 ((close > open) & (close > 0)): {and_expr}") - assert isinstance(and_expr, BinaryOpNode), "& 应生成 BinaryOpNode" - assert and_expr.op == "&", "运算符应为 &" - print("[通过] DSL 层支持 & 运算符") - - # 测试 | 运算符(注意 Python 运算符优先级,需要用括号) - or_expr = (close < open_price) | (close < 0) - print(f"DSL 表达式 ((close < open) | (close < 0)): {or_expr}") - assert isinstance(or_expr, BinaryOpNode), "| 应生成 BinaryOpNode" - assert or_expr.op == "|", "运算符应为 |" - print("[通过] DSL 层支持 | 运算符") - - # 测试字符串解析 - from src.factors.parser import FormulaParser - from src.factors.registry import FunctionRegistry - - parser = FormulaParser(FunctionRegistry()) - - # 解析包含 & 的表达式 - try: - parsed_and = parser.parse("(close > open) & (volume > 0)") - print(f"解析器支持 & 运算符: {parsed_and}") - print("[通过] Parser 支持 & 运算符") - except Exception as e: - print(f"[失败] Parser 解析 & 失败: {e}") - return False - - # 解析包含 | 的表达式 - try: - parsed_or = parser.parse("(close < open) | (volume < 0)") - print(f"解析器支持 | 运算符: {parsed_or}") - print("[通过] Parser 支持 | 运算符") - except Exception as e: - print(f"[失败] Parser 解析 | 失败: {e}") - return False - - # 测试翻译到 Polars - from src.factors.translator import PolarsTranslator - import polars as pl - - translator = PolarsTranslator() - - try: - polars_and = translator.translate(parsed_and) - print(f"Polars 表达式 (&): {polars_and}") - print("[通过] Translator 支持 & 运算符") - except Exception as e: - print(f"[失败] Translator 翻译 & 失败: {e}") - return False - - try: - polars_or = translator.translate(parsed_or) - print(f"Polars 表达式 (|): {polars_or}") - print("[通过] Translator 支持 | 运算符") - except Exception as e: - print(f"[失败] Translator 翻译 | 失败: {e}") - return False - - return True - - -if __name__ == "__main__": - print("=" * 60) - print("Bug 修复验证测试") - print("=" * 60) - - test1_passed = test_temp_name_uniqueness() - test2_passed = test_logical_operators() - - print("\n" + "=" * 60) - print("测试结果汇总") - print("=" * 60) - print(f"临时因子命名冲突修复: {'[通过]' if test1_passed else '[失败]'}") - print(f"逻辑运算符支持: {'[通过]' if test2_passed else '[失败]'}") - - if test1_passed and test2_passed: - print("\n所有测试通过!") - sys.exit(0) - else: - print("\n存在失败的测试!") - sys.exit(1) diff --git a/tests/test_db_manager.py b/tests/test_db_manager.py deleted file mode 100644 index c57e3f9..0000000 --- a/tests/test_db_manager.py +++ /dev/null @@ -1,377 +0,0 @@ -"""Tests for DuckDB database manager and incremental sync.""" - -import pytest -import pandas as pd -from datetime import datetime, timedelta -from unittest.mock import Mock, patch, MagicMock - -from src.data.db_manager import ( - TableManager, - IncrementalSync, - SyncManager, - ensure_table, - get_table_info, - sync_table, -) - - -class TestTableManager: - """Test table creation and management.""" - - @pytest.fixture - def mock_storage(self): - """Create a mock storage instance.""" - storage = Mock() - storage._connection = Mock() - storage.exists = Mock(return_value=False) - return storage - - @pytest.fixture - def sample_data(self): - """Create sample DataFrame with ts_code and trade_date.""" - return pd.DataFrame( - { - "ts_code": ["000001.SZ", "000001.SZ", "600000.SH"], - "trade_date": ["20240101", "20240102", "20240101"], - "open": [10.0, 10.5, 20.0], - "close": [10.5, 11.0, 20.5], - "volume": [1000, 2000, 3000], - } - ) - - def test_create_table_from_dataframe(self, mock_storage, sample_data): - """Test table creation from DataFrame.""" - manager = TableManager(mock_storage) - - result = manager.create_table_from_dataframe("daily", sample_data) - - assert result is True - # Should execute CREATE TABLE - assert mock_storage._connection.execute.call_count >= 1 - - # Get the CREATE TABLE SQL - calls = mock_storage._connection.execute.call_args_list - create_table_call = None - for call in calls: - sql = call[0][0] if call[0] else call[1].get("sql", "") - if "CREATE TABLE" in str(sql): - create_table_call = sql - break - - assert create_table_call is not None - assert "ts_code" in str(create_table_call) - assert "trade_date" in str(create_table_call) - - def test_create_table_with_index(self, mock_storage, sample_data): - """Test that composite index is created for trade_date and ts_code.""" - manager = TableManager(mock_storage) - - manager.create_table_from_dataframe("daily", sample_data, create_index=True) - - # Check that index creation was called - calls = mock_storage._connection.execute.call_args_list - index_calls = [call for call in calls if "CREATE INDEX" in str(call)] - assert len(index_calls) > 0 - - def test_create_table_empty_dataframe(self, mock_storage): - """Test that empty DataFrame is rejected.""" - manager = TableManager(mock_storage) - empty_df = pd.DataFrame() - - result = manager.create_table_from_dataframe("daily", empty_df) - - assert result is False - mock_storage._connection.execute.assert_not_called() - - def test_ensure_table_exists_creates_table(self, mock_storage, sample_data): - """Test ensure_table_exists creates table if not exists.""" - mock_storage.exists.return_value = False - manager = TableManager(mock_storage) - - result = manager.ensure_table_exists("daily", sample_data) - - assert result is True - mock_storage._connection.execute.assert_called() - - def test_ensure_table_exists_already_exists(self, mock_storage): - """Test ensure_table_exists returns True if table already exists.""" - mock_storage.exists.return_value = True - manager = TableManager(mock_storage) - - result = manager.ensure_table_exists("daily", None) - - assert result is True - mock_storage._connection.execute.assert_not_called() - - -class TestIncrementalSync: - """Test incremental synchronization strategies.""" - - @pytest.fixture - def mock_storage(self): - """Create a mock storage instance.""" - storage = Mock() - storage._connection = Mock() - storage.exists = Mock(return_value=False) - storage.get_distinct_stocks = Mock(return_value=[]) - return storage - - def test_sync_strategy_new_table(self, mock_storage): - """Test strategy for non-existent table.""" - mock_storage.exists.return_value = False - sync = IncrementalSync(mock_storage) - - strategy, start, end, stocks = sync.get_sync_strategy( - "daily", "20240101", "20240131" - ) - - assert strategy == "by_date" - assert start == "20240101" - assert end == "20240131" - assert stocks is None - - def test_sync_strategy_empty_table(self, mock_storage): - """Test strategy for empty table.""" - mock_storage.exists.return_value = True - sync = IncrementalSync(mock_storage) - - # Mock get_table_stats to return empty - sync.get_table_stats = Mock( - return_value={ - "exists": True, - "row_count": 0, - "max_date": None, - } - ) - - strategy, start, end, stocks = sync.get_sync_strategy( - "daily", "20240101", "20240131" - ) - - assert strategy == "by_date" - assert start == "20240101" - assert end == "20240131" - - def test_sync_strategy_up_to_date(self, mock_storage): - """Test strategy when table is already up-to-date.""" - mock_storage.exists.return_value = True - sync = IncrementalSync(mock_storage) - - # Mock get_table_stats to show table is up-to-date - sync.get_table_stats = Mock( - return_value={ - "exists": True, - "row_count": 100, - "max_date": "20240131", - } - ) - - strategy, start, end, stocks = sync.get_sync_strategy( - "daily", "20240101", "20240131" - ) - - assert strategy == "none" - assert start is None - assert end is None - - def test_sync_strategy_incremental_by_date(self, mock_storage): - """Test incremental sync by date when new data available.""" - mock_storage.exists.return_value = True - sync = IncrementalSync(mock_storage) - - # Table has data until Jan 15 - sync.get_table_stats = Mock( - return_value={ - "exists": True, - "row_count": 100, - "max_date": "20240115", - } - ) - - strategy, start, end, stocks = sync.get_sync_strategy( - "daily", "20240101", "20240131" - ) - - assert strategy == "by_date" - assert start == "20240116" # Next day after last date - assert end == "20240131" - - def test_sync_strategy_by_stock(self, mock_storage): - """Test sync by stock for specific stocks.""" - mock_storage.exists.return_value = True - mock_storage.get_distinct_stocks.return_value = ["000001.SZ"] - sync = IncrementalSync(mock_storage) - - sync.get_table_stats = Mock( - return_value={ - "exists": True, - "row_count": 100, - "max_date": "20240131", - } - ) - - # Request 2 stocks, but only 1 exists - strategy, start, end, stocks = sync.get_sync_strategy( - "daily", "20240101", "20240131", stock_codes=["000001.SZ", "600000.SH"] - ) - - assert strategy == "by_stock" - assert "600000.SH" in stocks - assert "000001.SZ" not in stocks - - def test_sync_data_by_date(self, mock_storage): - """Test syncing data by date strategy.""" - mock_storage.exists.return_value = True - mock_storage.save = Mock(return_value={"status": "success", "rows": 1}) - sync = IncrementalSync(mock_storage) - data = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240101"], - "close": [10.0], - } - ) - - result = sync.sync_data("daily", data, strategy="by_date") - - assert result["status"] == "success" - - def test_sync_data_empty_dataframe(self, mock_storage): - """Test syncing empty DataFrame.""" - sync = IncrementalSync(mock_storage) - empty_df = pd.DataFrame() - - result = sync.sync_data("daily", empty_df) - - assert result["status"] == "skipped" - - -class TestSyncManager: - """Test high-level sync manager.""" - - @pytest.fixture - def mock_storage(self): - """Create a mock storage instance.""" - storage = Mock() - storage._connection = Mock() - storage.exists = Mock(return_value=False) - storage.save = Mock(return_value={"status": "success", "rows": 10}) - storage.get_distinct_stocks = Mock(return_value=[]) - return storage - - def test_sync_no_sync_needed(self, mock_storage): - """Test sync when no update is needed.""" - mock_storage.exists.return_value = True - manager = SyncManager(mock_storage) - - # Mock incremental_sync to return 'none' strategy - manager.incremental_sync.get_sync_strategy = Mock( - return_value=("none", None, None, None) - ) - - # Mock fetch function - fetch_func = Mock() - - result = manager.sync("daily", fetch_func, "20240101", "20240131") - - assert result["status"] == "skipped" - fetch_func.assert_not_called() - - def test_sync_fetches_data(self, mock_storage): - """Test that sync fetches data when needed.""" - mock_storage.exists.return_value = False - manager = SyncManager(mock_storage) - - # Mock table_manager - manager.table_manager.ensure_table_exists = Mock(return_value=True) - - # Mock incremental_sync - manager.incremental_sync.get_sync_strategy = Mock( - return_value=("by_date", "20240101", "20240131", None) - ) - manager.incremental_sync.sync_data = Mock( - return_value={"status": "success", "rows_inserted": 10} - ) - - # Mock fetch function returning data - fetch_func = Mock( - return_value=pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240101"], - } - ) - ) - - result = manager.sync("daily", fetch_func, "20240101", "20240131") - - fetch_func.assert_called_once() - assert result["status"] == "success" - - def test_sync_handles_fetch_error(self, mock_storage): - """Test error handling during data fetch.""" - manager = SyncManager(mock_storage) - - # Mock incremental_sync - manager.incremental_sync.get_sync_strategy = Mock( - return_value=("by_date", "20240101", "20240131", None) - ) - - # Mock fetch function that raises exception - fetch_func = Mock(side_effect=Exception("API Error")) - - result = manager.sync("daily", fetch_func, "20240101", "20240131") - - assert result["status"] == "error" - assert "API Error" in result["error"] - - -class TestConvenienceFunctions: - """Test convenience functions.""" - - @patch("src.data.db_manager.TableManager") - def test_ensure_table(self, mock_manager_class): - """Test ensure_table convenience function.""" - mock_manager = Mock() - mock_manager.ensure_table_exists = Mock(return_value=True) - mock_manager_class.return_value = mock_manager - - data = pd.DataFrame({"ts_code": ["000001.SZ"], "trade_date": ["20240101"]}) - result = ensure_table("daily", data) - - assert result is True - mock_manager.ensure_table_exists.assert_called_once_with("daily", data) - - @patch("src.data.db_manager.IncrementalSync") - def test_get_table_info(self, mock_sync_class): - """Test get_table_info convenience function.""" - mock_sync = Mock() - mock_sync.get_table_stats = Mock( - return_value={ - "exists": True, - "row_count": 100, - } - ) - mock_sync_class.return_value = mock_sync - - result = get_table_info("daily") - - assert result["exists"] is True - assert result["row_count"] == 100 - - @patch("src.data.db_manager.SyncManager") - def test_sync_table(self, mock_manager_class): - """Test sync_table convenience function.""" - mock_manager = Mock() - mock_manager.sync = Mock(return_value={"status": "success", "rows": 10}) - mock_manager_class.return_value = mock_manager - - fetch_func = Mock() - result = sync_table("daily", fetch_func, "20240101", "20240131") - - assert result["status"] == "success" - mock_manager.sync.assert_called_once() - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_factor_engine.py b/tests/test_factor_engine.py deleted file mode 100644 index 73ab8e3..0000000 --- a/tests/test_factor_engine.py +++ /dev/null @@ -1,160 +0,0 @@ -"""FactorEngine 端到端测试。 - -模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。 -""" - -import pytest -import polars as pl -import numpy as np -from datetime import datetime, timedelta - -from src.factors.engine import FactorEngine, DataSpec -from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym -from src.factors.dsl import Symbol, FunctionNode - - -def create_mock_data( - start_date: str = "20240101", - end_date: str = "20240131", - n_stocks: int = 5, -) -> pl.DataFrame: - """创建模拟的日线数据。""" - start = datetime.strptime(start_date, "%Y%m%d") - end = datetime.strptime(end_date, "%Y%m%d") - - dates = [] - current = start - while current <= end: - if current.weekday() < 5: # 周一到周五 - dates.append(current.strftime("%Y%m%d")) - current += timedelta(days=1) - - stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)] - np.random.seed(42) - - rows = [] - for date in dates: - for stock in stocks: - base_price = 10 + np.random.randn() * 5 - close_val = base_price + np.random.randn() * 0.5 - open_val = close_val + np.random.randn() * 0.2 - high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3 - low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3 - vol = int(1000000 + np.random.exponential(500000)) - amt = close_val * vol - - rows.append( - { - "ts_code": stock, - "trade_date": date, - "open": round(open_val, 2), - "high": round(high_val, 2), - "low": round(low_val, 2), - "close": round(close_val, 2), - "volume": vol, - "amount": round(amt, 2), - "pre_close": round(close_val - np.random.randn() * 0.3, 2), - } - ) - - return pl.DataFrame(rows) - - -class TestFactorEngineEndToEnd: - """FactorEngine 端到端测试类。""" - - @pytest.fixture - def mock_data(self): - """提供模拟数据的 fixture。""" - return create_mock_data("20240101", "20240131", n_stocks=5) - - @pytest.fixture - def engine(self, mock_data): - """提供配置好的 FactorEngine fixture。""" - data_source = {"pro_bar": mock_data} - return FactorEngine(data_source=data_source) - - def test_simple_symbol_expression(self, engine): - """测试简单的符号表达式。""" - engine.register("close_price", close) - result = engine.compute("close_price", "20240115", "20240120") - assert "close_price" in result.columns - assert len(result) > 0 - print("[PASS] 简单符号表达式测试") - - def test_arithmetic_expression(self, engine): - """测试算术表达式。""" - engine.register("returns", (close - open_sym) / open_sym) - result = engine.compute("returns", "20240115", "20240120") - assert "returns" in result.columns - print("[PASS] 算术表达式测试") - - def test_cs_rank_factor(self, engine): - """测试截面排名因子。""" - engine.register("price_rank", cs_rank(close)) - result = engine.compute("price_rank", "20240115", "20240120") - assert "price_rank" in result.columns - assert result["price_rank"].min() >= 0 - assert result["price_rank"].max() <= 1 - print("[PASS] 截面排名因子测试") - - -class TestFullWorkflow: - """完整工作流测试类。""" - - def test_full_workflow_demo(self): - """演示完整的因子计算工作流。""" - print("\n" + "=" * 60) - print("FactorEngine Full Workflow Demo") - print("=" * 60) - - # 1. 准备数据 - print("\nStep 1: Prepare mock data...") - mock_data = create_mock_data("20240101", "20240131", n_stocks=5) - print(f" Generated {len(mock_data)} rows") - print(f" Stocks: {mock_data['ts_code'].n_unique()}") - - # 2. 初始化引擎 - print("\nStep 2: Initialize FactorEngine...") - engine = FactorEngine(data_source={"pro_bar": mock_data}) - print(" Engine initialized") - - # 3. 注册因子 - 使用简单因子避免回看窗口问题 - print("\nStep 3: Register factors...") - engine.register("returns", (close - open_sym) / open_sym) - engine.register("price_rank", cs_rank(close)) - print(" Registered: returns, price_rank") - - # 4. 执行计算 - 使用完整日期范围 - print("\nStep 4: Compute factors...") - result = engine.compute( - ["returns", "price_rank"], - "20240115", - "20240120", - ) - print(f" Computed {len(result)} rows") - - # 5. 验证结果 - print("\nStep 5: Verify results...") - assert "returns" in result.columns - assert "price_rank" in result.columns - assert result["price_rank"].min() >= 0 - assert result["price_rank"].max() <= 1 - print(" All assertions passed") - - # 6. 展示样本 - print("\nStep 6: Sample output...") - sample = result.select( - ["ts_code", "trade_date", "close", "returns", "price_rank"] - ).head(3) - print(sample.to_pandas().to_string(index=False)) - - print("\n" + "=" * 60) - print("Workflow completed successfully!") - print("=" * 60) - - -if __name__ == "__main__": - test = TestFullWorkflow() - test.test_full_workflow_demo() - pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/test_factor_engine_metadata.py b/tests/test_factor_engine_metadata.py deleted file mode 100644 index 6752f1e..0000000 --- a/tests/test_factor_engine_metadata.py +++ /dev/null @@ -1,106 +0,0 @@ -"""FactorEngine 与 Metadata 集成测试。 - -测试 add_factor_by_name 方法的功能。 -""" - -import pytest - -from src.factors import FactorEngine -from src.factors.metadata import FactorManager - - -class TestFactorEngineMetadataIntegration: - """测试 FactorEngine 与 Metadata 的集成功能。""" - - @pytest.fixture - def metadata_file(self): - """使用 data 目录下的 factors.jsonl 文件。""" - return "data/factors.jsonl" - - def test_init_without_metadata(self): - """测试不启用 metadata 时初始化引擎。""" - engine = FactorEngine() - assert engine._metadata is None - - def test_init_with_metadata(self, metadata_file): - """测试启用 metadata 时初始化引擎。""" - engine = FactorEngine(metadata_path=metadata_file) - assert engine._metadata is not None - assert isinstance(engine._metadata, FactorManager) - - def test_add_factor_by_name_success(self, metadata_file): - """测试从 metadata 成功添加因子。""" - engine = FactorEngine(metadata_path=metadata_file) - - # 添加 return_5 因子 - result = engine.add_factor_by_name("return_5") - - # 验证链式调用返回自身 - assert result is engine - - # 验证因子已注册 - assert "return_5" in engine.list_registered() - - def test_add_factor_by_name_with_alias(self, metadata_file): - """测试使用别名添加因子。""" - engine = FactorEngine(metadata_path=metadata_file) - - # 使用不同名称注册 metadata 中的因子 - engine.add_factor_by_name("my_ma", "ma_5") - - # 验证使用别名注册的因子 - assert "my_ma" in engine.list_registered() - assert "ma_5" not in engine.list_registered() - - def test_add_factor_by_name_not_found(self, metadata_file): - """测试添加不存在的因子时抛出异常。""" - engine = FactorEngine(metadata_path=metadata_file) - - with pytest.raises(ValueError) as exc_info: - engine.add_factor_by_name("nonexistent_factor") - - assert "未找到因子" in str(exc_info.value) - assert "nonexistent_factor" in str(exc_info.value) - - def test_add_factor_by_name_without_metadata(self): - """测试未配置 metadata 时调用 add_factor_by_name 抛出异常。""" - engine = FactorEngine() # 不传入 metadata_path - - with pytest.raises(RuntimeError) as exc_info: - engine.add_factor_by_name("return_5") - - assert "未配置 metadata 路径" in str(exc_info.value) - - def test_chain_calls(self, metadata_file): - """测试链式调用。""" - engine = FactorEngine(metadata_path=metadata_file) - - # 链式添加多个因子 - ( - engine.add_factor_by_name("return_5") - .add_factor_by_name("ma_5") - .add_factor_by_name("custom_ma20", "ma_20") - ) - - # 验证所有因子都已注册 - assert "return_5" in engine.list_registered() - assert "ma_5" in engine.list_registered() - assert "custom_ma20" in engine.list_registered() - - def test_add_factor_by_name_preserves_existing_add_factor(self, metadata_file): - """测试 add_factor_by_name 不影响原有的 add_factor 方法。""" - engine = FactorEngine(metadata_path=metadata_file) - - # 使用 add_factor 添加字符串表达式 - engine.add_factor("manual_factor", "ts_mean(close, 10)") - - # 使用 add_factor_by_name 添加 metadata 因子 - engine.add_factor_by_name("return_5") - - # 验证两者都正常工作 - assert "manual_factor" in engine.list_registered() - assert "return_5" in engine.list_registered() - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_factor_integration.py b/tests/test_factor_integration.py deleted file mode 100644 index eb1dcb0..0000000 --- a/tests/test_factor_integration.py +++ /dev/null @@ -1,451 +0,0 @@ -"""因子框架集成测试脚本 - -测试目标:验证因子框架在 DuckDB 真实数据上的核心逻辑 - -测试范围: -1. 时序因子 ts_mean - 验证滑动窗口和数据隔离 -2. 截面因子 cs_rank - 验证每日独立排名和结果分布 -3. 组合运算 - 验证多字段算术运算和算子嵌套 - -排除范围:PIT 因子(使用低频财务数据) -""" - -import random -from datetime import datetime - -import polars as pl - -from src.data.catalog import DatabaseCatalog -from src.factors.engine import FactorEngine -from src.factors.api import close, open, ts_mean, cs_rank - - -def select_sample_stocks(catalog: DatabaseCatalog, n: int = 8) -> list: - """随机选择代表性股票样本。 - - 确保样本覆盖不同交易所: - - .SH: 上海证券交易所(主板、科创板) - - .SZ: 深圳证券交易所(主板、创业板) - - Args: - catalog: 数据库目录实例 - n: 需要选择的股票数量 - - Returns: - 股票代码列表 - """ - # 从 catalog 获取数据库连接 - db_path = catalog.db_path.replace("duckdb://", "").lstrip("/") - import duckdb - - conn = duckdb.connect(db_path, read_only=True) - - try: - # 获取2023年上半年的所有股票 - result = conn.execute(""" - SELECT DISTINCT ts_code - FROM daily - WHERE trade_date >= '2023-01-01' AND trade_date <= '2023-06-30' - """).fetchall() - - all_stocks = [row[0] for row in result] - - # 按交易所分类 - sh_stocks = [s for s in all_stocks if s.endswith(".SH")] - sz_stocks = [s for s in all_stocks if s.endswith(".SZ")] - - # 选择样本:确保覆盖两个交易所 - sample = [] - - # 从上海市场选择 (包含主板600/601/603/605和科创板688) - sh_main = [ - s for s in sh_stocks if s.startswith("6") and not s.startswith("688") - ] - sh_kcb = [s for s in sh_stocks if s.startswith("688")] - - # 从深圳市场选择 (包含主板000/001/002和创业板300/301) - sz_main = [s for s in sz_stocks if s.startswith("0")] - sz_cyb = [s for s in sz_stocks if s.startswith("300") or s.startswith("301")] - - # 每类选择部分股票 - if sh_main: - sample.extend(random.sample(sh_main, min(2, len(sh_main)))) - if sh_kcb: - sample.extend(random.sample(sh_kcb, min(2, len(sh_kcb)))) - if sz_main: - sample.extend(random.sample(sz_main, min(2, len(sz_main)))) - if sz_cyb: - sample.extend(random.sample(sz_cyb, min(2, len(sz_cyb)))) - - # 如果还不够,随机补充 - while len(sample) < n and len(sample) < len(all_stocks): - remaining = [s for s in all_stocks if s not in sample] - if remaining: - sample.append(random.choice(remaining)) - else: - break - - return sorted(sample[:n]) - - finally: - conn.close() - - -def run_factor_integration_test(): - """执行因子框架集成测试。""" - - print("=" * 80) - print("因子框架集成测试 - DuckDB 真实数据验证") - print("=" * 80) - - # ========================================================================= - # 1. 测试环境准备 - # ========================================================================= - print("\n" + "=" * 80) - print("1. 测试环境准备") - print("=" * 80) - - # 数据库配置 - db_path = "data/prostock.db" - db_uri = f"duckdb:///{db_path}" - - print(f"\n数据库路径: {db_path}") - print(f"数据库URI: {db_uri}") - - # 时间范围 - start_date = "20230101" - end_date = "20230630" - print(f"\n测试时间范围: {start_date} 至 {end_date}") - - # 创建 DatabaseCatalog 并发现表结构 - print("\n[1.1] 创建 DatabaseCatalog 并发现表结构...") - catalog = DatabaseCatalog(db_path) - print(f"发现表数量: {len(catalog.tables)}") - for table_name, metadata in catalog.tables.items(): - print( - f" - {table_name}: {metadata.frequency.value} (日期字段: {metadata.date_field})" - ) - - # 选择样本股票 - print("\n[1.2] 选择样本股票...") - sample_stocks = select_sample_stocks(catalog, n=8) - print(f"选中 {len(sample_stocks)} 只代表性股票:") - for stock in sample_stocks: - exchange = "上交所" if stock.endswith(".SH") else "深交所" - board = "" - if stock.startswith("688"): - board = "科创板" - elif ( - stock.startswith("600") - or stock.startswith("601") - or stock.startswith("603") - ): - board = "主板" - elif stock.startswith("300") or stock.startswith("301"): - board = "创业板" - elif ( - stock.startswith("000") - or stock.startswith("001") - or stock.startswith("002") - ): - board = "主板" - print(f" - {stock} ({exchange} {board})") - - # ========================================================================= - # 2. 因子定义 - # ========================================================================= - print("\n" + "=" * 80) - print("2. 因子定义") - print("=" * 80) - - # 创建 FactorEngine - print("\n[2.1] 创建 FactorEngine...") - engine = FactorEngine(catalog) - - # 因子 A: 时序均线 ts_mean(close, 10) - print("\n[2.2] 注册因子 A (时序均线): ts_mean(close, 10)") - print(" 验证重点: 10日滑动窗口是否正确;是否存在'数据串户'") - factor_a = ts_mean(close, 10) - engine.add_factor("factor_a_ts_mean_10", factor_a) - print(f" AST: {factor_a}") - - # 因子 B: 截面排名 cs_rank(close) - print("\n[2.3] 注册因子 B (截面排名): cs_rank(close)") - print(" 验证重点: 每天内部独立排名;结果是否严格分布在 0-1 之间") - factor_b = cs_rank(close) - engine.add_factor("factor_b_cs_rank", factor_b) - print(f" AST: {factor_b}") - - # 因子 C: 组合运算 ts_mean(close, 5) / open - print("\n[2.4] 注册因子 C (组合运算): ts_mean(close, 5) / open") - print(" 验证重点: 多字段算术运算与时序算子嵌套的稳定性") - factor_c = ts_mean(close, 5) / open - engine.add_factor("factor_c_composite", factor_c) - print(f" AST: {factor_c}") - - # 同时注册原始字段用于验证 - engine.add_factor("close_price", close) - engine.add_factor("open_price", open) - - print(f"\n已注册因子列表: {engine.list_factors()}") - - # ========================================================================= - # 3. 计算执行 - # ========================================================================= - print("\n" + "=" * 80) - print("3. 计算执行") - print("=" * 80) - - print(f"\n[3.1] 执行因子计算 ({start_date} - {end_date})...") - result_df = engine.compute( - start_date=start_date, - end_date=end_date, - db_uri=db_uri, - ) - - print(f"\n计算完成!") - print(f"结果形状: {result_df.shape}") - print(f"结果列: {result_df.columns}") - - # ========================================================================= - # 4. 调试信息:打印 Context LazyFrame 前5行 - # ========================================================================= - print("\n" + "=" * 80) - print("4. 调试信息:DataLoader 拼接后的数据预览") - print("=" * 80) - - print("\n[4.1] 重新构建 Context LazyFrame 并打印前 5 行...") - from src.data.catalog import build_context_lazyframe - - context_lf = build_context_lazyframe( - required_fields=["close", "open"], - start_date=start_date, - end_date=end_date, - db_uri=db_uri, - catalog=catalog, - ) - - print("\nContext LazyFrame 前 5 行:") - print(context_lf.fetch(5)) - - # ========================================================================= - # 5. 时序切片检查 - # ========================================================================= - print("\n" + "=" * 80) - print("5. 时序切片检查") - print("=" * 80) - - # 选择特定股票进行时序验证 - target_stock = sample_stocks[0] if sample_stocks else "000001.SZ" - print(f"\n[5.1] 筛选股票: {target_stock}") - - stock_df = result_df.filter(pl.col("ts_code") == target_stock) - print(f"该股票数据行数: {len(stock_df)}") - - print(f"\n[5.2] 打印前 15 行结果(验证 ts_mean 滑动窗口):") - print("-" * 80) - print("人工核查点:") - print(" - 前 9 行的 factor_a_ts_mean_10 应该为 Null(滑动窗口未满)") - print(" - 第 10 行开始应该有值") - print("-" * 80) - - display_cols = [ - "ts_code", - "trade_date", - "close_price", - "open_price", - "factor_a_ts_mean_10", - ] - available_cols = [c for c in display_cols if c in stock_df.columns] - print(stock_df.select(available_cols).head(15)) - - # 验证滑动窗口 - print("\n[5.3] 滑动窗口验证:") - stock_list = stock_df.select("factor_a_ts_mean_10").to_series().to_list() - null_count_first_9 = sum(1 for x in stock_list[:9] if x is None) - non_null_from_10 = sum(1 for x in stock_list[9:15] if x is not None) - - print(f" 前 9 行 Null 值数量: {null_count_first_9}/9") - print(f" 第 10-15 行非 Null 值数量: {non_null_from_10}/6") - - if null_count_first_9 == 9 and non_null_from_10 == 6: - print(" ✅ 滑动窗口验证通过!") - else: - print(" ⚠️ 滑动窗口验证异常,请检查数据") - - # ========================================================================= - # 6. 截面切片检查 - # ========================================================================= - print("\n" + "=" * 80) - print("6. 截面切片检查") - print("=" * 80) - - # 选择特定交易日 - target_date = "20230301" - print(f"\n[6.1] 筛选交易日: {target_date}") - - date_df = result_df.filter(pl.col("trade_date") == target_date) - print(f"该交易日股票数量: {len(date_df)}") - - print(f"\n[6.2] 打印该日所有股票的 close 和 cs_rank 结果:") - print("-" * 80) - print("人工核查点:") - print(" - close 最高的股票其 cs_rank 应该接近 1.0") - print(" - close 最低的股票其 cs_rank 应该接近 0.0") - print(" - cs_rank 值应该严格分布在 [0, 1] 区间") - print("-" * 80) - - # 按 close 排序显示 - display_df = date_df.select( - ["ts_code", "trade_date", "close_price", "factor_b_cs_rank"] - ) - display_df = display_df.sort("close_price", descending=True) - print(display_df) - - # 验证截面排名 - print("\n[6.3] 截面排名验证:") - rank_values = date_df.select("factor_b_cs_rank").to_series().to_list() - rank_values = [x for x in rank_values if x is not None] - - if rank_values: - min_rank = min(rank_values) - max_rank = max(rank_values) - print(f" cs_rank 最小值: {min_rank:.6f}") - print(f" cs_rank 最大值: {max_rank:.6f}") - print(f" cs_rank 值域: [{min_rank:.6f}, {max_rank:.6f}]") - - # 验证 close 最高的股票 rank 是否为 1.0 - highest_close_row = date_df.sort("close_price", descending=True).head(1) - if len(highest_close_row) > 0: - highest_rank = highest_close_row.select("factor_b_cs_rank").item() - print(f" 最高 close 股票的 cs_rank: {highest_rank:.6f}") - - if abs(highest_rank - 1.0) < 0.01: - print(" ✅ 截面排名验证通过! (最高 close 股票 rank 接近 1.0)") - else: - print(f" ⚠️ 截面排名验证异常 (期望接近 1.0,实际 {highest_rank:.6f})") - - # ========================================================================= - # 7. 数据完整性统计 - # ========================================================================= - print("\n" + "=" * 80) - print("7. 数据完整性统计") - print("=" * 80) - - factor_cols = ["factor_a_ts_mean_10", "factor_b_cs_rank", "factor_c_composite"] - - print("\n[7.1] 各因子的空值数量和描述性统计:") - print("-" * 80) - - for col in factor_cols: - if col in result_df.columns: - series = result_df.select(col).to_series() - null_count = series.null_count() - total_count = len(series) - - print(f"\n因子: {col}") - print(f" 总记录数: {total_count}") - print(f" 空值数量: {null_count} ({null_count / total_count * 100:.2f}%)") - - # 描述性统计(排除空值) - non_null_series = series.drop_nulls() - if len(non_null_series) > 0: - print(f" 描述性统计:") - print(f" Mean: {non_null_series.mean():.6f}") - print(f" Std: {non_null_series.std():.6f}") - print(f" Min: {non_null_series.min():.6f}") - print(f" Max: {non_null_series.max():.6f}") - - # ========================================================================= - # 8. 综合验证 - # ========================================================================= - print("\n" + "=" * 80) - print("8. 综合验证") - print("=" * 80) - - print("\n[8.1] 数据串户检查:") - # 检查不同股票的数据是否正确隔离 - print(" 验证方法: 检查不同股票的 trade_date 序列是否独立") - - stock_dates = {} - for stock in sample_stocks[:3]: # 检查前3只股票 - stock_data = ( - result_df.filter(pl.col("ts_code") == stock) - .select("trade_date") - .to_series() - .to_list() - ) - stock_dates[stock] = stock_data[:5] # 前5个日期 - print(f" {stock} 前5个交易日期: {stock_data[:5]}") - - # 检查日期序列是否一致(应该一致,因为是同一时间段) - dates_match = all( - dates == list(stock_dates.values())[0] for dates in stock_dates.values() - ) - if dates_match: - print(" ✅ 日期序列一致,数据对齐正确") - else: - print(" ⚠️ 日期序列不一致,请检查数据对齐") - - print("\n[8.2] 因子 C 组合运算验证:") - # 手动计算几行验证组合运算 - sample_row = result_df.filter( - (pl.col("ts_code") == target_stock) - & (pl.col("factor_a_ts_mean_10").is_not_null()) - ).head(1) - - if len(sample_row) > 0: - close_val = sample_row.select("close_price").item() - open_val = sample_row.select("open_price").item() - factor_c_val = sample_row.select("factor_c_composite").item() - - # 手动计算 ts_mean(close, 5) / open - # 注意:这里只是验证表达式结构,不是精确计算 - print(f" 样本数据:") - print(f" close: {close_val:.4f}") - print(f" open: {open_val:.4f}") - print(f" factor_c (ts_mean(close, 5) / open): {factor_c_val:.6f}") - - # 验证 factor_c 是否合理(应该接近 close/open 的某个均值) - ratio = close_val / open_val if open_val != 0 else 0 - print(f" close/open 比值: {ratio:.6f}") - print(f" ✅ 组合运算结果已生成") - - # ========================================================================= - # 9. 测试总结 - # ========================================================================= - print("\n" + "=" * 80) - print("9. 测试总结") - print("=" * 80) - - print("\n测试完成! 以下是关键验证点总结:") - print("-" * 80) - print("✅ 因子 A (ts_mean):") - print(" - 10日滑动窗口计算正确") - print(" - 前9行为Null,第10行开始有值") - print(" - 不同股票数据隔离(over(ts_code))") - print() - print("✅ 因子 B (cs_rank):") - print(" - 每日独立排名(over(trade_date))") - print(" - 结果分布在 [0, 1] 区间") - print(" - 最高close股票rank接近1.0") - print() - print("✅ 因子 C (组合运算):") - print(" - 多字段算术运算正常") - print(" - 时序算子嵌套稳定") - print() - print("✅ 数据完整性:") - print(f" - 总记录数: {len(result_df)}") - print(f" - 样本股票数: {len(sample_stocks)}") - print(f" - 时间范围: {start_date} 至 {end_date}") - print("-" * 80) - - return result_df - - -if __name__ == "__main__": - # 设置随机种子以确保可重复性 - random.seed(42) - - # 运行测试 - result = run_factor_integration_test() diff --git a/tests/test_financial_price_merge.py b/tests/test_financial_price_merge.py deleted file mode 100644 index b6bd6aa..0000000 --- a/tests/test_financial_price_merge.py +++ /dev/null @@ -1,351 +0,0 @@ -"""财务数据与行情数据拼接测试。 - -测试场景: -1. 普通财务数据:正常公告,之后无修改 -2. 隔日修改:公告后几天发布修正版 -3. 当日修改:同一天发布多版,取 update_flag=1 的 -4. 边界条件:财务数据缺失、行情数据早于最早财务数据 -""" - -import polars as pl -from datetime import date -from src.data.financial_loader import FinancialLoader - - -def create_mock_price_data() -> pl.DataFrame: - """创建模拟行情数据。""" - return pl.DataFrame( - { - "ts_code": ["000001.SZ"] * 12, - "trade_date": [ - "20240101", - "20240102", - "20240103", - "20240104", - "20240105", - "20240108", - "20240109", - "20240110", - "20240111", - "20240112", - # 添加2024-04-30之后的日期,用于测试同日不同报告期场景 - "20240501", - "20240502", - ], - "close": [ - 10.0, - 10.2, - 10.3, - 10.1, - 10.5, - 10.6, - 10.4, - 10.7, - 10.8, - 10.9, - 11.0, - 11.1, - ], - } - ) - - -def create_mock_financial_data() -> pl.DataFrame: - """创建模拟财务数据(覆盖多种场景)。 - - 场景说明: - 1. 2024-01-02 发布 2023Q3 报告(end_date=20230930) - 2. 2024-01-02 发布 2023Q3 更正版(update_flag=1) - 3. 2024-04-30 同时发布 2023年报(end_date=20231231)和 2024Q1季报(end_date=20240331) - 4. 2024-04-30 发布 2023年报更正版 - - 预期结果: - - 2024-01-02 保留 2023Q3 更正版 - - 2024-04-30 保留 2024Q1 季报(end_date 最新) - - 注意:f_ann_date 必须是 Date 类型(与数据库保持一致)。 - """ - return pl.DataFrame( - { - "ts_code": [ - "000001.SZ", - "000001.SZ", - "000001.SZ", - "000001.SZ", - "000001.SZ", - ], - "f_ann_date": [ - date(2024, 1, 2), - date(2024, 1, 2), # 同日多版 - date(2024, 4, 30), - date(2024, 4, 30), - date(2024, 4, 30), # 同日不同报告期 - ], - "end_date": [ - "20230930", - "20230930", # 2023Q3 - "20231231", - "20240331", - "20231231", # 年报和季报同一天发布 - ], - "report_type": [1, 1, 1, 1, 1], # 整数类型(与数据库一致) - "update_flag": [0, 1, 0, 0, 1], # 年报也有更正版 - "net_profit": [ - 1000000.0, - 1100000.0, # 2023Q3 - 5000000.0, - 1500000.0, - 5500000.0, # 年报更正后550万,季报150万 - ], - "revenue": [ - 5000000.0, - 5200000.0, # 2023Q3 - 20000000.0, - 8000000.0, - 22000000.0, - ], - } - ) - - -def test_financial_data_cleaning(): - """测试财务数据清洗逻辑 - 确保同日多报告期时选 end_date 最新的。""" - print("=== 测试 1: 财务数据清洗 ===") - - df_finance = create_mock_financial_data() - print("原始财务数据:") - print(df_finance) - - loader = FinancialLoader() - - # 手动执行新的清洗逻辑 - df = df_finance.filter(pl.col("report_type") == 1) - - # 添加辅助列 - df = df.with_columns( - [ - pl.col("end_date").cast(pl.Int32).alias("end_date_int"), - pl.col("update_flag") - .fill_null("0") - .cast(pl.Int32, strict=False) - .fill_null(0) - .alias("update_flag_int"), - ] - ) - - # 确定性排序 - df = df.sort(["ts_code", "f_ann_date", "end_date_int", "update_flag_int"]) - - # 累积最大报告期 - df = df.with_columns( - pl.col("end_date_int").cum_max().over("ts_code").alias("max_end_date_seen") - ) - - # 过滤历史包袱 - df = df.filter(pl.col("end_date_int") == pl.col("max_end_date_seen")) - - # 去重保留最后一条(end_date 最大的) - df = df.unique(subset=["ts_code", "f_ann_date"], keep="last") - - # 清理辅助列 - df = df.drop(["end_date_int", "update_flag_int", "max_end_date_seen"]) - df = df.sort(["ts_code", "f_ann_date"]) - - print("\n清洗后的财务数据:") - print(df) - - # 验证:应该有2条记录(2024-01-02 和 2024-04-30) - assert len(df) == 2, f"清洗后应该有2条记录,实际有 {len(df)} 条" - - # 验证:2024-01-02 的 end_date 应该是 20230930 - row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2)) - assert len(row_jan02) == 1 - assert row_jan02["end_date"][0] == "20230930" - assert row_jan02["update_flag"][0] == 1 - print("[验证 1] 2024-01-02 正确保留了 2023Q3 更正版") - - # 验证:2024-04-30 应该保留 2024Q1(end_date=20240331),而不是年报 - row_apr30 = df.filter(pl.col("f_ann_date") == date(2024, 4, 30)) - assert len(row_apr30) == 1 - assert row_apr30["end_date"][0] == "20240331", ( - f"2024-04-30 应该保留 end_date 最新的 20240331," - f"实际为 {row_apr30['end_date'][0]}" - ) - assert row_apr30["net_profit"][0] == 1500000.0 - print("[验证 2] 2024-04-30 正确保留了 2024Q1 季报(end_date 最新)") - - print("\n[通过] 财务数据清洗测试通过!") - return df - - -def test_financial_price_merge(): - """测试财务数据拼接逻辑(无未来函数验证)。""" - print("\n=== 测试 2: 财务数据与行情数据拼接 ===") - - df_price = create_mock_price_data() - df_finance_raw = create_mock_financial_data() - - loader = FinancialLoader() - - # 步骤1: 清洗财务数据(手动执行新的清洗逻辑) - # 注意:f_ann_date 已经是 Date 类型,不需要转换 - df_finance = df_finance_raw.filter(pl.col("report_type") == 1) - - # 添加辅助列 - df_finance = df_finance.with_columns( - [ - pl.col("end_date").cast(pl.Int32).alias("end_date_int"), - pl.col("update_flag") - .fill_null("0") - .cast(pl.Int32, strict=False) - .fill_null(0) - .alias("update_flag_int"), - ] - ) - - # 确定性排序 - df_finance = df_finance.sort( - ["ts_code", "f_ann_date", "end_date_int", "update_flag_int"] - ) - - # 累积最大报告期 - df_finance = df_finance.with_columns( - pl.col("end_date_int").cum_max().over("ts_code").alias("max_end_date_seen") - ) - - # 过滤历史包袱 - df_finance = df_finance.filter( - pl.col("end_date_int") == pl.col("max_end_date_seen") - ) - - # 去重保留最后一条(end_date 最大的) - df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="last") - - # 清理辅助列 - df_finance = df_finance.drop( - ["end_date_int", "update_flag_int", "max_end_date_seen"] - ) - df_finance = df_finance.sort(["ts_code", "f_ann_date"]) - - print("清洗后的财务数据:") - print(df_finance) - - # 步骤2: 转换行情数据日期为 Date 类型 - df_price = df_price.with_columns( - [pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")] - ) - df_price = df_price.sort(["ts_code", "trade_date"]) - - # 步骤3: 拼接 - financial_cols = ["net_profit", "revenue"] - merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols) - - # 步骤4: 转回字符串格式 - merged = merged.with_columns( - [pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")] - ) - - print("\n拼接结果:") - print(merged) - - # 验证无未来函数: - # 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告) - jan01 = merged.filter(pl.col("trade_date") == "20240101") - assert jan01["net_profit"].is_null().all(), ( - "2024-01-01 不应有 2023Q3 数据(尚未公告)" - ) - print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)") - - # 20240102 及之后应该看到 net_profit=1100000(update_flag=1 的版本) - jan02 = merged.filter(pl.col("trade_date") == "20240102") - assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据" - print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1)") - - # 20240104 应延续使用 2023Q3 数据 - jan04 = merged.filter(pl.col("trade_date") == "20240104") - assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据" - print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)") - - # 20240110 应延续使用 2023Q3 数据(2024-04-30 还未公告) - jan10 = merged.filter(pl.col("trade_date") == "20240110") - assert jan10["net_profit"][0] == 1100000.0, "2024-01-10 应延续使用 2023Q3 数据" - print("[验证 4] 2024-01-10 net_profit=1100000 - 正确(延续使用 2023Q3)") - - # 20240112 应继续延续使用 2023Q3 数据 - jan12 = merged.filter(pl.col("trade_date") == "20240112") - assert jan12["net_profit"][0] == 1100000.0, "2024-01-12 应继续使用 2023Q3 数据" - print("[验证 5] 2024-01-12 net_profit=1100000 - 正确(延续使用 2023Q3)") - - # 20240501 应切换到 2024Q1 数据(2024-04-30 已公告,且选择 end_date 最新的) - may01 = merged.filter(pl.col("trade_date") == "20240501") - assert may01["net_profit"][0] == 1500000.0, "2024-05-01 应切换到 2024Q1 数据" - print( - "[验证 6] 2024-05-01 net_profit=1500000 - 正确(切换到 2024Q1,end_date 最新)" - ) - - print("\n[通过] 所有验证通过,无未来函数!") - return merged - - -def test_empty_financial_data(): - """测试财务数据为空的情况。""" - print("\n=== 测试 3: 空财务数据场景 ===") - - df_price = create_mock_price_data() - df_empty = pl.DataFrame() - - loader = FinancialLoader() - - # 转换行情数据日期为 Date 类型 - df_price = df_price.with_columns( - [pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")] - ) - df_price = df_price.sort(["ts_code", "trade_date"]) - - # 拼接空财务数据 - merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"]) - - # 转回字符串格式 - merged = merged.with_columns( - [pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")] - ) - - # 验证财务列为空 - assert merged["net_profit"].is_null().all(), ( - "财务数据为空时,net_profit 应全为 null" - ) - - print("空财务数据拼接结果:") - print(merged) - print("\n[通过] 空财务数据场景测试通过!") - - -def run_all_tests(): - """运行所有测试。""" - print("开始运行财务数据拼接功能测试...\n") - print("=" * 60) - - try: - # 测试 1: 数据清洗 - test_financial_data_cleaning() - - # 测试 2: 数据拼接 - test_financial_price_merge() - - # 测试 3: 空数据场景 - test_empty_financial_data() - - print("\n" + "=" * 60) - print("所有测试通过!") - print("=" * 60) - - except AssertionError as e: - print(f"\n[失败] 测试断言失败: {e}") - raise - except Exception as e: - print(f"\n[错误] 测试执行出错: {e}") - raise - - -if __name__ == "__main__": - run_all_tests() diff --git a/tests/test_new_ts_functions.py b/tests/test_new_ts_functions.py deleted file mode 100644 index 2c9429b..0000000 --- a/tests/test_new_ts_functions.py +++ /dev/null @@ -1,130 +0,0 @@ -"""测试新增的时间序列函数和智能分发逻辑。""" - -import pytest -import polars as pl -import numpy as np -from src.factors.dsl import Symbol, FunctionNode -from src.factors.translator import PolarsTranslator - - -def test_ts_sma_translate(): - """测试 ts_sma 翻译正确。""" - close = Symbol("close") - expr = FunctionNode("ts_sma", close, 10, 5) - translator = PolarsTranslator() - result = translator.translate(expr) - assert isinstance(result, pl.Expr) - - -def test_ts_wma_translate(): - """测试 ts_wma 翻译正确。""" - close = Symbol("close") - expr = FunctionNode("ts_wma", close, 20) - translator = PolarsTranslator() - result = translator.translate(expr) - assert isinstance(result, pl.Expr) - - -def test_ts_sumac_translate(): - """测试 ts_sumac 翻译正确。""" - close = Symbol("close") - expr = FunctionNode("ts_sumac", close) - translator = PolarsTranslator() - result = translator.translate(expr) - assert isinstance(result, pl.Expr) - - -def test_max_intelligent_dispatch(): - """测试 max_ 智能分发: int -> ts_max,其他 -> element-wise max。""" - from src.factors.api import max_, close - - # 正整数 -> ts_max - result = max_(close, 20) - assert result.func_name == "ts_max" - - # 零或负数 -> element-wise max - result = max_(close, 0) - assert result.func_name == "max" - - result = max_(close, -1) - assert result.func_name == "max" - - # 浮点数 -> element-wise max - result = max_(close, 10.5) - assert result.func_name == "max" - - -def test_min_intelligent_dispatch(): - """测试 min_ 智能分发: int -> ts_min,其他 -> element-wise min。""" - from src.factors.api import min_, close - - # 正整数 -> ts_min - result = min_(close, 20) - assert result.func_name == "ts_min" - - # 零或负数 -> element-wise min - result = min_(close, 0) - assert result.func_name == "min" - - -def create_test_data() -> pl.DataFrame: - """创建测试数据。""" - np.random.seed(42) - n = 100 - return pl.DataFrame( - { - "ts_code": ["000001.SZ"] * n, - "trade_date": list(range(20240101, 20240101 + n)), - "close": np.random.randn(n).cumsum() + 100, - } - ) - - -def test_ts_sma_computation(): - """测试 ts_sma 计算与原生 Polars 一致。""" - df = create_test_data() - translator = PolarsTranslator() - - # 翻译因子 - close = Symbol("close") - expr_node = FunctionNode("ts_sma", close, 10, 5) - expr = translator.translate(expr_node) - - # 使用翻译后的表达式计算 - result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")]) - - # 原生 Polars 计算 - native = df.with_columns( - [pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")] - ) - - # 对比结果 - assert np.allclose( - result["ts_sma_result"].to_numpy()[9:], - native["native_sma"].to_numpy()[9:], - equal_nan=True, - ) - - -def test_ts_sumac_computation(): - """测试 ts_sumac 计算与原生 Polars 一致。""" - df = create_test_data() - translator = PolarsTranslator() - - close = Symbol("close") - expr_node = FunctionNode("ts_sumac", close) - expr = translator.translate(expr_node) - - result = df.select( - ["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")] - ) - - native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")]) - - assert np.allclose( - result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy() - ) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_phase1_2_factors.py b/tests/test_phase1_2_factors.py deleted file mode 100644 index acc464c..0000000 --- a/tests/test_phase1_2_factors.py +++ /dev/null @@ -1,541 +0,0 @@ -"""Phase 1-2 因子函数集成测试。 - -测试所有新实现的函数,使用字符串因子表达式形式计算因子, -并与原始 Polars 计算结果进行对比。 - -测试范围: -1. 数学函数:atan, log1p -2. 统计函数:ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema -3. TA-Lib 函数:ts_atr, ts_rsi, ts_obv -""" - -import numpy as np -import polars as pl -import pytest - -from src.factors import FormulaParser, FunctionRegistry -from src.factors.translator import PolarsTranslator, HAS_TALIB -from src.factors.engine import FactorEngine -from src.data.catalog import DatabaseCatalog - - -# ============== 测试数据准备 ============== - - -def create_test_data() -> pl.DataFrame: - """创建测试用的模拟数据。 - - 创建一个包含多只股票、多个交易日的 DataFrame, - 用于测试因子函数的计算。 - """ - np.random.seed(42) - - dates = pl.date_range( - start=pl.date(2024, 1, 1), - end=pl.date(2024, 1, 31), - interval="1d", - eager=True, - ) - - stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600001.SH"] - - data = [] - for stock in stocks: - base_price = 100 + np.random.randn() * 10 - for i, date in enumerate(dates): - price = base_price + np.random.randn() * 5 + i * 0.1 - data.append( - { - "ts_code": stock, - "trade_date": date, - "close": price, - "open": price * (1 + np.random.randn() * 0.01), - "high": price * (1 + abs(np.random.randn()) * 0.02), - "low": price * (1 - abs(np.random.randn()) * 0.02), - "vol": int(1000000 + np.random.randn() * 500000), - } - ) - - return pl.DataFrame(data) - - -# ============== 数学函数测试 ============== - - -def test_atan_function(): - """测试 atan 函数:计算反正切值。""" - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - df = pl.DataFrame( - { - "ts_code": ["A"] * 5, - "trade_date": pl.date_range( - pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True - ), - "value": [0.0, 1.0, -1.0, 0.5, -0.5], - } - ) - - # DSL 计算 - expr = parser.parse("atan(value)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 原始 Polars 计算 - result_pl = df.with_columns(pl_result=pl.col("value").arctan()).to_pandas()[ - "pl_result" - ] - - # 对比结果 - np.testing.assert_array_almost_equal( - result_dsl.values, result_pl.values, decimal=10 - ) - - -def test_log1p_function(): - """测试 log1p 函数:计算 log(1+x)。""" - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - df = pl.DataFrame( - { - "ts_code": ["A"] * 5, - "trade_date": pl.date_range( - pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True - ), - "value": [0.0, 0.1, -0.1, 1.0, -0.5], - } - ) - - # DSL 计算 - expr = parser.parse("log1p(value)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 原始 Polars 计算 - result_pl = df.with_columns(pl_result=pl.col("value").log1p()).to_pandas()[ - "pl_result" - ] - - # 对比结果 - np.testing.assert_array_almost_equal( - result_dsl.values, result_pl.values, decimal=10 - ) - - -# ============== 统计函数测试 ============== - - -def test_ts_var_function(): - """测试 ts_var 函数:滚动方差。""" - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - df = pl.DataFrame( - { - "ts_code": ["A"] * 10 + ["B"] * 10, - "trade_date": pl.date_range( - pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True - ).append( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True) - ), - "close": list(range(1, 11)) + list(range(10, 20)), - } - ) - - # DSL 计算 - expr = parser.parse("ts_var(close, 5)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = ( - df.with_columns(dsl_result=polars_expr) - .to_pandas() - .groupby("ts_code")["dsl_result"] - .apply(list) - ) - - # 原始 Polars 计算 - result_pl = ( - df.with_columns( - pl_result=pl.col("close").rolling_var(window_size=5).over("ts_code") - ) - .to_pandas() - .groupby("ts_code")["pl_result"] - .apply(list) - ) - - # 对比结果 - for stock in ["A", "B"]: - np.testing.assert_array_almost_equal( - result_dsl[stock], result_pl[stock], decimal=10 - ) - - -def test_ts_skew_function(): - """测试 ts_skew 函数:滚动偏度。""" - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - np.random.seed(42) - df = pl.DataFrame( - { - "ts_code": ["A"] * 20 + ["B"] * 20, - "trade_date": list( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) - ) - * 2, - "close": np.random.randn(40), - } - ) - - # DSL 计算 - expr = parser.parse("ts_skew(close, 10)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 原始 Polars 计算 - result_pl = df.with_columns( - pl_result=pl.col("close").rolling_skew(window_size=10).over("ts_code") - ).to_pandas()["pl_result"] - - # 对比结果 - np.testing.assert_array_almost_equal( - result_dsl.values, result_pl.values, decimal=10 - ) - - -def test_ts_kurt_function(): - """测试 ts_kurt 函数:滚动峰度。""" - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - np.random.seed(42) - df = pl.DataFrame( - { - "ts_code": ["A"] * 20 + ["B"] * 20, - "trade_date": list( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) - ) - * 2, - "close": np.random.randn(40), - } - ) - - # DSL 计算 - expr = parser.parse("ts_kurt(close, 10)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 原始 Polars 计算 - result_pl = df.with_columns( - pl_result=pl.col("close") - .rolling_map( - lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"), - window_size=10, - ) - .over("ts_code") - ).to_pandas()["pl_result"] - - # 对比结果 - np.testing.assert_array_almost_equal( - result_dsl.values, result_pl.values, decimal=10 - ) - - -def test_ts_pct_change_function(): - """测试 ts_pct_change 函数:百分比变化。""" - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - df = pl.DataFrame( - { - "ts_code": ["A"] * 5 + ["B"] * 5, - "trade_date": list( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True) - ) - * 2, - "close": [100, 105, 102, 108, 110, 50, 52, 48, 55, 60], - } - ) - - # DSL 计算 - expr = parser.parse("ts_pct_change(close, 1)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 原始 Polars 计算 - result_pl = df.with_columns( - pl_result=(pl.col("close") - pl.col("close").shift(1)) - / pl.col("close").shift(1).over("ts_code") - ).to_pandas()["pl_result"] - - # 对比结果 - np.testing.assert_array_almost_equal( - result_dsl.values, result_pl.values, decimal=10 - ) - - -def test_ts_ema_function(): - """测试 ts_ema 函数:指数移动平均。""" - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - df = pl.DataFrame( - { - "ts_code": ["A"] * 10 + ["B"] * 10, - "trade_date": list( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True) - ) - * 2, - "close": list(range(1, 11)) + list(range(10, 20)), - } - ) - - # DSL 计算 - expr = parser.parse("ts_ema(close, 5)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 原始 Polars 计算 - result_pl = df.with_columns( - pl_result=pl.col("close").ewm_mean(span=5).over("ts_code") - ).to_pandas()["pl_result"] - - # 对比结果 - np.testing.assert_array_almost_equal( - result_dsl.values, result_pl.values, decimal=10 - ) - - -# ============== TA-Lib 函数测试 ============== - - -@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") -def test_ts_atr_function(): - """测试 ts_atr 函数:平均真实波幅。""" - import talib - - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - np.random.seed(42) - df = pl.DataFrame( - { - "ts_code": ["A"] * 20 + ["B"] * 20, - "trade_date": list( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) - ) - * 2, - "high": 100 + np.random.randn(40) * 2, - "low": 98 + np.random.randn(40) * 2, - "close": 99 + np.random.randn(40) * 2, - } - ) - - # DSL 计算 - expr = parser.parse("ts_atr(high, low, close, 14)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 使用 talib 手动计算(分组计算) - result_expected = [] - for stock in ["A", "B"]: - stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() - atr = talib.ATR( - stock_df["high"].values, - stock_df["low"].values, - stock_df["close"].values, - timeperiod=14, - ) - result_expected.extend(atr) - - # 对比结果(允许小误差) - np.testing.assert_array_almost_equal( - result_dsl.values, np.array(result_expected), decimal=5 - ) - - -@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") -def test_ts_rsi_function(): - """测试 ts_rsi 函数:相对强弱指数。""" - import talib - - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - np.random.seed(42) - df = pl.DataFrame( - { - "ts_code": ["A"] * 30 + ["B"] * 30, - "trade_date": list( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True) - ) - * 2, - "close": 100 + np.cumsum(np.random.randn(60)), - } - ) - - # DSL 计算 - expr = parser.parse("ts_rsi(close, 14)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 使用 talib 手动计算(分组计算) - result_expected = [] - for stock in ["A", "B"]: - stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() - rsi = talib.RSI(stock_df["close"].values, timeperiod=14) - result_expected.extend(rsi) - - # 对比结果(允许小误差) - np.testing.assert_array_almost_equal( - result_dsl.values, np.array(result_expected), decimal=5 - ) - - -@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") -def test_ts_obv_function(): - """测试 ts_obv 函数:能量潮指标。""" - import talib - - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - np.random.seed(42) - df = pl.DataFrame( - { - "ts_code": ["A"] * 20 + ["B"] * 20, - "trade_date": list( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) - ) - * 2, - "close": 100 + np.cumsum(np.random.randn(40)), - "vol": np.random.randint(100000, 1000000, 40).astype(float), - } - ) - - # DSL 计算 - expr = parser.parse("ts_obv(close, vol)") - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] - - # 使用 talib 手动计算(分组计算) - result_expected = [] - for stock in ["A", "B"]: - stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() - obv = talib.OBV( - stock_df["close"].values, - stock_df["vol"].values, - ) - result_expected.extend(obv) - - # 对比结果(允许小误差) - np.testing.assert_array_almost_equal( - result_dsl.values, np.array(result_expected), decimal=5 - ) - - -# ============== 综合测试 ============== - - -def test_complex_factor_expressions(): - """测试复杂因子表达式的计算。""" - parser = FormulaParser(FunctionRegistry()) - - # 创建测试数据 - np.random.seed(42) - df = pl.DataFrame( - { - "ts_code": ["A"] * 30 + ["B"] * 30, - "trade_date": list( - pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True) - ) - * 2, - "close": 100 + np.cumsum(np.random.randn(60)), - } - ) - - # 测试 act_factor1: atan((ts_ema(close,5)/ts_delay(ts_ema(close,5),1)-1)*100) * 57.3 / 50 - expr = parser.parse( - "atan((ts_ema(close, 5) / ts_delay(ts_ema(close, 5), 1) - 1) * 100) * 57.3 / 50" - ) - translator = PolarsTranslator() - polars_expr = translator.translate(expr) - result = df.with_columns(factor=polars_expr) - - # 验证结果不为空 - assert len(result) == 60 - assert "factor" in result.columns - - print("复杂因子表达式测试通过") - - -# ============== 主函数 ============== - - -if __name__ == "__main__": - print("运行 Phase 1-2 因子函数测试...") - print("=" * 80) - - # 运行数学函数测试 - print("\n[数学函数测试]") - test_atan_function() - print(" ✅ atan 测试通过") - - test_log1p_function() - print(" ✅ log1p 测试通过") - - # 运行统计函数测试 - print("\n[统计函数测试]") - test_ts_var_function() - print(" ✅ ts_var 测试通过") - - test_ts_skew_function() - print(" ✅ ts_skew 测试通过") - - test_ts_kurt_function() - print(" ✅ ts_kurt 测试通过") - - test_ts_pct_change_function() - print(" ✅ ts_pct_change 测试通过") - - test_ts_ema_function() - print(" ✅ ts_ema 测试通过") - - # 运行 TA-Lib 函数测试 - print("\n[TA-Lib 函数测试]") - try: - import talib - - HAS_TALIB = True - except ImportError: - HAS_TALIB = False - print(" ⚠️ TA-Lib 未安装,跳过 TA-Lib 测试") - - if HAS_TALIB: - test_ts_atr_function() - print(" ✅ ts_atr 测试通过") - - test_ts_rsi_function() - print(" ✅ ts_rsi 测试通过") - - test_ts_obv_function() - print(" ✅ ts_obv 测试通过") - - # 运行综合测试 - print("\n[综合测试]") - test_complex_factor_expressions() - print(" ✅ 复杂因子表达式测试通过") - - print("\n" + "=" * 80) - print("所有测试通过!") diff --git a/tests/test_pro_bar.py b/tests/test_pro_bar.py deleted file mode 100644 index 7f8d282..0000000 --- a/tests/test_pro_bar.py +++ /dev/null @@ -1,421 +0,0 @@ -"""Test for pro_bar (universal market) API. - -Tests the pro_bar interface implementation: -- Backward-adjusted (后复权) data fetching -- All output fields including tor, vr, and adj_factor (default behavior) -- Multiple asset types support -- ProBarSync batch synchronization -""" - -import pytest -import pandas as pd -from unittest.mock import patch, MagicMock -from src.data.api_wrappers.api_pro_bar import ( - get_pro_bar, - ProBarSync, - sync_pro_bar, - preview_pro_bar_sync, -) - - -# Expected output fields according to api.md -EXPECTED_BASE_FIELDS = [ - "ts_code", # 股票代码 - "trade_date", # 交易日期 - "open", # 开盘价 - "high", # 最高价 - "low", # 最低价 - "close", # 收盘价 - "pre_close", # 昨收价 - "change", # 涨跌额 - "pct_chg", # 涨跌幅 - "vol", # 成交量 - "amount", # 成交额 -] - -EXPECTED_FACTOR_FIELDS = [ - "turnover_rate", # 换手率 (tor) - "volume_ratio", # 量比 (vr) -] - - -class TestGetProBar: - """Test cases for get_pro_bar function.""" - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_fetch_basic(self, mock_client_class): - """Test basic pro_bar data fetch.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - "open": [10.5], - "high": [11.0], - "low": [10.2], - "close": [10.8], - "pre_close": [10.5], - "change": [0.3], - "pct_chg": [2.86], - "vol": [100000.0], - "amount": [1080000.0], - } - ) - - # Test - result = get_pro_bar("000001.SZ", start_date="20240101", end_date="20240131") - - # Assert - assert isinstance(result, pd.DataFrame) - assert not result.empty - assert result["ts_code"].iloc[0] == "000001.SZ" - mock_client.query.assert_called_once() - # Verify pro_bar API is called - call_args = mock_client.query.call_args - assert call_args[0][0] == "pro_bar" - assert call_args[1]["ts_code"] == "000001.SZ" - # Default should use hfq (backward-adjusted) - assert call_args[1]["adj"] == "hfq" - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_default_backward_adjusted(self, mock_client_class): - """Test that default adjustment is backward (hfq).""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - "close": [100.5], - } - ) - - result = get_pro_bar("000001.SZ") - - call_args = mock_client.query.call_args - assert call_args[1]["adj"] == "hfq" - assert call_args[1]["adjfactor"] == "True" - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_default_factors_all_fields(self, mock_client_class): - """Test that default factors includes tor and vr.""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - "close": [10.8], - "turnover_rate": [2.5], - "volume_ratio": [1.2], - "adj_factor": [1.05], - } - ) - - result = get_pro_bar("000001.SZ") - - call_args = mock_client.query.call_args - # Default should include both tor and vr - assert call_args[1]["factors"] == "tor,vr" - assert "turnover_rate" in result.columns - assert "volume_ratio" in result.columns - assert "adj_factor" in result.columns - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_fetch_with_custom_factors(self, mock_client_class): - """Test fetch with custom factors.""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - "close": [10.8], - "turnover_rate": [2.5], - } - ) - - # Only request tor - result = get_pro_bar( - "000001.SZ", - start_date="20240101", - end_date="20240131", - factors=["tor"], - ) - - call_args = mock_client.query.call_args - assert call_args[1]["factors"] == "tor" - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_fetch_with_no_factors(self, mock_client_class): - """Test fetch with no factors (empty list).""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - "close": [10.8], - } - ) - - # Explicitly set factors to empty list - result = get_pro_bar( - "000001.SZ", - start_date="20240101", - end_date="20240131", - factors=[], - ) - - call_args = mock_client.query.call_args - # Should not include factors parameter - assert "factors" not in call_args[1] - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_fetch_with_ma(self, mock_client_class): - """Test fetch with moving averages.""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - "close": [10.8], - "ma_5": [10.5], - "ma_10": [10.3], - "ma_v_5": [95000.0], - } - ) - - result = get_pro_bar( - "000001.SZ", - start_date="20240101", - end_date="20240131", - ma=[5, 10], - ) - - call_args = mock_client.query.call_args - assert call_args[1]["ma"] == "5,10" - assert "ma_5" in result.columns - assert "ma_10" in result.columns - assert "ma_v_5" in result.columns - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_fetch_index_data(self, mock_client_class): - """Test fetching index data.""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SH"], - "trade_date": ["20240115"], - "close": [2900.5], - } - ) - - result = get_pro_bar( - "000001.SH", - asset="I", - start_date="20240101", - end_date="20240131", - ) - - call_args = mock_client.query.call_args - assert call_args[1]["asset"] == "I" - assert call_args[1]["ts_code"] == "000001.SH" - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_forward_adjustment(self, mock_client_class): - """Test forward adjustment (qfq).""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - "close": [10.8], - } - ) - - result = get_pro_bar("000001.SZ", adj="qfq") - - call_args = mock_client.query.call_args - assert call_args[1]["adj"] == "qfq" - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_no_adjustment(self, mock_client_class): - """Test no adjustment.""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - "close": [10.8], - } - ) - - result = get_pro_bar("000001.SZ", adj=None) - - call_args = mock_client.query.call_args - assert "adj" not in call_args[1] - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_empty_response(self, mock_client_class): - """Test handling empty response.""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame() - - result = get_pro_bar("INVALID.SZ") - - assert isinstance(result, pd.DataFrame) - assert result.empty - - @patch("src.data.api_wrappers.api_pro_bar.TushareClient") - def test_date_column_rename(self, mock_client_class): - """Test that 'date' column is renamed to 'trade_date'.""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "date": ["20240115"], # API returns 'date' instead of 'trade_date' - "close": [10.8], - } - ) - - result = get_pro_bar("000001.SZ") - - assert "trade_date" in result.columns - assert "date" not in result.columns - assert result["trade_date"].iloc[0] == "20240115" - - -class TestProBarSync: - """Test cases for ProBarSync class.""" - - @patch("src.data.api_wrappers.api_pro_bar.sync_all_stocks") - @patch("src.data.api_wrappers.api_pro_bar.pd.read_csv") - @patch("src.data.api_wrappers.api_pro_bar._get_csv_path") - def test_get_all_stock_codes(self, mock_get_path, mock_read_csv, mock_sync_stocks): - """Test getting all stock codes.""" - from pathlib import Path - from unittest.mock import MagicMock - - # Create a mock path that exists - mock_path = MagicMock(spec=Path) - mock_path.exists.return_value = True - mock_get_path.return_value = mock_path - - mock_read_csv.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ", "600000.SH"], - "list_status": ["L", "L"], - } - ) - - sync = ProBarSync() - codes = sync.get_all_stock_codes() - - assert len(codes) == 2 - assert "000001.SZ" in codes - assert "600000.SH" in codes - - @patch("src.data.api_wrappers.api_pro_bar.Storage") - def test_check_sync_needed_force_full(self, mock_storage_class): - """Test check_sync_needed with force_full=True.""" - mock_storage = MagicMock() - mock_storage_class.return_value = mock_storage - mock_storage.exists.return_value = False - - sync = ProBarSync() - needed, start, end, local_last = sync.check_sync_needed(force_full=True) - - assert needed is True - assert start == "20180101" # DEFAULT_START_DATE - assert local_last is None - @patch("src.data.api_wrappers.api_pro_bar.Storage") - def test_check_sync_needed_force_full(self, mock_storage_class): - """Test check_sync_needed with force_full=True.""" - mock_storage = MagicMock() - mock_storage_class.return_value = mock_storage - mock_storage.exists.return_value = False - - sync = ProBarSync() - needed, start, end, local_last = sync.check_sync_needed(force_full=True) - - assert needed is True - assert start == "20180101" # DEFAULT_START_DATE - assert local_last is None - - -class TestSyncProBar: - """Test cases for sync_pro_bar function.""" - - @patch("src.data.api_wrappers.api_pro_bar.ProBarSync") - def test_sync_pro_bar(self, mock_sync_class): - """Test sync_pro_bar function.""" - mock_sync = MagicMock() - mock_sync_class.return_value = mock_sync - mock_sync.sync_all.return_value = {"000001.SZ": pd.DataFrame({"close": [10.5]})} - - result = sync_pro_bar(force_full=True, max_workers=5) - - mock_sync_class.assert_called_once_with(max_workers=5) - mock_sync.sync_all.assert_called_once() - assert "000001.SZ" in result - - @patch("src.data.api_wrappers.api_pro_bar.ProBarSync") - def test_preview_pro_bar_sync(self, mock_sync_class): - """Test preview_pro_bar_sync function.""" - mock_sync = MagicMock() - mock_sync_class.return_value = mock_sync - mock_sync.preview_sync.return_value = { - "sync_needed": True, - "stock_count": 5000, - "mode": "full", - } - - result = preview_pro_bar_sync(force_full=True) - - mock_sync_class.assert_called_once_with() - mock_sync.preview_sync.assert_called_once() - assert result["sync_needed"] is True - assert result["stock_count"] == 5000 - - -class TestProBarIntegration: - """Integration tests with real Tushare API.""" - - def test_real_api_call(self): - """Test with real API (requires valid token).""" - import os - - token = os.environ.get("TUSHARE_TOKEN") - if not token: - pytest.skip("TUSHARE_TOKEN not configured") - - result = get_pro_bar( - "000001.SZ", - start_date="20240101", - end_date="20240131", - ) - - # Verify structure - assert isinstance(result, pd.DataFrame) - if not result.empty: - # Check base fields - for field in EXPECTED_BASE_FIELDS: - assert field in result.columns, f"Missing base field: {field}" - # Check factor fields (should be present by default) - for field in EXPECTED_FACTOR_FIELDS: - assert field in result.columns, f"Missing factor field: {field}" - # Check adj_factor is present (default behavior) - assert "adj_factor" in result.columns - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_stk_limit.py b/tests/test_stk_limit.py deleted file mode 100644 index 240ad43..0000000 --- a/tests/test_stk_limit.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Tests for stock limit price API wrapper.""" - -import pytest -import pandas as pd -from unittest.mock import patch, MagicMock - -from src.data.api_wrappers.api_stk_limit import ( - get_stk_limit, - sync_stk_limit, - preview_stk_limit_sync, - StkLimitSync, -) - - -class TestStkLimit: - """Test suite for stk_limit API wrapper.""" - - @patch("src.data.api_wrappers.api_stk_limit.TushareClient") - def test_get_by_date(self, mock_client_class): - """Test fetching data by trade_date.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ", "000002.SZ"], - "trade_date": ["20240625", "20240625"], - "pre_close": [10.0, 20.0], - "up_limit": [11.0, 22.0], - "down_limit": [9.0, 18.0], - } - ) - - # Test - result = get_stk_limit(trade_date="20240625") - - # Assert - assert not result.empty - assert len(result) == 2 - assert "ts_code" in result.columns - assert "trade_date" in result.columns - assert "up_limit" in result.columns - assert "down_limit" in result.columns - mock_client.query.assert_called_once_with("stk_limit", trade_date="20240625") - - @patch("src.data.api_wrappers.api_stk_limit.TushareClient") - def test_get_by_date_range(self, mock_client_class): - """Test fetching data by date range.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ", "000001.SZ"], - "trade_date": ["20240624", "20240625"], - "pre_close": [10.0, 10.5], - "up_limit": [11.0, 11.55], - "down_limit": [9.0, 9.45], - } - ) - - # Test - result = get_stk_limit(start_date="20240624", end_date="20240625") - - # Assert - assert not result.empty - assert len(result) == 2 - mock_client.query.assert_called_once_with( - "stk_limit", start_date="20240624", end_date="20240625" - ) - - @patch("src.data.api_wrappers.api_stk_limit.TushareClient") - def test_get_by_stock_code(self, mock_client_class): - """Test fetching data by stock code.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240625"], - "pre_close": [10.0], - "up_limit": [11.0], - "down_limit": [9.0], - } - ) - - # Test - result = get_stk_limit(ts_code="000001.SZ", trade_date="20240625") - - # Assert - assert not result.empty - assert len(result) == 1 - assert result.iloc[0]["ts_code"] == "000001.SZ" - mock_client.query.assert_called_once_with( - "stk_limit", trade_date="20240625", ts_code="000001.SZ" - ) - - @patch("src.data.api_wrappers.api_stk_limit.TushareClient") - def test_empty_response(self, mock_client_class): - """Test handling empty response.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame() - - # Test - result = get_stk_limit(trade_date="20240625") - - # Assert - assert result.empty - - @patch("src.data.api_wrappers.api_stk_limit.TushareClient") - def test_shared_client(self, mock_client_class): - """Test passing shared client for rate limiting.""" - # Setup mock - shared_client = MagicMock() - shared_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240625"], - "pre_close": [10.0], - "up_limit": [11.0], - "down_limit": [9.0], - } - ) - - # Test - result = get_stk_limit(trade_date="20240625", client=shared_client) - - # Assert - assert not result.empty - shared_client.query.assert_called_once() - # Verify new client was not created - mock_client_class.assert_not_called() - - -class TestStkLimitSync: - """Test suite for StkLimitSync class.""" - - @patch("src.data.api_wrappers.api_stk_limit.TushareClient") - @patch("src.data.api_wrappers.base_sync.Storage") - @patch("src.data.api_wrappers.base_sync.sync_trade_cal_cache") - def test_fetch_single_date( - self, mock_sync_cal, mock_storage_class, mock_client_class - ): - """Test fetch_single_date method.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ", "000002.SZ"], - "trade_date": ["20240625", "20240625"], - "pre_close": [10.0, 20.0], - "up_limit": [11.0, 22.0], - "down_limit": [9.0, 18.0], - } - ) - - mock_storage = MagicMock() - mock_storage_class.return_value = mock_storage - mock_storage.exists.return_value = True - mock_storage.load.return_value = pd.DataFrame() - - # Test - sync = StkLimitSync() - result = sync.fetch_single_date("20240625") - - # Assert - assert not result.empty - assert len(result) == 2 - mock_client.query.assert_called_once_with("stk_limit", trade_date="20240625") - - def test_table_schema(self): - """Test table schema definition.""" - sync = StkLimitSync() - - # Assert table configuration - assert sync.table_name == "stk_limit" - assert "ts_code" in sync.TABLE_SCHEMA - assert "trade_date" in sync.TABLE_SCHEMA - assert "pre_close" in sync.TABLE_SCHEMA - assert "up_limit" in sync.TABLE_SCHEMA - assert "down_limit" in sync.TABLE_SCHEMA - assert sync.PRIMARY_KEY == ("ts_code", "trade_date") - - -class TestSyncFunctions: - """Test suite for sync convenience functions.""" - - @patch.object(StkLimitSync, "sync_all") - def test_sync_stk_limit(self, mock_sync_all): - """Test sync_stk_limit convenience function.""" - # Setup mock - mock_sync_all.return_value = pd.DataFrame( - { - "ts_code": ["000001.SZ"], - "trade_date": ["20240625"], - "up_limit": [11.0], - "down_limit": [9.0], - } - ) - - # Test - result = sync_stk_limit(force_full=True) - - # Assert - assert not result.empty - mock_sync_all.assert_called_once_with( - force_full=True, - start_date=None, - end_date=None, - dry_run=False, - ) - - @patch.object(StkLimitSync, "preview_sync") - def test_preview_stk_limit_sync(self, mock_preview): - """Test preview_stk_limit_sync convenience function.""" - # Setup mock - mock_preview.return_value = { - "sync_needed": True, - "date_count": 10, - "start_date": "20240601", - "end_date": "20240610", - "estimated_records": 5000, - "sample_data": pd.DataFrame(), - "mode": "incremental", - } - - # Test - result = preview_stk_limit_sync() - - # Assert - assert result["sync_needed"] is True - assert result["mode"] == "incremental" - mock_preview.assert_called_once_with( - force_full=False, - start_date=None, - end_date=None, - sample_size=3, - ) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_stock_st.py b/tests/test_stock_st.py deleted file mode 100644 index 7e5b4c8..0000000 --- a/tests/test_stock_st.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Test suite for stock_st API wrapper.""" - -import pytest -import pandas as pd -from unittest.mock import patch, MagicMock - -from src.data.api_wrappers.api_stock_st import get_stock_st, sync_stock_st, StockSTSync - - -class TestStockST: - """Test suite for stock_st API wrapper.""" - - @patch("src.data.api_wrappers.api_stock_st.TushareClient") - def test_get_by_date(self, mock_client_class): - """Test fetching ST stock list by date.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["300313.SZ", "605081.SH", "300391.SZ"], - "name": ["*ST天山", "*ST太和", "*ST长药"], - "trade_date": ["20240101", "20240101", "20240101"], - "type": ["ST", "ST", "ST"], - "type_name": ["风险警示板", "风险警示板", "风险警示板"], - } - ) - - # Test - result = get_stock_st(trade_date="20240101") - - # Assert - assert not result.empty - assert len(result) == 3 - assert "ts_code" in result.columns - assert "name" in result.columns - assert "trade_date" in result.columns - assert "type" in result.columns - assert "type_name" in result.columns - mock_client.query.assert_called_once() - - @patch("src.data.api_wrappers.api_stock_st.TushareClient") - def test_get_by_stock(self, mock_client_class): - """Test fetching ST history by stock code.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["300313.SZ", "300313.SZ"], - "name": ["*ST天山", "*ST天山"], - "trade_date": ["20240101", "20240102"], - "type": ["ST", "ST"], - "type_name": ["风险警示板", "风险警示板"], - } - ) - - # Test - result = get_stock_st( - ts_code="300313.SZ", start_date="20240101", end_date="20240102" - ) - - # Assert - assert not result.empty - assert len(result) == 2 - mock_client.query.assert_called_once() - - @patch("src.data.api_wrappers.api_stock_st.TushareClient") - def test_empty_response(self, mock_client_class): - """Test handling empty response.""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame() - - result = get_stock_st(trade_date="20240101") - assert result.empty - - @patch("src.data.api_wrappers.api_stock_st.TushareClient") - def test_get_by_date_range(self, mock_client_class): - """Test fetching ST stock list by date range.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["300313.SZ"], - "name": ["*ST天山"], - "trade_date": ["20240101"], - "type": ["ST"], - "type_name": ["风险警示板"], - } - ) - - # Test - result = get_stock_st(start_date="20240101", end_date="20240131") - - # Assert - assert not result.empty - mock_client.query.assert_called_once() - - -class TestStockSTSync: - """Test suite for StockSTSync class.""" - - def test_sync_class_attributes(self): - """Test that sync class has correct attributes.""" - sync = StockSTSync() - assert sync.table_name == "stock_st" - assert sync.default_start_date == "20160101" - assert "ts_code" in sync.TABLE_SCHEMA - assert "trade_date" in sync.TABLE_SCHEMA - assert "name" in sync.TABLE_SCHEMA - assert "type" in sync.TABLE_SCHEMA - assert "type_name" in sync.TABLE_SCHEMA - assert sync.PRIMARY_KEY == ("trade_date", "ts_code") - - @patch("src.data.api_wrappers.api_stock_st.TushareClient") - def test_fetch_single_date(self, mock_client_class): - """Test fetching single date data.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame( - { - "ts_code": ["300313.SZ"], - "name": ["*ST天山"], - "trade_date": ["20240101"], - "type": ["ST"], - "type_name": ["风险警示板"], - } - ) - - # Test - sync = StockSTSync() - result = sync.fetch_single_date("20240101") - - # Assert - assert not result.empty - assert len(result) == 1 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_sync.py b/tests/test_sync.py deleted file mode 100644 index caf78c7..0000000 --- a/tests/test_sync.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Sync 接口测试规范与实现。 - -【测试规范】 -1. 所有 sync 测试只使用 2018-01-01 到 2018-04-01 的数据 -2. 只测试接口是否能正常返回数据,不测试落库逻辑 -3. 对于按股票查询的接口,只测试 000001.SZ、000002.SZ 两支股票 -4. 使用真实 API 调用,确保接口可用性 - -【测试范围】 -- get_daily: 日线数据接口(按股票) -- sync_all_stocks: 股票基础信息接口 -- sync_trade_cal_cache: 交易日历接口 -- sync_namechange: 名称变更接口 -- sync_bak_basic: 备用股票基础信息接口 -""" - -import pytest -import pandas as pd -from datetime import datetime - -# 测试用常量 -TEST_START_DATE = "20180101" -TEST_END_DATE = "20180401" -TEST_STOCK_CODES = ["000001.SZ", "000002.SZ"] - - -class TestGetDaily: - """测试日线数据 get 接口(按股票查询).""" - - def test_get_daily_single_stock(self): - """测试 get_daily 获取单只股票数据.""" - from src.data.api_wrappers.api_daily import get_daily - - result = get_daily( - ts_code=TEST_STOCK_CODES[0], - start_date=TEST_START_DATE, - end_date=TEST_END_DATE, - ) - - # 验证返回了数据 - assert isinstance(result, pd.DataFrame), "get_daily 应返回 DataFrame" - assert not result.empty, "get_daily 应返回非空数据" - - def test_get_daily_has_required_columns(self): - """测试 get_daily 返回的数据包含必要字段.""" - from src.data.api_wrappers.api_daily import get_daily - - result = get_daily( - ts_code=TEST_STOCK_CODES[0], - start_date=TEST_START_DATE, - end_date=TEST_END_DATE, - ) - - # 验证必要的列存在 - required_columns = ["ts_code", "trade_date", "open", "high", "low", "close"] - for col in required_columns: - assert col in result.columns, f"get_daily 返回应包含 {col} 列" - - def test_get_daily_multiple_stocks(self): - """测试 get_daily 获取多只股票数据.""" - from src.data.api_wrappers.api_daily import get_daily - - results = {} - for code in TEST_STOCK_CODES: - result = get_daily( - ts_code=code, - start_date=TEST_START_DATE, - end_date=TEST_END_DATE, - ) - results[code] = result - assert isinstance(result, pd.DataFrame), ( - f"get_daily({code}) 应返回 DataFrame" - ) - assert not result.empty, f"get_daily({code}) 应返回非空数据" - - -class TestSyncStockBasic: - """测试股票基础信息 sync 接口.""" - - def test_sync_all_stocks_returns_data(self): - """测试 sync_all_stocks 是否能正常返回数据.""" - from src.data.api_wrappers.api_stock_basic import sync_all_stocks - - result = sync_all_stocks() - - # 验证返回了数据 - assert isinstance(result, pd.DataFrame), "sync_all_stocks 应返回 DataFrame" - assert not result.empty, "sync_all_stocks 应返回非空数据" - - def test_sync_all_stocks_has_required_columns(self): - """测试 sync_all_stocks 返回的数据包含必要字段.""" - from src.data.api_wrappers.api_stock_basic import sync_all_stocks - - result = sync_all_stocks() - - # 验证必要的列存在 - required_columns = ["ts_code"] - for col in required_columns: - assert col in result.columns, f"sync_all_stocks 返回应包含 {col} 列" - - -class TestSyncTradeCal: - """测试交易日历 sync 接口.""" - - def test_sync_trade_cal_cache_returns_data(self): - """测试 sync_trade_cal_cache 是否能正常返回数据.""" - from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache - - result = sync_trade_cal_cache( - start_date=TEST_START_DATE, - end_date=TEST_END_DATE, - ) - - # 验证返回了数据 - assert isinstance(result, pd.DataFrame), "sync_trade_cal_cache 应返回 DataFrame" - assert not result.empty, "sync_trade_cal_cache 应返回非空数据" - - def test_sync_trade_cal_cache_has_required_columns(self): - """测试 sync_trade_cal_cache 返回的数据包含必要字段.""" - from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache - - result = sync_trade_cal_cache( - start_date=TEST_START_DATE, - end_date=TEST_END_DATE, - ) - - # 验证必要的列存在 - required_columns = ["cal_date", "is_open"] - for col in required_columns: - assert col in result.columns, f"sync_trade_cal_cache 返回应包含 {col} 列" - - -class TestSyncNamechange: - """测试名称变更 sync 接口.""" - - def test_sync_namechange_returns_data(self): - """测试 sync_namechange 是否能正常返回数据.""" - from src.data.api_wrappers.api_namechange import sync_namechange - - result = sync_namechange() - - # 验证返回了数据(可能是空 DataFrame,因为是历史变更) - assert isinstance(result, pd.DataFrame), "sync_namechange 应返回 DataFrame" - - -class TestSyncBakBasic: - """测试备用股票基础信息 sync 接口.""" - - def test_sync_bak_basic_returns_data(self): - """测试 sync_bak_basic 是否能正常返回数据.""" - from src.data.api_wrappers.api_bak_basic import sync_bak_basic - - result = sync_bak_basic( - start_date=TEST_START_DATE, - end_date=TEST_END_DATE, - ) - - # 验证返回了数据 - assert isinstance(result, pd.DataFrame), "sync_bak_basic 应返回 DataFrame" - # 注意:bak_basic 可能返回空数据,这是正常的 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_tushare_api.py b/tests/test_tushare_api.py deleted file mode 100644 index 0fc57f4..0000000 --- a/tests/test_tushare_api.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Tushare API 验证脚本 - 快速生成 pro 对象用于调试。""" - -import os - -os.environ.setdefault("DATA_PATH", "data") - -from src.data.config import get_config -import tushare as ts - -config = get_config() -token = config.tushare_token - -if not token: - raise ValueError("请在 config/.env.local 中配置 TUSHARE_TOKEN") - -pro = ts.pro_api(token) -print(f"pro_api 对象已创建,token: {token[:10]}...") - -df = pro.query('daily', ts_code='000001.SZ', start_date='20180702', end_date='20180718') -print(df)