from tqdm import tqdm from main.factor.factor import get_rolling_factor, get_simple_factor from main.utils.utils import read_and_merge_h5_data import pandas as pd def create_factor_table_clickhouse(clickhouse_host: str, clickhouse_port: int, clickhouse_user: str, clickhouse_password: str, clickhouse_database: str, table_name: str = 'factor_data'): """ 在 ClickHouse 中创建 factor_data 表,考虑读取速度。 """ try: print('create factor table') client = Client(host=clickhouse_host, port=clickhouse_port, user=clickhouse_user, password=clickhouse_password, database=clickhouse_database) create_table_query = f""" CREATE TABLE IF NOT EXISTS {table_name} ( date Date, asset_id String, factor_name String, factor_value Float64 ) ENGINE = MergeTree() PARTITION BY toYYYYMM(date) ORDER BY (date, asset_id, factor_name) """ client.execute(create_table_query) print(f"成功在 ClickHouse 数据库 '{clickhouse_database}' 中创建表 '{table_name}'!") except Exception as e: print(f"创建 ClickHouse 表发生错误: {e}") finally: if 'client' in locals() and client.connection: client.disconnect() def write_features_to_clickhouse(df: pd.DataFrame, feature_columns: list, clickhouse_host: str, clickhouse_port: int, clickhouse_user: str, clickhouse_password: str, clickhouse_database: str, table_name: str = 'stock_factor', batch_size: int = 5000): # 设置批次大小 """ 将 DataFrame 中指定的特征列分批写入 ClickHouse 的宽表,动态添加列。 """ try: client = Client(host=clickhouse_host, port=clickhouse_port, user=clickhouse_user, password=clickhouse_password, database=clickhouse_database) if 'ts_code' not in df.columns or 'trade_date' not in df.columns: raise ValueError("DataFrame 必须包含 'ts_code' 和 'trade_date' 列。") existing_columns = set() columns_query = f"DESCRIBE TABLE {table_name}" columns_result = client.execute(columns_query) for col in columns_result: existing_columns.add(col[0]) for factor_name in feature_columns: if factor_name not in existing_columns: if factor_name not in df.columns: print(f"警告: 特征 '{factor_name}' 不存在于 DataFrame 中,将跳过添加列。") continue factor_series = df[factor_name] factor_dtype = factor_series.dtype clickhouse_dtype = None if pd.api.types.is_float_dtype(factor_dtype): clickhouse_dtype = 'Float64' elif pd.api.types.is_integer_dtype(factor_dtype): clickhouse_dtype = 'Int64' elif factor_dtype == 'object': print(f"警告: 特征 '{factor_name}' 的数据类型为 object,将跳过添加列。") continue else: clickhouse_dtype = 'Float64' if clickhouse_dtype: add_column_query = f"ALTER TABLE {table_name} ADD COLUMN IF NOT EXISTS {factor_name} {clickhouse_dtype}" client.execute(add_column_query) print(f"在表 '{table_name}' 中添加了新列: {factor_name} ({clickhouse_dtype})") existing_columns.add(factor_name) insert_columns_order = ['date', 'asset_id'] + [col for col in feature_columns if col in existing_columns and col in df.columns] # 分批处理 DataFrame num_rows = len(df) for i in tqdm(range(0, num_rows, batch_size), desc="写入批次"): batch_df = df[i:i + batch_size] data_to_insert_batch = [] for row in batch_df.itertuples(index=False): insert_row = [getattr(row, 'trade_date'), getattr(row, 'ts_code')] for factor in feature_columns: if factor in existing_columns and factor in df.columns: try: insert_row.append(getattr(row, factor)) except AttributeError: insert_row.append(None) data_to_insert_batch.append(tuple(insert_row)) write_batch_to_clickhouse(client, table_name, data_to_insert_batch, insert_columns_order) except Exception as e: print(f"写入 ClickHouse 发生错误: {e}") finally: if 'client' in locals() and client.connection: client.disconnect() def write_batch_to_clickhouse(client, table_name, data_to_insert, columns_order): """将一个批次的数据写入 ClickHouse""" if data_to_insert: insert_query_final = f"INSERT INTO {table_name} ({', '.join(columns_order)}) VALUES" try: client.execute(insert_query_final, data_to_insert) print(f"成功写入 {len(data_to_insert)} 条数据到 ClickHouse 表 '{table_name}'!") except Exception as e: print(f"写入 ClickHouse 批次数据发生错误: {e}") # -------------------- 使用示例 -------------------- if __name__ == "__main__": # 示例 DataFrame print('daily data') df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data', columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'], df=None) print('daily basic') df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic', columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio', 'is_st'], df=df, join='inner') df = df[df['trade_date'] >= '2021-01-01'] print('stk limit') df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit', columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'], df=df) print('money flow') df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow', columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'], df=df) print('cyq perf') df = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf', columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'], df=df) print(df.info()) origin_columns = df.columns.tolist() origin_columns = [col for col in origin_columns if 'cyq' not in col] print(origin_columns) def filter_data(df): # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1')) df = df[~df['is_st']] df = df[~df['ts_code'].str.endswith('BJ')] df = df[~df['ts_code'].str.startswith('30')] df = df[~df['ts_code'].str.startswith('68')] df = df[~df['ts_code'].str.startswith('8')] df = df[df['trade_date'] >= '20180101'] if 'in_date' in df.columns: df = df.drop(columns=['in_date']) df = df.reset_index(drop=True) return df df = filter_data(df) df, _ = get_rolling_factor(df) df, _ = get_simple_factor(df) # df['test'] = 1 # df['test2'] = 2 # df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left') df = df.rename(columns={'l2_code': 'cat_l2_code'}) # df = df.merge(index_data, on='trade_date', how='left') print(df.info()) feature_columns = [col for col in df.columns if col in df.columns] feature_columns = [col for col in feature_columns if col not in ['trade_date', 'ts_code', 'label']] feature_columns = [col for col in feature_columns if 'future' not in col] feature_columns = [col for col in feature_columns if 'label' not in col] feature_columns = [col for col in feature_columns if 'score' not in col] feature_columns = [col for col in feature_columns if 'gen' not in col] feature_columns = [col for col in feature_columns if 'is_st' not in col] # feature_columns = [col for col in feature_columns if 'pe_ttm' not in col] # feature_columns = [col for col in feature_columns if 'volatility' not in col] # feature_columns = [col for col in feature_columns if 'circ_mv' not in col] feature_columns = [col for col in feature_columns if 'cat_l2_code' not in col] feature_columns = [col for col in feature_columns if col not in origin_columns] feature_columns = [col for col in feature_columns if not col.startswith('_')] print(feature_columns) # 替换为您的 ClickHouse 连接信息 clickhouse_host = '127.0.0.1' clickhouse_port = 9000 clickhouse_user = 'default' clickhouse_password = 'clickhouse520102' clickhouse_database = 'stock_data' # create_factor_table_clickhouse(clickhouse_host, clickhouse_port, # clickhouse_user, clickhouse_password, # clickhouse_database) write_features_to_clickhouse( df[[col for col in df.columns if col in ['ts_code', 'trade_date'] or col in feature_columns]], feature_columns, clickhouse_host, clickhouse_port, clickhouse_user, clickhouse_password, clickhouse_database)