过滤向量索引#
cuVS 支持根据所使用的向量索引类型进行不同的过滤。所有向量索引中使用的主要方法是预过滤,这是一种在计算最近邻之前考虑向量过滤的技术,从而节省了计算距离所需的计算量。
位集合#
位集合是一个位数组,其中每个位可以有两个可能的值:0
和 1
,在过滤上下文中,它们表示是否应过滤样本。0
表示相应的向量将被过滤,因此不会出现在搜索结果中。这种机制经过优化,占用尽可能少的内存空间,并通过 RAFT 库提供(请查阅 RAFT 的 bitset API documentation
)。调用 ANN 索引的搜索函数时,位集合的长度应与数据库中存在的向量数量匹配。
位图#
位图与位集合基于相同的原理,但采用二维结构。这允许用户为每个搜索查询提供不同的位集合。请查阅 RAFT 的 bitmap API documentation
。
示例#
在 CAGRA 索引上使用位集合过滤#
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/core/bitset.hpp>
using namespace cuvs::neighbors;
cagra::index index;
// ... build index ...
cagra::search_params search_params;
raft::device_resources res;
raft::device_matrix_view<float> queries = load_queries();
raft::device_matrix_view<uint32_t> neighbors = make_device_matrix_view<uint32_t>(n_queries, k);
raft::device_matrix_view<float> distances = make_device_matrix_view<float>(n_queries, k);
// Load a list of all the samples that will get filtered
std::vector<uint32_t> removed_indices_host = get_invalid_indices();
auto removed_indices_device =
raft::make_device_vector<uint32_t, uint32_t>(res, removed_indices_host.size());
// Copy this list to device
raft::copy(removed_indices_device.data_handle(), removed_indices_host.data(),
removed_indices_host.size(), raft::resource::get_cuda_stream(res));
// Create a bitset with the list of samples to filter.
cuvs::core::bitset<uint32_t, uint32_t> removed_indices_bitset(
res, removed_indices_device.view(), index.size());
// Use a `bitset_filter` in the `cagra::search` function call.
auto bitset_filter =
cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view());
cagra::search(res,
search_params,
index,
queries,
neighbors,
distances,
bitset_filter);
在暴力搜索索引上使用位图过滤#
#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/core/bitmap.hpp>
using namespace cuvs::neighbors;
using indexing_dtype = int64_t;
// ... build index ...
brute_force::index_params index_params;
brute_force::search_params search_params;
raft::device_resources res;
raft::device_matrix_view<float, indexing_dtype> dataset = load_dataset(n_vectors, dim);
raft::device_matrix_view<float, indexing_dtype> queries = load_queries(n_queries, dim);
auto index = brute_force::build(res, index_params, raft::make_const_mdspan(dataset.view()));
// Load a list of all the samples that will get filtered
std::vector<uint32_t> removed_indices_host = get_invalid_indices();
auto removed_indices_device =
raft::make_device_vector<uint32_t, uint32_t>(res, removed_indices_host.size());
// Copy this list to device
raft::copy(removed_indices_device.data_handle(), removed_indices_host.data(),
removed_indices_host.size(), raft::resource::get_cuda_stream(res));
// Create a bitmap with the list of samples to filter.
cuvs::core::bitset<uint32_t, indexing_dtype> removed_indices_bitset(
res, removed_indices_device.view(), n_queries * n_vectors);
cuvs::core::bitmap_view<const uint32_t, indexing_dtype> removed_indices_bitmap(
removed_indices_bitset.data(), n_queries, n_vectors);
// Use a `bitmap_filter` in the `brute_force::search` function call.
auto bitmap_filter =
cuvs::neighbors::filtering::bitmap_filter(removed_indices_bitmap);
auto neighbors = raft::make_device_matrix_view<uint32_t, indexing_dtype>(n_queries, k);
auto distances = raft::make_device_matrix_view<float, indexing_dtype>(n_queries, k);
brute_force::search(res,
search_params,
index,
raft::make_const_mdspan(queries.view()),
neighbors.view(),
distances.view(),
bitmap_filter);