RollingRank赚钱- Sharp-1.43
This commit is contained in:
0
main/factor/__init__.py
Normal file
0
main/factor/__init__.py
Normal file
BIN
main/factor/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
main/factor/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
main/factor/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
main/factor/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
main/factor/__pycache__/factor.cpython-310.pyc
Normal file
BIN
main/factor/__pycache__/factor.cpython-310.pyc
Normal file
Binary file not shown.
BIN
main/factor/__pycache__/factor.cpython-311.pyc
Normal file
BIN
main/factor/__pycache__/factor.cpython-311.pyc
Normal file
Binary file not shown.
BIN
main/factor/__pycache__/operator.cpython-311.pyc
Normal file
BIN
main/factor/__pycache__/operator.cpython-311.pyc
Normal file
Binary file not shown.
1532
main/factor/factor.py
Normal file
1532
main/factor/factor.py
Normal file
File diff suppressed because it is too large
Load Diff
1028
main/factor/generate_factor.ipynb
Normal file
1028
main/factor/generate_factor.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
7
main/factor/operators.py
Normal file
7
main/factor/operators.py
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
|
||||
from main.utils.utils import read_and_merge_h5_data, merge_with_industry_data
|
||||
|
||||
|
||||
import sys
|
||||
print(sys.path)
|
||||
222
main/factor/save_factor.py
Normal file
222
main/factor/save_factor.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from tqdm import tqdm
|
||||
|
||||
from main.factor.factor import get_rolling_factor, get_simple_factor
|
||||
from main.utils.utils import read_and_merge_h5_data
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def create_factor_table_clickhouse(clickhouse_host: str, clickhouse_port: int,
|
||||
clickhouse_user: str, clickhouse_password: str,
|
||||
clickhouse_database: str, table_name: str = 'factor_data'):
|
||||
"""
|
||||
在 ClickHouse 中创建 factor_data 表,考虑读取速度。
|
||||
"""
|
||||
try:
|
||||
print('create factor table')
|
||||
client = Client(host=clickhouse_host, port=clickhouse_port, user=clickhouse_user,
|
||||
password=clickhouse_password, database=clickhouse_database)
|
||||
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {table_name}
|
||||
(
|
||||
date Date,
|
||||
asset_id String,
|
||||
factor_name String,
|
||||
factor_value Float64
|
||||
)
|
||||
ENGINE = MergeTree()
|
||||
PARTITION BY toYYYYMM(date)
|
||||
ORDER BY (date, asset_id, factor_name)
|
||||
"""
|
||||
|
||||
client.execute(create_table_query)
|
||||
print(f"成功在 ClickHouse 数据库 '{clickhouse_database}' 中创建表 '{table_name}'!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建 ClickHouse 表发生错误: {e}")
|
||||
finally:
|
||||
if 'client' in locals() and client.connection:
|
||||
client.disconnect()
|
||||
|
||||
|
||||
def write_features_to_clickhouse(df: pd.DataFrame, feature_columns: list,
|
||||
clickhouse_host: str, clickhouse_port: int,
|
||||
clickhouse_user: str, clickhouse_password: str,
|
||||
clickhouse_database: str, table_name: str = 'stock_factor',
|
||||
batch_size: int = 5000): # 设置批次大小
|
||||
"""
|
||||
将 DataFrame 中指定的特征列分批写入 ClickHouse 的宽表,动态添加列。
|
||||
"""
|
||||
try:
|
||||
client = Client(host=clickhouse_host, port=clickhouse_port, user=clickhouse_user,
|
||||
password=clickhouse_password, database=clickhouse_database)
|
||||
|
||||
if 'ts_code' not in df.columns or 'trade_date' not in df.columns:
|
||||
raise ValueError("DataFrame 必须包含 'ts_code' 和 'trade_date' 列。")
|
||||
|
||||
existing_columns = set()
|
||||
columns_query = f"DESCRIBE TABLE {table_name}"
|
||||
columns_result = client.execute(columns_query)
|
||||
for col in columns_result:
|
||||
existing_columns.add(col[0])
|
||||
|
||||
for factor_name in feature_columns:
|
||||
if factor_name not in existing_columns:
|
||||
if factor_name not in df.columns:
|
||||
print(f"警告: 特征 '{factor_name}' 不存在于 DataFrame 中,将跳过添加列。")
|
||||
continue
|
||||
|
||||
factor_series = df[factor_name]
|
||||
factor_dtype = factor_series.dtype
|
||||
|
||||
clickhouse_dtype = None
|
||||
if pd.api.types.is_float_dtype(factor_dtype):
|
||||
clickhouse_dtype = 'Float64'
|
||||
elif pd.api.types.is_integer_dtype(factor_dtype):
|
||||
clickhouse_dtype = 'Int64'
|
||||
elif factor_dtype == 'object':
|
||||
print(f"警告: 特征 '{factor_name}' 的数据类型为 object,将跳过添加列。")
|
||||
continue
|
||||
else:
|
||||
clickhouse_dtype = 'Float64'
|
||||
|
||||
if clickhouse_dtype:
|
||||
add_column_query = f"ALTER TABLE {table_name} ADD COLUMN IF NOT EXISTS {factor_name} {clickhouse_dtype}"
|
||||
client.execute(add_column_query)
|
||||
print(f"在表 '{table_name}' 中添加了新列: {factor_name} ({clickhouse_dtype})")
|
||||
existing_columns.add(factor_name)
|
||||
|
||||
insert_columns_order = ['date', 'asset_id'] + [col for col in feature_columns if
|
||||
col in existing_columns and col in df.columns]
|
||||
|
||||
# 分批处理 DataFrame
|
||||
num_rows = len(df)
|
||||
for i in tqdm(range(0, num_rows, batch_size), desc="写入批次"):
|
||||
batch_df = df[i:i + batch_size]
|
||||
data_to_insert_batch = []
|
||||
for row in batch_df.itertuples(index=False):
|
||||
insert_row = [getattr(row, 'trade_date'), getattr(row, 'ts_code')]
|
||||
for factor in feature_columns:
|
||||
if factor in existing_columns and factor in df.columns:
|
||||
try:
|
||||
insert_row.append(getattr(row, factor))
|
||||
except AttributeError:
|
||||
insert_row.append(None)
|
||||
data_to_insert_batch.append(tuple(insert_row))
|
||||
write_batch_to_clickhouse(client, table_name, data_to_insert_batch, insert_columns_order)
|
||||
|
||||
except Exception as e:
|
||||
print(f"写入 ClickHouse 发生错误: {e}")
|
||||
finally:
|
||||
if 'client' in locals() and client.connection:
|
||||
client.disconnect()
|
||||
|
||||
|
||||
def write_batch_to_clickhouse(client, table_name, data_to_insert, columns_order):
|
||||
"""将一个批次的数据写入 ClickHouse"""
|
||||
if data_to_insert:
|
||||
insert_query_final = f"INSERT INTO {table_name} ({', '.join(columns_order)}) VALUES"
|
||||
try:
|
||||
client.execute(insert_query_final, data_to_insert)
|
||||
print(f"成功写入 {len(data_to_insert)} 条数据到 ClickHouse 表 '{table_name}'!")
|
||||
except Exception as e:
|
||||
print(f"写入 ClickHouse 批次数据发生错误: {e}")
|
||||
|
||||
|
||||
# -------------------- 使用示例 --------------------
|
||||
if __name__ == "__main__":
|
||||
# 示例 DataFrame
|
||||
|
||||
print('daily data')
|
||||
df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',
|
||||
columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],
|
||||
df=None)
|
||||
|
||||
print('daily basic')
|
||||
df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',
|
||||
columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',
|
||||
'is_st'], df=df, join='inner')
|
||||
df = df[df['trade_date'] >= '2021-01-01']
|
||||
|
||||
print('stk limit')
|
||||
df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',
|
||||
columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],
|
||||
df=df)
|
||||
print('money flow')
|
||||
df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',
|
||||
columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol',
|
||||
'sell_lg_vol',
|
||||
'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],
|
||||
df=df)
|
||||
print('cyq perf')
|
||||
df = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf',
|
||||
columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',
|
||||
'cost_50pct',
|
||||
'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],
|
||||
df=df)
|
||||
print(df.info())
|
||||
|
||||
origin_columns = df.columns.tolist()
|
||||
origin_columns = [col for col in origin_columns if 'cyq' not in col]
|
||||
print(origin_columns)
|
||||
|
||||
|
||||
def filter_data(df):
|
||||
# df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))
|
||||
df = df[~df['is_st']]
|
||||
df = df[~df['ts_code'].str.endswith('BJ')]
|
||||
df = df[~df['ts_code'].str.startswith('30')]
|
||||
df = df[~df['ts_code'].str.startswith('68')]
|
||||
df = df[~df['ts_code'].str.startswith('8')]
|
||||
df = df[df['trade_date'] >= '20180101']
|
||||
if 'in_date' in df.columns:
|
||||
df = df.drop(columns=['in_date'])
|
||||
df = df.reset_index(drop=True)
|
||||
return df
|
||||
|
||||
|
||||
df = filter_data(df)
|
||||
df, _ = get_rolling_factor(df)
|
||||
df, _ = get_simple_factor(df)
|
||||
# df['test'] = 1
|
||||
# df['test2'] = 2
|
||||
# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')
|
||||
df = df.rename(columns={'l2_code': 'cat_l2_code'})
|
||||
# df = df.merge(index_data, on='trade_date', how='left')
|
||||
|
||||
print(df.info())
|
||||
|
||||
feature_columns = [col for col in df.columns if col in df.columns]
|
||||
feature_columns = [col for col in feature_columns if col not in ['trade_date',
|
||||
'ts_code',
|
||||
'label']]
|
||||
feature_columns = [col for col in feature_columns if 'future' not in col]
|
||||
feature_columns = [col for col in feature_columns if 'label' not in col]
|
||||
feature_columns = [col for col in feature_columns if 'score' not in col]
|
||||
feature_columns = [col for col in feature_columns if 'gen' not in col]
|
||||
feature_columns = [col for col in feature_columns if 'is_st' not in col]
|
||||
# feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]
|
||||
# feature_columns = [col for col in feature_columns if 'volatility' not in col]
|
||||
# feature_columns = [col for col in feature_columns if 'circ_mv' not in col]
|
||||
feature_columns = [col for col in feature_columns if 'cat_l2_code' not in col]
|
||||
feature_columns = [col for col in feature_columns if col not in origin_columns]
|
||||
feature_columns = [col for col in feature_columns if not col.startswith('_')]
|
||||
|
||||
print(feature_columns)
|
||||
|
||||
# 替换为您的 ClickHouse 连接信息
|
||||
clickhouse_host = '127.0.0.1'
|
||||
clickhouse_port = 9000
|
||||
clickhouse_user = 'default'
|
||||
clickhouse_password = 'clickhouse520102'
|
||||
clickhouse_database = 'stock_data'
|
||||
|
||||
# create_factor_table_clickhouse(clickhouse_host, clickhouse_port,
|
||||
# clickhouse_user, clickhouse_password,
|
||||
# clickhouse_database)
|
||||
|
||||
write_features_to_clickhouse(
|
||||
df[[col for col in df.columns if col in ['ts_code', 'trade_date'] or col in feature_columns]], feature_columns,
|
||||
clickhouse_host, clickhouse_port,
|
||||
clickhouse_user, clickhouse_password,
|
||||
clickhouse_database)
|
||||
Reference in New Issue
Block a user