1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rpc_module.cc
22 * \brief RPC runtime module.
23 */
24#include <tvm/runtime/container/string.h>
25#include <tvm/runtime/device_api.h>
26#include <tvm/runtime/profiling.h>
27#include <tvm/runtime/registry.h>
28
29#include <chrono>
30#include <cstring>
31#include <memory>
32#include <thread>
33#if defined(_M_X64) || defined(__x86_64__)
34#include <immintrin.h>
35#endif
36
37#include "rpc_endpoint.h"
38#include "rpc_session.h"
39
40namespace tvm {
41namespace runtime {
42
43// deleter of RPC remote array
44static void RemoteNDArrayDeleter(Object* obj) {
45 auto* ptr = static_cast<NDArray::Container*>(obj);
46 RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
47 if (ptr->manager_ctx != nullptr) {
48 space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle);
49 }
50 delete space;
51 delete ptr;
52}
53
54/*!
55 * \brief Build a local NDArray with remote backing storage.
56 * \param sess the RPCSession which owns the given handle.
57 * \param handle A pointer valid on the remote end which should form the `data` field of the
58 * underlying DLTensor.
59 * \param template_tensor An empty DLTensor whose shape and dtype fields are used to fill the newly
60 * created array. Needed because it's difficult to pass a shape vector as a PackedFunc arg.
61 * \param dev Remote device used with this tensor. Must have non-zero RPCSessMask.
62 * \param remote_ndarray_handle The handle returned by RPC server to identify the NDArray.
63 */
64NDArray NDArrayFromRemoteOpaqueHandle(std::shared_ptr<RPCSession> sess, void* handle,
65 DLTensor* template_tensor, Device dev,
66 void* remote_ndarray_handle) {
67 ICHECK_EQ(sess->table_index(), GetRPCSessionIndex(dev))
68 << "The Device given does not belong to the given session";
69 RemoteSpace* space = new RemoteSpace();
70 space->sess = sess;
71 space->data = handle;
72 std::vector<int64_t> shape_vec{template_tensor->shape,
73 template_tensor->shape + template_tensor->ndim};
74 NDArray::Container* data = new NDArray::Container(static_cast<void*>(space), std::move(shape_vec),
75 template_tensor->dtype, dev);
76 data->manager_ctx = remote_ndarray_handle;
77 data->SetDeleter(RemoteNDArrayDeleter);
78 return NDArray(GetObjectPtr<Object>(data));
79}
80
81/*!
82 * \brief A wrapped remote function as a PackedFunc.
83 */
84class RPCWrappedFunc : public Object {
85 public:
86 RPCWrappedFunc(void* handle, std::shared_ptr<RPCSession> sess) : handle_(handle), sess_(sess) {}
87
88 void operator()(TVMArgs args, TVMRetValue* rv) const {
89 std::vector<TVMValue> values(args.values, args.values + args.size());
90 std::vector<int> type_codes(args.type_codes, args.type_codes + args.size());
91 std::vector<std::unique_ptr<DLTensor>> temp_dltensors;
92
93 // scan and check whether we need rewrite these arguments
94 // to their remote variant.
95 for (int i = 0; i < args.size(); ++i) {
96 if (args[i].IsObjectRef<String>()) {
97 String str = args[i];
98 type_codes[i] = kTVMStr;
99 values[i].v_str = str.c_str();
100 continue;
101 }
102 int tcode = type_codes[i];
103 switch (tcode) {
104 case kTVMDLTensorHandle:
105 case kTVMNDArrayHandle: {
106 // Pass NDArray as DLTensor, NDArray and DLTensor
107 // are compatible to each other, just need to change the index.
108 type_codes[i] = kTVMDLTensorHandle;
109 // translate to a remote view of DLTensor
110 auto dptr = std::make_unique<DLTensor>(*static_cast<DLTensor*>(values[i].v_handle));
111 dptr->device = RemoveSessMask(dptr->device);
112 dptr->data = static_cast<RemoteSpace*>(dptr->data)->data;
113 values[i].v_handle = dptr.get();
114 temp_dltensors.emplace_back(std::move(dptr));
115 break;
116 }
117 case kDLDevice: {
118 values[i].v_device = RemoveSessMask(values[i].v_device);
119 break;
120 }
121 case kTVMPackedFuncHandle:
122 case kTVMModuleHandle: {
123 values[i].v_handle = UnwrapRemoteValueToHandle(TVMArgValue(values[i], tcode));
124 break;
125 }
126 }
127 }
128 auto set_return = [this, rv](TVMArgs args) { this->WrapRemoteReturnToValue(args, rv); };
129 sess_->CallFunc(handle_, values.data(), type_codes.data(), args.size(), set_return);
130 }
131
132 ~RPCWrappedFunc() {
133 try {
134 sess_->FreeHandle(handle_, kTVMPackedFuncHandle);
135 } catch (const Error& e) {
136 // fault tolerance to remote close
137 }
138 }
139
140 private:
141 // remote function handle
142 void* handle_{nullptr};
143 // pointer to the session.
144 std::shared_ptr<RPCSession> sess_;
145
146 // unwrap a remote value to the underlying handle.
147 void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const;
148 // wrap a remote return via Set
149 void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const;
150
151 // remove a remote session mask
152 Device RemoveSessMask(Device dev) const {
153 ICHECK(IsRPCSessionDevice(dev)) << "Can not pass in local device";
154 ICHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index())
155 << "Can not pass in device with a different remote session";
156 return RemoveRPCSessionMask(dev);
157 }
158};
159
160// RPC that represents a remote module session.
161class RPCModuleNode final : public ModuleNode {
162 public:
163 RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
164 : module_handle_(module_handle), sess_(sess) {}
165
166 ~RPCModuleNode() {
167 if (module_handle_ != nullptr) {
168 try {
169 sess_->FreeHandle(module_handle_, kTVMModuleHandle);
170 } catch (const Error& e) {
171 // fault tolerance to remote close
172 }
173 module_handle_ = nullptr;
174 }
175 }
176
177 const char* type_key() const final { return "rpc"; }
178
179 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
180 if (name == "CloseRPCConnection") {
181 return PackedFunc([this](TVMArgs, TVMRetValue*) { sess_->Shutdown(); });
182 }
183
184 if (module_handle_ == nullptr) {
185 return WrapRemoteFunc(sess_->GetFunction(name));
186 } else {
187 InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction");
188 return remote_mod_get_function_(GetRef<Module>(this), name, true);
189 }
190 }
191
192 std::string GetSource(const std::string& format) final {
193 LOG(FATAL) << "GetSource for rpc Module is not supported";
194 }
195
196 PackedFunc GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat,
197 int min_repeat_ms, int limit_zero_time_iterations,
198 int cooldown_interval_ms, int repeats_to_cooldown,
199 const std::string& f_preproc_name) {
200 InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator");
201 // Remove session mask because we pass dev by parts.
202 ICHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index())
203 << "ValueError: Need to pass the matched remote device to RPCModule.GetTimeEvaluator";
204 dev = RemoveRPCSessionMask(dev);
205
206 if (module_handle_ != nullptr) {
207 return remote_get_time_evaluator_(GetRef<Module>(this), name,
208 static_cast<int>(dev.device_type), dev.device_id, number,
209 repeat, min_repeat_ms, limit_zero_time_iterations,
210 cooldown_interval_ms, repeats_to_cooldown, f_preproc_name);
211 } else {
212 return remote_get_time_evaluator_(Optional<Module>(nullptr), name,
213 static_cast<int>(dev.device_type), dev.device_id, number,
214 repeat, min_repeat_ms, limit_zero_time_iterations,
215 cooldown_interval_ms, repeats_to_cooldown, f_preproc_name);
216 }
217 }
218
219 Module LoadModule(std::string name) {
220 InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module");
221 return remote_load_module_(name);
222 }
223
224 void ImportModule(Module other) {
225 InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule");
226 remote_import_module_(GetRef<Module>(this), other);
227 }
228
229 const std::shared_ptr<RPCSession>& sess() { return sess_; }
230
231 void* module_handle() const { return module_handle_; }
232
233 private:
234 template <typename FType>
235 void InitRemoteFunc(FType* func, const std::string& name) {
236 if (*func != nullptr) return;
237 RPCSession::PackedFuncHandle handle = sess_->GetFunction(name);
238 ICHECK(handle != nullptr) << "Cannot found remote function " << name;
239 *func = WrapRemoteFunc(handle);
240 }
241
242 PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) {
243 if (handle == nullptr) return PackedFunc();
244 auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
245 return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); });
246 }
247
248 // The module handle
249 void* module_handle_{nullptr};
250 // The local channel
251 std::shared_ptr<RPCSession> sess_;
252 // remote function to get time evaluator
253 TypedPackedFunc<PackedFunc(Optional<Module>, std::string, int, int, int, int, int, int, int, int,
254 std::string)>
255 remote_get_time_evaluator_;
256 // remote function getter for modules.
257 TypedPackedFunc<PackedFunc(Module, std::string, bool)> remote_mod_get_function_;
258 // remote function getter for load module
259 TypedPackedFunc<Module(std::string)> remote_load_module_;
260 // remote function getter for load module
261 TypedPackedFunc<void(Module, Module)> remote_import_module_;
262};
263
264void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const {
265 if (arg.type_code() == kTVMModuleHandle) {
266 Module mod = arg;
267 std::string tkey = mod->type_key();
268 ICHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote";
269 auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
270 ICHECK(rmod->sess() == sess_)
271 << "ValueError: Cannot pass in module into a different remote session";
272 return rmod->module_handle();
273 } else {
274 LOG(FATAL) << "ValueError: Cannot pass type " << runtime::ArgTypeCode2Str(arg.type_code())
275 << " as an argument to the remote";
276 return nullptr;
277 }
278}
279
280void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const {
281 int tcode = args[0];
282
283 if (tcode == kTVMNullptr) return;
284 if (tcode == kTVMPackedFuncHandle) {
285 ICHECK_EQ(args.size(), 2);
286 void* handle = args[1];
287 auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
288 *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); });
289 } else if (tcode == kTVMModuleHandle) {
290 ICHECK_EQ(args.size(), 2);
291 void* handle = args[1];
292 auto n = make_object<RPCModuleNode>(handle, sess_);
293 *rv = Module(n);
294 } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) {
295 ICHECK_EQ(args.size(), 3);
296 DLTensor* tensor = args[1];
297 void* nd_handle = args[2];
298 *rv = NDArrayFromRemoteOpaqueHandle(sess_, tensor->data, tensor,
299 AddRPCSessionMask(tensor->device, sess_->table_index()),
300 nd_handle);
301 } else {
302 ICHECK_EQ(args.size(), 2);
303 *rv = args[1];
304 }
305}
306
307Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess) {
308 auto n = make_object<RPCModuleNode>(nullptr, sess);
309 RPCSession::InsertToSessionTable(sess);
310 return Module(n);
311}
312
313std::shared_ptr<RPCSession> RPCModuleGetSession(Module mod) {
314 std::string tkey = mod->type_key();
315 ICHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote";
316 auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
317 return rmod->sess();
318}
319
320/*!
321 * \brief Flush the cache.
322 * \param addr The address of data we want to flush
323 * \param len The length of data
324 */
325/*
326 * When we are in the tuning of TVM, we will make TVM occupy
327 * the cache fully and doesn't flush it during iteration.
328 * This has problems then in e2e testing, since arrays that
329 * we assume exist in cache (ie. weights) are evicted during e2e runs,
330 * which leads to lower performance.
331 */
332inline void CPUCacheFlushImpl(const char* addr, unsigned int len) {
333#if (defined(_M_X64) || defined(__x86_64__) || defined(__aarch64__))
334
335#if defined(__aarch64__)
336 size_t ctr_el0 = 0;
337 asm volatile("mrs %0, ctr_el0" : "=r"(ctr_el0));
338 const size_t cache_line = 4 << ((ctr_el0 >> 16) & 15);
339#else
340 const size_t cache_line = 64;
341#endif
342
343 if (addr == nullptr || len <= 0) {
344 return;
345 }
346
347 for (uintptr_t uptr = (uintptr_t)addr & ~(cache_line - 1); uptr < (uintptr_t)addr + len;
348 uptr += cache_line) {
349#if defined(__aarch64__)
350 asm volatile("dc civac, %0\n\t" : : "r"(reinterpret_cast<const void*>(uptr)) : "memory");
351#else
352 _mm_clflush(reinterpret_cast<const void*>(uptr));
353#endif
354 }
355
356#if defined(__aarch64__)
357 asm volatile("dmb ishst" : : : "memory");
358#endif
359
360#endif
361}
362
363inline void CPUCacheFlush(int begin_index, const TVMArgs& args) {
364 for (int i = begin_index; i < args.size(); i++) {
365 CPUCacheFlushImpl(static_cast<char*>((args[i].operator DLTensor*()->data)),
366 GetDataSize(*(args[i].operator DLTensor*())));
367 }
368}
369
370TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator")
371 .set_body_typed([](Optional<Module> opt_mod, std::string name, int device_type, int device_id,
372 int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations,
373 int cooldown_interval_ms, int repeats_to_cooldown,
374 std::string f_preproc_name) {
375 Device dev;
376 dev.device_type = static_cast<DLDeviceType>(device_type);
377 dev.device_id = device_id;
378 if (opt_mod.defined()) {
379 Module m = opt_mod.value();
380 std::string tkey = m->type_key();
381 if (tkey == "rpc") {
382 return static_cast<RPCModuleNode*>(m.operator->())
383 ->GetTimeEvaluator(name, dev, number, repeat, min_repeat_ms,
384 limit_zero_time_iterations, cooldown_interval_ms,
385 repeats_to_cooldown, f_preproc_name);
386 } else {
387 PackedFunc f_preproc;
388 if (!f_preproc_name.empty()) {
389 auto* pf_preproc = runtime::Registry::Get(f_preproc_name);
390 ICHECK(pf_preproc != nullptr)
391 << "Cannot find " << f_preproc_name << " in the global function";
392 f_preproc = *pf_preproc;
393 }
394 PackedFunc pf = m.GetFunction(name, true);
395 CHECK(pf != nullptr) << "Cannot find " << name << " in the global registry";
396 return profiling::WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms,
397 limit_zero_time_iterations, cooldown_interval_ms,
398 repeats_to_cooldown, f_preproc);
399 }
400 } else {
401 auto* pf = runtime::Registry::Get(name);
402 ICHECK(pf != nullptr) << "Cannot find " << name << " in the global function";
403 PackedFunc f_preproc;
404 if (!f_preproc_name.empty()) {
405 auto* pf_preproc = runtime::Registry::Get(f_preproc_name);
406 ICHECK(pf_preproc != nullptr)
407 << "Cannot find " << f_preproc_name << " in the global function";
408 f_preproc = *pf_preproc;
409 }
410 return profiling::WrapTimeEvaluator(*pf, dev, number, repeat, min_repeat_ms,
411 limit_zero_time_iterations, cooldown_interval_ms,
412 repeats_to_cooldown, f_preproc);
413 }
414 });
415
416TVM_REGISTER_GLOBAL("cache_flush_cpu_non_first_arg").set_body([](TVMArgs args, TVMRetValue* rv) {
417 CPUCacheFlush(1, args);
418});
419
420// server function registration.
421TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule").set_body_typed([](Module parent, Module child) {
422 parent->Import(child);
423});
424
425TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction")
426 .set_body_typed([](Module parent, std::string name, bool query_imports) {
427 return parent->GetFunction(name, query_imports);
428 });
429
430// functions to access an RPC module.
431TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) {
432 std::string tkey = sess->type_key();
433 ICHECK_EQ(tkey, "rpc");
434 return static_cast<RPCModuleNode*>(sess.operator->())->LoadModule(name);
435});
436
437TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) {
438 std::string tkey = parent->type_key();
439 ICHECK_EQ(tkey, "rpc");
440 static_cast<RPCModuleNode*>(parent.operator->())->ImportModule(child);
441});
442
443TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body([](TVMArgs args, TVMRetValue* rv) {
444 Module m = args[0];
445 std::string tkey = m->type_key();
446 ICHECK_EQ(tkey, "rpc");
447 *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
448});
449
450TVM_REGISTER_GLOBAL("tvm.rpc.NDArrayFromRemoteOpaqueHandle")
451 .set_body_typed([](Module mod, void* remote_array, DLTensor* template_tensor, Device dev,
452 void* ndarray_handle) -> NDArray {
453 return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, template_tensor,
454 dev, ndarray_handle);
455 });
456
457} // namespace runtime
458} // namespace tvm
459