failure_callback_resource_adaptor.hpp
转到此文件的文档。
1 /*
2  * Copyright (c) 2020-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * https://apache.ac.cn/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <rmm/detail/error.hpp>
19 #include <rmm/detail/export.hpp>
22 #include <rmm/resource_ref.hpp>
23 
24 #include <cstddef>
25 #include <functional>
26 #include <utility>
27 
28 namespace RMM_NAMESPACE {
29 namespace mr {
51 using failure_callback_t = std::function<bool(std::size_t, void*)>;
52 
94 template <typename Upstream, typename ExceptionType = rmm::out_of_memory>
96  public
97  using exception_type = ExceptionType;
98 
108  failure_callback_t callback,
109  void* callback_arg)
110  : upstream_{upstream}, callback_{std::move(callback)}, callback_arg_{callback_arg}
111  {
112  }
113 
125  failure_callback_t callback,
126  void* callback_arg)
127  : upstream_{to_device_async_resource_ref_checked(upstream)},
128  callback_{std::move(callback)},
129  callback_arg_{callback_arg}
130  {
131  }
132 
134  ~failure_callback_resource_adaptor() override = default;
138  default;
140  default;
141 
145  [[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept
146  {
147  return upstream_;
148  }
149 
150  private
162  void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
163  {
164  void* ret{};
165 
166  while (true) {
167  try {
168  ret = get_upstream_resource().allocate_async(bytes, stream);
169  break;
170  } catch (exception_type const& e) {
171  if (!callback_(bytes, callback_arg_)) { throw; }
172  }
173  }
174  return ret;
175  }
176 
184  void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override
185  {
186  get_upstream_resource().deallocate_async(ptr, bytes, stream);
187  }
188 
196  [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
197  {
198  if (this == &other) { return true; }
199  auto cast = dynamic_cast<failure_callback_resource_adaptor<Upstream> const*>(&other);
200  if (cast == nullptr) { return false; }
201  return get_upstream_resource() == cast->get_upstream_resource();
202  }
203 
204  // 用于满足分配请求的上游资源
205  device_async_resource_ref upstream_;
206  failure_callback_t callback_;
207  void* callback_arg_;
208 };
209  // 组结束
211 } // namespace mr
212 } // namespace RMM_NAMESPACE
CUDA 流的强类型非拥有包装器,带有默认构造函数。
定义: cuda_stream_view.hpp:39
所有 librmm 设备内存分配的基类。
定义: device_memory_resource.hpp:92
void * allocate_async(std::size_t bytes, std::size_t alignment, cuda_stream_view stream)
分配至少 bytes 大小的内存。
定义: device_memory_resource.hpp:215
当分配抛出指定的异常时,调用回调函数的设备内存资源...
定义: failure_callback_resource_adaptor.hpp:95
failure_callback_resource_adaptor(device_async_resource_ref upstream, failure_callback_t callback, void *callback_arg)
使用 upstream 构造一个新的 failure_callback_resource_adaptor 来满足分配请求。
定义: failure_callback_resource_adaptor.hpp:107
ExceptionType exception_type
此对象捕获/抛出的异常类型。
定义: failure_callback_resource_adaptor.hpp:97
failure_callback_resource_adaptor(Upstream *upstream, failure_callback_t callback, void *callback_arg)
使用 upstream 构造一个新的 failure_callback_resource_adaptor 来满足分配请求。
定义: failure_callback_resource_adaptor.hpp:124
failure_callback_resource_adaptor(failure_callback_resource_adaptor &&) noexcept=default
默认移动构造函数。
std::function< bool(std::size_t, void *)> failure_callback_t
由 failure_callback_resource_adaptor 使用的回调函数类型。
定义: failure_callback_resource_adaptor.hpp:51
cuda::mr::async_resource_ref< cuda::mr::device_accessible > device_async_resource_ref
cuda::mr::async_resource_ref 的别名,带有 cuda::mr::device_accessible 属性。
定义: resource_ref.hpp:40
device_async_resource_ref to_device_async_resource_ref_checked(Resource *res)
将内存资源指针转换为 device_async_resource_ref,并检查 nullptr
定义: resource_ref.hpp:78
每设备 device_memory_resource 的管理。