1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file cuda_module.cc |
22 | */ |
23 | #include "cuda_module.h" |
24 | |
25 | #include <cuda.h> |
26 | #include <cuda_runtime.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | #include <array> |
30 | #include <mutex> |
31 | #include <string> |
32 | #include <unordered_map> |
33 | #include <vector> |
34 | |
35 | #include "../file_utils.h" |
36 | #include "../meta_data.h" |
37 | #include "../pack_args.h" |
38 | #include "../thread_storage_scope.h" |
39 | #include "cuda_common.h" |
40 | |
41 | namespace tvm { |
42 | namespace runtime { |
43 | |
44 | // Module to support thread-safe multi-GPU execution. |
45 | // cuModule is a per-GPU module |
46 | // The runtime will contain a per-device module table |
47 | // The modules will be lazily loaded |
48 | class CUDAModuleNode : public runtime::ModuleNode { |
49 | public: |
50 | explicit CUDAModuleNode(std::string data, std::string fmt, |
51 | std::unordered_map<std::string, FunctionInfo> fmap, |
52 | std::string cuda_source) |
53 | : data_(data), fmt_(fmt), fmap_(fmap), cuda_source_(cuda_source) { |
54 | std::fill(module_.begin(), module_.end(), nullptr); |
55 | } |
56 | // destructor |
57 | ~CUDAModuleNode() { |
58 | for (size_t i = 0; i < module_.size(); ++i) { |
59 | if (module_[i] != nullptr) { |
60 | CUDA_CALL(cudaSetDevice(static_cast<int>(i))); |
61 | CUDA_DRIVER_CALL(cuModuleUnload(module_[i])); |
62 | } |
63 | } |
64 | } |
65 | |
66 | const char* type_key() const final { return "cuda" ; } |
67 | |
68 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final; |
69 | |
70 | void SaveToFile(const std::string& file_name, const std::string& format) final { |
71 | std::string fmt = GetFileFormat(file_name, format); |
72 | std::string meta_file = GetMetaFilePath(file_name); |
73 | if (fmt == "cu" ) { |
74 | ICHECK_NE(cuda_source_.length(), 0); |
75 | SaveMetaDataToFile(meta_file, fmap_); |
76 | SaveBinaryToFile(file_name, cuda_source_); |
77 | } else { |
78 | ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; |
79 | SaveMetaDataToFile(meta_file, fmap_); |
80 | SaveBinaryToFile(file_name, data_); |
81 | } |
82 | } |
83 | |
84 | void SaveToBinary(dmlc::Stream* stream) final { |
85 | stream->Write(fmt_); |
86 | stream->Write(fmap_); |
87 | stream->Write(data_); |
88 | } |
89 | |
90 | std::string GetSource(const std::string& format) final { |
91 | if (format == fmt_) return data_; |
92 | if (cuda_source_.length() != 0) { |
93 | return cuda_source_; |
94 | } else { |
95 | if (fmt_ == "ptx" ) return data_; |
96 | return "" ; |
97 | } |
98 | } |
99 | |
100 | // get a CUfunction from primary context in device_id |
101 | CUfunction GetFunc(int device_id, const std::string& func_name) { |
102 | std::lock_guard<std::mutex> lock(mutex_); |
103 | // must recheck under the lock scope |
104 | if (module_[device_id] == nullptr) { |
105 | CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str())); |
106 | } |
107 | CUfunction func; |
108 | CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str()); |
109 | if (result != CUDA_SUCCESS) { |
110 | const char* msg; |
111 | cuGetErrorName(result, &msg); |
112 | LOG(FATAL) << "CUDAError: cuModuleGetFunction " << func_name << " failed with error: " << msg; |
113 | } |
114 | return func; |
115 | } |
116 | // get a global var from primary context in device_id |
117 | CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) { |
118 | std::lock_guard<std::mutex> lock(mutex_); |
119 | // must recheck under the lock scope |
120 | if (module_[device_id] == nullptr) { |
121 | CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str())); |
122 | } |
123 | CUdeviceptr global; |
124 | size_t nbytes; |
125 | |
126 | CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str()); |
127 | ICHECK_EQ(nbytes, expect_nbytes); |
128 | if (result != CUDA_SUCCESS) { |
129 | const char* msg; |
130 | cuGetErrorName(result, &msg); |
131 | LOG(FATAL) << "CUDAError: cuModuleGetGlobal " << global_name << " failed with error: " << msg; |
132 | } |
133 | return global; |
134 | } |
135 | |
136 | private: |
137 | // the binary data |
138 | std::string data_; |
139 | // The format |
140 | std::string fmt_; |
141 | // function information table. |
142 | std::unordered_map<std::string, FunctionInfo> fmap_; |
143 | // The cuda source. |
144 | std::string cuda_source_; |
145 | // the internal modules per GPU, to be lazily initialized. |
146 | std::array<CUmodule, kMaxNumGPUs> module_; |
147 | // internal mutex when updating the module |
148 | std::mutex mutex_; |
149 | }; |
150 | |
151 | // a wrapped function class to get packed func. |
152 | class CUDAWrappedFunc { |
153 | public: |
154 | // initialize the CUDA function. |
155 | void Init(CUDAModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name, |
156 | size_t num_void_args, const std::vector<std::string>& launch_param_tags) { |
157 | m_ = m; |
158 | sptr_ = sptr; |
159 | func_name_ = func_name; |
160 | std::fill(fcache_.begin(), fcache_.end(), nullptr); |
161 | launch_param_config_.Init(num_void_args, launch_param_tags); |
162 | } |
163 | // invoke the function with void arguments |
164 | void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { |
165 | int device_id; |
166 | CUDA_CALL(cudaGetDevice(&device_id)); |
167 | ThreadWorkLoad wl = launch_param_config_.Extract(args); |
168 | |
169 | if (fcache_[device_id] == nullptr) { |
170 | fcache_[device_id] = m_->GetFunc(device_id, func_name_); |
171 | if (wl.dyn_shmem_size >= (48 << 10)) { |
172 | // Assumption: dyn_shmem_size doesn't change across different invocations of |
173 | // fcache_[device_id] |
174 | CUresult result = cuFuncSetAttribute( |
175 | fcache_[device_id], CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, wl.dyn_shmem_size); |
176 | if (result != CUDA_SUCCESS) { |
177 | LOG(FATAL) << "Failed to set the allowed dynamic shared memory size to " |
178 | << wl.dyn_shmem_size; |
179 | } |
180 | } |
181 | } |
182 | CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream); |
183 | CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), |
184 | wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), |
185 | wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); |
186 | if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { |
187 | const char* msg; |
188 | cuGetErrorName(result, &msg); |
189 | std::ostringstream os; |
190 | os << "CUDALaunch Error: " << msg << "\n" |
191 | << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " |
192 | << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) |
193 | << ")\n" ; |
194 | std::string cuda = m_->GetSource("" ); |
195 | if (cuda.length() != 0) { |
196 | os << "// func_name=" << func_name_ << "\n" |
197 | << "// CUDA Source\n" |
198 | << "// -----------\n" |
199 | << cuda; |
200 | } |
201 | LOG(FATAL) << os.str(); |
202 | } |
203 | } |
204 | |
205 | private: |
206 | // internal module |
207 | CUDAModuleNode* m_; |
208 | // the resource holder |
209 | ObjectPtr<Object> sptr_; |
210 | // The name of the function. |
211 | std::string func_name_; |
212 | // Device function cache per device. |
213 | // mark as mutable, to enable lazy initialization |
214 | mutable std::array<CUfunction, kMaxNumGPUs> fcache_; |
215 | // launch parameters configuration |
216 | LaunchParamConfig launch_param_config_; |
217 | }; |
218 | |
219 | class CUDAPrepGlobalBarrier { |
220 | public: |
221 | CUDAPrepGlobalBarrier(CUDAModuleNode* m, ObjectPtr<Object> sptr) : m_(m), sptr_(sptr) { |
222 | std::fill(pcache_.begin(), pcache_.end(), 0); |
223 | } |
224 | |
225 | void operator()(const TVMArgs& args, TVMRetValue* rv) const { |
226 | int device_id; |
227 | CUDA_CALL(cudaGetDevice(&device_id)); |
228 | if (pcache_[device_id] == 0) { |
229 | pcache_[device_id] = |
230 | m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned)); |
231 | } |
232 | CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1)); |
233 | } |
234 | |
235 | private: |
236 | // internal module |
237 | CUDAModuleNode* m_; |
238 | // the resource holder |
239 | ObjectPtr<Object> sptr_; |
240 | // mark as mutable, to enable lazy initialization |
241 | mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_; |
242 | }; |
243 | |
244 | PackedFunc CUDAModuleNode::GetFunction(const std::string& name, |
245 | const ObjectPtr<Object>& sptr_to_self) { |
246 | ICHECK_EQ(sptr_to_self.get(), this); |
247 | ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main" ; |
248 | if (name == symbol::tvm_prepare_global_barrier) { |
249 | return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self)); |
250 | } |
251 | auto it = fmap_.find(name); |
252 | if (it == fmap_.end()) return PackedFunc(); |
253 | const FunctionInfo& info = it->second; |
254 | CUDAWrappedFunc f; |
255 | f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); |
256 | return PackFuncVoidAddr(f, info.arg_types); |
257 | } |
258 | |
259 | Module CUDAModuleCreate(std::string data, std::string fmt, |
260 | std::unordered_map<std::string, FunctionInfo> fmap, |
261 | std::string cuda_source) { |
262 | auto n = make_object<CUDAModuleNode>(data, fmt, fmap, cuda_source); |
263 | return Module(n); |
264 | } |
265 | |
266 | // Load module from module. |
267 | Module CUDAModuleLoadFile(const std::string& file_name, const std::string& format) { |
268 | std::string data; |
269 | std::unordered_map<std::string, FunctionInfo> fmap; |
270 | std::string fmt = GetFileFormat(file_name, format); |
271 | std::string meta_file = GetMetaFilePath(file_name); |
272 | LoadBinaryFromFile(file_name, &data); |
273 | LoadMetaDataFromFile(meta_file, &fmap); |
274 | return CUDAModuleCreate(data, fmt, fmap, std::string()); |
275 | } |
276 | |
277 | Module CUDAModuleLoadBinary(void* strm) { |
278 | dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); |
279 | std::string data; |
280 | std::unordered_map<std::string, FunctionInfo> fmap; |
281 | std::string fmt; |
282 | stream->Read(&fmt); |
283 | stream->Read(&fmap); |
284 | stream->Read(&data); |
285 | return CUDAModuleCreate(data, fmt, fmap, std::string()); |
286 | } |
287 | |
288 | TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin" ).set_body_typed(CUDAModuleLoadFile); |
289 | |
290 | TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx" ).set_body_typed(CUDAModuleLoadFile); |
291 | |
292 | TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda" ).set_body_typed(CUDAModuleLoadBinary); |
293 | } // namespace runtime |
294 | } // namespace tvm |
295 | |