From 0a942f92d14b57651bdc966caab9b84ea850f3b6 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Tue, 14 Oct 2025 09:45:03 +0800 Subject: [PATCH] =?UTF-8?q?factor=E4=BC=98=E5=8C=96=EF=BC=88=E6=9A=82?= =?UTF-8?q?=E5=AD=98=E7=89=88=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env | 1 + test_operator_optimization.py | 153 ++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 .env create mode 100644 test_operator_optimization.py diff --git a/.env b/.env new file mode 100644 index 0000000..5727e9e --- /dev/null +++ b/.env @@ -0,0 +1 @@ +PYTHONPATH=${PYTHONPATH}:${workspaceFolder} diff --git a/test_operator_optimization.py b/test_operator_optimization.py new file mode 100644 index 0000000..dae8331 --- /dev/null +++ b/test_operator_optimization.py @@ -0,0 +1,153 @@ +""" +测试operator框架优化后的功能 +验证因子列名设置和返回数据结构 +""" + +import polars as pl +import numpy as np +from main.factor.operator_framework import StockWiseOperator, DateWiseOperator, OperatorConfig + + +class TestStockFactorOperator(StockWiseOperator): + """测试股票因子算子""" + + def __init__(self, window: int = 5): + config = OperatorConfig( + name=f"test_stock_factor_{window}", + description=f"测试股票因子{window}日", + required_columns=['close'], + output_columns=[f'test_stock_factor_{window}'], + parameters={'window': window} + ) + super().__init__(config) + self.window = window + + def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame: + """计算测试因子 - 简单移动平均""" + # 计算移动平均线 + ma = stock_df['close'].rolling_mean(window_size=self.window) + + # 返回原始df加上新列 + return stock_df.with_columns( + ma.alias(f'test_stock_factor_{self.window}') + ) + + +class TestDateFactorOperator(DateWiseOperator): + """测试日期因子算子""" + + def __init__(self): + config = OperatorConfig( + name="test_date_factor", + description="测试日期因子", + required_columns=['close'], + output_columns=['test_date_factor'], + parameters={} + ) + super().__init__(config) + + def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame: + """计算测试因子 - 当日涨跌幅排名""" + # 计算当日涨跌幅 + pct_chg = date_df['close'].pct_change() + + # 返回原始df加上新列 + return date_df.with_columns( + pct_chg.alias('test_date_factor') + ) + + +def create_test_data(): + """创建测试数据""" + # 创建模拟股票数据 + dates = pl.date_range(pl.date(2023, 1, 1), pl.date(2023, 12, 31), "1d", eager=True).to_list() + stocks = ['000001.SZ', '000002.SZ', '000003.SZ'] + + data = [] + for stock in stocks: + for date in dates: + # 模拟价格数据 + base_price = 10 + np.random.randn() * 2 + data.append({ + 'ts_code': stock, + 'trade_date': date, + 'close': base_price + np.random.randn() * 0.5, + 'vol': np.random.randint(100000, 1000000) + }) + + return pl.DataFrame(data) + + +def test_stock_wise_operator(): + """测试股票切面算子""" + print("=== 测试股票切面算子 ===") + + # 创建测试数据 + df = create_test_data() + print(f"原始数据形状: {df.shape}") + print(f"原始数据列: {df.columns}") + + # 创建算子 + operator = TestStockFactorOperator(window=10) + + # 应用算子 + result_df = operator(df) + + print(f"结果数据形状: {result_df.shape}") + print(f"结果数据列: {result_df.columns}") + + # 检查新列是否存在 + new_column = 'test_stock_factor_10' + if new_column in result_df.columns: + print(f"✓ 新列 '{new_column}' 成功添加") + + # 检查数据完整性 + original_cols = ['ts_code', 'trade_date', 'close', 'vol'] + for col in original_cols: + if col in result_df.columns: + print(f"✓ 原始列 '{col}' 保留") + else: + print(f"✗ 原始列 '{col}' 丢失") + else: + print(f"✗ 新列 '{new_column}' 未找到") + + # 检查数据排序 + sample = result_df.filter(pl.col('ts_code') == '000001.SZ').select(['trade_date', 'close', new_column]).head(15) + print("\n样本数据:") + print(sample) + + +def test_date_wise_operator(): + """测试日期切面算子""" + print("\n=== 测试日期切面算子 ===") + + # 创建测试数据 + df = create_test_data() + print(f"原始数据形状: {df.shape}") + print(f"原始数据列: {df.columns}") + + # 创建算子 + operator = TestDateFactorOperator() + + # 应用算子 + result_df = operator(df) + + print(f"结果数据形状: {result_df.shape}") + print(f"结果数据列: {result_df.columns}") + + # 检查新列是否存在 + new_column = 'test_date_factor' + if new_column in result_df.columns: + print(f"✓ 新列 '{new_column}' 成功添加") + else: + print(f"✗ 新列 '{new_column}' 未找到") + + # 检查数据排序 + sample = result_df.filter(pl.col('trade_date') == pl.date(2023, 1, 10)).select(['ts_code', 'close', new_column]) + print("\n样本数据 (2023-01-10):") + print(sample) + + +if __name__ == "__main__": + test_stock_wise_operator() + test_date_wise_operator()