cuda.hpp
1 /*
2  * 版权所有 (c) 2022-2025, NVIDIA CORPORATION。
3  *
4  * 根据 Apache 许可证 2.0 版(“许可证”)获得许可;
5  * 除非遵守许可证,否则您不得使用此文件。
6  * 您可以在以下位置获取许可证副本:
7  *
8  * https://apache.ac.cn/licenses/LICENSE-2.0
9  *
10  * 除非适用法律要求或书面同意,否则软件
11  * 根据许可证分发,按“原样”提供,
12  * 不附带任何明示或默示的担保或条件。
13  * 请参阅许可证了解特定语言的管理权限和
14  * 许可证下的限制。
15  */
16 #pragma once
17 
18 #include <kvikio/shim/cuda_h_wrapper.hpp>
19 #include <kvikio/shim/utils.hpp>
20 
21 namespace kvikio {
22 
30 class cudaAPI {
31  public
32  decltype(cuInit)* Init{nullptr};
33  decltype(cuMemHostAlloc)* MemHostAlloc{nullptr};
34  decltype(cuMemFreeHost)* MemFreeHost{nullptr};
35  decltype(cuMemcpyHtoDAsync)* MemcpyHtoDAsync{nullptr};
36  decltype(cuMemcpyDtoHAsync)* MemcpyDtoHAsync{nullptr};
37  decltype(cuPointerGetAttribute)* PointerGetAttribute{nullptr};
38  decltype(cuPointerGetAttributes)* PointerGetAttributes{nullptr};
39  decltype(cuCtxPushCurrent)* CtxPushCurrent{nullptr};
40  decltype(cuCtxPopCurrent)* CtxPopCurrent{nullptr};
41  decltype(cuCtxGetCurrent)* CtxGetCurrent{nullptr};
42  decltype(cuMemGetAddressRange)* MemGetAddressRange{nullptr};
43  decltype(cuGetErrorName)* GetErrorName{nullptr};
44  decltype(cuGetErrorString)* GetErrorString{nullptr};
45  decltype(cuDeviceGet)* DeviceGet{nullptr};
46  decltype(cuDevicePrimaryCtxRetain)* DevicePrimaryCtxRetain{nullptr};
47  decltype(cuDevicePrimaryCtxRelease)* DevicePrimaryCtxRelease{nullptr};
48  decltype(cuStreamSynchronize)* StreamSynchronize{nullptr};
49  decltype(cuStreamCreate)* StreamCreate{nullptr};
50  decltype(cuStreamDestroy)* StreamDestroy{nullptr};
51 
52  private
53  cudaAPI();
54 
55  public
56  cudaAPI(cudaAPI const&) = delete;
57  void operator=(cudaAPI const&) = delete;
58 
59  KVIKIO_EXPORT static cudaAPI& instance();
60 };
61 
69 #ifdef KVIKIO_CUDA_FOUND
70 bool is_cuda_available();
71 #else
72 constexpr bool is_cuda_available() { return false; }
73 #endif
74 
75 } // 命名空间 kvikio
CUDA C-API 的 Shim 层。
定义: cuda.hpp:30
KvikIO 命名空间。
定义: batch.hpp:27
constexpr bool is_cuda_available()
检查 CUDA 库是否可用。
定义: cuda.hpp:72