2025-11-29 00:23:12 +08:00
|
|
|
|
import polars as pl
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import qlib
|
|
|
|
|
|
from qlib.data.dataset.handler import DataHandlerLP
|
|
|
|
|
|
from qlib.contrib.report import analysis_model, analysis_position
|
|
|
|
|
|
from qlib.constant import REG_CN
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_data(
|
|
|
|
|
|
polars_df: pl.DataFrame,
|
|
|
|
|
|
open_col: str = "open",
|
|
|
|
|
|
date_col: str = "trade_date",
|
|
|
|
|
|
code_col: str = "ts_code",
|
|
|
|
|
|
) -> pd.DataFrame:
|
|
|
|
|
|
required = [date_col, code_col, open_col]
|
|
|
|
|
|
missing = [col for col in required if col not in polars_df.columns]
|
|
|
|
|
|
if missing:
|
|
|
|
|
|
raise ValueError(f"Missing columns: {missing}")
|
|
|
|
|
|
|
|
|
|
|
|
df = polars_df.sort([code_col, date_col])
|
|
|
|
|
|
|
2026-01-27 00:52:35 +08:00
|
|
|
|
# # 获取 T+1 日的开盘价(作为买入价)
|
|
|
|
|
|
# df = df.with_columns([
|
|
|
|
|
|
# pl.col(open_col).shift(-1).over(code_col).alias("__buy_price"),
|
|
|
|
|
|
# pl.col(open_col).shift(-(1 + label_horizon)).over(code_col).alias("__sell_price"),
|
|
|
|
|
|
# ]).with_columns([
|
|
|
|
|
|
# (pl.col("__sell_price") / pl.col("__buy_price") - 1).alias("label")
|
|
|
|
|
|
# ]).drop(["__buy_price", "__sell_price"])
|
2025-11-29 00:23:12 +08:00
|
|
|
|
|
|
|
|
|
|
# 转 pandas
|
|
|
|
|
|
df = df.to_pandas()
|
|
|
|
|
|
df.rename(columns={date_col: "datetime", code_col: "instrument"}, inplace=True)
|
|
|
|
|
|
df["datetime"] = pd.to_datetime(df["datetime"])
|
|
|
|
|
|
df.set_index(["datetime", "instrument"], inplace=True)
|
|
|
|
|
|
df.sort_index(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Qlib初始化
|
|
|
|
|
|
def initialize_qlib():
|
|
|
|
|
|
"""
|
|
|
|
|
|
在内存模式下初始化Qlib。
|
|
|
|
|
|
由于我们直接从DataFrame加载数据,provider_uri可以指向一个虚拟或空路径。
|
|
|
|
|
|
"""
|
|
|
|
|
|
# provider_uri设置为一个虚拟路径,因为所有数据将从内存加载
|
|
|
|
|
|
# region设置为REG_CN表示使用中国A股的交易日历和交易成本设置
|
|
|
|
|
|
qlib.init(provider_uri="/mnt/d/PyProject/NewStock/data/qlib", region=REG_CN, freq="day")
|
|
|
|
|
|
print("Qlib has been initialized in memory mode.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
import lightgbm as lgb
|
|
|
|
|
|
from qlib.workflow import R
|
|
|
|
|
|
from qlib.workflow.record_temp import PortAnaRecord # SignalRecord 在此场景下未被直接使用
|
|
|
|
|
|
|
|
|
|
|
|
def train_and_backtest_from_df(
|
|
|
|
|
|
df: pd.DataFrame,
|
|
|
|
|
|
all_features: list,
|
|
|
|
|
|
label_col: str = "label",
|
|
|
|
|
|
topk: int = 50,
|
|
|
|
|
|
start_train: str = "2019-01-01",
|
|
|
|
|
|
end_train: str = "2021-12-31",
|
|
|
|
|
|
start_valid: str = "2022-01-01",
|
|
|
|
|
|
end_valid: str = "2022-12-31",
|
|
|
|
|
|
start_test: str = "2023-01-01",
|
|
|
|
|
|
end_test: str = "2023-12-31",
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
直接从准备好的 pandas DataFrame 训练模型并运行回测。
|
|
|
|
|
|
"""
|
|
|
|
|
|
# === 1. 手动准备数据 ===
|
|
|
|
|
|
if not isinstance(df.index, pd.MultiIndex):
|
|
|
|
|
|
raise ValueError("df 必须是 MultiIndex (datetime, instrument)")
|
|
|
|
|
|
df.index = df.index.set_names(["datetime", "instrument"])
|
|
|
|
|
|
df.index = df.index.set_levels(pd.to_datetime(df.index.levels[0]), level='datetime')
|
|
|
|
|
|
df.sort_index(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
dh = DataHandlerLP.from_df(df)
|
|
|
|
|
|
print(dh.fetch())
|
|
|
|
|
|
print(dh._infer)
|
|
|
|
|
|
print(dh._learn)
|