1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rpc_module.cc |
22 | * \brief RPC runtime module. |
23 | */ |
24 | #include <tvm/runtime/container/string.h> |
25 | #include <tvm/runtime/device_api.h> |
26 | #include <tvm/runtime/profiling.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | #include <chrono> |
30 | #include <cstring> |
31 | #include <memory> |
32 | #include <thread> |
33 | #if defined(_M_X64) || defined(__x86_64__) |
34 | #include <immintrin.h> |
35 | #endif |
36 | |
37 | #include "rpc_endpoint.h" |
38 | #include "rpc_session.h" |
39 | |
40 | namespace tvm { |
41 | namespace runtime { |
42 | |
43 | // deleter of RPC remote array |
44 | static void RemoteNDArrayDeleter(Object* obj) { |
45 | auto* ptr = static_cast<NDArray::Container*>(obj); |
46 | RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data); |
47 | if (ptr->manager_ctx != nullptr) { |
48 | space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle); |
49 | } |
50 | delete space; |
51 | delete ptr; |
52 | } |
53 | |
54 | /*! |
55 | * \brief Build a local NDArray with remote backing storage. |
56 | * \param sess the RPCSession which owns the given handle. |
57 | * \param handle A pointer valid on the remote end which should form the `data` field of the |
58 | * underlying DLTensor. |
59 | * \param template_tensor An empty DLTensor whose shape and dtype fields are used to fill the newly |
60 | * created array. Needed because it's difficult to pass a shape vector as a PackedFunc arg. |
61 | * \param dev Remote device used with this tensor. Must have non-zero RPCSessMask. |
62 | * \param remote_ndarray_handle The handle returned by RPC server to identify the NDArray. |
63 | */ |
64 | NDArray NDArrayFromRemoteOpaqueHandle(std::shared_ptr<RPCSession> sess, void* handle, |
65 | DLTensor* template_tensor, Device dev, |
66 | void* remote_ndarray_handle) { |
67 | ICHECK_EQ(sess->table_index(), GetRPCSessionIndex(dev)) |
68 | << "The Device given does not belong to the given session" ; |
69 | RemoteSpace* space = new RemoteSpace(); |
70 | space->sess = sess; |
71 | space->data = handle; |
72 | std::vector<int64_t> shape_vec{template_tensor->shape, |
73 | template_tensor->shape + template_tensor->ndim}; |
74 | NDArray::Container* data = new NDArray::Container(static_cast<void*>(space), std::move(shape_vec), |
75 | template_tensor->dtype, dev); |
76 | data->manager_ctx = remote_ndarray_handle; |
77 | data->SetDeleter(RemoteNDArrayDeleter); |
78 | return NDArray(GetObjectPtr<Object>(data)); |
79 | } |
80 | |
81 | /*! |
82 | * \brief A wrapped remote function as a PackedFunc. |
83 | */ |
84 | class RPCWrappedFunc : public Object { |
85 | public: |
86 | RPCWrappedFunc(void* handle, std::shared_ptr<RPCSession> sess) : handle_(handle), sess_(sess) {} |
87 | |
88 | void operator()(TVMArgs args, TVMRetValue* rv) const { |
89 | std::vector<TVMValue> values(args.values, args.values + args.size()); |
90 | std::vector<int> type_codes(args.type_codes, args.type_codes + args.size()); |
91 | std::vector<std::unique_ptr<DLTensor>> temp_dltensors; |
92 | |
93 | // scan and check whether we need rewrite these arguments |
94 | // to their remote variant. |
95 | for (int i = 0; i < args.size(); ++i) { |
96 | if (args[i].IsObjectRef<String>()) { |
97 | String str = args[i]; |
98 | type_codes[i] = kTVMStr; |
99 | values[i].v_str = str.c_str(); |
100 | continue; |
101 | } |
102 | int tcode = type_codes[i]; |
103 | switch (tcode) { |
104 | case kTVMDLTensorHandle: |
105 | case kTVMNDArrayHandle: { |
106 | // Pass NDArray as DLTensor, NDArray and DLTensor |
107 | // are compatible to each other, just need to change the index. |
108 | type_codes[i] = kTVMDLTensorHandle; |
109 | // translate to a remote view of DLTensor |
110 | auto dptr = std::make_unique<DLTensor>(*static_cast<DLTensor*>(values[i].v_handle)); |
111 | dptr->device = RemoveSessMask(dptr->device); |
112 | dptr->data = static_cast<RemoteSpace*>(dptr->data)->data; |
113 | values[i].v_handle = dptr.get(); |
114 | temp_dltensors.emplace_back(std::move(dptr)); |
115 | break; |
116 | } |
117 | case kDLDevice: { |
118 | values[i].v_device = RemoveSessMask(values[i].v_device); |
119 | break; |
120 | } |
121 | case kTVMPackedFuncHandle: |
122 | case kTVMModuleHandle: { |
123 | values[i].v_handle = UnwrapRemoteValueToHandle(TVMArgValue(values[i], tcode)); |
124 | break; |
125 | } |
126 | } |
127 | } |
128 | auto set_return = [this, rv](TVMArgs args) { this->WrapRemoteReturnToValue(args, rv); }; |
129 | sess_->CallFunc(handle_, values.data(), type_codes.data(), args.size(), set_return); |
130 | } |
131 | |
132 | ~RPCWrappedFunc() { |
133 | try { |
134 | sess_->FreeHandle(handle_, kTVMPackedFuncHandle); |
135 | } catch (const Error& e) { |
136 | // fault tolerance to remote close |
137 | } |
138 | } |
139 | |
140 | private: |
141 | // remote function handle |
142 | void* handle_{nullptr}; |
143 | // pointer to the session. |
144 | std::shared_ptr<RPCSession> sess_; |
145 | |
146 | // unwrap a remote value to the underlying handle. |
147 | void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const; |
148 | // wrap a remote return via Set |
149 | void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const; |
150 | |
151 | // remove a remote session mask |
152 | Device RemoveSessMask(Device dev) const { |
153 | ICHECK(IsRPCSessionDevice(dev)) << "Can not pass in local device" ; |
154 | ICHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index()) |
155 | << "Can not pass in device with a different remote session" ; |
156 | return RemoveRPCSessionMask(dev); |
157 | } |
158 | }; |
159 | |
160 | // RPC that represents a remote module session. |
161 | class RPCModuleNode final : public ModuleNode { |
162 | public: |
163 | RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess) |
164 | : module_handle_(module_handle), sess_(sess) {} |
165 | |
166 | ~RPCModuleNode() { |
167 | if (module_handle_ != nullptr) { |
168 | try { |
169 | sess_->FreeHandle(module_handle_, kTVMModuleHandle); |
170 | } catch (const Error& e) { |
171 | // fault tolerance to remote close |
172 | } |
173 | module_handle_ = nullptr; |
174 | } |
175 | } |
176 | |
177 | const char* type_key() const final { return "rpc" ; } |
178 | |
179 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
180 | if (name == "CloseRPCConnection" ) { |
181 | return PackedFunc([this](TVMArgs, TVMRetValue*) { sess_->Shutdown(); }); |
182 | } |
183 | |
184 | if (module_handle_ == nullptr) { |
185 | return WrapRemoteFunc(sess_->GetFunction(name)); |
186 | } else { |
187 | InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction" ); |
188 | return remote_mod_get_function_(GetRef<Module>(this), name, true); |
189 | } |
190 | } |
191 | |
192 | std::string GetSource(const std::string& format) final { |
193 | LOG(FATAL) << "GetSource for rpc Module is not supported" ; |
194 | } |
195 | |
196 | PackedFunc GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat, |
197 | int min_repeat_ms, int limit_zero_time_iterations, |
198 | int cooldown_interval_ms, int repeats_to_cooldown, |
199 | const std::string& f_preproc_name) { |
200 | InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator" ); |
201 | // Remove session mask because we pass dev by parts. |
202 | ICHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index()) |
203 | << "ValueError: Need to pass the matched remote device to RPCModule.GetTimeEvaluator" ; |
204 | dev = RemoveRPCSessionMask(dev); |
205 | |
206 | if (module_handle_ != nullptr) { |
207 | return remote_get_time_evaluator_(GetRef<Module>(this), name, |
208 | static_cast<int>(dev.device_type), dev.device_id, number, |
209 | repeat, min_repeat_ms, limit_zero_time_iterations, |
210 | cooldown_interval_ms, repeats_to_cooldown, f_preproc_name); |
211 | } else { |
212 | return remote_get_time_evaluator_(Optional<Module>(nullptr), name, |
213 | static_cast<int>(dev.device_type), dev.device_id, number, |
214 | repeat, min_repeat_ms, limit_zero_time_iterations, |
215 | cooldown_interval_ms, repeats_to_cooldown, f_preproc_name); |
216 | } |
217 | } |
218 | |
219 | Module LoadModule(std::string name) { |
220 | InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module" ); |
221 | return remote_load_module_(name); |
222 | } |
223 | |
224 | void ImportModule(Module other) { |
225 | InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule" ); |
226 | remote_import_module_(GetRef<Module>(this), other); |
227 | } |
228 | |
229 | const std::shared_ptr<RPCSession>& sess() { return sess_; } |
230 | |
231 | void* module_handle() const { return module_handle_; } |
232 | |
233 | private: |
234 | template <typename FType> |
235 | void InitRemoteFunc(FType* func, const std::string& name) { |
236 | if (*func != nullptr) return; |
237 | RPCSession::PackedFuncHandle handle = sess_->GetFunction(name); |
238 | ICHECK(handle != nullptr) << "Cannot found remote function " << name; |
239 | *func = WrapRemoteFunc(handle); |
240 | } |
241 | |
242 | PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { |
243 | if (handle == nullptr) return PackedFunc(); |
244 | auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_); |
245 | return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); |
246 | } |
247 | |
248 | // The module handle |
249 | void* module_handle_{nullptr}; |
250 | // The local channel |
251 | std::shared_ptr<RPCSession> sess_; |
252 | // remote function to get time evaluator |
253 | TypedPackedFunc<PackedFunc(Optional<Module>, std::string, int, int, int, int, int, int, int, int, |
254 | std::string)> |
255 | remote_get_time_evaluator_; |
256 | // remote function getter for modules. |
257 | TypedPackedFunc<PackedFunc(Module, std::string, bool)> remote_mod_get_function_; |
258 | // remote function getter for load module |
259 | TypedPackedFunc<Module(std::string)> remote_load_module_; |
260 | // remote function getter for load module |
261 | TypedPackedFunc<void(Module, Module)> remote_import_module_; |
262 | }; |
263 | |
264 | void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { |
265 | if (arg.type_code() == kTVMModuleHandle) { |
266 | Module mod = arg; |
267 | std::string tkey = mod->type_key(); |
268 | ICHECK_EQ(tkey, "rpc" ) << "ValueError: Cannot pass a non-RPC module to remote" ; |
269 | auto* rmod = static_cast<RPCModuleNode*>(mod.operator->()); |
270 | ICHECK(rmod->sess() == sess_) |
271 | << "ValueError: Cannot pass in module into a different remote session" ; |
272 | return rmod->module_handle(); |
273 | } else { |
274 | LOG(FATAL) << "ValueError: Cannot pass type " << runtime::ArgTypeCode2Str(arg.type_code()) |
275 | << " as an argument to the remote" ; |
276 | return nullptr; |
277 | } |
278 | } |
279 | |
280 | void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const { |
281 | int tcode = args[0]; |
282 | |
283 | if (tcode == kTVMNullptr) return; |
284 | if (tcode == kTVMPackedFuncHandle) { |
285 | ICHECK_EQ(args.size(), 2); |
286 | void* handle = args[1]; |
287 | auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_); |
288 | *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); |
289 | } else if (tcode == kTVMModuleHandle) { |
290 | ICHECK_EQ(args.size(), 2); |
291 | void* handle = args[1]; |
292 | auto n = make_object<RPCModuleNode>(handle, sess_); |
293 | *rv = Module(n); |
294 | } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) { |
295 | ICHECK_EQ(args.size(), 3); |
296 | DLTensor* tensor = args[1]; |
297 | void* nd_handle = args[2]; |
298 | *rv = NDArrayFromRemoteOpaqueHandle(sess_, tensor->data, tensor, |
299 | AddRPCSessionMask(tensor->device, sess_->table_index()), |
300 | nd_handle); |
301 | } else { |
302 | ICHECK_EQ(args.size(), 2); |
303 | *rv = args[1]; |
304 | } |
305 | } |
306 | |
307 | Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess) { |
308 | auto n = make_object<RPCModuleNode>(nullptr, sess); |
309 | RPCSession::InsertToSessionTable(sess); |
310 | return Module(n); |
311 | } |
312 | |
313 | std::shared_ptr<RPCSession> RPCModuleGetSession(Module mod) { |
314 | std::string tkey = mod->type_key(); |
315 | ICHECK_EQ(tkey, "rpc" ) << "ValueError: Cannot pass a non-RPC module to remote" ; |
316 | auto* rmod = static_cast<RPCModuleNode*>(mod.operator->()); |
317 | return rmod->sess(); |
318 | } |
319 | |
320 | /*! |
321 | * \brief Flush the cache. |
322 | * \param addr The address of data we want to flush |
323 | * \param len The length of data |
324 | */ |
325 | /* |
326 | * When we are in the tuning of TVM, we will make TVM occupy |
327 | * the cache fully and doesn't flush it during iteration. |
328 | * This has problems then in e2e testing, since arrays that |
329 | * we assume exist in cache (ie. weights) are evicted during e2e runs, |
330 | * which leads to lower performance. |
331 | */ |
332 | inline void CPUCacheFlushImpl(const char* addr, unsigned int len) { |
333 | #if (defined(_M_X64) || defined(__x86_64__) || defined(__aarch64__)) |
334 | |
335 | #if defined(__aarch64__) |
336 | size_t ctr_el0 = 0; |
337 | asm volatile("mrs %0, ctr_el0" : "=r" (ctr_el0)); |
338 | const size_t cache_line = 4 << ((ctr_el0 >> 16) & 15); |
339 | #else |
340 | const size_t cache_line = 64; |
341 | #endif |
342 | |
343 | if (addr == nullptr || len <= 0) { |
344 | return; |
345 | } |
346 | |
347 | for (uintptr_t uptr = (uintptr_t)addr & ~(cache_line - 1); uptr < (uintptr_t)addr + len; |
348 | uptr += cache_line) { |
349 | #if defined(__aarch64__) |
350 | asm volatile("dc civac, %0\n\t" : : "r" (reinterpret_cast<const void*>(uptr)) : "memory" ); |
351 | #else |
352 | _mm_clflush(reinterpret_cast<const void*>(uptr)); |
353 | #endif |
354 | } |
355 | |
356 | #if defined(__aarch64__) |
357 | asm volatile("dmb ishst" : : : "memory" ); |
358 | #endif |
359 | |
360 | #endif |
361 | } |
362 | |
363 | inline void CPUCacheFlush(int begin_index, const TVMArgs& args) { |
364 | for (int i = begin_index; i < args.size(); i++) { |
365 | CPUCacheFlushImpl(static_cast<char*>((args[i].operator DLTensor*()->data)), |
366 | GetDataSize(*(args[i].operator DLTensor*()))); |
367 | } |
368 | } |
369 | |
370 | TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator" ) |
371 | .set_body_typed([](Optional<Module> opt_mod, std::string name, int device_type, int device_id, |
372 | int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, |
373 | int cooldown_interval_ms, int repeats_to_cooldown, |
374 | std::string f_preproc_name) { |
375 | Device dev; |
376 | dev.device_type = static_cast<DLDeviceType>(device_type); |
377 | dev.device_id = device_id; |
378 | if (opt_mod.defined()) { |
379 | Module m = opt_mod.value(); |
380 | std::string tkey = m->type_key(); |
381 | if (tkey == "rpc" ) { |
382 | return static_cast<RPCModuleNode*>(m.operator->()) |
383 | ->GetTimeEvaluator(name, dev, number, repeat, min_repeat_ms, |
384 | limit_zero_time_iterations, cooldown_interval_ms, |
385 | repeats_to_cooldown, f_preproc_name); |
386 | } else { |
387 | PackedFunc f_preproc; |
388 | if (!f_preproc_name.empty()) { |
389 | auto* pf_preproc = runtime::Registry::Get(f_preproc_name); |
390 | ICHECK(pf_preproc != nullptr) |
391 | << "Cannot find " << f_preproc_name << " in the global function" ; |
392 | f_preproc = *pf_preproc; |
393 | } |
394 | PackedFunc pf = m.GetFunction(name, true); |
395 | CHECK(pf != nullptr) << "Cannot find " << name << " in the global registry" ; |
396 | return profiling::WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, |
397 | limit_zero_time_iterations, cooldown_interval_ms, |
398 | repeats_to_cooldown, f_preproc); |
399 | } |
400 | } else { |
401 | auto* pf = runtime::Registry::Get(name); |
402 | ICHECK(pf != nullptr) << "Cannot find " << name << " in the global function" ; |
403 | PackedFunc f_preproc; |
404 | if (!f_preproc_name.empty()) { |
405 | auto* pf_preproc = runtime::Registry::Get(f_preproc_name); |
406 | ICHECK(pf_preproc != nullptr) |
407 | << "Cannot find " << f_preproc_name << " in the global function" ; |
408 | f_preproc = *pf_preproc; |
409 | } |
410 | return profiling::WrapTimeEvaluator(*pf, dev, number, repeat, min_repeat_ms, |
411 | limit_zero_time_iterations, cooldown_interval_ms, |
412 | repeats_to_cooldown, f_preproc); |
413 | } |
414 | }); |
415 | |
416 | TVM_REGISTER_GLOBAL("cache_flush_cpu_non_first_arg" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
417 | CPUCacheFlush(1, args); |
418 | }); |
419 | |
420 | // server function registration. |
421 | TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule" ).set_body_typed([](Module parent, Module child) { |
422 | parent->Import(child); |
423 | }); |
424 | |
425 | TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction" ) |
426 | .set_body_typed([](Module parent, std::string name, bool query_imports) { |
427 | return parent->GetFunction(name, query_imports); |
428 | }); |
429 | |
430 | // functions to access an RPC module. |
431 | TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule" ).set_body_typed([](Module sess, std::string name) { |
432 | std::string tkey = sess->type_key(); |
433 | ICHECK_EQ(tkey, "rpc" ); |
434 | return static_cast<RPCModuleNode*>(sess.operator->())->LoadModule(name); |
435 | }); |
436 | |
437 | TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule" ).set_body_typed([](Module parent, Module child) { |
438 | std::string tkey = parent->type_key(); |
439 | ICHECK_EQ(tkey, "rpc" ); |
440 | static_cast<RPCModuleNode*>(parent.operator->())->ImportModule(child); |
441 | }); |
442 | |
443 | TVM_REGISTER_GLOBAL("rpc.SessTableIndex" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
444 | Module m = args[0]; |
445 | std::string tkey = m->type_key(); |
446 | ICHECK_EQ(tkey, "rpc" ); |
447 | *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index(); |
448 | }); |
449 | |
450 | TVM_REGISTER_GLOBAL("tvm.rpc.NDArrayFromRemoteOpaqueHandle" ) |
451 | .set_body_typed([](Module mod, void* remote_array, DLTensor* template_tensor, Device dev, |
452 | void* ndarray_handle) -> NDArray { |
453 | return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, template_tensor, |
454 | dev, ndarray_handle); |
455 | }); |
456 | |
457 | } // namespace runtime |
458 | } // namespace tvm |
459 | |