Files
NewStock/code/train/utils/utils.py
2025-04-01 00:26:15 +08:00

121 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import pandas as pd
def read_and_merge_h5_data(h5_filename, key, columns, df=None, join='left', on=['ts_code', 'trade_date'], prefix=None):
processed_columns = []
for col in columns:
if col.startswith('_'):
processed_columns.append(col[1:]) # 去掉下划线
else:
processed_columns.append(col)
# 从 HDF5 文件读取数据,选择需要的列
data = pd.read_hdf(h5_filename, key=key, columns=processed_columns)
# 修改列名,如果列名以前有 _加上 _
for col in data.columns:
if col not in columns: # 只有不在 columns 中的列才需要加下划线
new_col = f'_{col}'
data.rename(columns={col: new_col}, inplace=True)
if prefix is not None:
for col in data.columns:
if col not in ['ts_code', 'trade_date']: # 只有不在 columns 中的列才需要加下划线
new_col = f'{prefix}_{col}'
data.rename(columns={col: new_col}, inplace=True)
# 如果传入的 df 不为空,则进行合并
if df is not None and not df.empty:
print(f'{join} merge on {on}')
if 'trade_date' in on:
# 确保两个 DataFrame 都有 ts_code 和 trade_date 列
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d')
# 根据 ts_code 和 trade_date 合并
merged_df = pd.merge(df, data, on=on, how=join)
else:
# 如果 df 为空,则直接返回读取的数据
merged_df = data
return merged_df
def calculate_risk_adjusted_return(df, days=1, method='ratio', lambda_=0.5, eps=1e-8):
"""
计算单只股票的风险调整收益。
参数:
- df: DataFrame包含 'ts_code''close' 列,按日期排序(假设 'trade_date' 已排序)。
- days: 预测未来多少天的收益默认1天。
- method: 'ratio'(收益/波动率) 或 'difference'(收益 - λ * 波动率)。
- lambda_: 风险惩罚系数,仅当 method='difference' 时有效。
- eps: 防止除零的小常数。
返回:
- df添加 'risk_adj_return' 列的 DataFrame表示风险调整后的收益。
"""
# 确保数据按 ts_code 和 trade_date 排序
df = df.sort_values(by=['ts_code', 'trade_date'])
# 计算未来的对数收益率
df['future_return'] = np.log(df.groupby('ts_code')['close'].shift(-days) / df['close'])
# 计算历史收益(对数收益率)
df['historical_return'] = np.log(df.groupby('ts_code')['close'].shift(1) / df['close'])
# 计算波动率(历史收益的标准差)
df['volatility'] = df.groupby('ts_code')['historical_return'].rolling(window=days).std().reset_index(level=0,
drop=True)
# 根据选择的 method 计算风险调整收益
if method == 'ratio':
# 收益/波动率(防止除零)
df['risk_adj_return'] = df['future_return'] / (df['volatility'] + eps)
elif method == 'difference':
# 收益 - λ * 波动率
df['risk_adj_return'] = df['future_return'] - lambda_ * df['volatility']
else:
raise ValueError("Invalid method. Use 'ratio' or 'difference'.")
return df
# import polars as pl
#
# def read_and_merge_h5_data_polars(h5_filename, key, columns, df=None, join='left', on=['ts_code', 'trade_date']):
# processed_columns = []
# for col in columns:
# if col.startswith('_'):
# processed_columns.append(col[1:]) # 去掉下划线
# else:
# processed_columns.append(col)
#
# # 从 HDF5 文件读取数据,选择需要的列
# pd_df = pd.read_hdf(h5_filename, key=key, columns=processed_columns)
#
# # 将 Pandas DataFrame 转换为 Polars DataFrame
# data = pl.from_pandas(pd_df)
#
# # 修改列名,如果列名以前有 _加上 _
# data = data.rename({col: f'_{col}' for col in data.columns if col not in columns})
#
# # 如果传入的 df 不为空,则进行合并
# if df is not None and not df.is_empty():
# print(f'{join} merge on {on}')
#
# # 确保两个 DataFrame 都有 ts_code 和 trade_date 列
# # df = df.with_columns(pl.col('trade_date').str.strptime(pl.Datetime, format='%Y%m%d'))
# # data = data.with_columns(pl.col('trade_date').str.strptime(pl.Datetime, format='%Y%m%d'))
#
# # 根据 ts_code 和 trade_date 合并
# merged_df = df.join(data, on=on, how=join)
# else:
# # 如果 df 为空,则直接返回读取的数据
# merged_df = data
#
# return merged_df