1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file c_runtime_api.cc
22 * \brief Device specific implementations
23 */
24#include <dmlc/thread_local.h>
25#include <tvm/runtime/c_backend_api.h>
26#include <tvm/runtime/c_runtime_api.h>
27#include <tvm/runtime/device_api.h>
28#include <tvm/runtime/module.h>
29#include <tvm/runtime/packed_func.h>
30#include <tvm/runtime/registry.h>
31
32#include <algorithm>
33#include <array>
34#include <cctype>
35#include <cstdlib>
36#include <sstream>
37#include <string>
38
39#include "object_internal.h"
40#include "runtime_base.h"
41
42namespace tvm {
43namespace runtime {
44
45std::string GetCustomTypeName(uint8_t type_code) {
46 auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_name");
47 ICHECK(f) << "Function runtime._datatype_get_type_name not found";
48 return (*f)(type_code).operator std::string();
49}
50
51uint8_t GetCustomTypeCode(const std::string& type_name) {
52 auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_code");
53 ICHECK(f) << "Function runtime._datatype_get_type_code not found";
54 return (*f)(type_name).operator int();
55}
56
57bool GetCustomTypeRegistered(uint8_t type_code) {
58 auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_registered");
59 ICHECK(f) << "Function runtime._datatype_get_type_registered not found";
60 return (*f)(type_code).operator bool();
61}
62
63uint8_t ParseCustomDatatype(const std::string& s, const char** scan) {
64 ICHECK(s.substr(0, 6) == "custom") << "Not a valid custom datatype string";
65
66 auto tmp = s.c_str();
67
68 ICHECK(s.c_str() == tmp);
69 *scan = s.c_str() + 6;
70 ICHECK(s.c_str() == tmp);
71 if (**scan != '[') LOG(FATAL) << "expected opening brace after 'custom' type in" << s;
72 ICHECK(s.c_str() == tmp);
73 *scan += 1;
74 ICHECK(s.c_str() == tmp);
75 size_t custom_name_len = 0;
76 ICHECK(s.c_str() == tmp);
77 while (*scan + custom_name_len <= s.c_str() + s.length() && *(*scan + custom_name_len) != ']')
78 ++custom_name_len;
79 ICHECK(s.c_str() == tmp);
80 if (*(*scan + custom_name_len) != ']')
81 LOG(FATAL) << "expected closing brace after 'custom' type in" << s;
82 ICHECK(s.c_str() == tmp);
83 *scan += custom_name_len + 1;
84 ICHECK(s.c_str() == tmp);
85
86 auto type_name = s.substr(7, custom_name_len);
87 ICHECK(s.c_str() == tmp);
88 return GetCustomTypeCode(type_name);
89}
90
91class DeviceAPIManager {
92 public:
93 static const int kMaxDeviceAPI = 32;
94 // Get API
95 static DeviceAPI* Get(const Device& dev) { return Get(dev.device_type); }
96 static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
97 return Global()->GetAPI(dev_type, allow_missing);
98 }
99
100 private:
101 std::array<DeviceAPI*, kMaxDeviceAPI> api_;
102 DeviceAPI* rpc_api_{nullptr};
103 std::mutex mutex_;
104 // constructor
105 DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
106 // Global static variable.
107 static DeviceAPIManager* Global() {
108 static DeviceAPIManager* inst = new DeviceAPIManager();
109 return inst;
110 }
111 // Get or initialize API.
112 DeviceAPI* GetAPI(int type, bool allow_missing) {
113 if (type < kRPCSessMask) {
114 if (api_[type] != nullptr) return api_[type];
115 std::lock_guard<std::mutex> lock(mutex_);
116 if (api_[type] != nullptr) return api_[type];
117 api_[type] = GetAPI(DeviceName(type), allow_missing);
118 return api_[type];
119 } else {
120 if (rpc_api_ != nullptr) return rpc_api_;
121 std::lock_guard<std::mutex> lock(mutex_);
122 if (rpc_api_ != nullptr) return rpc_api_;
123 rpc_api_ = GetAPI("rpc", allow_missing);
124 return rpc_api_;
125 }
126 }
127 DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
128 std::string factory = "device_api." + name;
129 auto* f = Registry::Get(factory);
130 if (f == nullptr) {
131 ICHECK(allow_missing) << "Device API " << name << " is not enabled.";
132 return nullptr;
133 }
134 void* ptr = (*f)();
135 return static_cast<DeviceAPI*>(ptr);
136 }
137};
138
139DeviceAPI* DeviceAPI::Get(Device dev, bool allow_missing) {
140 return DeviceAPIManager::Get(static_cast<int>(dev.device_type), allow_missing);
141}
142
143void* DeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
144 return AllocDataSpace(dev, size, kTempAllocaAlignment, type_hint);
145}
146
147static size_t GetDataAlignment(const DLDataType dtype) {
148 size_t align = (dtype.bits / 8) * dtype.lanes;
149 if (align < kAllocAlignment) return kAllocAlignment;
150 return align;
151}
152
153void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
154 Optional<String> mem_scope) {
155 if (!mem_scope.defined() || mem_scope.value() == "global") {
156 // by default, we can always redirect to the flat memory allocations
157 DLTensor temp;
158 temp.data = nullptr;
159 temp.device = dev;
160 temp.ndim = ndim;
161 temp.dtype = dtype;
162 temp.shape = const_cast<int64_t*>(shape);
163 temp.strides = nullptr;
164 temp.byte_offset = 0;
165 size_t size = GetDataSize(temp);
166 size_t alignment = GetDataAlignment(temp.dtype);
167 return AllocDataSpace(dev, size, alignment, dtype);
168 }
169 LOG(FATAL) << "Device does not support allocate data space with "
170 << "specified memory scope: " << mem_scope.value();
171 return nullptr;
172}
173
174void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
175 // by default, we can always redirect to the flat memory copy operation.
176 size_t nbytes = GetDataSize(*from);
177 ICHECK_EQ(nbytes, GetDataSize(*to));
178
179 ICHECK(IsContiguous(*from) && IsContiguous(*to))
180 << "CopyDataFromTo only support contiguous array for now";
181 CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device,
182 to->device, from->dtype, stream);
183}
184
185void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
186 size_t num_bytes, Device dev_from, Device dev_to,
187 DLDataType type_hint, TVMStreamHandle stream) {
188 LOG(FATAL) << "Device does not support CopyDataFromTo.";
189}
190
191void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); }
192
193TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; }
194
195void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}
196
197void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
198}
199
200//--------------------------------------------------------
201// Error handling mechanism
202// -------------------------------------------------------
203// Standard error message format, {} means optional
204//--------------------------------------------------------
205// {error_type:} {message0}
206// {message1}
207// {message2}
208// {Stack trace:} // stack traces follow by this line
209// {trace 0} // two spaces in the beginning.
210// {trace 1}
211// {trace 2}
212//--------------------------------------------------------
213/*!
214 * \brief Normalize error message
215 *
216 * Parse them header generated by by LOG(FATAL) and ICHECK
217 * and reformat the message into the standard format.
218 *
219 * This function will also merge all the stack traces into
220 * one trace and trim them.
221 *
222 * \param err_msg The error message.
223 * \return normalized message.
224 */
225std::string NormalizeError(std::string err_msg) {
226 // ------------------------------------------------------------------------
227 // log with header, {} indicates optional
228 //-------------------------------------------------------------------------
229 // [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0}
230 // {message1}
231 // Stack trace:
232 // {stack trace 0}
233 // {stack trace 1}
234 //-------------------------------------------------------------------------
235 // Normalzied version
236 //-------------------------------------------------------------------------
237 // error_type: check_msg message0
238 // {message1}
239 // Stack trace:
240 // File file_name, line lineno
241 // {stack trace 0}
242 // {stack trace 1}
243 //-------------------------------------------------------------------------
244 int line_number = 0;
245 std::istringstream is(err_msg);
246 std::string line, file_name, error_type, check_msg;
247
248 // Parse log header and set the fields,
249 // Return true if it the log is in correct format,
250 // return false if something is wrong.
251 auto parse_log_header = [&]() {
252 // skip timestamp
253 if (is.peek() != '[') {
254 getline(is, line);
255 return true;
256 }
257 if (!(is >> line)) return false;
258 // get filename
259 while (is.peek() == ' ') is.get();
260#ifdef _MSC_VER // handle volume separator ":" in Windows path
261 std::string drive;
262 if (!getline(is, drive, ':')) return false;
263 if (!getline(is, file_name, ':')) return false;
264 file_name = drive + ":" + file_name;
265#else
266 if (!getline(is, file_name, ':')) return false;
267#endif
268 // get line number
269 if (!(is >> line_number)) return false;
270 // get rest of the message.
271 while (is.peek() == ' ' || is.peek() == ':') is.get();
272 if (!getline(is, line)) return false;
273 // detect check message, rewrite to remote extra :
274 if (line.compare(0, 13, "Check failed:") == 0) {
275 std::string ending = ": ";
276 size_t end_pos = line.find(ending, 13);
277 if (end_pos == std::string::npos) return false;
278 check_msg = line.substr(0, end_pos + ending.size());
279 line = line.substr(end_pos + ending.size());
280 }
281 return true;
282 };
283 // if not in correct format, do not do any rewrite.
284 if (!parse_log_header()) return err_msg;
285 // Parse error type.
286 {
287 size_t start_pos = 0, end_pos;
288 for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {
289 }
290 for (end_pos = start_pos; end_pos < line.length(); ++end_pos) {
291 char ch = line[end_pos];
292 if (ch == ':') {
293 error_type = line.substr(start_pos, end_pos - start_pos);
294 break;
295 }
296 // [A-Z0-9a-z_.]
297 if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') break;
298 }
299 if (error_type.length() != 0) {
300 // if we successfully detected error_type: trim the following space.
301 for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' ';
302 ++start_pos) {
303 }
304 line = line.substr(start_pos);
305 } else {
306 // did not detect error_type, use default value.
307 line = line.substr(start_pos);
308 error_type = "TVMError";
309 }
310 }
311 // Separate out stack trace.
312 std::ostringstream os;
313 os << error_type << ": " << check_msg << line << '\n';
314
315 bool trace_mode = true;
316 std::vector<std::string> stack_trace;
317 while (getline(is, line)) {
318 if (trace_mode) {
319 if (line.compare(0, 2, " ") == 0) {
320 stack_trace.push_back(line);
321 } else {
322 trace_mode = false;
323 // remove EOL trailing stacktrace.
324 if (line.length() == 0) continue;
325 }
326 }
327 if (!trace_mode) {
328 if (line.compare(0, 11, "Stack trace") == 0) {
329 trace_mode = true;
330 } else {
331 os << line << '\n';
332 }
333 }
334 }
335 if (stack_trace.size() != 0 || file_name.length() != 0) {
336 os << "Stack trace:\n";
337 if (file_name.length() != 0) {
338 os << " File \"" << file_name << "\", line " << line_number << "\n";
339 }
340 // Print out stack traces, optionally trim the c++ traces
341 // about the frontends (as they will be provided by the frontends).
342 bool ffi_boundary = false;
343 for (const auto& line : stack_trace) {
344 // Heuristic to detect python ffi.
345 if (line.find("libffi.so") != std::string::npos ||
346 line.find("core.cpython") != std::string::npos) {
347 ffi_boundary = true;
348 }
349 // If the backtrace is not c++ backtrace with the prefix " [bt]",
350 // then we can stop trimming.
351 if (ffi_boundary && line.compare(0, 6, " [bt]") != 0) {
352 ffi_boundary = false;
353 }
354 if (!ffi_boundary) {
355 os << line << '\n';
356 }
357 // The line after TVMFuncCall cound be in FFI.
358 if (line.find("(TVMFuncCall") != std::string::npos) {
359 ffi_boundary = true;
360 }
361 }
362 }
363 return os.str();
364}
365
366} // namespace runtime
367} // namespace tvm
368
369using namespace tvm::runtime;
370
371struct TVMRuntimeEntry {
372 std::string ret_str;
373 std::string last_error;
374 TVMByteArray ret_bytes;
375};
376
377typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
378
379const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); }
380
381int TVMAPIHandleException(const std::exception& e) {
382 TVMAPISetLastError(NormalizeError(e.what()).c_str());
383 return -1;
384}
385
386void TVMAPISetLastError(const char* msg) { TVMAPIRuntimeStore::Get()->last_error = msg; }
387
388int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) {
389 API_BEGIN();
390 TVMRetValue ret;
391 ret = Module::LoadFromFile(file_name, format);
392 TVMValue val;
393 int type_code;
394 ret.MoveToCHost(&val, &type_code);
395 *out = val.v_handle;
396 API_END();
397}
398
399int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) {
400 API_BEGIN();
401 ObjectInternal::GetModuleNode(mod)->Import(GetRef<Module>(ObjectInternal::GetModuleNode(dep)));
402 API_END();
403}
404
405int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports,
406 TVMFunctionHandle* func) {
407 API_BEGIN();
408 PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0);
409 if (pf != nullptr) {
410 tvm::runtime::TVMRetValue ret;
411 ret = pf;
412 TVMValue val;
413 int type_code;
414 ret.MoveToCHost(&val, &type_code);
415 *func = val.v_handle;
416 } else {
417 *func = nullptr;
418 }
419 API_END();
420}
421
422int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); }
423
424int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) {
425 API_BEGIN();
426 *func = (TVMFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name))->get();
427 API_END();
428}
429
430void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint,
431 int dtype_bits_hint) {
432 DLDevice dev;
433 dev.device_type = static_cast<DLDeviceType>(device_type);
434 dev.device_id = device_id;
435
436 DLDataType type_hint;
437 type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
438 type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
439 type_hint.lanes = 1;
440
441 return DeviceAPIManager::Get(dev)->AllocWorkspace(dev, static_cast<size_t>(size), type_hint);
442}
443
444int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
445 DLDevice dev;
446 dev.device_type = static_cast<DLDeviceType>(device_type);
447 dev.device_id = device_id;
448 DeviceAPIManager::Get(dev)->FreeWorkspace(dev, ptr);
449 return 0;
450}
451
452int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
453 if (*handle == nullptr) {
454 *handle = reinterpret_cast<void*>(1);
455 return (*f)(cdata);
456 }
457 return 0;
458}
459
460int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); }
461
462int TVMByteArrayFree(TVMByteArray* arr) {
463 if (arr == &TVMAPIRuntimeStore::Get()->ret_bytes) {
464 return 0; // Thread-local storage does not need explicit deleting.
465 }
466
467 delete arr;
468 return 0;
469}
470
471int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,
472 TVMValue* ret_val, int* ret_type_code) {
473 API_BEGIN();
474
475 TVMRetValue rv;
476 (static_cast<const PackedFuncObj*>(func))
477 ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);
478 // handle return string.
479 if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) {
480 TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
481 if (rv.type_code() != kTVMDataType) {
482 e->ret_str = *rv.ptr<std::string>();
483 } else {
484 e->ret_str = rv.operator std::string();
485 }
486 if (rv.type_code() == kTVMBytes) {
487 e->ret_bytes.data = e->ret_str.c_str();
488 e->ret_bytes.size = e->ret_str.length();
489 *ret_type_code = kTVMBytes;
490 ret_val->v_handle = &(e->ret_bytes);
491 } else {
492 *ret_type_code = kTVMStr;
493 ret_val->v_str = e->ret_str.c_str();
494 }
495 } else {
496 rv.MoveToCHost(ret_val, ret_type_code);
497 }
498 API_END();
499}
500
501int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) {
502 API_BEGIN();
503 ICHECK_EQ(num_ret, 1);
504 TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
505 *rv = TVMArgValue(value[0], type_code[0]);
506 API_END();
507}
508
509int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin,
510 TVMFunctionHandle* out) {
511 API_BEGIN();
512 if (fin == nullptr) {
513 tvm::runtime::TVMRetValue ret;
514 ret = PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) {
515 int ret = func(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
516 args.num_args, rv, resource_handle);
517 if (ret != 0) {
518 throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace());
519 }
520 });
521 TVMValue val;
522 int type_code;
523 ret.MoveToCHost(&val, &type_code);
524 *out = val.v_handle;
525 } else {
526 // wrap it in a shared_ptr, with fin as deleter.
527 // so fin will be called when the lambda went out of scope.
528 std::shared_ptr<void> rpack(resource_handle, fin);
529 tvm::runtime::TVMRetValue ret;
530 ret = PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) {
531 int ret = func(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
532 args.num_args, rv, rpack.get());
533 if (ret != 0) {
534 throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace());
535 }
536 });
537 TVMValue val;
538 int type_code;
539 ret.MoveToCHost(&val, &type_code);
540 *out = val.v_handle;
541 }
542 API_END();
543}
544
545int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
546 API_BEGIN();
547 DLDevice dev;
548 dev.device_type = static_cast<DLDeviceType>(device_type);
549 dev.device_id = device_id;
550 *out = DeviceAPIManager::Get(dev)->CreateStream(dev);
551 API_END();
552}
553
554int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) {
555 API_BEGIN();
556 DLDevice dev;
557 dev.device_type = static_cast<DLDeviceType>(device_type);
558 dev.device_id = device_id;
559 DeviceAPIManager::Get(dev)->FreeStream(dev, stream);
560 API_END();
561}
562
563int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
564 API_BEGIN();
565 DLDevice dev;
566 dev.device_type = static_cast<DLDeviceType>(device_type);
567 dev.device_id = device_id;
568 DeviceAPIManager::Get(dev)->SetStream(dev, stream);
569 API_END();
570}
571
572int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
573 API_BEGIN();
574 DLDevice dev;
575 dev.device_type = static_cast<DLDeviceType>(device_type);
576 dev.device_id = device_id;
577 DeviceAPIManager::Get(dev)->StreamSync(dev, stream);
578 API_END();
579}
580
581int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src,
582 TVMStreamHandle dst) {
583 API_BEGIN();
584 DLDevice dev;
585 dev.device_type = static_cast<DLDeviceType>(device_type);
586 dev.device_id = device_id;
587 DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, src, dst);
588 API_END();
589}
590
591int TVMCbArgToReturn(TVMValue* value, int* code) {
592 API_BEGIN();
593 tvm::runtime::TVMRetValue rv;
594 rv = tvm::runtime::TVMMovableArgValue_(*value, *code);
595 rv.MoveToCHost(value, code);
596 API_END();
597}
598
599int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, DLDataType type_hint,
600 void** out_data) {
601 API_BEGIN();
602 out_data[0] = DeviceAPIManager::Get(dev)->AllocDataSpace(dev, nbytes, alignment, type_hint);
603 API_END();
604}
605
606int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shape, DLDataType dtype,
607 const char* mem_scope, void** out_data) {
608 API_BEGIN();
609 Optional<String> scope;
610 if (mem_scope != nullptr) {
611 scope = String(std::string(mem_scope));
612 }
613 out_data[0] = DeviceAPIManager::Get(dev)->AllocDataSpace(dev, ndim, shape, dtype, scope);
614 API_END();
615}
616
617int TVMDeviceFreeDataSpace(DLDevice dev, void* ptr) {
618 API_BEGIN();
619 DeviceAPIManager::Get(dev)->FreeDataSpace(dev, ptr);
620 API_END();
621}
622
623int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
624 API_BEGIN();
625 DLDevice dev_from = from->device;
626 DLDevice dev_to = to->device;
627 DLDevice dev = dev_from.device_type != kDLCPU ? dev_from : dev_to;
628 DeviceAPIManager::Get(dev)->CopyDataFromTo(from, to, stream);
629 API_END();
630}
631
632// set device api
633TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
634 .set_body([](TVMArgs args, TVMRetValue* ret) {
635 DLDevice dev;
636 dev.device_type = static_cast<DLDeviceType>(args[0].operator int());
637 dev.device_id = args[1];
638 DeviceAPIManager::Get(dev)->SetDevice(dev);
639 });
640
641// set device api
642TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr").set_body([](TVMArgs args, TVMRetValue* ret) {
643 DLDevice dev;
644 dev.device_type = static_cast<DLDeviceType>(args[0].operator int());
645 dev.device_id = args[1];
646
647 DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
648 if (kind == kExist) {
649 DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true);
650 if (api != nullptr) {
651 api->GetAttr(dev, kind, ret);
652 } else {
653 *ret = 0;
654 }
655 } else {
656 DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret);
657 }
658});
659
660TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream);
661