154 lines
4.7 KiB
Python
154 lines
4.7 KiB
Python
"""
|
|
测试operator框架优化后的功能
|
|
验证因子列名设置和返回数据结构
|
|
"""
|
|
|
|
import polars as pl
|
|
import numpy as np
|
|
from main.factor.operator_framework import StockWiseOperator, DateWiseOperator, OperatorConfig
|
|
|
|
|
|
class TestStockFactorOperator(StockWiseOperator):
|
|
"""测试股票因子算子"""
|
|
|
|
def __init__(self, window: int = 5):
|
|
config = OperatorConfig(
|
|
name=f"test_stock_factor_{window}",
|
|
description=f"测试股票因子{window}日",
|
|
required_columns=['close'],
|
|
output_columns=[f'test_stock_factor_{window}'],
|
|
parameters={'window': window}
|
|
)
|
|
super().__init__(config)
|
|
self.window = window
|
|
|
|
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
|
"""计算测试因子 - 简单移动平均"""
|
|
# 计算移动平均线
|
|
ma = stock_df['close'].rolling_mean(window_size=self.window)
|
|
|
|
# 返回原始df加上新列
|
|
return stock_df.with_columns(
|
|
ma.alias(f'test_stock_factor_{self.window}')
|
|
)
|
|
|
|
|
|
class TestDateFactorOperator(DateWiseOperator):
|
|
"""测试日期因子算子"""
|
|
|
|
def __init__(self):
|
|
config = OperatorConfig(
|
|
name="test_date_factor",
|
|
description="测试日期因子",
|
|
required_columns=['close'],
|
|
output_columns=['test_date_factor'],
|
|
parameters={}
|
|
)
|
|
super().__init__(config)
|
|
|
|
def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
|
"""计算测试因子 - 当日涨跌幅排名"""
|
|
# 计算当日涨跌幅
|
|
pct_chg = date_df['close'].pct_change()
|
|
|
|
# 返回原始df加上新列
|
|
return date_df.with_columns(
|
|
pct_chg.alias('test_date_factor')
|
|
)
|
|
|
|
|
|
def create_test_data():
|
|
"""创建测试数据"""
|
|
# 创建模拟股票数据
|
|
dates = pl.date_range(pl.date(2023, 1, 1), pl.date(2023, 12, 31), "1d", eager=True).to_list()
|
|
stocks = ['000001.SZ', '000002.SZ', '000003.SZ']
|
|
|
|
data = []
|
|
for stock in stocks:
|
|
for date in dates:
|
|
# 模拟价格数据
|
|
base_price = 10 + np.random.randn() * 2
|
|
data.append({
|
|
'ts_code': stock,
|
|
'trade_date': date,
|
|
'close': base_price + np.random.randn() * 0.5,
|
|
'vol': np.random.randint(100000, 1000000)
|
|
})
|
|
|
|
return pl.DataFrame(data)
|
|
|
|
|
|
def test_stock_wise_operator():
|
|
"""测试股票切面算子"""
|
|
print("=== 测试股票切面算子 ===")
|
|
|
|
# 创建测试数据
|
|
df = create_test_data()
|
|
print(f"原始数据形状: {df.shape}")
|
|
print(f"原始数据列: {df.columns}")
|
|
|
|
# 创建算子
|
|
operator = TestStockFactorOperator(window=10)
|
|
|
|
# 应用算子
|
|
result_df = operator(df)
|
|
|
|
print(f"结果数据形状: {result_df.shape}")
|
|
print(f"结果数据列: {result_df.columns}")
|
|
|
|
# 检查新列是否存在
|
|
new_column = 'test_stock_factor_10'
|
|
if new_column in result_df.columns:
|
|
print(f"✓ 新列 '{new_column}' 成功添加")
|
|
|
|
# 检查数据完整性
|
|
original_cols = ['ts_code', 'trade_date', 'close', 'vol']
|
|
for col in original_cols:
|
|
if col in result_df.columns:
|
|
print(f"✓ 原始列 '{col}' 保留")
|
|
else:
|
|
print(f"✗ 原始列 '{col}' 丢失")
|
|
else:
|
|
print(f"✗ 新列 '{new_column}' 未找到")
|
|
|
|
# 检查数据排序
|
|
sample = result_df.filter(pl.col('ts_code') == '000001.SZ').select(['trade_date', 'close', new_column]).head(15)
|
|
print("\n样本数据:")
|
|
print(sample)
|
|
|
|
|
|
def test_date_wise_operator():
|
|
"""测试日期切面算子"""
|
|
print("\n=== 测试日期切面算子 ===")
|
|
|
|
# 创建测试数据
|
|
df = create_test_data()
|
|
print(f"原始数据形状: {df.shape}")
|
|
print(f"原始数据列: {df.columns}")
|
|
|
|
# 创建算子
|
|
operator = TestDateFactorOperator()
|
|
|
|
# 应用算子
|
|
result_df = operator(df)
|
|
|
|
print(f"结果数据形状: {result_df.shape}")
|
|
print(f"结果数据列: {result_df.columns}")
|
|
|
|
# 检查新列是否存在
|
|
new_column = 'test_date_factor'
|
|
if new_column in result_df.columns:
|
|
print(f"✓ 新列 '{new_column}' 成功添加")
|
|
else:
|
|
print(f"✗ 新列 '{new_column}' 未找到")
|
|
|
|
# 检查数据排序
|
|
sample = result_df.filter(pl.col('trade_date') == pl.date(2023, 1, 10)).select(['ts_code', 'close', new_column])
|
|
print("\n样本数据 (2023-01-10):")
|
|
print(sample)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_stock_wise_operator()
|
|
test_date_wise_operator()
|