""" 测试operator框架优化后的功能 验证因子列名设置和返回数据结构 """ import polars as pl import numpy as np from main.factor.operator_framework import StockWiseOperator, DateWiseOperator, OperatorConfig class TestStockFactorOperator(StockWiseOperator): """测试股票因子算子""" def __init__(self, window: int = 5): config = OperatorConfig( name=f"test_stock_factor_{window}", description=f"测试股票因子{window}日", required_columns=['close'], output_columns=[f'test_stock_factor_{window}'], parameters={'window': window} ) super().__init__(config) self.window = window def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """计算测试因子 - 简单移动平均""" # 计算移动平均线 ma = stock_df['close'].rolling_mean(window_size=self.window) # 返回原始df加上新列 return stock_df.with_columns( ma.alias(f'test_stock_factor_{self.window}') ) class TestDateFactorOperator(DateWiseOperator): """测试日期因子算子""" def __init__(self): config = OperatorConfig( name="test_date_factor", description="测试日期因子", required_columns=['close'], output_columns=['test_date_factor'], parameters={} ) super().__init__(config) def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """计算测试因子 - 当日涨跌幅排名""" # 计算当日涨跌幅 pct_chg = date_df['close'].pct_change() # 返回原始df加上新列 return date_df.with_columns( pct_chg.alias('test_date_factor') ) def create_test_data(): """创建测试数据""" # 创建模拟股票数据 dates = pl.date_range(pl.date(2023, 1, 1), pl.date(2023, 12, 31), "1d", eager=True).to_list() stocks = ['000001.SZ', '000002.SZ', '000003.SZ'] data = [] for stock in stocks: for date in dates: # 模拟价格数据 base_price = 10 + np.random.randn() * 2 data.append({ 'ts_code': stock, 'trade_date': date, 'close': base_price + np.random.randn() * 0.5, 'vol': np.random.randint(100000, 1000000) }) return pl.DataFrame(data) def test_stock_wise_operator(): """测试股票切面算子""" print("=== 测试股票切面算子 ===") # 创建测试数据 df = create_test_data() print(f"原始数据形状: {df.shape}") print(f"原始数据列: {df.columns}") # 创建算子 operator = TestStockFactorOperator(window=10) # 应用算子 result_df = operator(df) print(f"结果数据形状: {result_df.shape}") print(f"结果数据列: {result_df.columns}") # 检查新列是否存在 new_column = 'test_stock_factor_10' if new_column in result_df.columns: print(f"✓ 新列 '{new_column}' 成功添加") # 检查数据完整性 original_cols = ['ts_code', 'trade_date', 'close', 'vol'] for col in original_cols: if col in result_df.columns: print(f"✓ 原始列 '{col}' 保留") else: print(f"✗ 原始列 '{col}' 丢失") else: print(f"✗ 新列 '{new_column}' 未找到") # 检查数据排序 sample = result_df.filter(pl.col('ts_code') == '000001.SZ').select(['trade_date', 'close', new_column]).head(15) print("\n样本数据:") print(sample) def test_date_wise_operator(): """测试日期切面算子""" print("\n=== 测试日期切面算子 ===") # 创建测试数据 df = create_test_data() print(f"原始数据形状: {df.shape}") print(f"原始数据列: {df.columns}") # 创建算子 operator = TestDateFactorOperator() # 应用算子 result_df = operator(df) print(f"结果数据形状: {result_df.shape}") print(f"结果数据列: {result_df.columns}") # 检查新列是否存在 new_column = 'test_date_factor' if new_column in result_df.columns: print(f"✓ 新列 '{new_column}' 成功添加") else: print(f"✗ 新列 '{new_column}' 未找到") # 检查数据排序 sample = result_df.filter(pl.col('trade_date') == pl.date(2023, 1, 10)).select(['ts_code', 'close', new_column]) print("\n样本数据 (2023-01-10):") print(sample) if __name__ == "__main__": test_stock_wise_operator() test_date_wise_operator()