factor优化(暂存版)
This commit is contained in:
153
test_operator_optimization.py
Normal file
153
test_operator_optimization.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
测试operator框架优化后的功能
|
||||
验证因子列名设置和返回数据结构
|
||||
"""
|
||||
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from main.factor.operator_framework import StockWiseOperator, DateWiseOperator, OperatorConfig
|
||||
|
||||
|
||||
class TestStockFactorOperator(StockWiseOperator):
|
||||
"""测试股票因子算子"""
|
||||
|
||||
def __init__(self, window: int = 5):
|
||||
config = OperatorConfig(
|
||||
name=f"test_stock_factor_{window}",
|
||||
description=f"测试股票因子{window}日",
|
||||
required_columns=['close'],
|
||||
output_columns=[f'test_stock_factor_{window}'],
|
||||
parameters={'window': window}
|
||||
)
|
||||
super().__init__(config)
|
||||
self.window = window
|
||||
|
||||
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""计算测试因子 - 简单移动平均"""
|
||||
# 计算移动平均线
|
||||
ma = stock_df['close'].rolling_mean(window_size=self.window)
|
||||
|
||||
# 返回原始df加上新列
|
||||
return stock_df.with_columns(
|
||||
ma.alias(f'test_stock_factor_{self.window}')
|
||||
)
|
||||
|
||||
|
||||
class TestDateFactorOperator(DateWiseOperator):
|
||||
"""测试日期因子算子"""
|
||||
|
||||
def __init__(self):
|
||||
config = OperatorConfig(
|
||||
name="test_date_factor",
|
||||
description="测试日期因子",
|
||||
required_columns=['close'],
|
||||
output_columns=['test_date_factor'],
|
||||
parameters={}
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""计算测试因子 - 当日涨跌幅排名"""
|
||||
# 计算当日涨跌幅
|
||||
pct_chg = date_df['close'].pct_change()
|
||||
|
||||
# 返回原始df加上新列
|
||||
return date_df.with_columns(
|
||||
pct_chg.alias('test_date_factor')
|
||||
)
|
||||
|
||||
|
||||
def create_test_data():
|
||||
"""创建测试数据"""
|
||||
# 创建模拟股票数据
|
||||
dates = pl.date_range(pl.date(2023, 1, 1), pl.date(2023, 12, 31), "1d", eager=True).to_list()
|
||||
stocks = ['000001.SZ', '000002.SZ', '000003.SZ']
|
||||
|
||||
data = []
|
||||
for stock in stocks:
|
||||
for date in dates:
|
||||
# 模拟价格数据
|
||||
base_price = 10 + np.random.randn() * 2
|
||||
data.append({
|
||||
'ts_code': stock,
|
||||
'trade_date': date,
|
||||
'close': base_price + np.random.randn() * 0.5,
|
||||
'vol': np.random.randint(100000, 1000000)
|
||||
})
|
||||
|
||||
return pl.DataFrame(data)
|
||||
|
||||
|
||||
def test_stock_wise_operator():
|
||||
"""测试股票切面算子"""
|
||||
print("=== 测试股票切面算子 ===")
|
||||
|
||||
# 创建测试数据
|
||||
df = create_test_data()
|
||||
print(f"原始数据形状: {df.shape}")
|
||||
print(f"原始数据列: {df.columns}")
|
||||
|
||||
# 创建算子
|
||||
operator = TestStockFactorOperator(window=10)
|
||||
|
||||
# 应用算子
|
||||
result_df = operator(df)
|
||||
|
||||
print(f"结果数据形状: {result_df.shape}")
|
||||
print(f"结果数据列: {result_df.columns}")
|
||||
|
||||
# 检查新列是否存在
|
||||
new_column = 'test_stock_factor_10'
|
||||
if new_column in result_df.columns:
|
||||
print(f"✓ 新列 '{new_column}' 成功添加")
|
||||
|
||||
# 检查数据完整性
|
||||
original_cols = ['ts_code', 'trade_date', 'close', 'vol']
|
||||
for col in original_cols:
|
||||
if col in result_df.columns:
|
||||
print(f"✓ 原始列 '{col}' 保留")
|
||||
else:
|
||||
print(f"✗ 原始列 '{col}' 丢失")
|
||||
else:
|
||||
print(f"✗ 新列 '{new_column}' 未找到")
|
||||
|
||||
# 检查数据排序
|
||||
sample = result_df.filter(pl.col('ts_code') == '000001.SZ').select(['trade_date', 'close', new_column]).head(15)
|
||||
print("\n样本数据:")
|
||||
print(sample)
|
||||
|
||||
|
||||
def test_date_wise_operator():
|
||||
"""测试日期切面算子"""
|
||||
print("\n=== 测试日期切面算子 ===")
|
||||
|
||||
# 创建测试数据
|
||||
df = create_test_data()
|
||||
print(f"原始数据形状: {df.shape}")
|
||||
print(f"原始数据列: {df.columns}")
|
||||
|
||||
# 创建算子
|
||||
operator = TestDateFactorOperator()
|
||||
|
||||
# 应用算子
|
||||
result_df = operator(df)
|
||||
|
||||
print(f"结果数据形状: {result_df.shape}")
|
||||
print(f"结果数据列: {result_df.columns}")
|
||||
|
||||
# 检查新列是否存在
|
||||
new_column = 'test_date_factor'
|
||||
if new_column in result_df.columns:
|
||||
print(f"✓ 新列 '{new_column}' 成功添加")
|
||||
else:
|
||||
print(f"✗ 新列 '{new_column}' 未找到")
|
||||
|
||||
# 检查数据排序
|
||||
sample = result_df.filter(pl.col('trade_date') == pl.date(2023, 1, 10)).select(['ts_code', 'close', new_column])
|
||||
print("\n样本数据 (2023-01-10):")
|
||||
print(sample)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_stock_wise_operator()
|
||||
test_date_wise_operator()
|
||||
Reference in New Issue
Block a user