在Databricks中使用Dask RAPIDS训练XGBoost#

本notebook展示了如何在Databricks中部署Dask RAPIDS工作流程。我们将重点关注HIGGS数据集,这是一个来自UCI机器学习仓库的适中大小的分类问题。

在以下章节中,我们将首先从Delta Lake加载数据集并使用Dask进行预处理。然后使用各种配置训练XGBoost模型,并探索优化推理的技术。

启动多节点Dask集群#

这个工作流程示例可以在GPU上运行,甚至不需要在本地拥有GPU,因为Databricks可以为您提供GPU。Dask则允许用户轻松地在单个GPU内或跨多个GPU分发或扩展计算任务。

Dask最近推出了dask-databricks(可通过condapip获取)。使用此CLI工具,dask databricks run --cuda命令将在驱动节点中启动一个Dask调度器,并在其余节点中启动cuda workers

从宏观层面看,我们可以将本节分解为以下步骤

  • 创建一个新的init脚本,安装RAPIDS并运行dask-databricks

  • 创建一个使用此init脚本的新多节点集群

  • 集群运行后,将此notebook上传到Databricks并在其上继续运行这些单元格

查看文档

有关在Databricks中启动Dask-RAPIDS的更详细信息,请参阅文档。

导入包#

集群启动后,首先导入所有必要的库和依赖项。

import os
from typing import Tuple

import dask_cudf
import dask_databricks
import dask_deltatable as ddt
import numpy as np
import xgboost as xgb
from dask_ml.model_selection import train_test_split
from distributed import wait
from xgboost import dask as dxgb

连接到Dask客户端#

连接到客户端(和可选的Dashboard)以提交任务。

client = dask_databricks.get_client()
client

客户端

Client-23114b4f-b7aa-11ee-87d9-9a67d50005f3

连接方法: 集群对象 集群类型: dask_databricks.DatabricksCluster
Dashboard: https://dbc-dp-8721196619973675.cloud.databricks.com/driver-proxy/o/8721196619973675/1031-230718-l2ubf858/8087/status

集群信息

下载数据集#

首先我们将数据集下载到Databricks文件存储(DBFS)。或者,您也可以使用云存储(S3Google CloudAzure Data Lake),有关更多信息,请参阅文档

import subprocess

# Define the directory and file paths
directory_path = "/dbfs/databricks/rapids"
file_path = f"{directory_path}/HIGGS.csv.gz"

# Check if directory already exists
if not os.path.exists(directory_path):
    os.makedirs(directory_path)

# Check if the file already exists
if not os.path.exists(file_path):
    # If not, download dataset to the directory
    data_url = (
        "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz"
    )
    download_command = f"curl {data_url} --output {file_path}"
    subprocess.run(download_command, shell=True)

    # decompress the csv file
    decompress_command = f"gunzip {file_path}"
    subprocess.run(decompress_command, shell=True)

接下来我们将数据加载到GPU中。由于数据在参数调优期间会被多次加载,为了获得更好的性能,我们将原始CSV文件转换为Parquet格式。使用delta lake可以轻松完成此操作,如后续步骤所示。

整合Dask和Delta Lake#

Delta Lake是Databricks lakehouse中的一个优化存储层,为存储数据和表提供了基础平台。这个开源软件通过集成基于文件的事务日志来扩展Parquet数据文件,以支持ACID事务和可扩展的元数据处理。

Delta Lake是Databricks上所有操作的默认存储格式,即(除非另有说明,否则Databricks上的所有表都是Delta表)。请查看教程以获取Delta Lake基本操作的示例。

让我们逐步探讨如何利用带有Dask的Delta Lake表来使用RAPIDS加速数据预处理。

使用Dask从Delta表读取数据#

使用Dask的dask-deltatable,我们可以使用Spark.csv文件写入Delta表,然后使用Dask进行读取和并行化。

delta_table_name = "higgs_delta_table"

# Check if the Delta table already exists
if spark.catalog.tableExists(delta_table_name):
    # If it exists, print a message
    print(f"The Delta table '{delta_table_name}' already exists.")
else:
    # If not, Load csv file into a Spark dataframe then
    # Write the spark dataframe into delta table
    data = spark.read.csv(file_path, header=True, inferSchema=True)
    data.write.saveAsTable(delta_table_name)
    print(f"The Delta table '{delta_table_name}' has been created.")
The Delta table 'higgs_delta_table' already exists.
display(spark.sql("DESCRIBE DETAIL higgs_delta_table"))
格式ID名称描述位置创建时间最后修改时间分区列文件数字节大小属性最小读取器版本最小写入器版本表特性统计信息
delta90cdac79-5500-4a20-914b-47f86b616275spark_catalog.default.higgs_delta_tablenulldbfs:/user/hive/warehouse/higgs_delta_table2024-01-09T15:01:35.629+00002024-01-09T15:04:37.000+0000List()60906326187Map()12List(appendOnly, invariants)Map()

调用dask_deltalake.read_deltalake()将返回一个dask dataframe。然而,我们的目标是对整个ML流水线(包括数据处理、模型训练和推理)使用GPU加速。因此,我们将使用dask_cudf.from_dask_dataframe()将dask dataframe读入cuDF dask-dataframe

请注意,这些操作将自动利用我们创建的Dask客户端,通过dask的并行性确保最佳性能提升。

# Read the Delta Lake into a Dask DataFrame using `dask-deltatable`
df = ddt.read_deltalake("/dbfs/user/hive/warehouse/higgs_delta_table")

# Convert Dask DataFrame to Dask cuDF for GPU acceleration
ddf = dask_cudf.from_dask_dataframe(df)

ddf.head()
1.000000000000000000e+00 8.692932128906250000e-01 -6.350818276405334473e-01 2.256902605295181274e-01 3.274700641632080078e-01 -6.899932026863098145e-01 7.542022466659545898e-01 -2.485731393098831177e-01 -1.092063903808593750e+00 0.000000000000000000e+009 ... -1.045456994324922562e-02 -4.576716944575309753e-02 3.101961374282836914e+00 1.353760004043579102e+00 9.795631170272827148e-01 9.780761599540710449e-01 9.200048446655273438e-01 7.216574549674987793e-01 9.887509346008300781e-01 8.766783475875854492e-01
0 1.0 0.907542 0.329147 0.359412 1.497970 -0.313010 1.095531 -0.557525 -1.588230 2.173076 ... -1.138930 -0.000819 0.000000 0.302220 0.833048 0.985700 0.978098 0.779732 0.992356 0.798343
1 1.0 0.798835 1.470639 -1.635975 0.453773 0.425629 1.104875 1.282322 1.381664 0.000000 ... 1.128848 0.900461 0.000000 0.909753 1.108330 0.985692 0.951331 0.803252 0.865924 0.780118
2 0.0 1.344385 -0.876626 0.935913 1.992050 0.882454 1.786066 -1.646778 -0.942383 0.000000 ... -0.678379 -1.360356 0.000000 0.946652 1.028704 0.998656 0.728281 0.869200 1.026736 0.957904
3 1.0 1.105009 0.321356 1.522401 0.882808 -1.205349 0.681466 -1.070464 -0.921871 0.000000 ... -0.373566 0.113041 0.000000 0.755856 1.361057 0.986610 0.838085 1.133295 0.872245 0.808487
4 0.0 1.595839 -0.607811 0.007075 1.818450 -0.111906 0.847550 -0.566437 1.581239 2.173076 ... -0.654227 -1.274345 3.101961 0.823761 0.938191 0.971758 0.789176 0.430553 0.961357 0.957818

5行 × 29列

colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)]
ddf.columns = colnames
ddf.head()
label feature-01 feature-02 feature-03 feature-04 feature-05 feature-06 feature-07 feature-08 feature-09 ... feature-19 feature-20 feature-21 feature-22 feature-23 feature-24 feature-25 feature-26 feature-27 feature-28
0 1.0 0.907542 0.329147 0.359412 1.497970 -0.313010 1.095531 -0.557525 -1.588230 2.173076 ... -1.138930 -0.000819 0.000000 0.302220 0.833048 0.985700 0.978098 0.779732 0.992356 0.798343
1 1.0 0.798835 1.470639 -1.635975 0.453773 0.425629 1.104875 1.282322 1.381664 0.000000 ... 1.128848 0.900461 0.000000 0.909753 1.108330 0.985692 0.951331 0.803252 0.865924 0.780118
2 0.0 1.344385 -0.876626 0.935913 1.992050 0.882454 1.786066 -1.646778 -0.942383 0.000000 ... -0.678379 -1.360356 0.000000 0.946652 1.028704 0.998656 0.728281 0.869200 1.026736 0.957904
3 1.0 1.105009 0.321356 1.522401 0.882808 -1.205349 0.681466 -1.070464 -0.921871 0.000000 ... -0.373566 0.113041 0.000000 0.755856 1.361057 0.986610 0.838085 1.133295 0.872245 0.808487
4 0.0 1.595839 -0.607811 0.007075 1.818450 -0.111906 0.847550 -0.566437 1.581239 2.173076 ... -0.654227 -1.274345 3.101961 0.823761 0.938191 0.971758 0.789176 0.430553 0.961357 0.957818

5行 × 29列

分割数据#

在前面的步骤中,我们使用了dask-cudf从Delta表加载数据,现在使用dask-ml中的train_test_split()函数来分割数据集。

大多数时候,Dask的GPU后端与dask-ml中的工具无缝协作,我们可以这样加速整个ML流水线

def load_higgs(
    ddf,
) -> Tuple[
    dask_cudf.core.DataFrame,
    dask_cudf.core.Series,
    dask_cudf.core.DataFrame,
    dask_cudf.core.Series,
]:
    y = ddf["label"]
    X = ddf[ddf.columns.difference(["label"])]

    X_train, X_valid, y_train, y_valid = train_test_split(
        X, y, test_size=0.33, random_state=42
    )
    X_train, X_valid, y_train, y_valid = client.persist(
        [X_train, X_valid, y_train, y_valid]
    )
    wait([X_train, X_valid, y_train, y_valid])

    return X_train, X_valid, y_train, y_valid
X_train, X_valid, y_train, y_valid = load_higgs(ddf)
/databricks/python/lib/python3.10/site-packages/dask_ml/model_selection/_split.py:462: FutureWarning: The default value for 'shuffle' must be specified when splitting DataFrames. In the future DataFrames will automatically be shuffled within blocks prior to splitting. Specify 'shuffle=True' to adopt the future behavior now, or 'shuffle=False' to retain the previous behavior.
  warnings.warn(
X_train.head()
feature-01 feature-02 feature-03 feature-04 feature-05 feature-06 feature-07 feature-08 feature-09 feature-10 ... feature-19 feature-20 feature-21 feature-22 feature-23 feature-24 feature-25 feature-26 feature-27 feature-28
0 0.907542 0.329147 0.359412 1.497970 -0.313010 1.095531 -0.557525 -1.588230 2.173076 0.812581 ... -1.138930 -0.000819 0.000000 0.302220 0.833048 0.985700 0.978098 0.779732 0.992356 0.798343
1 0.798835 1.470639 -1.635975 0.453773 0.425629 1.104875 1.282322 1.381664 0.000000 0.851737 ... 1.128848 0.900461 0.000000 0.909753 1.108330 0.985692 0.951331 0.803252 0.865924 0.780118
3 1.105009 0.321356 1.522401 0.882808 -1.205349 0.681466 -1.070464 -0.921871 0.000000 0.800872 ... -0.373566 0.113041 0.000000 0.755856 1.361057 0.986610 0.838085 1.133295 0.872245 0.808487
10 0.739357 -0.178290 0.829934 0.504539 -0.130217 0.961051 -0.355518 -1.717399 2.173076 0.620956 ... 0.774065 0.398820 3.101961 0.944536 1.026261 0.982197 0.542115 1.250979 0.830045 0.761308
11 1.384098 0.116822 -1.179879 0.762913 -0.079782 1.019863 0.877318 1.276887 2.173076 0.331252 ... 0.846521 0.504809 3.101961 0.959325 0.807376 1.191814 1.221210 0.861141 0.929341 0.838302

5行 × 28列

y_train.head()
Out[14]: 0     1.0
1     1.0
3     1.0
10    0.0
11    1.0
Name: label, dtype: float64

模型训练#

这里有两点需要注意。首先,我们指定了触发早期停止训练的回合数。一旦验证指标在连续X个回合内未能改善,XGBoost将停止训练过程,其中X是指定用于早期停止的回合数。

其次,我们使用名为DaskDeviceQuantileDMatrix的数据类型进行训练,但使用DaskDMatrix进行验证。DaskDeviceQuantileDMatrixDaskDMatrix的直接替代品,用于基于GPU的训练输入,可避免额外的数据复制。

def fit_model_es(client, X, y, X_valid, y_valid) -> dxgb.Booster:
    early_stopping_rounds = 5
    Xy = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
    Xy_valid = dxgb.DaskDMatrix(client, X_valid, y_valid)
    # train the model
    booster = dxgb.train(
        client,
        {
            "objective": "binary:logistic",
            "eval_metric": "error",
            "tree_method": "gpu_hist",
        },
        Xy,
        evals=[(Xy_valid, "Valid")],
        num_boost_round=1000,
        early_stopping_rounds=early_stopping_rounds,
    )["booster"]
    return booster
booster = fit_model_es(client, X=X_train, y=y_train, X_valid=X_valid, y_valid=y_valid)
booster
/databricks/python/lib/python3.10/site-packages/xgboost/dask.py:703: FutureWarning: Please use `DaskQuantileDMatrix` instead.
  warnings.warn("Please use `DaskQuantileDMatrix` instead.", FutureWarning)
Out[16]: <xgboost.core.Booster at 0x7f7c5702c4c0>

使用自定义目标和评估指标进行训练#

在下面的示例中,使用自定义的基于逻辑回归的目标函数(logit)和自定义评估指标(error)以及早期停止来训练XGBoost模型。

请注意,该函数返回梯度和Hessian,XGBoost使用它们来优化模型。此外,需要在我们的回调函数中指定名为metric_name的参数。它用于通知XGBoost应使用自定义误差函数来评估早期停止标准。

def fit_model_customized_objective(client, X, y, X_valid, y_valid) -> dxgb.Booster:
    def logit(predt: np.ndarray, Xy: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
        predt = 1.0 / (1.0 + np.exp(-predt))
        labels = Xy.get_label()
        grad = predt - labels
        hess = predt * (1.0 - predt)
        return grad, hess

    def error(predt: np.ndarray, Xy: xgb.DMatrix) -> Tuple[str, float]:
        label = Xy.get_label()
        r = np.zeros(predt.shape)
        predt = 1.0 / (1.0 + np.exp(-predt))
        gt = predt > 0.5
        r[gt] = 1 - label[gt]
        le = predt <= 0.5
        r[le] = label[le]
        return "CustomErr", float(np.average(r))

    # Use early stopping with custom objective and metric.
    early_stopping_rounds = 5
    # Specify the metric we want to use for early stopping.
    es = xgb.callback.EarlyStopping(
        rounds=early_stopping_rounds, save_best=True, metric_name="CustomErr"
    )

    Xy = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
    Xy_valid = dxgb.DaskDMatrix(client, X_valid, y_valid)
    booster = dxgb.train(
        client,
        {"eval_metric": "error", "tree_method": "gpu_hist"},
        Xy,
        evals=[(Xy_valid, "Valid")],
        num_boost_round=1000,
        obj=logit,  # pass the custom objective
        feval=error,  # pass the custom metric
        callbacks=[es],
    )["booster"]
    return booster
booster_custom = fit_model_customized_objective(
    client, X=X_train, y=y_train, X_valid=X_valid, y_valid=y_valid
)
booster_custom
/databricks/python/lib/python3.10/site-packages/xgboost/dask.py:703: FutureWarning: Please use `DaskQuantileDMatrix` instead.
  warnings.warn("Please use `DaskQuantileDMatrix` instead.", FutureWarning)
Out[18]: <xgboost.core.Booster at 0x7f7c5702cd30>

运行推理#

经过一些调优后,我们得到了用于对新数据执行推理的最终模型。

def predict(client, model, X):
    predt = dxgb.predict(client, model, X)
    return predt
preds = predict(client, booster, X_train)
preds.head()
Out[20]: 0     0.843650
1     0.975618
3     0.378462
10    0.293985
11    0.966303
Name: 0, dtype: float32

清理#

完成后,务必销毁您的集群,以避免因空闲资源产生额外费用。

注意 如果您忘记手动销毁集群,请务必注意Databricks集群会在一段时间后自动超时(在创建集群时指定)。

client.close()