1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/vm/executable.cc |
22 | * \brief The implementation of a virtual machine executable APIs. |
23 | */ |
24 | |
25 | #include <dmlc/memory_io.h> |
26 | #include <tvm/runtime/c_runtime_api.h> |
27 | #include <tvm/runtime/debug.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/runtime/vm/executable.h> |
30 | #include <tvm/runtime/vm/vm.h> |
31 | |
32 | #include <algorithm> |
33 | #include <iomanip> |
34 | #include <iostream> |
35 | #include <memory> |
36 | #include <random> |
37 | #include <sstream> |
38 | #include <utility> |
39 | #include <vector> |
40 | |
41 | #include "../file_utils.h" |
42 | #include "../library_module.h" |
43 | #include "serialize_utils.h" |
44 | |
45 | namespace tvm { |
46 | namespace runtime { |
47 | namespace vm { |
48 | |
49 | #define STREAM_CHECK(val, section) \ |
50 | ICHECK(val) << "Invalid VM file format in the " << section << " section." \ |
51 | << "\n"; |
52 | |
53 | // Helper to serialize a vm instruction. |
54 | VMInstructionSerializer SerializeInstruction(const Instruction& instr); |
55 | // Helper to deserialize a serialized vm instruction. |
56 | Instruction DeserializeInstruction(const VMInstructionSerializer& instr); |
57 | |
58 | PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) { |
59 | if (name == "get_lib" ) { |
60 | return PackedFunc( |
61 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); |
62 | } else if (name == "get_bytecode" ) { |
63 | return PackedFunc( |
64 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); }); |
65 | } else if (name == "get_constants" ) { |
66 | return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetConstants(); }); |
67 | } else if (name == "get_virtual_devices" ) { |
68 | return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetVirtualDevices(); }); |
69 | } else if (name == "get_primitives" ) { |
70 | return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetPrimitives(); }); |
71 | } else if (name == "get_stats" ) { |
72 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); |
73 | } else if (name == "save" ) { |
74 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); }); |
75 | } else if (name == "get_function_arity" ) { |
76 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
77 | std::string func_name = args[0]; |
78 | *rv = this->GetFunctionArity(func_name); |
79 | }); |
80 | } else if (name == "get_function_param_name" ) { |
81 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
82 | std::string func_name = args[0]; |
83 | int index = args[1]; |
84 | *rv = this->GetFunctionParameterName(func_name, index); |
85 | }); |
86 | } else if (name == "vm_load_executable" ) { |
87 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
88 | auto vm = make_object<VirtualMachine>(); |
89 | ICHECK(sptr_to_self.get() == this); |
90 | vm->LoadExecutable(GetObjectPtr<Executable>(this)); |
91 | *rv = Module(vm); |
92 | }); |
93 | } else if (name == "move_late_bound_consts" ) { |
94 | return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { |
95 | CHECK_EQ(args.size(), 2); |
96 | std::string path = args[0]; |
97 | uint64_t byte_limit = args[1]; |
98 | MoveLateBoundConstantsToFile(path, static_cast<size_t>(byte_limit)); |
99 | }); |
100 | } else if (name == "get_late_bound_consts" ) { |
101 | return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { |
102 | CHECK_EQ(args.size(), 1); |
103 | uint64_t byte_limit = args[0]; |
104 | Map<String, NDArray> consts = GetLateBoundConstants(static_cast<size_t>(byte_limit)); |
105 | *rv = consts; |
106 | }); |
107 | } else if (name == "load_late_bound_consts" ) { |
108 | return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { |
109 | CHECK_EQ(args.size(), 1); |
110 | std::string path = args[0]; |
111 | LoadLateBoundConstantsFromFile(path); |
112 | }); |
113 | } else if (name == "load_late_bound_consts_from_map" ) { |
114 | return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { |
115 | CHECK_EQ(args.size(), 1); |
116 | Map<String, NDArray> map = args[0]; |
117 | LoadLateBoundConstantsFromMap(map); |
118 | }); |
119 | } else { |
120 | LOG(FATAL) << "Unknown packed function: " << name; |
121 | } |
122 | } |
123 | |
124 | const VMFunction& Executable::GetVMFunctionWithName(const std::string& func_name) const { |
125 | auto it = global_map.find(func_name); |
126 | ICHECK(it != global_map.end()) << "Cannot find function " << func_name << " in executable" ; |
127 | return functions[it->second]; |
128 | } |
129 | |
130 | int Executable::GetFunctionArity(std::string func_name) const { |
131 | const auto& func = GetVMFunctionWithName(func_name); |
132 | return func.params.size(); |
133 | } |
134 | |
135 | std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const { |
136 | const auto& func = GetVMFunctionWithName(func_name); |
137 | ICHECK_LT(index, func.params.size()) << "Invalid parameter index" ; |
138 | return func.params[index]; |
139 | } |
140 | |
141 | std::string Executable::GetBytecode() const { |
142 | std::ostringstream oss; |
143 | |
144 | for (size_t i = 0; i < functions.size(); ++i) { |
145 | const auto& func = functions[i]; |
146 | // Print the header of the function format. |
147 | oss << "VM Function[" << i << "]: " << func.name << "(" ; |
148 | bool first = true; |
149 | for (const auto& param : func.params) { |
150 | if (!first) { |
151 | oss << ", " ; |
152 | } |
153 | oss << param; |
154 | first = false; |
155 | } |
156 | oss << ")" << std::endl; |
157 | oss << "# reg file size = " << func.register_file_size << std::endl; |
158 | oss << "# instruction count = " << func.instructions.size() << std::endl; |
159 | |
160 | // Print the instructions of a `VMFunction`. |
161 | // The part after ";" is the instruction in text format. |
162 | oss << "opcode, fields # inst(text):" << std::endl; |
163 | for (size_t idx = 0; idx < func.instructions.size(); ++idx) { |
164 | const auto& instr = func.instructions[idx]; |
165 | const auto& serialized_instr = SerializeInstruction(instr); |
166 | std::ostringstream line; |
167 | line << std::setw(2) << idx << ": " << serialized_instr.opcode << " " ; |
168 | for (auto it : serialized_instr.fields) { |
169 | line << it << " " ; |
170 | } |
171 | oss << std::setw(40) << std::setfill(' ') << std::left << line.str(); |
172 | oss << " # " << instr; |
173 | if (oss.str().back() != '\n') oss << std::endl; |
174 | } |
175 | oss << std::endl; |
176 | } |
177 | |
178 | return oss.str(); |
179 | } |
180 | |
181 | std::string Executable::GetConstants() const { |
182 | std::ostringstream oss; |
183 | for (size_t i = 0; i < constants.size(); ++i) { |
184 | const auto& constant = constants[i]; |
185 | auto ndarray = Downcast<NDArray>(constant); |
186 | oss << "VM Const[" << i |
187 | << "]: " << RuntimeObject2String(ndarray, virtual_devices[host_device_index]) |
188 | << " on device index " << const_device_indexes[i] << std::endl; |
189 | } |
190 | return oss.str(); |
191 | } |
192 | |
193 | std::string Executable::GetVirtualDevices() const { |
194 | std::ostringstream oss; |
195 | for (size_t i = 0; i < virtual_devices.size(); ++i) { |
196 | const auto& device = virtual_devices[i]; |
197 | oss << "VM VirtualDevice[" << i << "]: device type " << device.device_type << " and id " |
198 | << device.device_id << std::endl; |
199 | } |
200 | return oss.str(); |
201 | } |
202 | |
203 | std::string Executable::GetPrimitives() const { |
204 | std::ostringstream os; |
205 | std::vector<std::pair<int, std::string>> entries; |
206 | entries.reserve(primitive_map.size()); |
207 | for (const auto& kv : primitive_map) { |
208 | entries.emplace_back(kv.second, kv.first); |
209 | } |
210 | std::sort(entries.begin(), entries.end(), |
211 | [](const std::pair<int, std::string>& left, const std::pair<int, std::string>& right) { |
212 | return left.first < right.first; |
213 | }); |
214 | for (const auto& entry : entries) { |
215 | os << "VM PackedFunc[" << entry.first << "]: " << entry.second << std::endl; |
216 | } |
217 | return os.str(); |
218 | } |
219 | |
220 | std::string Executable::Stats() const { |
221 | std::ostringstream oss; |
222 | oss << "Relay VM executable statistics:" << std::endl; |
223 | |
224 | // Get the number of constants and the shape of each of them. |
225 | oss << " Constant shapes (# " << constants.size() << "): [" ; |
226 | for (const auto& it : constants) { |
227 | const auto constant = Downcast<NDArray>(it); |
228 | const auto& shape = constant.Shape(); |
229 | |
230 | // Scalar |
231 | if (shape.empty()) { |
232 | oss << "scalar, " ; |
233 | continue; |
234 | } |
235 | |
236 | oss << "[" ; |
237 | for (auto s : shape) { |
238 | oss << s << ", " ; |
239 | } |
240 | oss.seekp(-2, oss.cur); |
241 | oss << "], " << std::endl; |
242 | } |
243 | if (!constants.empty()) oss.seekp(-2, oss.cur); |
244 | oss << "]" << std::endl; |
245 | |
246 | // Get the number of globals and the name of each of them. |
247 | oss << " Globals (#" << global_map.size() << "): [" ; |
248 | for (const auto& it : global_map) { |
249 | oss << "(\"" << it.first << "\", " << it.second << ")" |
250 | << ", " ; |
251 | } |
252 | if (!global_map.empty()) oss.seekp(-2, oss.cur); |
253 | oss << "]" << std::endl; |
254 | |
255 | // Get the number of primitive ops and the name of each of them. |
256 | oss << " Primitive ops (#" << primitive_map.size() << "): [" ; |
257 | std::vector<std::string> prim_ops; |
258 | for (const auto& it : primitive_map) { |
259 | auto packed_index = static_cast<size_t>(it.second); |
260 | if (prim_ops.size() <= packed_index) { |
261 | prim_ops.resize(packed_index + 1); |
262 | } |
263 | prim_ops[packed_index] = it.first; |
264 | } |
265 | for (const auto& it : prim_ops) { |
266 | oss << it << ", " ; |
267 | } |
268 | if (!prim_ops.empty()) oss.seekp(-2, oss.cur); |
269 | oss << "]" << std::endl; |
270 | |
271 | return oss.str(); |
272 | } |
273 | |
274 | void (dmlc::Stream* strm) { |
275 | uint64_t = kTVMVMBytecodeMagic; |
276 | strm->Write(header); |
277 | std::string version = TVM_VERSION; |
278 | strm->Write(version); |
279 | } |
280 | |
281 | TVMByteArray Executable::Save() { |
282 | // Initialize the stream object. |
283 | code_.clear(); |
284 | dmlc::MemoryStringStream strm(&code_); |
285 | |
286 | // Save header |
287 | SaveHeader(&strm); |
288 | |
289 | // Save virtual devices section. |
290 | SaveVirtualDevicesSection(&strm); |
291 | |
292 | // Global section. |
293 | SaveGlobalSection(&strm); |
294 | |
295 | // Constant section. |
296 | SaveConstantSection(&strm); |
297 | |
298 | // Primitive names. |
299 | SavePrimitiveOpNames(&strm); |
300 | |
301 | // Code section. |
302 | SaveCodeSection(&strm); |
303 | |
304 | TVMByteArray arr; |
305 | arr.data = code_.c_str(); |
306 | arr.size = code_.length(); |
307 | return arr; |
308 | } |
309 | |
310 | void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) { |
311 | strm->Write(virtual_devices); |
312 | strm->Write(host_device_index); |
313 | } |
314 | |
315 | Map<String, NDArray> Executable::GetLateBoundConstants(size_t byte_limit) { |
316 | ICHECK(late_bound_constant_names.empty()); |
317 | late_bound_constant_names.reserve(constants.size()); |
318 | Map<String, NDArray> map; |
319 | size_t total_late_bound_bytes = 0; |
320 | for (size_t const_index = 0; const_index < constants.size(); ++const_index) { |
321 | const auto ndarray = Downcast<NDArray>(constants[const_index]); |
322 | ICHECK(ndarray.defined()) << "Undefined constant at index " << const_index; |
323 | size_t num_bytes = runtime::GetDataSize(*ndarray.operator->()); |
324 | if (num_bytes < byte_limit) { |
325 | // Leave as immediate. |
326 | late_bound_constant_names.emplace_back(nullptr); |
327 | continue; |
328 | } |
329 | total_late_bound_bytes += num_bytes; |
330 | std::ostringstream os; |
331 | os << "const_" << const_index; |
332 | String name = os.str(); |
333 | map.Set(name, Downcast<NDArray>(std::move(constants[const_index]))); |
334 | late_bound_constant_names.emplace_back(std::move(name)); |
335 | } |
336 | VLOG(1) << "moved " << map.size() << " constants of " << total_late_bound_bytes |
337 | << " bytes (out of " << constants.size() << " overall) to be late-bound" ; |
338 | return map; |
339 | } |
340 | |
341 | void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) { |
342 | Map<String, NDArray> map = GetLateBoundConstants(byte_limit); |
343 | runtime::SaveParams(stream, map); |
344 | } |
345 | |
346 | void Executable::MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit) { |
347 | tvm::runtime::SimpleBinaryFileStream stream(path, "wb" ); |
348 | MoveLateBoundConstantsToStream(&stream, byte_limit); |
349 | } |
350 | |
351 | void Executable::LoadLateBoundConstantsFromStream(dmlc::Stream* stream) { |
352 | if (late_bound_constant_names.empty()) { |
353 | VLOG(1) << "Found no late-bound constants to load" ; |
354 | return; |
355 | } |
356 | ICHECK_EQ(late_bound_constant_names.size(), constants.size()); |
357 | Map<String, NDArray> map = runtime::LoadParams(stream); |
358 | VLOG(1) << "loaded " << map.size() << " late-bound constants" ; |
359 | LoadLateBoundConstantsFromMap(map); |
360 | } |
361 | |
362 | void Executable::LoadLateBoundConstantsFromMap(Map<String, NDArray> map) { |
363 | for (size_t const_index = 0; const_index < constants.size(); ++const_index) { |
364 | if (!late_bound_constant_names[const_index].defined()) { |
365 | ICHECK(constants[const_index].defined()) |
366 | << "Undefined immediate constant at index " << const_index; |
367 | continue; |
368 | } |
369 | const String& name = late_bound_constant_names[const_index]; |
370 | ICHECK(!constants[const_index].defined()) << "Unexpected constant at index " << const_index; |
371 | auto itr = map.find(name); |
372 | ICHECK(itr != map.end()) << "No binding for late-bound constant at index " << const_index |
373 | << " with name '" << name << "'" ; |
374 | constants[const_index] = (*itr).second; |
375 | map.erase(name); |
376 | } |
377 | late_bound_constant_names.clear(); |
378 | ICHECK(map.empty()) << "Have " << map.size() << " unused late-bound constants" ; |
379 | } |
380 | |
381 | void Executable::LoadLateBoundConstantsFromFile(const std::string& path) { |
382 | tvm::runtime::SimpleBinaryFileStream stream(path, "rb" ); |
383 | LoadLateBoundConstantsFromStream(&stream); |
384 | } |
385 | |
386 | void Executable::SaveGlobalSection(dmlc::Stream* strm) { |
387 | std::vector<std::pair<std::string, Index>> globals(this->global_map.begin(), |
388 | this->global_map.end()); |
389 | auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) { |
390 | return a.second < b.second; |
391 | }; |
392 | std::sort(globals.begin(), globals.end(), comp); |
393 | |
394 | std::vector<std::string> glbs; |
395 | for (const auto& it : globals) { |
396 | glbs.push_back(it.first); |
397 | } |
398 | strm->Write(glbs); |
399 | } |
400 | |
401 | namespace { |
402 | // Tags to distinguish immediate vs late-bound constants in constants table bytestream. |
403 | constexpr uint32_t kImmediateConstTag = 0; |
404 | constexpr uint32_t kLateBoundConstTag = 1; |
405 | } // namespace |
406 | |
407 | void Executable::SaveConstantSection(dmlc::Stream* stream) { |
408 | // Save the overall number of constants. |
409 | stream->Write(static_cast<uint64_t>(constants.size())); |
410 | |
411 | for (size_t const_index = 0; const_index < constants.size(); ++const_index) { |
412 | if (late_bound_constant_names.empty() || !late_bound_constant_names[const_index].defined()) { |
413 | // Tag immediate constants by 0. |
414 | stream->Write(kImmediateConstTag); |
415 | // Write as DLTensor. |
416 | const auto ndarray = Downcast<runtime::NDArray>(constants[const_index]); |
417 | ICHECK(ndarray.defined()); |
418 | runtime::SaveDLTensor(stream, ndarray.operator->()); |
419 | VLOG(1) << "save " << const_index << " as immediate" ; |
420 | } else { |
421 | // Tag late-bound constants by 1. |
422 | const String& name = late_bound_constant_names[const_index]; |
423 | ICHECK(!constants[const_index].defined()); |
424 | stream->Write(kLateBoundConstTag); |
425 | // Write a string. |
426 | stream->Write(std::string(name)); |
427 | VLOG(1) << "save " << const_index << " as late-bound" ; |
428 | } |
429 | } |
430 | |
431 | VLOG(1) << "saved " << constants.size() << " constants" ; |
432 | |
433 | // Save the const to device index mapping. |
434 | stream->Write(const_device_indexes); |
435 | } |
436 | |
437 | void Executable::LoadConstantSection(dmlc::Stream* stream) { |
438 | uint64_t sz; |
439 | // Load the overall number of constants. |
440 | STREAM_CHECK(stream->Read(&sz, sizeof(sz)), "constants table size" ); |
441 | size_t size = static_cast<size_t>(sz); |
442 | |
443 | VLOG(1) << "loading " << size << " constants" ; |
444 | |
445 | constants.resize(size); |
446 | late_bound_constant_names.resize(size); |
447 | bool any_late_bound = false; |
448 | |
449 | // Load each of the constants. |
450 | for (size_t const_index = 0; const_index < size; const_index++) { |
451 | uint32_t tag; |
452 | STREAM_CHECK(stream->Read(&tag, sizeof(tag)), "constant tag" ); |
453 | if (tag == kImmediateConstTag) { |
454 | // Immediate constants tagged by 0. |
455 | VLOG(1) << "load " << const_index << " as immediate" ; |
456 | runtime::NDArray ndarray; |
457 | STREAM_CHECK(ndarray.Load(stream), "constant tensor" ); |
458 | constants[const_index] = std::move(ndarray); |
459 | late_bound_constant_names[const_index] = String(ObjectPtr<StringObj>(nullptr)); |
460 | } else if (tag == kLateBoundConstTag) { |
461 | // Late-bound constants tagged by 1. |
462 | VLOG(1) << "load " << const_index << " as late-bound" ; |
463 | std::string name; |
464 | STREAM_CHECK(stream->Read(&name), "late-bound constant name" ); |
465 | constants[const_index] = NDArray(nullptr); |
466 | late_bound_constant_names[const_index] = std::move(name); |
467 | any_late_bound = true; |
468 | } else { |
469 | STREAM_CHECK(false, "constant tag" ); |
470 | } |
471 | } |
472 | |
473 | if (!any_late_bound) { |
474 | late_bound_constant_names.clear(); |
475 | } |
476 | |
477 | // Load the const to device index mapping. |
478 | std::vector<Index> indexes; |
479 | indexes.reserve(size); |
480 | STREAM_CHECK(stream->Read(&indexes), "constant devices" ); |
481 | ICHECK_EQ(size, indexes.size()); |
482 | const_device_indexes = std::move(indexes); |
483 | } |
484 | |
485 | void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { |
486 | std::vector<std::string> primitive_names; |
487 | for (const auto& it : this->primitive_map) { |
488 | auto packed_index = static_cast<size_t>(it.second); |
489 | if (primitive_names.size() <= packed_index) { |
490 | primitive_names.resize(packed_index + 1); |
491 | } |
492 | primitive_names[packed_index] = it.first; |
493 | } |
494 | strm->Write(primitive_names); |
495 | std::map<uint64_t, std::map<std::string, std::string>> primitive_attrs; |
496 | for (const auto& it : this->op_attrs) { |
497 | auto packed_index = static_cast<size_t>(it.first); |
498 | std::map<std::string, std::string> attrs; |
499 | for (const auto& elem : it.second) { |
500 | // TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer, so we just serialize |
501 | // strings for now |
502 | if (elem.second.as<StringObj>()) { |
503 | attrs[elem.first] = Downcast<String>(elem.second); |
504 | } |
505 | } |
506 | primitive_attrs[packed_index] = attrs; |
507 | } |
508 | strm->Write(primitive_attrs); |
509 | } |
510 | |
511 | // Serialize a virtual machine instruction. It creates a list that contains the |
512 | // hash, opcode, and all fields of an instruction. |
513 | // |
514 | // For example, the function signature used to create an `AllocTensor` |
515 | // instruction is: |
516 | // Instruction AllocTensor(std::vector<Index> shape, DLDataType dtype, RegName dst) |
517 | // |
518 | // The serialized form will be: |
519 | // `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` |
520 | // |
521 | // where hash is the hash of serialized instruction that is computed internally |
522 | // by the `VMInstructionExecutable`. It is used for sanity check before decoding. |
523 | // 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` |
524 | // represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` |
525 | // is the destination register, and the rest of it together indicates the shape |
526 | // of the tensor to be allocated. |
527 | VMInstructionSerializer SerializeInstruction(const Instruction& instr) { |
528 | std::vector<Index> fields; |
529 | // Save the opcode. |
530 | VLOG(2) << "Serializing: " << instr << std::endl; |
531 | switch (instr.op) { |
532 | case Opcode::Move: { |
533 | // Number of fields = 2 |
534 | fields.assign({instr.from, instr.dst}); |
535 | break; |
536 | } |
537 | case Opcode::Ret: { |
538 | // Number of fields = 1 |
539 | fields.push_back(instr.result); |
540 | break; |
541 | } |
542 | case Opcode::Fatal: { |
543 | // Number of fields = 0 |
544 | break; |
545 | } |
546 | case Opcode::InvokePacked: { |
547 | // Number of fields = 3 + instr.arity |
548 | // Note that arity includes both input arguments and outputs. We will |
549 | // put all the `arity` number of fields in the end for serialization. |
550 | fields.assign({instr.packed_index, instr.arity, instr.output_size}); |
551 | // Save the args. |
552 | fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); |
553 | break; |
554 | } |
555 | case Opcode::AllocTensor: { |
556 | // Number of fields = 7 + instr.alloc_tensor.ndim |
557 | fields.push_back(instr.alloc_tensor.storage); |
558 | fields.push_back(instr.alloc_tensor.offset); |
559 | // Save `DLDataType` and the dst register. |
560 | const auto& dtype = instr.alloc_tensor.dtype; |
561 | fields.push_back(dtype.code); |
562 | fields.push_back(dtype.bits); |
563 | fields.push_back(dtype.lanes); |
564 | |
565 | // The number of dimensions is not needed for constructing an |
566 | // `AllocTensor` instruction as it equals to the length of the `shape` |
567 | // vector. However, we save it to conveniently deserialize the instruction |
568 | // because we will know how many fields are needed by the `shape` argument. |
569 | fields.push_back(instr.alloc_tensor.ndim); |
570 | fields.push_back(instr.dst); |
571 | |
572 | // Save the shape of the tensor. |
573 | // Note that this field is rotated to the end of the list. |
574 | fields.insert(fields.end(), instr.alloc_tensor.shape, |
575 | instr.alloc_tensor.shape + instr.alloc_tensor.ndim); |
576 | break; |
577 | } |
578 | case Opcode::AllocTensorReg: { |
579 | // Number of fields = 7 |
580 | fields.push_back(instr.alloc_tensor_reg.storage); |
581 | fields.push_back(instr.alloc_tensor_reg.offset); |
582 | fields.push_back(instr.alloc_tensor_reg.shape_register); |
583 | // Save `DLDataType` and the dst register. |
584 | const auto& dtype = instr.alloc_tensor_reg.dtype; |
585 | fields.push_back(dtype.code); |
586 | fields.push_back(dtype.bits); |
587 | fields.push_back(dtype.lanes); |
588 | fields.push_back(instr.dst); |
589 | break; |
590 | } |
591 | case Opcode::AllocStorage: { |
592 | fields.push_back(instr.alloc_storage.allocation_size); |
593 | fields.push_back(instr.alloc_storage.alignment); |
594 | // Save `DLDataType` and the dst register. |
595 | const auto& dtype = instr.alloc_storage.dtype_hint; |
596 | fields.push_back(dtype.code); |
597 | fields.push_back(dtype.bits); |
598 | fields.push_back(dtype.lanes); |
599 | fields.push_back(instr.alloc_storage.device_index); |
600 | fields.push_back(instr.dst); |
601 | break; |
602 | } |
603 | case Opcode::AllocADT: { |
604 | // Number of fields = 3 + instr.num_fields |
605 | fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); |
606 | |
607 | // Save the fields. |
608 | fields.insert(fields.end(), instr.datatype_fields, instr.datatype_fields + instr.num_fields); |
609 | break; |
610 | } |
611 | case Opcode::AllocClosure: { |
612 | // Number of fields = 3 + instr.num_freevar |
613 | fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); |
614 | |
615 | // Save the free vars. |
616 | fields.insert(fields.end(), instr.free_vars, instr.free_vars + instr.num_freevar); |
617 | break; |
618 | } |
619 | case Opcode::If: { |
620 | // Number of fields = 4 |
621 | fields.assign({instr.if_op.test, instr.if_op.target, instr.if_op.true_offset, |
622 | instr.if_op.false_offset}); |
623 | break; |
624 | } |
625 | case Opcode::Invoke: { |
626 | // Number of fields = 3 + instr.num_args |
627 | fields.assign({instr.func_index, instr.num_args, instr.dst}); |
628 | |
629 | // Save the args. |
630 | fields.insert(fields.end(), instr.invoke_args_registers, |
631 | instr.invoke_args_registers + instr.num_args); |
632 | break; |
633 | } |
634 | case Opcode::InvokeClosure: { |
635 | // Number of fields = 3 + instr.num_closure_args |
636 | fields.assign({instr.closure, instr.num_closure_args, instr.dst}); |
637 | |
638 | // Save the args. |
639 | fields.insert(fields.end(), instr.closure_args, instr.closure_args + instr.num_closure_args); |
640 | break; |
641 | } |
642 | case Opcode::LoadConst: { |
643 | // Number of fields = 2 |
644 | fields.assign({instr.const_index, instr.dst}); |
645 | break; |
646 | } |
647 | case Opcode::LoadConsti: { |
648 | // Number of fields = 2 |
649 | fields.assign({instr.load_consti.val, instr.dst}); |
650 | break; |
651 | } |
652 | case Opcode::GetField: { |
653 | // Number of fields = 3 |
654 | fields.assign({instr.object, instr.field_index, instr.dst}); |
655 | break; |
656 | } |
657 | case Opcode::GetTag: { |
658 | // Number of fields = 2 |
659 | fields.assign({instr.get_tag.object, instr.dst}); |
660 | break; |
661 | } |
662 | case Opcode::Goto: { |
663 | // Number of fields = 1 |
664 | fields.push_back(instr.pc_offset); |
665 | break; |
666 | } |
667 | case Opcode::ShapeOf: { |
668 | // Number of fields = 2 |
669 | fields.assign({instr.shape_of.tensor, instr.dst}); |
670 | break; |
671 | } |
672 | case Opcode::ReshapeTensor: { |
673 | // Number of fields = 3 |
674 | fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst}); |
675 | break; |
676 | } |
677 | case Opcode::DeviceCopy: { |
678 | // Number of fields = 4 |
679 | fields.assign({instr.device_copy.src, instr.device_copy.src_device_index, |
680 | instr.device_copy.dst_device_index, instr.dst}); |
681 | break; |
682 | } |
683 | case Opcode::KillRegister: { |
684 | fields.assign({instr.dst}); |
685 | break; |
686 | } |
687 | default: |
688 | LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op); |
689 | break; |
690 | } |
691 | |
692 | return VMInstructionSerializer(static_cast<Index>(instr.op), fields); |
693 | } |
694 | |
695 | void Executable::SaveCodeSection(dmlc::Stream* strm) { |
696 | // Save the number of functions. |
697 | strm->Write(static_cast<uint64_t>(this->functions.size())); |
698 | for (const auto& func : this->functions) { |
699 | // Save the function info. |
700 | VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), |
701 | func.params, func.param_device_indexes); |
702 | func_format.Save(strm); |
703 | |
704 | // Serialize each instruction. |
705 | for (const auto& instr : func.instructions) { |
706 | const auto& serialized_instr = SerializeInstruction(instr); |
707 | serialized_instr.Save(strm); |
708 | } |
709 | } |
710 | } |
711 | |
712 | void (dmlc::Stream* strm) { |
713 | // Check header. |
714 | uint64_t ; |
715 | STREAM_CHECK(strm->Read(&header), "header" ); |
716 | STREAM_CHECK(header == kTVMVMBytecodeMagic, "header" ); |
717 | |
718 | // Check version. |
719 | std::string version; |
720 | STREAM_CHECK(strm->Read(&version), "version" ); |
721 | STREAM_CHECK(version == TVM_VERSION, "version" ); |
722 | } |
723 | |
724 | runtime::Module Executable::GetLib() const { |
725 | ICHECK_LE(this->imports_.size(), 1) |
726 | << "The kernel library must be imported as the only module in an Executable" ; |
727 | |
728 | if (this->imports().size() == 0) { |
729 | return Module(nullptr); |
730 | } else { |
731 | return this->imports_[0]; |
732 | } |
733 | } |
734 | |
735 | void Executable::SetLib(const runtime::Module& lib) { |
736 | ICHECK(lib.defined()) << "the provided library can not be null" ; |
737 | |
738 | ICHECK_EQ(this->imports_.size(), 0) |
739 | << "A VMExecutable should never have more than one import inside an the executable, \n" |
740 | << "the first import should *always* be the library containing" |
741 | << "the platform specific kernel code" ; |
742 | |
743 | this->Import(lib); |
744 | } |
745 | |
746 | runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { |
747 | auto exec = make_object<Executable>(); |
748 | |
749 | // Support null-initialization of lib, to enable initialization during |
750 | // deserialization before we have deserialized the imports. |
751 | if (lib.defined()) { |
752 | exec->SetLib(lib); |
753 | } |
754 | |
755 | exec->code_ = code; |
756 | dmlc::MemoryStringStream strm(&exec->code_); |
757 | |
758 | // Load header. |
759 | LoadHeader(&strm); |
760 | |
761 | // Virtual devices section |
762 | exec->LoadVirtualDevicesSection(&strm); |
763 | |
764 | // Global section. |
765 | exec->LoadGlobalSection(&strm); |
766 | |
767 | // Constant section. |
768 | exec->LoadConstantSection(&strm); |
769 | |
770 | // Primitive names that will be invoked by `InvokePacked` instructions. |
771 | exec->LoadPrimitiveOpNames(&strm); |
772 | |
773 | // Code section. |
774 | exec->LoadCodeSection(&strm); |
775 | |
776 | return runtime::Module(exec); |
777 | } |
778 | |
779 | void Executable::LoadVirtualDevicesSection(dmlc::Stream* strm) { |
780 | STREAM_CHECK(strm->Read(&virtual_devices), "virtual_device" ); |
781 | STREAM_CHECK(strm->Read(&host_device_index), "virtual_device" ); |
782 | ICHECK(host_device_index >= 0 && host_device_index < static_cast<int>(virtual_devices.size())); |
783 | } |
784 | |
785 | void Executable::LoadGlobalSection(dmlc::Stream* strm) { |
786 | std::vector<std::string> globals; |
787 | STREAM_CHECK(strm->Read(&globals), "global" ); |
788 | for (size_t i = 0; i < globals.size(); i++) { |
789 | this->global_map.insert({globals[i], i}); |
790 | } |
791 | } |
792 | |
793 | void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { |
794 | std::vector<std::string> primitive_names; |
795 | STREAM_CHECK(strm->Read(&primitive_names), "primitive name" ); |
796 | for (size_t i = 0; i < primitive_names.size(); i++) { |
797 | this->primitive_map.insert({primitive_names[i], i}); |
798 | } |
799 | |
800 | std::map<uint64_t, std::map<std::string, std::string>> primitive_attrs; |
801 | STREAM_CHECK(strm->Read(&primitive_attrs), "primitive attrs" ); |
802 | for (const auto& fn : primitive_attrs) { |
803 | std::vector<std::pair<String, ObjectRef>> attrs; |
804 | for (const auto& elem : fn.second) { |
805 | attrs.push_back({elem.first, String(elem.second)}); |
806 | } |
807 | this->op_attrs[fn.first] = Map<String, ObjectRef>(attrs.begin(), attrs.end()); |
808 | } |
809 | } |
810 | |
811 | // Extract the `cnt` number of fields started at `start` from the list |
812 | // `instr_fields`. |
813 | inline std::vector<Index> (const std::vector<Index>& instr_fields, Index start, |
814 | Index cnt) { |
815 | ICHECK_LE(static_cast<size_t>(start + cnt), instr_fields.size()); |
816 | std::vector<Index> ret; |
817 | for (auto i = start; i < start + cnt; i++) { |
818 | ret.push_back(instr_fields[i]); |
819 | } |
820 | return ret; |
821 | } |
822 | |
823 | Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { |
824 | Opcode opcode = static_cast<Opcode>(instr.opcode); |
825 | switch (opcode) { |
826 | case Opcode::Move: { |
827 | // Number of fields = 2 |
828 | DCHECK_EQ(instr.fields.size(), 2U); |
829 | return Instruction::Move(instr.fields[0], instr.fields[1]); |
830 | } |
831 | case Opcode::Ret: { |
832 | // Number of fields = 1 |
833 | DCHECK_EQ(instr.fields.size(), 1U); |
834 | return Instruction::Ret(instr.fields[0]); |
835 | } |
836 | case Opcode::Fatal: { |
837 | // Number of fields = 0 |
838 | DCHECK(instr.fields.empty()); |
839 | return Instruction::Fatal(); |
840 | } |
841 | case Opcode::InvokePacked: { |
842 | // Number of fields = 3 + instr.arity |
843 | DCHECK_GE(instr.fields.size(), 3U); |
844 | DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1])); |
845 | |
846 | Index packed_index = instr.fields[0]; |
847 | Index arity = instr.fields[1]; |
848 | Index output_size = instr.fields[2]; |
849 | std::vector<RegName> args = ExtractFields(instr.fields, 3, arity); |
850 | return Instruction::InvokePacked(packed_index, arity, output_size, args); |
851 | } |
852 | case Opcode::AllocTensor: { |
853 | // Number of fields = 7 + instr.alloc_tensor.ndim |
854 | DCHECK_GE(instr.fields.size(), 7U); |
855 | DCHECK_EQ(instr.fields.size(), 7U + static_cast<size_t>(instr.fields[5])); |
856 | |
857 | RegName storage_reg = instr.fields[0]; |
858 | RegName offset = instr.fields[1]; |
859 | |
860 | DLDataType dtype; |
861 | dtype.code = instr.fields[2]; |
862 | dtype.bits = instr.fields[3]; |
863 | dtype.lanes = instr.fields[4]; |
864 | |
865 | Index ndim = instr.fields[5]; |
866 | RegName dst = instr.fields[6]; |
867 | |
868 | std::vector<Index> shape = ExtractFields(instr.fields, 7, ndim); |
869 | |
870 | return Instruction::AllocTensor(storage_reg, offset, shape, dtype, dst); |
871 | } |
872 | case Opcode::AllocTensorReg: { |
873 | // Number of fields = 7 |
874 | DCHECK_EQ(instr.fields.size(), 7U); |
875 | |
876 | RegName storage_reg = instr.fields[0]; |
877 | RegName offset = instr.fields[1]; |
878 | Index shape_register = instr.fields[2]; |
879 | |
880 | DLDataType dtype; |
881 | dtype.code = instr.fields[3]; |
882 | dtype.bits = instr.fields[4]; |
883 | dtype.lanes = instr.fields[5]; |
884 | |
885 | RegName dst = instr.fields[6]; |
886 | |
887 | return Instruction::AllocTensorReg(storage_reg, offset, shape_register, dtype, dst); |
888 | } |
889 | case Opcode::AllocADT: { |
890 | // Number of fields = 3 + instr.num_fields |
891 | DCHECK_GE(instr.fields.size(), 3U); |
892 | DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1])); |
893 | |
894 | Index constructor_tag = instr.fields[0]; |
895 | Index num_fields = instr.fields[1]; |
896 | RegName dst = instr.fields[2]; |
897 | std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields); |
898 | |
899 | return Instruction::AllocADT(constructor_tag, num_fields, fields, dst); |
900 | } |
901 | case Opcode::AllocClosure: { |
902 | // Number of fields = 3 + instr.num_freevar |
903 | DCHECK_GE(instr.fields.size(), 3U); |
904 | DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1])); |
905 | |
906 | Index clo_index = instr.fields[0]; |
907 | Index num_freevar = instr.fields[1]; |
908 | RegName dst = instr.fields[2]; |
909 | std::vector<Index> free_vars = ExtractFields(instr.fields, 3, num_freevar); |
910 | |
911 | return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); |
912 | } |
913 | case Opcode::AllocStorage: { |
914 | // Number of fields = 7 |
915 | DCHECK_GE(instr.fields.size(), 7U); |
916 | Index allocation_size = instr.fields[0]; |
917 | Index alignment = instr.fields[1]; |
918 | |
919 | DLDataType dtype; |
920 | dtype.code = instr.fields[2]; |
921 | dtype.bits = instr.fields[3]; |
922 | dtype.lanes = instr.fields[4]; |
923 | |
924 | Index device_type = instr.fields[5]; |
925 | RegName dst = instr.fields[6]; |
926 | |
927 | return Instruction::AllocStorage(allocation_size, alignment, dtype, device_type, dst); |
928 | } |
929 | case Opcode::If: { |
930 | // Number of fields = 4 |
931 | DCHECK_EQ(instr.fields.size(), 4U); |
932 | Index test = instr.fields[0]; |
933 | Index target = instr.fields[1]; |
934 | Index true_offset = instr.fields[2]; |
935 | Index false_offset = instr.fields[3]; |
936 | |
937 | return Instruction::If(test, target, true_offset, false_offset); |
938 | } |
939 | case Opcode::Invoke: { |
940 | // Number of fields = 3 + instr.num_args |
941 | DCHECK_GE(instr.fields.size(), 3U); |
942 | DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1])); |
943 | |
944 | Index func_index = instr.fields[0]; |
945 | Index num_args = instr.fields[1]; |
946 | RegName dst = instr.fields[2]; |
947 | std::vector<Index> args = ExtractFields(instr.fields, 3, num_args); |
948 | |
949 | return Instruction::Invoke(func_index, args, dst); |
950 | } |
951 | case Opcode::InvokeClosure: { |
952 | // Number of fields = 3 + instr.num_closure_args |
953 | DCHECK_GE(instr.fields.size(), 3U); |
954 | DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1])); |
955 | |
956 | Index closure = instr.fields[0]; |
957 | Index num_closure_args = instr.fields[1]; |
958 | RegName dst = instr.fields[2]; |
959 | std::vector<Index> args = ExtractFields(instr.fields, 3, num_closure_args); |
960 | |
961 | return Instruction::InvokeClosure(closure, args, dst); |
962 | } |
963 | case Opcode::LoadConst: { |
964 | // Number of fields = 2 |
965 | DCHECK_EQ(instr.fields.size(), 2U); |
966 | return Instruction::LoadConst(instr.fields[0], instr.fields[1]); |
967 | } |
968 | case Opcode::LoadConsti: { |
969 | // Number of fields = 2 |
970 | DCHECK_EQ(instr.fields.size(), 2U); |
971 | return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); |
972 | } |
973 | case Opcode::GetField: { |
974 | // Number of fields = 3 |
975 | DCHECK_EQ(instr.fields.size(), 3U); |
976 | return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); |
977 | } |
978 | case Opcode::GetTag: { |
979 | // Number of fields = 2 |
980 | DCHECK_EQ(instr.fields.size(), 2U); |
981 | return Instruction::GetTag(instr.fields[0], instr.fields[1]); |
982 | } |
983 | case Opcode::Goto: { |
984 | // Number of fields = 1 |
985 | DCHECK_EQ(instr.fields.size(), 1U); |
986 | return Instruction::Goto(instr.fields[0]); |
987 | } |
988 | case Opcode::ShapeOf: { |
989 | // Number of fields = 2 |
990 | DCHECK_EQ(instr.fields.size(), 2U); |
991 | return Instruction::ShapeOf(instr.fields[0], instr.fields[1]); |
992 | } |
993 | case Opcode::ReshapeTensor: { |
994 | // Number of fields = 3 |
995 | DCHECK_EQ(instr.fields.size(), 3U); |
996 | return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]); |
997 | } |
998 | case Opcode::DeviceCopy: { |
999 | // Number of fields = 4 |
1000 | DCHECK_EQ(instr.fields.size(), 4U); |
1001 | return Instruction::DeviceCopy(instr.fields[0], instr.fields[1], instr.fields[2], |
1002 | instr.fields[3]); |
1003 | } |
1004 | case Opcode::KillRegister: { |
1005 | DCHECK_EQ(instr.fields.size(), 1U); |
1006 | return Instruction::KillRegister(instr.fields[0]); |
1007 | } |
1008 | default: |
1009 | LOG(FATAL) << "Invalid opcode" << instr.opcode; |
1010 | } |
1011 | } |
1012 | |
1013 | void Executable::LoadCodeSection(dmlc::Stream* strm) { |
1014 | // Load the number of functions. |
1015 | uint64_t sz; |
1016 | STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code" ); |
1017 | |
1018 | size_t num_funcs = static_cast<size_t>(sz); |
1019 | this->functions.resize(num_funcs); |
1020 | for (size_t i = 0; i < num_funcs; i++) { |
1021 | // Load the function info. |
1022 | VMFunctionSerializer loaded_func; |
1023 | STREAM_CHECK(loaded_func.Load(strm), "code/function" ); |
1024 | |
1025 | // Load the instructions. |
1026 | std::vector<Instruction> instructions; |
1027 | for (size_t j = 0; j < loaded_func.num_instructions; j++) { |
1028 | VMInstructionSerializer instr; |
1029 | std::vector<Index> instr_fields; |
1030 | STREAM_CHECK(instr.Load(strm), "code/instruction" ); |
1031 | instructions.push_back(DeserializeInstruction(instr)); |
1032 | } |
1033 | |
1034 | // Create the VM function. |
1035 | VMFunction vm_func = |
1036 | VMFunction(loaded_func.name, loaded_func.params, instructions, |
1037 | loaded_func.register_file_size, loaded_func.param_device_indexes); |
1038 | auto it = this->global_map.find(loaded_func.name); |
1039 | ICHECK(it != this->global_map.end()); |
1040 | ICHECK_LE(it->second, this->global_map.size()); |
1041 | this->functions[it->second] = vm_func; |
1042 | } |
1043 | } |
1044 | |
1045 | void Executable::SaveToBinary(dmlc::Stream* stream) { |
1046 | auto code_bytes = this->Save(); |
1047 | std::string code(code_bytes.data, code_bytes.size); |
1048 | stream->Write(code); |
1049 | |
1050 | ICHECK(this->imports()[0].defined()) << "the library must be imported before serialization" ; |
1051 | } |
1052 | |
1053 | Module ExecutableLoadBinary(void* strm) { |
1054 | dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); |
1055 | std::string code; |
1056 | stream->Read(&code); |
1057 | auto exec = Executable::Load(code, Module()); |
1058 | return exec; |
1059 | } |
1060 | |
1061 | void Executable::SaveToFile(const std::string& path, const std::string& format) { |
1062 | tvm::runtime::SimpleBinaryFileStream stream(path, "wb" ); |
1063 | SaveToBinary(&stream); |
1064 | } |
1065 | |
1066 | TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable" ).set_body_typed(ExecutableLoadBinary); |
1067 | |
1068 | // Load module from module. |
1069 | Module ExecutableLoadFile(const std::string& file_name, const std::string& format) { |
1070 | tvm::runtime::SimpleBinaryFileStream stream(file_name, "rb" ); |
1071 | auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(&stream)); |
1072 | return exec; |
1073 | } |
1074 | |
1075 | TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable" ).set_body_typed(ExecutableLoadFile); |
1076 | |
1077 | TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
1078 | runtime::Module mod = args[0]; |
1079 | const auto* exec = dynamic_cast<Executable*>(mod.operator->()); |
1080 | ICHECK(exec); |
1081 | *rv = static_cast<int>(exec->global_map.size()); |
1082 | }); |
1083 | |
1084 | TVM_REGISTER_GLOBAL("runtime.GetGlobalFields" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
1085 | runtime::Module mod = args[0]; |
1086 | const auto* exec = dynamic_cast<Executable*>(mod.operator->()); |
1087 | ICHECK(exec); |
1088 | int idx = args[1]; |
1089 | std::vector<std::pair<std::string, Index>> globals(exec->global_map.begin(), |
1090 | exec->global_map.end()); |
1091 | auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) { |
1092 | return a.second < b.second; |
1093 | }; |
1094 | std::sort(globals.begin(), globals.end(), comp); |
1095 | ICHECK_LT(idx, globals.size()); |
1096 | *rv = globals[idx].first; |
1097 | }); |
1098 | |
1099 | TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
1100 | runtime::Module mod = args[0]; |
1101 | const auto* exec = dynamic_cast<Executable*>(mod.operator->()); |
1102 | ICHECK(exec); |
1103 | *rv = static_cast<int>(exec->primitive_map.size()); |
1104 | }); |
1105 | |
1106 | TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
1107 | runtime::Module mod = args[0]; |
1108 | const auto* exec = dynamic_cast<Executable*>(mod.operator->()); |
1109 | ICHECK(exec); |
1110 | int idx = args[1]; |
1111 | ICHECK_GE(idx, 0); |
1112 | ICHECK_LT(idx, exec->primitive_map.size()); |
1113 | |
1114 | for (const auto& it : exec->primitive_map) { |
1115 | if (idx == static_cast<int>(it.second)) { |
1116 | *rv = it.first; |
1117 | break; |
1118 | } |
1119 | } |
1120 | }); |
1121 | |
1122 | TVM_REGISTER_GLOBAL("runtime.Load_Executable" ) |
1123 | .set_body_typed([](std::string code, runtime::Module lib) { |
1124 | return Executable::Load(code, lib); |
1125 | }); |
1126 | |
1127 | } // namespace vm |
1128 | } // namespace runtime |
1129 | } // namespace tvm |
1130 | |