import polars as pl import pandas as pd import numpy as np import qlib from qlib.data.dataset.handler import DataHandlerLP from qlib.contrib.report import analysis_model, analysis_position from qlib.constant import REG_CN from typing import List import polars as pl import pandas as pd def prepare_data( polars_df: pl.DataFrame, open_col: str = "open", date_col: str = "trade_date", code_col: str = "ts_code", ) -> pd.DataFrame: required = [date_col, code_col, open_col] missing = [col for col in required if col not in polars_df.columns] if missing: raise ValueError(f"Missing columns: {missing}") df = polars_df.sort([code_col, date_col]) # # 获取 T+1 日的开盘价(作为买入价) # df = df.with_columns([ # pl.col(open_col).shift(-1).over(code_col).alias("__buy_price"), # pl.col(open_col).shift(-(1 + label_horizon)).over(code_col).alias("__sell_price"), # ]).with_columns([ # (pl.col("__sell_price") / pl.col("__buy_price") - 1).alias("label") # ]).drop(["__buy_price", "__sell_price"]) # 转 pandas df = df.to_pandas() df.rename(columns={date_col: "datetime", code_col: "instrument"}, inplace=True) df["datetime"] = pd.to_datetime(df["datetime"]) df.set_index(["datetime", "instrument"], inplace=True) df.sort_index(inplace=True) return df # 2. Qlib初始化 def initialize_qlib(): """ 在内存模式下初始化Qlib。 由于我们直接从DataFrame加载数据,provider_uri可以指向一个虚拟或空路径。 """ # provider_uri设置为一个虚拟路径,因为所有数据将从内存加载 # region设置为REG_CN表示使用中国A股的交易日历和交易成本设置 qlib.init(provider_uri="/mnt/d/PyProject/NewStock/data/qlib", region=REG_CN, freq="day") print("Qlib has been initialized in memory mode.") import pandas as pd import lightgbm as lgb from qlib.workflow import R from qlib.workflow.record_temp import PortAnaRecord # SignalRecord 在此场景下未被直接使用 def train_and_backtest_from_df( df: pd.DataFrame, all_features: list, label_col: str = "label", topk: int = 50, start_train: str = "2019-01-01", end_train: str = "2021-12-31", start_valid: str = "2022-01-01", end_valid: str = "2022-12-31", start_test: str = "2023-01-01", end_test: str = "2023-12-31", ): """ 直接从准备好的 pandas DataFrame 训练模型并运行回测。 """ # === 1. 手动准备数据 === if not isinstance(df.index, pd.MultiIndex): raise ValueError("df 必须是 MultiIndex (datetime, instrument)") df.index = df.index.set_names(["datetime", "instrument"]) df.index = df.index.set_levels(pd.to_datetime(df.index.levels[0]), level='datetime') df.sort_index(inplace=True) dh = DataHandlerLP.from_df(df) print(dh.fetch()) print(dh._infer) print(dh._learn)