Files
NewStock/main/factor/qlib_utils.py
2025-11-29 00:23:12 +08:00

86 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import polars as pl
import pandas as pd
import numpy as np
import qlib
from qlib.data.dataset.handler import DataHandlerLP
from qlib.contrib.report import analysis_model, analysis_position
from qlib.constant import REG_CN
from typing import List
import polars as pl
import pandas as pd
def prepare_data(
polars_df: pl.DataFrame,
label_horizon: int = 5,
open_col: str = "open",
date_col: str = "trade_date",
code_col: str = "ts_code",
) -> pd.DataFrame:
required = [date_col, code_col, open_col]
missing = [col for col in required if col not in polars_df.columns]
if missing:
raise ValueError(f"Missing columns: {missing}")
df = polars_df.sort([code_col, date_col])
# 获取 T+1 日的开盘价(作为买入价)
df = df.with_columns([
pl.col(open_col).shift(-1).over(code_col).alias("__buy_price"),
pl.col(open_col).shift(-(1 + label_horizon)).over(code_col).alias("__sell_price"),
]).with_columns([
(pl.col("__sell_price") / pl.col("__buy_price") - 1).alias("label")
]).drop(["__buy_price", "__sell_price"])
# 转 pandas
df = df.to_pandas()
df.rename(columns={date_col: "datetime", code_col: "instrument"}, inplace=True)
df["datetime"] = pd.to_datetime(df["datetime"])
df.set_index(["datetime", "instrument"], inplace=True)
df.sort_index(inplace=True)
return df
# 2. Qlib初始化
def initialize_qlib():
"""
在内存模式下初始化Qlib。
由于我们直接从DataFrame加载数据provider_uri可以指向一个虚拟或空路径。
"""
# provider_uri设置为一个虚拟路径因为所有数据将从内存加载
# region设置为REG_CN表示使用中国A股的交易日历和交易成本设置
qlib.init(provider_uri="/mnt/d/PyProject/NewStock/data/qlib", region=REG_CN, freq="day")
print("Qlib has been initialized in memory mode.")
import pandas as pd
import lightgbm as lgb
from qlib.workflow import R
from qlib.workflow.record_temp import PortAnaRecord # SignalRecord 在此场景下未被直接使用
def train_and_backtest_from_df(
df: pd.DataFrame,
all_features: list,
label_col: str = "label",
topk: int = 50,
start_train: str = "2019-01-01",
end_train: str = "2021-12-31",
start_valid: str = "2022-01-01",
end_valid: str = "2022-12-31",
start_test: str = "2023-01-01",
end_test: str = "2023-12-31",
):
"""
直接从准备好的 pandas DataFrame 训练模型并运行回测。
"""
# === 1. 手动准备数据 ===
if not isinstance(df.index, pd.MultiIndex):
raise ValueError("df 必须是 MultiIndex (datetime, instrument)")
df.index = df.index.set_names(["datetime", "instrument"])
df.index = df.index.set_levels(pd.to_datetime(df.index.levels[0]), level='datetime')
df.sort_index(inplace=True)
dh = DataHandlerLP.from_df(df)
print(dh.fetch())
print(dh._infer)
print(dh._learn)