RollingRank赚钱- Sharp-1.43

This commit is contained in:
liaozhaorun
2025-04-28 11:02:52 +08:00
parent 94cd9aa6c8
commit 9e598d4ed0
93 changed files with 18134 additions and 4342 deletions

0
main/factor/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

1532
main/factor/factor.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

7
main/factor/operators.py Normal file
View File

@@ -0,0 +1,7 @@
from main.utils.utils import read_and_merge_h5_data, merge_with_industry_data
import sys
print(sys.path)

222
main/factor/save_factor.py Normal file
View File

@@ -0,0 +1,222 @@
from tqdm import tqdm
from main.factor.factor import get_rolling_factor, get_simple_factor
from main.utils.utils import read_and_merge_h5_data
import pandas as pd
def create_factor_table_clickhouse(clickhouse_host: str, clickhouse_port: int,
clickhouse_user: str, clickhouse_password: str,
clickhouse_database: str, table_name: str = 'factor_data'):
"""
在 ClickHouse 中创建 factor_data 表,考虑读取速度。
"""
try:
print('create factor table')
client = Client(host=clickhouse_host, port=clickhouse_port, user=clickhouse_user,
password=clickhouse_password, database=clickhouse_database)
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {table_name}
(
date Date,
asset_id String,
factor_name String,
factor_value Float64
)
ENGINE = MergeTree()
PARTITION BY toYYYYMM(date)
ORDER BY (date, asset_id, factor_name)
"""
client.execute(create_table_query)
print(f"成功在 ClickHouse 数据库 '{clickhouse_database}' 中创建表 '{table_name}'!")
except Exception as e:
print(f"创建 ClickHouse 表发生错误: {e}")
finally:
if 'client' in locals() and client.connection:
client.disconnect()
def write_features_to_clickhouse(df: pd.DataFrame, feature_columns: list,
clickhouse_host: str, clickhouse_port: int,
clickhouse_user: str, clickhouse_password: str,
clickhouse_database: str, table_name: str = 'stock_factor',
batch_size: int = 5000): # 设置批次大小
"""
将 DataFrame 中指定的特征列分批写入 ClickHouse 的宽表,动态添加列。
"""
try:
client = Client(host=clickhouse_host, port=clickhouse_port, user=clickhouse_user,
password=clickhouse_password, database=clickhouse_database)
if 'ts_code' not in df.columns or 'trade_date' not in df.columns:
raise ValueError("DataFrame 必须包含 'ts_code''trade_date' 列。")
existing_columns = set()
columns_query = f"DESCRIBE TABLE {table_name}"
columns_result = client.execute(columns_query)
for col in columns_result:
existing_columns.add(col[0])
for factor_name in feature_columns:
if factor_name not in existing_columns:
if factor_name not in df.columns:
print(f"警告: 特征 '{factor_name}' 不存在于 DataFrame 中,将跳过添加列。")
continue
factor_series = df[factor_name]
factor_dtype = factor_series.dtype
clickhouse_dtype = None
if pd.api.types.is_float_dtype(factor_dtype):
clickhouse_dtype = 'Float64'
elif pd.api.types.is_integer_dtype(factor_dtype):
clickhouse_dtype = 'Int64'
elif factor_dtype == 'object':
print(f"警告: 特征 '{factor_name}' 的数据类型为 object将跳过添加列。")
continue
else:
clickhouse_dtype = 'Float64'
if clickhouse_dtype:
add_column_query = f"ALTER TABLE {table_name} ADD COLUMN IF NOT EXISTS {factor_name} {clickhouse_dtype}"
client.execute(add_column_query)
print(f"在表 '{table_name}' 中添加了新列: {factor_name} ({clickhouse_dtype})")
existing_columns.add(factor_name)
insert_columns_order = ['date', 'asset_id'] + [col for col in feature_columns if
col in existing_columns and col in df.columns]
# 分批处理 DataFrame
num_rows = len(df)
for i in tqdm(range(0, num_rows, batch_size), desc="写入批次"):
batch_df = df[i:i + batch_size]
data_to_insert_batch = []
for row in batch_df.itertuples(index=False):
insert_row = [getattr(row, 'trade_date'), getattr(row, 'ts_code')]
for factor in feature_columns:
if factor in existing_columns and factor in df.columns:
try:
insert_row.append(getattr(row, factor))
except AttributeError:
insert_row.append(None)
data_to_insert_batch.append(tuple(insert_row))
write_batch_to_clickhouse(client, table_name, data_to_insert_batch, insert_columns_order)
except Exception as e:
print(f"写入 ClickHouse 发生错误: {e}")
finally:
if 'client' in locals() and client.connection:
client.disconnect()
def write_batch_to_clickhouse(client, table_name, data_to_insert, columns_order):
"""将一个批次的数据写入 ClickHouse"""
if data_to_insert:
insert_query_final = f"INSERT INTO {table_name} ({', '.join(columns_order)}) VALUES"
try:
client.execute(insert_query_final, data_to_insert)
print(f"成功写入 {len(data_to_insert)} 条数据到 ClickHouse 表 '{table_name}'!")
except Exception as e:
print(f"写入 ClickHouse 批次数据发生错误: {e}")
# -------------------- 使用示例 --------------------
if __name__ == "__main__":
# 示例 DataFrame
print('daily data')
df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',
columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],
df=None)
print('daily basic')
df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',
columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',
'is_st'], df=df, join='inner')
df = df[df['trade_date'] >= '2021-01-01']
print('stk limit')
df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',
columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],
df=df)
print('money flow')
df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',
columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol',
'sell_lg_vol',
'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],
df=df)
print('cyq perf')
df = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf',
columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',
'cost_50pct',
'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],
df=df)
print(df.info())
origin_columns = df.columns.tolist()
origin_columns = [col for col in origin_columns if 'cyq' not in col]
print(origin_columns)
def filter_data(df):
# df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))
df = df[~df['is_st']]
df = df[~df['ts_code'].str.endswith('BJ')]
df = df[~df['ts_code'].str.startswith('30')]
df = df[~df['ts_code'].str.startswith('68')]
df = df[~df['ts_code'].str.startswith('8')]
df = df[df['trade_date'] >= '20180101']
if 'in_date' in df.columns:
df = df.drop(columns=['in_date'])
df = df.reset_index(drop=True)
return df
df = filter_data(df)
df, _ = get_rolling_factor(df)
df, _ = get_simple_factor(df)
# df['test'] = 1
# df['test2'] = 2
# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')
df = df.rename(columns={'l2_code': 'cat_l2_code'})
# df = df.merge(index_data, on='trade_date', how='left')
print(df.info())
feature_columns = [col for col in df.columns if col in df.columns]
feature_columns = [col for col in feature_columns if col not in ['trade_date',
'ts_code',
'label']]
feature_columns = [col for col in feature_columns if 'future' not in col]
feature_columns = [col for col in feature_columns if 'label' not in col]
feature_columns = [col for col in feature_columns if 'score' not in col]
feature_columns = [col for col in feature_columns if 'gen' not in col]
feature_columns = [col for col in feature_columns if 'is_st' not in col]
# feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]
# feature_columns = [col for col in feature_columns if 'volatility' not in col]
# feature_columns = [col for col in feature_columns if 'circ_mv' not in col]
feature_columns = [col for col in feature_columns if 'cat_l2_code' not in col]
feature_columns = [col for col in feature_columns if col not in origin_columns]
feature_columns = [col for col in feature_columns if not col.startswith('_')]
print(feature_columns)
# 替换为您的 ClickHouse 连接信息
clickhouse_host = '127.0.0.1'
clickhouse_port = 9000
clickhouse_user = 'default'
clickhouse_password = 'clickhouse520102'
clickhouse_database = 'stock_data'
# create_factor_table_clickhouse(clickhouse_host, clickhouse_port,
# clickhouse_user, clickhouse_password,
# clickhouse_database)
write_features_to_clickhouse(
df[[col for col in df.columns if col in ['ts_code', 'trade_date'] or col in feature_columns]], feature_columns,
clickhouse_host, clickhouse_port,
clickhouse_user, clickhouse_password,
clickhouse_database)