1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file c_runtime_api.cc |
22 | * \brief Device specific implementations |
23 | */ |
24 | #include <dmlc/thread_local.h> |
25 | #include <tvm/runtime/c_backend_api.h> |
26 | #include <tvm/runtime/c_runtime_api.h> |
27 | #include <tvm/runtime/device_api.h> |
28 | #include <tvm/runtime/module.h> |
29 | #include <tvm/runtime/packed_func.h> |
30 | #include <tvm/runtime/registry.h> |
31 | |
32 | #include <algorithm> |
33 | #include <array> |
34 | #include <cctype> |
35 | #include <cstdlib> |
36 | #include <sstream> |
37 | #include <string> |
38 | |
39 | #include "object_internal.h" |
40 | #include "runtime_base.h" |
41 | |
42 | namespace tvm { |
43 | namespace runtime { |
44 | |
45 | std::string GetCustomTypeName(uint8_t type_code) { |
46 | auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_name" ); |
47 | ICHECK(f) << "Function runtime._datatype_get_type_name not found" ; |
48 | return (*f)(type_code).operator std::string(); |
49 | } |
50 | |
51 | uint8_t GetCustomTypeCode(const std::string& type_name) { |
52 | auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_code" ); |
53 | ICHECK(f) << "Function runtime._datatype_get_type_code not found" ; |
54 | return (*f)(type_name).operator int(); |
55 | } |
56 | |
57 | bool GetCustomTypeRegistered(uint8_t type_code) { |
58 | auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_registered" ); |
59 | ICHECK(f) << "Function runtime._datatype_get_type_registered not found" ; |
60 | return (*f)(type_code).operator bool(); |
61 | } |
62 | |
63 | uint8_t ParseCustomDatatype(const std::string& s, const char** scan) { |
64 | ICHECK(s.substr(0, 6) == "custom" ) << "Not a valid custom datatype string" ; |
65 | |
66 | auto tmp = s.c_str(); |
67 | |
68 | ICHECK(s.c_str() == tmp); |
69 | *scan = s.c_str() + 6; |
70 | ICHECK(s.c_str() == tmp); |
71 | if (**scan != '[') LOG(FATAL) << "expected opening brace after 'custom' type in" << s; |
72 | ICHECK(s.c_str() == tmp); |
73 | *scan += 1; |
74 | ICHECK(s.c_str() == tmp); |
75 | size_t custom_name_len = 0; |
76 | ICHECK(s.c_str() == tmp); |
77 | while (*scan + custom_name_len <= s.c_str() + s.length() && *(*scan + custom_name_len) != ']') |
78 | ++custom_name_len; |
79 | ICHECK(s.c_str() == tmp); |
80 | if (*(*scan + custom_name_len) != ']') |
81 | LOG(FATAL) << "expected closing brace after 'custom' type in" << s; |
82 | ICHECK(s.c_str() == tmp); |
83 | *scan += custom_name_len + 1; |
84 | ICHECK(s.c_str() == tmp); |
85 | |
86 | auto type_name = s.substr(7, custom_name_len); |
87 | ICHECK(s.c_str() == tmp); |
88 | return GetCustomTypeCode(type_name); |
89 | } |
90 | |
91 | class DeviceAPIManager { |
92 | public: |
93 | static const int kMaxDeviceAPI = 32; |
94 | // Get API |
95 | static DeviceAPI* Get(const Device& dev) { return Get(dev.device_type); } |
96 | static DeviceAPI* Get(int dev_type, bool allow_missing = false) { |
97 | return Global()->GetAPI(dev_type, allow_missing); |
98 | } |
99 | |
100 | private: |
101 | std::array<DeviceAPI*, kMaxDeviceAPI> api_; |
102 | DeviceAPI* rpc_api_{nullptr}; |
103 | std::mutex mutex_; |
104 | // constructor |
105 | DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } |
106 | // Global static variable. |
107 | static DeviceAPIManager* Global() { |
108 | static DeviceAPIManager* inst = new DeviceAPIManager(); |
109 | return inst; |
110 | } |
111 | // Get or initialize API. |
112 | DeviceAPI* GetAPI(int type, bool allow_missing) { |
113 | if (type < kRPCSessMask) { |
114 | if (api_[type] != nullptr) return api_[type]; |
115 | std::lock_guard<std::mutex> lock(mutex_); |
116 | if (api_[type] != nullptr) return api_[type]; |
117 | api_[type] = GetAPI(DeviceName(type), allow_missing); |
118 | return api_[type]; |
119 | } else { |
120 | if (rpc_api_ != nullptr) return rpc_api_; |
121 | std::lock_guard<std::mutex> lock(mutex_); |
122 | if (rpc_api_ != nullptr) return rpc_api_; |
123 | rpc_api_ = GetAPI("rpc" , allow_missing); |
124 | return rpc_api_; |
125 | } |
126 | } |
127 | DeviceAPI* GetAPI(const std::string name, bool allow_missing) { |
128 | std::string factory = "device_api." + name; |
129 | auto* f = Registry::Get(factory); |
130 | if (f == nullptr) { |
131 | ICHECK(allow_missing) << "Device API " << name << " is not enabled." ; |
132 | return nullptr; |
133 | } |
134 | void* ptr = (*f)(); |
135 | return static_cast<DeviceAPI*>(ptr); |
136 | } |
137 | }; |
138 | |
139 | DeviceAPI* DeviceAPI::Get(Device dev, bool allow_missing) { |
140 | return DeviceAPIManager::Get(static_cast<int>(dev.device_type), allow_missing); |
141 | } |
142 | |
143 | void* DeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { |
144 | return AllocDataSpace(dev, size, kTempAllocaAlignment, type_hint); |
145 | } |
146 | |
147 | static size_t GetDataAlignment(const DLDataType dtype) { |
148 | size_t align = (dtype.bits / 8) * dtype.lanes; |
149 | if (align < kAllocAlignment) return kAllocAlignment; |
150 | return align; |
151 | } |
152 | |
153 | void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, |
154 | Optional<String> mem_scope) { |
155 | if (!mem_scope.defined() || mem_scope.value() == "global" ) { |
156 | // by default, we can always redirect to the flat memory allocations |
157 | DLTensor temp; |
158 | temp.data = nullptr; |
159 | temp.device = dev; |
160 | temp.ndim = ndim; |
161 | temp.dtype = dtype; |
162 | temp.shape = const_cast<int64_t*>(shape); |
163 | temp.strides = nullptr; |
164 | temp.byte_offset = 0; |
165 | size_t size = GetDataSize(temp); |
166 | size_t alignment = GetDataAlignment(temp.dtype); |
167 | return AllocDataSpace(dev, size, alignment, dtype); |
168 | } |
169 | LOG(FATAL) << "Device does not support allocate data space with " |
170 | << "specified memory scope: " << mem_scope.value(); |
171 | return nullptr; |
172 | } |
173 | |
174 | void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { |
175 | // by default, we can always redirect to the flat memory copy operation. |
176 | size_t nbytes = GetDataSize(*from); |
177 | ICHECK_EQ(nbytes, GetDataSize(*to)); |
178 | |
179 | ICHECK(IsContiguous(*from) && IsContiguous(*to)) |
180 | << "CopyDataFromTo only support contiguous array for now" ; |
181 | CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device, |
182 | to->device, from->dtype, stream); |
183 | } |
184 | |
185 | void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, |
186 | size_t num_bytes, Device dev_from, Device dev_to, |
187 | DLDataType type_hint, TVMStreamHandle stream) { |
188 | LOG(FATAL) << "Device does not support CopyDataFromTo." ; |
189 | } |
190 | |
191 | void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } |
192 | |
193 | TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } |
194 | |
195 | void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} |
196 | |
197 | void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { |
198 | } |
199 | |
200 | //-------------------------------------------------------- |
201 | // Error handling mechanism |
202 | // ------------------------------------------------------- |
203 | // Standard error message format, {} means optional |
204 | //-------------------------------------------------------- |
205 | // {error_type:} {message0} |
206 | // {message1} |
207 | // {message2} |
208 | // {Stack trace:} // stack traces follow by this line |
209 | // {trace 0} // two spaces in the beginning. |
210 | // {trace 1} |
211 | // {trace 2} |
212 | //-------------------------------------------------------- |
213 | /*! |
214 | * \brief Normalize error message |
215 | * |
216 | * Parse them header generated by by LOG(FATAL) and ICHECK |
217 | * and reformat the message into the standard format. |
218 | * |
219 | * This function will also merge all the stack traces into |
220 | * one trace and trim them. |
221 | * |
222 | * \param err_msg The error message. |
223 | * \return normalized message. |
224 | */ |
225 | std::string NormalizeError(std::string err_msg) { |
226 | // ------------------------------------------------------------------------ |
227 | // log with header, {} indicates optional |
228 | //------------------------------------------------------------------------- |
229 | // [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0} |
230 | // {message1} |
231 | // Stack trace: |
232 | // {stack trace 0} |
233 | // {stack trace 1} |
234 | //------------------------------------------------------------------------- |
235 | // Normalzied version |
236 | //------------------------------------------------------------------------- |
237 | // error_type: check_msg message0 |
238 | // {message1} |
239 | // Stack trace: |
240 | // File file_name, line lineno |
241 | // {stack trace 0} |
242 | // {stack trace 1} |
243 | //------------------------------------------------------------------------- |
244 | int line_number = 0; |
245 | std::istringstream is(err_msg); |
246 | std::string line, file_name, error_type, check_msg; |
247 | |
248 | // Parse log header and set the fields, |
249 | // Return true if it the log is in correct format, |
250 | // return false if something is wrong. |
251 | auto = [&]() { |
252 | // skip timestamp |
253 | if (is.peek() != '[') { |
254 | getline(is, line); |
255 | return true; |
256 | } |
257 | if (!(is >> line)) return false; |
258 | // get filename |
259 | while (is.peek() == ' ') is.get(); |
260 | #ifdef _MSC_VER // handle volume separator ":" in Windows path |
261 | std::string drive; |
262 | if (!getline(is, drive, ':')) return false; |
263 | if (!getline(is, file_name, ':')) return false; |
264 | file_name = drive + ":" + file_name; |
265 | #else |
266 | if (!getline(is, file_name, ':')) return false; |
267 | #endif |
268 | // get line number |
269 | if (!(is >> line_number)) return false; |
270 | // get rest of the message. |
271 | while (is.peek() == ' ' || is.peek() == ':') is.get(); |
272 | if (!getline(is, line)) return false; |
273 | // detect check message, rewrite to remote extra : |
274 | if (line.compare(0, 13, "Check failed:" ) == 0) { |
275 | std::string ending = ": " ; |
276 | size_t end_pos = line.find(ending, 13); |
277 | if (end_pos == std::string::npos) return false; |
278 | check_msg = line.substr(0, end_pos + ending.size()); |
279 | line = line.substr(end_pos + ending.size()); |
280 | } |
281 | return true; |
282 | }; |
283 | // if not in correct format, do not do any rewrite. |
284 | if (!parse_log_header()) return err_msg; |
285 | // Parse error type. |
286 | { |
287 | size_t start_pos = 0, end_pos; |
288 | for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) { |
289 | } |
290 | for (end_pos = start_pos; end_pos < line.length(); ++end_pos) { |
291 | char ch = line[end_pos]; |
292 | if (ch == ':') { |
293 | error_type = line.substr(start_pos, end_pos - start_pos); |
294 | break; |
295 | } |
296 | // [A-Z0-9a-z_.] |
297 | if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') break; |
298 | } |
299 | if (error_type.length() != 0) { |
300 | // if we successfully detected error_type: trim the following space. |
301 | for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' '; |
302 | ++start_pos) { |
303 | } |
304 | line = line.substr(start_pos); |
305 | } else { |
306 | // did not detect error_type, use default value. |
307 | line = line.substr(start_pos); |
308 | error_type = "TVMError" ; |
309 | } |
310 | } |
311 | // Separate out stack trace. |
312 | std::ostringstream os; |
313 | os << error_type << ": " << check_msg << line << '\n'; |
314 | |
315 | bool trace_mode = true; |
316 | std::vector<std::string> stack_trace; |
317 | while (getline(is, line)) { |
318 | if (trace_mode) { |
319 | if (line.compare(0, 2, " " ) == 0) { |
320 | stack_trace.push_back(line); |
321 | } else { |
322 | trace_mode = false; |
323 | // remove EOL trailing stacktrace. |
324 | if (line.length() == 0) continue; |
325 | } |
326 | } |
327 | if (!trace_mode) { |
328 | if (line.compare(0, 11, "Stack trace" ) == 0) { |
329 | trace_mode = true; |
330 | } else { |
331 | os << line << '\n'; |
332 | } |
333 | } |
334 | } |
335 | if (stack_trace.size() != 0 || file_name.length() != 0) { |
336 | os << "Stack trace:\n" ; |
337 | if (file_name.length() != 0) { |
338 | os << " File \"" << file_name << "\", line " << line_number << "\n" ; |
339 | } |
340 | // Print out stack traces, optionally trim the c++ traces |
341 | // about the frontends (as they will be provided by the frontends). |
342 | bool ffi_boundary = false; |
343 | for (const auto& line : stack_trace) { |
344 | // Heuristic to detect python ffi. |
345 | if (line.find("libffi.so" ) != std::string::npos || |
346 | line.find("core.cpython" ) != std::string::npos) { |
347 | ffi_boundary = true; |
348 | } |
349 | // If the backtrace is not c++ backtrace with the prefix " [bt]", |
350 | // then we can stop trimming. |
351 | if (ffi_boundary && line.compare(0, 6, " [bt]" ) != 0) { |
352 | ffi_boundary = false; |
353 | } |
354 | if (!ffi_boundary) { |
355 | os << line << '\n'; |
356 | } |
357 | // The line after TVMFuncCall cound be in FFI. |
358 | if (line.find("(TVMFuncCall" ) != std::string::npos) { |
359 | ffi_boundary = true; |
360 | } |
361 | } |
362 | } |
363 | return os.str(); |
364 | } |
365 | |
366 | } // namespace runtime |
367 | } // namespace tvm |
368 | |
369 | using namespace tvm::runtime; |
370 | |
371 | struct TVMRuntimeEntry { |
372 | std::string ret_str; |
373 | std::string last_error; |
374 | TVMByteArray ret_bytes; |
375 | }; |
376 | |
377 | typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore; |
378 | |
379 | const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); } |
380 | |
381 | int TVMAPIHandleException(const std::exception& e) { |
382 | TVMAPISetLastError(NormalizeError(e.what()).c_str()); |
383 | return -1; |
384 | } |
385 | |
386 | void TVMAPISetLastError(const char* msg) { TVMAPIRuntimeStore::Get()->last_error = msg; } |
387 | |
388 | int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) { |
389 | API_BEGIN(); |
390 | TVMRetValue ret; |
391 | ret = Module::LoadFromFile(file_name, format); |
392 | TVMValue val; |
393 | int type_code; |
394 | ret.MoveToCHost(&val, &type_code); |
395 | *out = val.v_handle; |
396 | API_END(); |
397 | } |
398 | |
399 | int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) { |
400 | API_BEGIN(); |
401 | ObjectInternal::GetModuleNode(mod)->Import(GetRef<Module>(ObjectInternal::GetModuleNode(dep))); |
402 | API_END(); |
403 | } |
404 | |
405 | int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, |
406 | TVMFunctionHandle* func) { |
407 | API_BEGIN(); |
408 | PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); |
409 | if (pf != nullptr) { |
410 | tvm::runtime::TVMRetValue ret; |
411 | ret = pf; |
412 | TVMValue val; |
413 | int type_code; |
414 | ret.MoveToCHost(&val, &type_code); |
415 | *func = val.v_handle; |
416 | } else { |
417 | *func = nullptr; |
418 | } |
419 | API_END(); |
420 | } |
421 | |
422 | int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); } |
423 | |
424 | int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) { |
425 | API_BEGIN(); |
426 | *func = (TVMFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name))->get(); |
427 | API_END(); |
428 | } |
429 | |
430 | void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, |
431 | int dtype_bits_hint) { |
432 | DLDevice dev; |
433 | dev.device_type = static_cast<DLDeviceType>(device_type); |
434 | dev.device_id = device_id; |
435 | |
436 | DLDataType type_hint; |
437 | type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint); |
438 | type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint); |
439 | type_hint.lanes = 1; |
440 | |
441 | return DeviceAPIManager::Get(dev)->AllocWorkspace(dev, static_cast<size_t>(size), type_hint); |
442 | } |
443 | |
444 | int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { |
445 | DLDevice dev; |
446 | dev.device_type = static_cast<DLDeviceType>(device_type); |
447 | dev.device_id = device_id; |
448 | DeviceAPIManager::Get(dev)->FreeWorkspace(dev, ptr); |
449 | return 0; |
450 | } |
451 | |
452 | int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { |
453 | if (*handle == nullptr) { |
454 | *handle = reinterpret_cast<void*>(1); |
455 | return (*f)(cdata); |
456 | } |
457 | return 0; |
458 | } |
459 | |
460 | int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); } |
461 | |
462 | int TVMByteArrayFree(TVMByteArray* arr) { |
463 | if (arr == &TVMAPIRuntimeStore::Get()->ret_bytes) { |
464 | return 0; // Thread-local storage does not need explicit deleting. |
465 | } |
466 | |
467 | delete arr; |
468 | return 0; |
469 | } |
470 | |
471 | int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args, |
472 | TVMValue* ret_val, int* ret_type_code) { |
473 | API_BEGIN(); |
474 | |
475 | TVMRetValue rv; |
476 | (static_cast<const PackedFuncObj*>(func)) |
477 | ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); |
478 | // handle return string. |
479 | if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { |
480 | TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); |
481 | if (rv.type_code() != kTVMDataType) { |
482 | e->ret_str = *rv.ptr<std::string>(); |
483 | } else { |
484 | e->ret_str = rv.operator std::string(); |
485 | } |
486 | if (rv.type_code() == kTVMBytes) { |
487 | e->ret_bytes.data = e->ret_str.c_str(); |
488 | e->ret_bytes.size = e->ret_str.length(); |
489 | *ret_type_code = kTVMBytes; |
490 | ret_val->v_handle = &(e->ret_bytes); |
491 | } else { |
492 | *ret_type_code = kTVMStr; |
493 | ret_val->v_str = e->ret_str.c_str(); |
494 | } |
495 | } else { |
496 | rv.MoveToCHost(ret_val, ret_type_code); |
497 | } |
498 | API_END(); |
499 | } |
500 | |
501 | int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) { |
502 | API_BEGIN(); |
503 | ICHECK_EQ(num_ret, 1); |
504 | TVMRetValue* rv = static_cast<TVMRetValue*>(ret); |
505 | *rv = TVMArgValue(value[0], type_code[0]); |
506 | API_END(); |
507 | } |
508 | |
509 | int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin, |
510 | TVMFunctionHandle* out) { |
511 | API_BEGIN(); |
512 | if (fin == nullptr) { |
513 | tvm::runtime::TVMRetValue ret; |
514 | ret = PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) { |
515 | int ret = func(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes), |
516 | args.num_args, rv, resource_handle); |
517 | if (ret != 0) { |
518 | throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); |
519 | } |
520 | }); |
521 | TVMValue val; |
522 | int type_code; |
523 | ret.MoveToCHost(&val, &type_code); |
524 | *out = val.v_handle; |
525 | } else { |
526 | // wrap it in a shared_ptr, with fin as deleter. |
527 | // so fin will be called when the lambda went out of scope. |
528 | std::shared_ptr<void> rpack(resource_handle, fin); |
529 | tvm::runtime::TVMRetValue ret; |
530 | ret = PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { |
531 | int ret = func(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes), |
532 | args.num_args, rv, rpack.get()); |
533 | if (ret != 0) { |
534 | throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); |
535 | } |
536 | }); |
537 | TVMValue val; |
538 | int type_code; |
539 | ret.MoveToCHost(&val, &type_code); |
540 | *out = val.v_handle; |
541 | } |
542 | API_END(); |
543 | } |
544 | |
545 | int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) { |
546 | API_BEGIN(); |
547 | DLDevice dev; |
548 | dev.device_type = static_cast<DLDeviceType>(device_type); |
549 | dev.device_id = device_id; |
550 | *out = DeviceAPIManager::Get(dev)->CreateStream(dev); |
551 | API_END(); |
552 | } |
553 | |
554 | int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { |
555 | API_BEGIN(); |
556 | DLDevice dev; |
557 | dev.device_type = static_cast<DLDeviceType>(device_type); |
558 | dev.device_id = device_id; |
559 | DeviceAPIManager::Get(dev)->FreeStream(dev, stream); |
560 | API_END(); |
561 | } |
562 | |
563 | int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { |
564 | API_BEGIN(); |
565 | DLDevice dev; |
566 | dev.device_type = static_cast<DLDeviceType>(device_type); |
567 | dev.device_id = device_id; |
568 | DeviceAPIManager::Get(dev)->SetStream(dev, stream); |
569 | API_END(); |
570 | } |
571 | |
572 | int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { |
573 | API_BEGIN(); |
574 | DLDevice dev; |
575 | dev.device_type = static_cast<DLDeviceType>(device_type); |
576 | dev.device_id = device_id; |
577 | DeviceAPIManager::Get(dev)->StreamSync(dev, stream); |
578 | API_END(); |
579 | } |
580 | |
581 | int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, |
582 | TVMStreamHandle dst) { |
583 | API_BEGIN(); |
584 | DLDevice dev; |
585 | dev.device_type = static_cast<DLDeviceType>(device_type); |
586 | dev.device_id = device_id; |
587 | DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, src, dst); |
588 | API_END(); |
589 | } |
590 | |
591 | int TVMCbArgToReturn(TVMValue* value, int* code) { |
592 | API_BEGIN(); |
593 | tvm::runtime::TVMRetValue rv; |
594 | rv = tvm::runtime::TVMMovableArgValue_(*value, *code); |
595 | rv.MoveToCHost(value, code); |
596 | API_END(); |
597 | } |
598 | |
599 | int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, DLDataType type_hint, |
600 | void** out_data) { |
601 | API_BEGIN(); |
602 | out_data[0] = DeviceAPIManager::Get(dev)->AllocDataSpace(dev, nbytes, alignment, type_hint); |
603 | API_END(); |
604 | } |
605 | |
606 | int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shape, DLDataType dtype, |
607 | const char* mem_scope, void** out_data) { |
608 | API_BEGIN(); |
609 | Optional<String> scope; |
610 | if (mem_scope != nullptr) { |
611 | scope = String(std::string(mem_scope)); |
612 | } |
613 | out_data[0] = DeviceAPIManager::Get(dev)->AllocDataSpace(dev, ndim, shape, dtype, scope); |
614 | API_END(); |
615 | } |
616 | |
617 | int TVMDeviceFreeDataSpace(DLDevice dev, void* ptr) { |
618 | API_BEGIN(); |
619 | DeviceAPIManager::Get(dev)->FreeDataSpace(dev, ptr); |
620 | API_END(); |
621 | } |
622 | |
623 | int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { |
624 | API_BEGIN(); |
625 | DLDevice dev_from = from->device; |
626 | DLDevice dev_to = to->device; |
627 | DLDevice dev = dev_from.device_type != kDLCPU ? dev_from : dev_to; |
628 | DeviceAPIManager::Get(dev)->CopyDataFromTo(from, to, stream); |
629 | API_END(); |
630 | } |
631 | |
632 | // set device api |
633 | TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) |
634 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
635 | DLDevice dev; |
636 | dev.device_type = static_cast<DLDeviceType>(args[0].operator int()); |
637 | dev.device_id = args[1]; |
638 | DeviceAPIManager::Get(dev)->SetDevice(dev); |
639 | }); |
640 | |
641 | // set device api |
642 | TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
643 | DLDevice dev; |
644 | dev.device_type = static_cast<DLDeviceType>(args[0].operator int()); |
645 | dev.device_id = args[1]; |
646 | |
647 | DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int()); |
648 | if (kind == kExist) { |
649 | DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); |
650 | if (api != nullptr) { |
651 | api->GetAttr(dev, kind, ret); |
652 | } else { |
653 | *ret = 0; |
654 | } |
655 | } else { |
656 | DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); |
657 | } |
658 | }); |
659 | |
660 | TVM_REGISTER_GLOBAL("runtime.TVMSetStream" ).set_body_typed(TVMSetStream); |
661 | |