Files
NewStock/main/factor/save_factor.py
2025-04-28 11:02:52 +08:00

223 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from tqdm import tqdm
from main.factor.factor import get_rolling_factor, get_simple_factor
from main.utils.utils import read_and_merge_h5_data
import pandas as pd
def create_factor_table_clickhouse(clickhouse_host: str, clickhouse_port: int,
clickhouse_user: str, clickhouse_password: str,
clickhouse_database: str, table_name: str = 'factor_data'):
"""
在 ClickHouse 中创建 factor_data 表,考虑读取速度。
"""
try:
print('create factor table')
client = Client(host=clickhouse_host, port=clickhouse_port, user=clickhouse_user,
password=clickhouse_password, database=clickhouse_database)
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {table_name}
(
date Date,
asset_id String,
factor_name String,
factor_value Float64
)
ENGINE = MergeTree()
PARTITION BY toYYYYMM(date)
ORDER BY (date, asset_id, factor_name)
"""
client.execute(create_table_query)
print(f"成功在 ClickHouse 数据库 '{clickhouse_database}' 中创建表 '{table_name}'!")
except Exception as e:
print(f"创建 ClickHouse 表发生错误: {e}")
finally:
if 'client' in locals() and client.connection:
client.disconnect()
def write_features_to_clickhouse(df: pd.DataFrame, feature_columns: list,
clickhouse_host: str, clickhouse_port: int,
clickhouse_user: str, clickhouse_password: str,
clickhouse_database: str, table_name: str = 'stock_factor',
batch_size: int = 5000): # 设置批次大小
"""
将 DataFrame 中指定的特征列分批写入 ClickHouse 的宽表,动态添加列。
"""
try:
client = Client(host=clickhouse_host, port=clickhouse_port, user=clickhouse_user,
password=clickhouse_password, database=clickhouse_database)
if 'ts_code' not in df.columns or 'trade_date' not in df.columns:
raise ValueError("DataFrame 必须包含 'ts_code''trade_date' 列。")
existing_columns = set()
columns_query = f"DESCRIBE TABLE {table_name}"
columns_result = client.execute(columns_query)
for col in columns_result:
existing_columns.add(col[0])
for factor_name in feature_columns:
if factor_name not in existing_columns:
if factor_name not in df.columns:
print(f"警告: 特征 '{factor_name}' 不存在于 DataFrame 中,将跳过添加列。")
continue
factor_series = df[factor_name]
factor_dtype = factor_series.dtype
clickhouse_dtype = None
if pd.api.types.is_float_dtype(factor_dtype):
clickhouse_dtype = 'Float64'
elif pd.api.types.is_integer_dtype(factor_dtype):
clickhouse_dtype = 'Int64'
elif factor_dtype == 'object':
print(f"警告: 特征 '{factor_name}' 的数据类型为 object将跳过添加列。")
continue
else:
clickhouse_dtype = 'Float64'
if clickhouse_dtype:
add_column_query = f"ALTER TABLE {table_name} ADD COLUMN IF NOT EXISTS {factor_name} {clickhouse_dtype}"
client.execute(add_column_query)
print(f"在表 '{table_name}' 中添加了新列: {factor_name} ({clickhouse_dtype})")
existing_columns.add(factor_name)
insert_columns_order = ['date', 'asset_id'] + [col for col in feature_columns if
col in existing_columns and col in df.columns]
# 分批处理 DataFrame
num_rows = len(df)
for i in tqdm(range(0, num_rows, batch_size), desc="写入批次"):
batch_df = df[i:i + batch_size]
data_to_insert_batch = []
for row in batch_df.itertuples(index=False):
insert_row = [getattr(row, 'trade_date'), getattr(row, 'ts_code')]
for factor in feature_columns:
if factor in existing_columns and factor in df.columns:
try:
insert_row.append(getattr(row, factor))
except AttributeError:
insert_row.append(None)
data_to_insert_batch.append(tuple(insert_row))
write_batch_to_clickhouse(client, table_name, data_to_insert_batch, insert_columns_order)
except Exception as e:
print(f"写入 ClickHouse 发生错误: {e}")
finally:
if 'client' in locals() and client.connection:
client.disconnect()
def write_batch_to_clickhouse(client, table_name, data_to_insert, columns_order):
"""将一个批次的数据写入 ClickHouse"""
if data_to_insert:
insert_query_final = f"INSERT INTO {table_name} ({', '.join(columns_order)}) VALUES"
try:
client.execute(insert_query_final, data_to_insert)
print(f"成功写入 {len(data_to_insert)} 条数据到 ClickHouse 表 '{table_name}'!")
except Exception as e:
print(f"写入 ClickHouse 批次数据发生错误: {e}")
# -------------------- 使用示例 --------------------
if __name__ == "__main__":
# 示例 DataFrame
print('daily data')
df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',
columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],
df=None)
print('daily basic')
df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',
columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',
'is_st'], df=df, join='inner')
df = df[df['trade_date'] >= '2021-01-01']
print('stk limit')
df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',
columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],
df=df)
print('money flow')
df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',
columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol',
'sell_lg_vol',
'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],
df=df)
print('cyq perf')
df = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf',
columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',
'cost_50pct',
'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],
df=df)
print(df.info())
origin_columns = df.columns.tolist()
origin_columns = [col for col in origin_columns if 'cyq' not in col]
print(origin_columns)
def filter_data(df):
# df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))
df = df[~df['is_st']]
df = df[~df['ts_code'].str.endswith('BJ')]
df = df[~df['ts_code'].str.startswith('30')]
df = df[~df['ts_code'].str.startswith('68')]
df = df[~df['ts_code'].str.startswith('8')]
df = df[df['trade_date'] >= '20180101']
if 'in_date' in df.columns:
df = df.drop(columns=['in_date'])
df = df.reset_index(drop=True)
return df
df = filter_data(df)
df, _ = get_rolling_factor(df)
df, _ = get_simple_factor(df)
# df['test'] = 1
# df['test2'] = 2
# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')
df = df.rename(columns={'l2_code': 'cat_l2_code'})
# df = df.merge(index_data, on='trade_date', how='left')
print(df.info())
feature_columns = [col for col in df.columns if col in df.columns]
feature_columns = [col for col in feature_columns if col not in ['trade_date',
'ts_code',
'label']]
feature_columns = [col for col in feature_columns if 'future' not in col]
feature_columns = [col for col in feature_columns if 'label' not in col]
feature_columns = [col for col in feature_columns if 'score' not in col]
feature_columns = [col for col in feature_columns if 'gen' not in col]
feature_columns = [col for col in feature_columns if 'is_st' not in col]
# feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]
# feature_columns = [col for col in feature_columns if 'volatility' not in col]
# feature_columns = [col for col in feature_columns if 'circ_mv' not in col]
feature_columns = [col for col in feature_columns if 'cat_l2_code' not in col]
feature_columns = [col for col in feature_columns if col not in origin_columns]
feature_columns = [col for col in feature_columns if not col.startswith('_')]
print(feature_columns)
# 替换为您的 ClickHouse 连接信息
clickhouse_host = '127.0.0.1'
clickhouse_port = 9000
clickhouse_user = 'default'
clickhouse_password = 'clickhouse520102'
clickhouse_database = 'stock_data'
# create_factor_table_clickhouse(clickhouse_host, clickhouse_port,
# clickhouse_user, clickhouse_password,
# clickhouse_database)
write_features_to_clickhouse(
df[[col for col in df.columns if col in ['ts_code', 'trade_date'] or col in feature_columns]], feature_columns,
clickhouse_host, clickhouse_port,
clickhouse_user, clickhouse_password,
clickhouse_database)