注意
RAFT 中的向量搜索和聚类算法正在迁移到一个专门用于向量搜索的新库,名为 cuVS。在此迁移期间,我们将继续支持 RAFT 中的向量搜索算法,但在 RAPIDS 24.06(6 月)版本之后将不再更新它们。我们计划在 RAPIDS 24.10(10 月)版本之前完成迁移,并在 24.12(12 月)版本中将它们从 RAFT 中完全移除。
回归模型评分#
信息准则#
#include <raft/stats/information_criterion.cuh>
namespace raft::stats
-
template<typename value_t, typename idx_t>
void information_criterion_batched(raft::resources const &handle, raft::device_vector_view<const value_t, idx_t> d_loglikelihood, raft::device_vector_view<value_t, idx_t> d_ic, IC_Type ic_type, idx_t n_params, idx_t n_samples)# 计算给定类型的信息准则
注意
: 进行就地计算是安全的(即输入和输出使用相同的指针)参见
- 模板参数:
value_t – 数据类型
idx_t – 索引类型
- 参数:
handle – [in] raft 句柄
d_loglikelihood – [in] 每条序列的对数似然 (设备端) 长度: batch_size
d_ic – [out] 返回的每条序列的信息准则 (设备端) 长度: batch_size
ic_type – [in] 要计算的准则类型。参见 IC_Type
n_params – [in] 模型中的参数数量
n_samples – [in] 每条序列中的样本数量
R2 分数#
#include <raft/stats/r2_score.cuh>
namespace raft::stats
-
template<typename value_t, typename idx_t>
device_vector_view<const value_t, idx_t> y, raft::device_vector_view<const value_t, idx_t> y_hat)# 计算“决定系数”(R 方)分数,通过总平方和对平方误差和进行归一化。
该分数表示线性回归模型中自变量解释的预期响应变量变异的比例。R 方值越大,线性回归模型解释的变异越多。
注意
y 和 y_hat 的 const 属性目前被强制取消。
- 模板参数:
value_t – 数据类型
idx_t – 索引类型
- 参数:
handle – [in] raft 句柄
y – [in] 真实响应变量数组
y_hat – [in] 预测响应变量数组
- 返回值:
: R 方值。
回归指标#
#include <raft/stats/regression_metrics.cuh>
namespace raft::stats
-
template<typename value_t, typename idx_t>
void regression_metrics(raft::resources const &handle, raft::device_vector_view<const value_t, idx_t> predictions, raft::device_vector_view<const value_t, idx_t> ref_predictions, raft::host_scalar_view<double> mean_abs_error, raft::host_scalar_view<double> mean_squared_error, raft::host_scalar_view<double> median_abs_error)# 计算回归指标:平均绝对误差、均方误差、中位数绝对误差。
- 模板参数:
value_t – 预测值的数据类型(例如,回归使用 float 或 double)。
idx_t – 索引类型
- 参数:
handle – [in] raft 句柄
predictions – [in] 预测值数组。
ref_predictions – [in] 参考(真实)预测值数组。
mean_abs_error – [out] 平均绝对误差。(|predictions[i] - ref_predictions[i]|) 在 n 上的总和 / n。
mean_squared_error – [out] 均方误差。((predictions[i] - ref_predictions[i])^2) 在 n 上的总和 / n。
median_abs_error – [out] 中位数绝对误差。|predictions[i] - ref_predictions[i]| 在 i ∈ [0, n) 范围内的中位数。