过滤向量索引#

cuVS 支持根据所使用的向量索引类型进行不同的过滤。所有向量索引中使用的主要方法是预过滤,这是一种在计算最近邻之前考虑向量过滤的技术,从而节省了计算距离所需的计算量。

位集合#

位集合是一个位数组,其中每个位可以有两个可能的值:01,在过滤上下文中,它们表示是否应过滤样本。0 表示相应的向量将被过滤,因此不会出现在搜索结果中。这种机制经过优化,占用尽可能少的内存空间,并通过 RAFT 库提供(请查阅 RAFT 的 bitset API documentation)。调用 ANN 索引的搜索函数时,位集合的长度应与数据库中存在的向量数量匹配。

位图#

位图与位集合基于相同的原理,但采用二维结构。这允许用户为每个搜索查询提供不同的位集合。请查阅 RAFT 的 bitmap API documentation

示例#

在 CAGRA 索引上使用位集合过滤#

#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/core/bitset.hpp>

using namespace cuvs::neighbors;
cagra::index index;

// ... build index ...

cagra::search_params search_params;
raft::device_resources res;
raft::device_matrix_view<float> queries = load_queries();
raft::device_matrix_view<uint32_t> neighbors = make_device_matrix_view<uint32_t>(n_queries, k);
raft::device_matrix_view<float> distances = make_device_matrix_view<float>(n_queries, k);

// Load a list of all the samples that will get filtered
std::vector<uint32_t> removed_indices_host = get_invalid_indices();
auto removed_indices_device =
      raft::make_device_vector<uint32_t, uint32_t>(res, removed_indices_host.size());
// Copy this list to device
raft::copy(removed_indices_device.data_handle(), removed_indices_host.data(),
           removed_indices_host.size(), raft::resource::get_cuda_stream(res));

// Create a bitset with the list of samples to filter.
cuvs::core::bitset<uint32_t, uint32_t> removed_indices_bitset(
    res, removed_indices_device.view(), index.size());
// Use a `bitset_filter` in the `cagra::search` function call.
auto bitset_filter =
      cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view());
cagra::search(res,
              search_params,
              index,
              queries,
              neighbors,
              distances,
              bitset_filter);

在暴力搜索索引上使用位图过滤#

#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/core/bitmap.hpp>

using namespace cuvs::neighbors;
using indexing_dtype = int64_t;

// ... build index ...
brute_force::index_params index_params;
brute_force::search_params search_params;
raft::device_resources res;
raft::device_matrix_view<float, indexing_dtype> dataset = load_dataset(n_vectors, dim);
raft::device_matrix_view<float, indexing_dtype> queries = load_queries(n_queries, dim);
auto index = brute_force::build(res, index_params, raft::make_const_mdspan(dataset.view()));

// Load a list of all the samples that will get filtered
std::vector<uint32_t> removed_indices_host = get_invalid_indices();
auto removed_indices_device =
      raft::make_device_vector<uint32_t, uint32_t>(res, removed_indices_host.size());
// Copy this list to device
raft::copy(removed_indices_device.data_handle(), removed_indices_host.data(),
           removed_indices_host.size(), raft::resource::get_cuda_stream(res));

// Create a bitmap with the list of samples to filter.
cuvs::core::bitset<uint32_t, indexing_dtype> removed_indices_bitset(
  res, removed_indices_device.view(), n_queries * n_vectors);
cuvs::core::bitmap_view<const uint32_t, indexing_dtype> removed_indices_bitmap(
    removed_indices_bitset.data(), n_queries, n_vectors);

// Use a `bitmap_filter` in the `brute_force::search` function call.
auto bitmap_filter =
      cuvs::neighbors::filtering::bitmap_filter(removed_indices_bitmap);

auto neighbors = raft::make_device_matrix_view<uint32_t, indexing_dtype>(n_queries, k);
auto distances = raft::make_device_matrix_view<float, indexing_dtype>(n_queries, k);
brute_force::search(res,
                    search_params,
                    index,
                    raft::make_const_mdspan(queries.view()),
                    neighbors.view(),
                    distances.view(),
                    bitmap_filter);