1 | /* |
2 | * Copyright 2021 NVIDIA Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 with the LLVM exception |
5 | * (the "License"); you may not use this file except in compliance with |
6 | * the License. |
7 | * |
8 | * You may obtain a copy of the License at |
9 | * |
10 | * http://llvm.org/foundation/relicensing/LICENSE.txt |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, software |
13 | * distributed under the License is distributed on an "AS IS" BASIS, |
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | * See the License for the specific language governing permissions and |
16 | * limitations under the License. |
17 | * \file l2_cache_flush.h |
18 | * \brief Functions to flush L2 cache using CUDA's API, adopted from nvbench. |
19 | */ |
20 | #ifndef L2_CACHE_FLUSH_H_ |
21 | #define L2_CACHE_FLUSH_H_ |
22 | |
23 | #include <cuda.h> |
24 | #include <cuda_runtime.h> |
25 | #include <dmlc/logging.h> |
26 | |
27 | namespace tvm { |
28 | namespace runtime { |
29 | |
30 | #define CUDA_CALL(func) \ |
31 | { \ |
32 | cudaError_t e = (func); \ |
33 | ICHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ |
34 | << "CUDA: " << cudaGetErrorString(e); \ |
35 | } |
36 | |
37 | class L2Flush { |
38 | public: |
39 | L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {} |
40 | |
41 | ~L2Flush() { |
42 | if (l2_size_ > 0) { |
43 | CUDA_CALL(cudaFree(l2_buffer_)); |
44 | } |
45 | } |
46 | |
47 | void Flush(cudaStream_t stream) { |
48 | if (!initialized_) { |
49 | // initialize l2_buffer_ and l2_size_ |
50 | initialized_ = true; |
51 | int device_id; |
52 | CUDA_CALL(cudaGetDevice(&device_id)); |
53 | CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id)); |
54 | if (l2_size_ > 0) { |
55 | CUDA_CALL(cudaMalloc(reinterpret_cast<void**>(&l2_buffer_), l2_size_)); |
56 | } |
57 | } |
58 | if (l2_size_ > 0) { |
59 | CUDA_CALL(cudaMemsetAsync(l2_buffer_, 0, l2_size_, stream)); |
60 | } |
61 | } |
62 | |
63 | static L2Flush* ThreadLocal(); |
64 | |
65 | private: |
66 | bool initialized_ = false; |
67 | int l2_size_; |
68 | int* l2_buffer_; |
69 | }; |
70 | |
71 | } // namespace runtime |
72 | } // namespace tvm |
73 | |
74 | #endif // L2_CACHE_FLUSH_H_ |
75 | |