factor优化(暂存版)

This commit is contained in:
2025-10-14 09:45:03 +08:00
parent 7862b9739a
commit 0a942f92d1
2 changed files with 154 additions and 0 deletions

1
.env Normal file
View File

@@ -0,0 +1 @@
PYTHONPATH=${PYTHONPATH}:${workspaceFolder}

View File

@@ -0,0 +1,153 @@
"""
测试operator框架优化后的功能
验证因子列名设置和返回数据结构
"""
import polars as pl
import numpy as np
from main.factor.operator_framework import StockWiseOperator, DateWiseOperator, OperatorConfig
class TestStockFactorOperator(StockWiseOperator):
"""测试股票因子算子"""
def __init__(self, window: int = 5):
config = OperatorConfig(
name=f"test_stock_factor_{window}",
description=f"测试股票因子{window}",
required_columns=['close'],
output_columns=[f'test_stock_factor_{window}'],
parameters={'window': window}
)
super().__init__(config)
self.window = window
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""计算测试因子 - 简单移动平均"""
# 计算移动平均线
ma = stock_df['close'].rolling_mean(window_size=self.window)
# 返回原始df加上新列
return stock_df.with_columns(
ma.alias(f'test_stock_factor_{self.window}')
)
class TestDateFactorOperator(DateWiseOperator):
"""测试日期因子算子"""
def __init__(self):
config = OperatorConfig(
name="test_date_factor",
description="测试日期因子",
required_columns=['close'],
output_columns=['test_date_factor'],
parameters={}
)
super().__init__(config)
def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""计算测试因子 - 当日涨跌幅排名"""
# 计算当日涨跌幅
pct_chg = date_df['close'].pct_change()
# 返回原始df加上新列
return date_df.with_columns(
pct_chg.alias('test_date_factor')
)
def create_test_data():
"""创建测试数据"""
# 创建模拟股票数据
dates = pl.date_range(pl.date(2023, 1, 1), pl.date(2023, 12, 31), "1d", eager=True).to_list()
stocks = ['000001.SZ', '000002.SZ', '000003.SZ']
data = []
for stock in stocks:
for date in dates:
# 模拟价格数据
base_price = 10 + np.random.randn() * 2
data.append({
'ts_code': stock,
'trade_date': date,
'close': base_price + np.random.randn() * 0.5,
'vol': np.random.randint(100000, 1000000)
})
return pl.DataFrame(data)
def test_stock_wise_operator():
"""测试股票切面算子"""
print("=== 测试股票切面算子 ===")
# 创建测试数据
df = create_test_data()
print(f"原始数据形状: {df.shape}")
print(f"原始数据列: {df.columns}")
# 创建算子
operator = TestStockFactorOperator(window=10)
# 应用算子
result_df = operator(df)
print(f"结果数据形状: {result_df.shape}")
print(f"结果数据列: {result_df.columns}")
# 检查新列是否存在
new_column = 'test_stock_factor_10'
if new_column in result_df.columns:
print(f"✓ 新列 '{new_column}' 成功添加")
# 检查数据完整性
original_cols = ['ts_code', 'trade_date', 'close', 'vol']
for col in original_cols:
if col in result_df.columns:
print(f"✓ 原始列 '{col}' 保留")
else:
print(f"✗ 原始列 '{col}' 丢失")
else:
print(f"✗ 新列 '{new_column}' 未找到")
# 检查数据排序
sample = result_df.filter(pl.col('ts_code') == '000001.SZ').select(['trade_date', 'close', new_column]).head(15)
print("\n样本数据:")
print(sample)
def test_date_wise_operator():
"""测试日期切面算子"""
print("\n=== 测试日期切面算子 ===")
# 创建测试数据
df = create_test_data()
print(f"原始数据形状: {df.shape}")
print(f"原始数据列: {df.columns}")
# 创建算子
operator = TestDateFactorOperator()
# 应用算子
result_df = operator(df)
print(f"结果数据形状: {result_df.shape}")
print(f"结果数据列: {result_df.columns}")
# 检查新列是否存在
new_column = 'test_date_factor'
if new_column in result_df.columns:
print(f"✓ 新列 '{new_column}' 成功添加")
else:
print(f"✗ 新列 '{new_column}' 未找到")
# 检查数据排序
sample = result_df.filter(pl.col('trade_date') == pl.date(2023, 1, 10)).select(['ts_code', 'close', new_column])
print("\n样本数据 (2023-01-10):")
print(sample)
if __name__ == "__main__":
test_stock_wise_operator()
test_date_wise_operator()