1、策略更新
2、新增qmt
This commit is contained in:
86
main/factor/qlib_utils.py
Normal file
86
main/factor/qlib_utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import polars as pl
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import qlib
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.report import analysis_model, analysis_position
|
||||
from qlib.constant import REG_CN
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
import pandas as pd
|
||||
|
||||
def prepare_data(
|
||||
polars_df: pl.DataFrame,
|
||||
label_horizon: int = 5,
|
||||
open_col: str = "open",
|
||||
date_col: str = "trade_date",
|
||||
code_col: str = "ts_code",
|
||||
) -> pd.DataFrame:
|
||||
required = [date_col, code_col, open_col]
|
||||
missing = [col for col in required if col not in polars_df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"Missing columns: {missing}")
|
||||
|
||||
df = polars_df.sort([code_col, date_col])
|
||||
|
||||
# 获取 T+1 日的开盘价(作为买入价)
|
||||
df = df.with_columns([
|
||||
pl.col(open_col).shift(-1).over(code_col).alias("__buy_price"),
|
||||
pl.col(open_col).shift(-(1 + label_horizon)).over(code_col).alias("__sell_price"),
|
||||
]).with_columns([
|
||||
(pl.col("__sell_price") / pl.col("__buy_price") - 1).alias("label")
|
||||
]).drop(["__buy_price", "__sell_price"])
|
||||
|
||||
# 转 pandas
|
||||
df = df.to_pandas()
|
||||
df.rename(columns={date_col: "datetime", code_col: "instrument"}, inplace=True)
|
||||
df["datetime"] = pd.to_datetime(df["datetime"])
|
||||
df.set_index(["datetime", "instrument"], inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
return df
|
||||
|
||||
# 2. Qlib初始化
|
||||
def initialize_qlib():
|
||||
"""
|
||||
在内存模式下初始化Qlib。
|
||||
由于我们直接从DataFrame加载数据,provider_uri可以指向一个虚拟或空路径。
|
||||
"""
|
||||
# provider_uri设置为一个虚拟路径,因为所有数据将从内存加载
|
||||
# region设置为REG_CN表示使用中国A股的交易日历和交易成本设置
|
||||
qlib.init(provider_uri="/mnt/d/PyProject/NewStock/data/qlib", region=REG_CN, freq="day")
|
||||
print("Qlib has been initialized in memory mode.")
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import PortAnaRecord # SignalRecord 在此场景下未被直接使用
|
||||
|
||||
def train_and_backtest_from_df(
|
||||
df: pd.DataFrame,
|
||||
all_features: list,
|
||||
label_col: str = "label",
|
||||
topk: int = 50,
|
||||
start_train: str = "2019-01-01",
|
||||
end_train: str = "2021-12-31",
|
||||
start_valid: str = "2022-01-01",
|
||||
end_valid: str = "2022-12-31",
|
||||
start_test: str = "2023-01-01",
|
||||
end_test: str = "2023-12-31",
|
||||
):
|
||||
"""
|
||||
直接从准备好的 pandas DataFrame 训练模型并运行回测。
|
||||
"""
|
||||
# === 1. 手动准备数据 ===
|
||||
if not isinstance(df.index, pd.MultiIndex):
|
||||
raise ValueError("df 必须是 MultiIndex (datetime, instrument)")
|
||||
df.index = df.index.set_names(["datetime", "instrument"])
|
||||
df.index = df.index.set_levels(pd.to_datetime(df.index.levels[0]), level='datetime')
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
dh = DataHandlerLP.from_df(df)
|
||||
print(dh.fetch())
|
||||
print(dh._infer)
|
||||
print(dh._learn)
|
||||
Reference in New Issue
Block a user