1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/vm/executable.cc
22 * \brief The implementation of a virtual machine executable APIs.
23 */
24
25#include <dmlc/memory_io.h>
26#include <tvm/runtime/c_runtime_api.h>
27#include <tvm/runtime/debug.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/runtime/vm/executable.h>
30#include <tvm/runtime/vm/vm.h>
31
32#include <algorithm>
33#include <iomanip>
34#include <iostream>
35#include <memory>
36#include <random>
37#include <sstream>
38#include <utility>
39#include <vector>
40
41#include "../file_utils.h"
42#include "../library_module.h"
43#include "serialize_utils.h"
44
45namespace tvm {
46namespace runtime {
47namespace vm {
48
49#define STREAM_CHECK(val, section) \
50 ICHECK(val) << "Invalid VM file format in the " << section << " section." \
51 << "\n";
52
53// Helper to serialize a vm instruction.
54VMInstructionSerializer SerializeInstruction(const Instruction& instr);
55// Helper to deserialize a serialized vm instruction.
56Instruction DeserializeInstruction(const VMInstructionSerializer& instr);
57
58PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
59 if (name == "get_lib") {
60 return PackedFunc(
61 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); });
62 } else if (name == "get_bytecode") {
63 return PackedFunc(
64 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); });
65 } else if (name == "get_constants") {
66 return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetConstants(); });
67 } else if (name == "get_virtual_devices") {
68 return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetVirtualDevices(); });
69 } else if (name == "get_primitives") {
70 return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetPrimitives(); });
71 } else if (name == "get_stats") {
72 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); });
73 } else if (name == "save") {
74 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); });
75 } else if (name == "get_function_arity") {
76 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
77 std::string func_name = args[0];
78 *rv = this->GetFunctionArity(func_name);
79 });
80 } else if (name == "get_function_param_name") {
81 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
82 std::string func_name = args[0];
83 int index = args[1];
84 *rv = this->GetFunctionParameterName(func_name, index);
85 });
86 } else if (name == "vm_load_executable") {
87 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
88 auto vm = make_object<VirtualMachine>();
89 ICHECK(sptr_to_self.get() == this);
90 vm->LoadExecutable(GetObjectPtr<Executable>(this));
91 *rv = Module(vm);
92 });
93 } else if (name == "move_late_bound_consts") {
94 return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
95 CHECK_EQ(args.size(), 2);
96 std::string path = args[0];
97 uint64_t byte_limit = args[1];
98 MoveLateBoundConstantsToFile(path, static_cast<size_t>(byte_limit));
99 });
100 } else if (name == "get_late_bound_consts") {
101 return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
102 CHECK_EQ(args.size(), 1);
103 uint64_t byte_limit = args[0];
104 Map<String, NDArray> consts = GetLateBoundConstants(static_cast<size_t>(byte_limit));
105 *rv = consts;
106 });
107 } else if (name == "load_late_bound_consts") {
108 return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
109 CHECK_EQ(args.size(), 1);
110 std::string path = args[0];
111 LoadLateBoundConstantsFromFile(path);
112 });
113 } else if (name == "load_late_bound_consts_from_map") {
114 return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
115 CHECK_EQ(args.size(), 1);
116 Map<String, NDArray> map = args[0];
117 LoadLateBoundConstantsFromMap(map);
118 });
119 } else {
120 LOG(FATAL) << "Unknown packed function: " << name;
121 }
122}
123
124const VMFunction& Executable::GetVMFunctionWithName(const std::string& func_name) const {
125 auto it = global_map.find(func_name);
126 ICHECK(it != global_map.end()) << "Cannot find function " << func_name << " in executable";
127 return functions[it->second];
128}
129
130int Executable::GetFunctionArity(std::string func_name) const {
131 const auto& func = GetVMFunctionWithName(func_name);
132 return func.params.size();
133}
134
135std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const {
136 const auto& func = GetVMFunctionWithName(func_name);
137 ICHECK_LT(index, func.params.size()) << "Invalid parameter index";
138 return func.params[index];
139}
140
141std::string Executable::GetBytecode() const {
142 std::ostringstream oss;
143
144 for (size_t i = 0; i < functions.size(); ++i) {
145 const auto& func = functions[i];
146 // Print the header of the function format.
147 oss << "VM Function[" << i << "]: " << func.name << "(";
148 bool first = true;
149 for (const auto& param : func.params) {
150 if (!first) {
151 oss << ", ";
152 }
153 oss << param;
154 first = false;
155 }
156 oss << ")" << std::endl;
157 oss << "# reg file size = " << func.register_file_size << std::endl;
158 oss << "# instruction count = " << func.instructions.size() << std::endl;
159
160 // Print the instructions of a `VMFunction`.
161 // The part after ";" is the instruction in text format.
162 oss << "opcode, fields # inst(text):" << std::endl;
163 for (size_t idx = 0; idx < func.instructions.size(); ++idx) {
164 const auto& instr = func.instructions[idx];
165 const auto& serialized_instr = SerializeInstruction(instr);
166 std::ostringstream line;
167 line << std::setw(2) << idx << ": " << serialized_instr.opcode << " ";
168 for (auto it : serialized_instr.fields) {
169 line << it << " ";
170 }
171 oss << std::setw(40) << std::setfill(' ') << std::left << line.str();
172 oss << " # " << instr;
173 if (oss.str().back() != '\n') oss << std::endl;
174 }
175 oss << std::endl;
176 }
177
178 return oss.str();
179}
180
181std::string Executable::GetConstants() const {
182 std::ostringstream oss;
183 for (size_t i = 0; i < constants.size(); ++i) {
184 const auto& constant = constants[i];
185 auto ndarray = Downcast<NDArray>(constant);
186 oss << "VM Const[" << i
187 << "]: " << RuntimeObject2String(ndarray, virtual_devices[host_device_index])
188 << " on device index " << const_device_indexes[i] << std::endl;
189 }
190 return oss.str();
191}
192
193std::string Executable::GetVirtualDevices() const {
194 std::ostringstream oss;
195 for (size_t i = 0; i < virtual_devices.size(); ++i) {
196 const auto& device = virtual_devices[i];
197 oss << "VM VirtualDevice[" << i << "]: device type " << device.device_type << " and id "
198 << device.device_id << std::endl;
199 }
200 return oss.str();
201}
202
203std::string Executable::GetPrimitives() const {
204 std::ostringstream os;
205 std::vector<std::pair<int, std::string>> entries;
206 entries.reserve(primitive_map.size());
207 for (const auto& kv : primitive_map) {
208 entries.emplace_back(kv.second, kv.first);
209 }
210 std::sort(entries.begin(), entries.end(),
211 [](const std::pair<int, std::string>& left, const std::pair<int, std::string>& right) {
212 return left.first < right.first;
213 });
214 for (const auto& entry : entries) {
215 os << "VM PackedFunc[" << entry.first << "]: " << entry.second << std::endl;
216 }
217 return os.str();
218}
219
220std::string Executable::Stats() const {
221 std::ostringstream oss;
222 oss << "Relay VM executable statistics:" << std::endl;
223
224 // Get the number of constants and the shape of each of them.
225 oss << " Constant shapes (# " << constants.size() << "): [";
226 for (const auto& it : constants) {
227 const auto constant = Downcast<NDArray>(it);
228 const auto& shape = constant.Shape();
229
230 // Scalar
231 if (shape.empty()) {
232 oss << "scalar, ";
233 continue;
234 }
235
236 oss << "[";
237 for (auto s : shape) {
238 oss << s << ", ";
239 }
240 oss.seekp(-2, oss.cur);
241 oss << "], " << std::endl;
242 }
243 if (!constants.empty()) oss.seekp(-2, oss.cur);
244 oss << "]" << std::endl;
245
246 // Get the number of globals and the name of each of them.
247 oss << " Globals (#" << global_map.size() << "): [";
248 for (const auto& it : global_map) {
249 oss << "(\"" << it.first << "\", " << it.second << ")"
250 << ", ";
251 }
252 if (!global_map.empty()) oss.seekp(-2, oss.cur);
253 oss << "]" << std::endl;
254
255 // Get the number of primitive ops and the name of each of them.
256 oss << " Primitive ops (#" << primitive_map.size() << "): [";
257 std::vector<std::string> prim_ops;
258 for (const auto& it : primitive_map) {
259 auto packed_index = static_cast<size_t>(it.second);
260 if (prim_ops.size() <= packed_index) {
261 prim_ops.resize(packed_index + 1);
262 }
263 prim_ops[packed_index] = it.first;
264 }
265 for (const auto& it : prim_ops) {
266 oss << it << ", ";
267 }
268 if (!prim_ops.empty()) oss.seekp(-2, oss.cur);
269 oss << "]" << std::endl;
270
271 return oss.str();
272}
273
274void SaveHeader(dmlc::Stream* strm) {
275 uint64_t header = kTVMVMBytecodeMagic;
276 strm->Write(header);
277 std::string version = TVM_VERSION;
278 strm->Write(version);
279}
280
281TVMByteArray Executable::Save() {
282 // Initialize the stream object.
283 code_.clear();
284 dmlc::MemoryStringStream strm(&code_);
285
286 // Save header
287 SaveHeader(&strm);
288
289 // Save virtual devices section.
290 SaveVirtualDevicesSection(&strm);
291
292 // Global section.
293 SaveGlobalSection(&strm);
294
295 // Constant section.
296 SaveConstantSection(&strm);
297
298 // Primitive names.
299 SavePrimitiveOpNames(&strm);
300
301 // Code section.
302 SaveCodeSection(&strm);
303
304 TVMByteArray arr;
305 arr.data = code_.c_str();
306 arr.size = code_.length();
307 return arr;
308}
309
310void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) {
311 strm->Write(virtual_devices);
312 strm->Write(host_device_index);
313}
314
315Map<String, NDArray> Executable::GetLateBoundConstants(size_t byte_limit) {
316 ICHECK(late_bound_constant_names.empty());
317 late_bound_constant_names.reserve(constants.size());
318 Map<String, NDArray> map;
319 size_t total_late_bound_bytes = 0;
320 for (size_t const_index = 0; const_index < constants.size(); ++const_index) {
321 const auto ndarray = Downcast<NDArray>(constants[const_index]);
322 ICHECK(ndarray.defined()) << "Undefined constant at index " << const_index;
323 size_t num_bytes = runtime::GetDataSize(*ndarray.operator->());
324 if (num_bytes < byte_limit) {
325 // Leave as immediate.
326 late_bound_constant_names.emplace_back(nullptr);
327 continue;
328 }
329 total_late_bound_bytes += num_bytes;
330 std::ostringstream os;
331 os << "const_" << const_index;
332 String name = os.str();
333 map.Set(name, Downcast<NDArray>(std::move(constants[const_index])));
334 late_bound_constant_names.emplace_back(std::move(name));
335 }
336 VLOG(1) << "moved " << map.size() << " constants of " << total_late_bound_bytes
337 << " bytes (out of " << constants.size() << " overall) to be late-bound";
338 return map;
339}
340
341void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) {
342 Map<String, NDArray> map = GetLateBoundConstants(byte_limit);
343 runtime::SaveParams(stream, map);
344}
345
346void Executable::MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit) {
347 tvm::runtime::SimpleBinaryFileStream stream(path, "wb");
348 MoveLateBoundConstantsToStream(&stream, byte_limit);
349}
350
351void Executable::LoadLateBoundConstantsFromStream(dmlc::Stream* stream) {
352 if (late_bound_constant_names.empty()) {
353 VLOG(1) << "Found no late-bound constants to load";
354 return;
355 }
356 ICHECK_EQ(late_bound_constant_names.size(), constants.size());
357 Map<String, NDArray> map = runtime::LoadParams(stream);
358 VLOG(1) << "loaded " << map.size() << " late-bound constants";
359 LoadLateBoundConstantsFromMap(map);
360}
361
362void Executable::LoadLateBoundConstantsFromMap(Map<String, NDArray> map) {
363 for (size_t const_index = 0; const_index < constants.size(); ++const_index) {
364 if (!late_bound_constant_names[const_index].defined()) {
365 ICHECK(constants[const_index].defined())
366 << "Undefined immediate constant at index " << const_index;
367 continue;
368 }
369 const String& name = late_bound_constant_names[const_index];
370 ICHECK(!constants[const_index].defined()) << "Unexpected constant at index " << const_index;
371 auto itr = map.find(name);
372 ICHECK(itr != map.end()) << "No binding for late-bound constant at index " << const_index
373 << " with name '" << name << "'";
374 constants[const_index] = (*itr).second;
375 map.erase(name);
376 }
377 late_bound_constant_names.clear();
378 ICHECK(map.empty()) << "Have " << map.size() << " unused late-bound constants";
379}
380
381void Executable::LoadLateBoundConstantsFromFile(const std::string& path) {
382 tvm::runtime::SimpleBinaryFileStream stream(path, "rb");
383 LoadLateBoundConstantsFromStream(&stream);
384}
385
386void Executable::SaveGlobalSection(dmlc::Stream* strm) {
387 std::vector<std::pair<std::string, Index>> globals(this->global_map.begin(),
388 this->global_map.end());
389 auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) {
390 return a.second < b.second;
391 };
392 std::sort(globals.begin(), globals.end(), comp);
393
394 std::vector<std::string> glbs;
395 for (const auto& it : globals) {
396 glbs.push_back(it.first);
397 }
398 strm->Write(glbs);
399}
400
401namespace {
402// Tags to distinguish immediate vs late-bound constants in constants table bytestream.
403constexpr uint32_t kImmediateConstTag = 0;
404constexpr uint32_t kLateBoundConstTag = 1;
405} // namespace
406
407void Executable::SaveConstantSection(dmlc::Stream* stream) {
408 // Save the overall number of constants.
409 stream->Write(static_cast<uint64_t>(constants.size()));
410
411 for (size_t const_index = 0; const_index < constants.size(); ++const_index) {
412 if (late_bound_constant_names.empty() || !late_bound_constant_names[const_index].defined()) {
413 // Tag immediate constants by 0.
414 stream->Write(kImmediateConstTag);
415 // Write as DLTensor.
416 const auto ndarray = Downcast<runtime::NDArray>(constants[const_index]);
417 ICHECK(ndarray.defined());
418 runtime::SaveDLTensor(stream, ndarray.operator->());
419 VLOG(1) << "save " << const_index << " as immediate";
420 } else {
421 // Tag late-bound constants by 1.
422 const String& name = late_bound_constant_names[const_index];
423 ICHECK(!constants[const_index].defined());
424 stream->Write(kLateBoundConstTag);
425 // Write a string.
426 stream->Write(std::string(name));
427 VLOG(1) << "save " << const_index << " as late-bound";
428 }
429 }
430
431 VLOG(1) << "saved " << constants.size() << " constants";
432
433 // Save the const to device index mapping.
434 stream->Write(const_device_indexes);
435}
436
437void Executable::LoadConstantSection(dmlc::Stream* stream) {
438 uint64_t sz;
439 // Load the overall number of constants.
440 STREAM_CHECK(stream->Read(&sz, sizeof(sz)), "constants table size");
441 size_t size = static_cast<size_t>(sz);
442
443 VLOG(1) << "loading " << size << " constants";
444
445 constants.resize(size);
446 late_bound_constant_names.resize(size);
447 bool any_late_bound = false;
448
449 // Load each of the constants.
450 for (size_t const_index = 0; const_index < size; const_index++) {
451 uint32_t tag;
452 STREAM_CHECK(stream->Read(&tag, sizeof(tag)), "constant tag");
453 if (tag == kImmediateConstTag) {
454 // Immediate constants tagged by 0.
455 VLOG(1) << "load " << const_index << " as immediate";
456 runtime::NDArray ndarray;
457 STREAM_CHECK(ndarray.Load(stream), "constant tensor");
458 constants[const_index] = std::move(ndarray);
459 late_bound_constant_names[const_index] = String(ObjectPtr<StringObj>(nullptr));
460 } else if (tag == kLateBoundConstTag) {
461 // Late-bound constants tagged by 1.
462 VLOG(1) << "load " << const_index << " as late-bound";
463 std::string name;
464 STREAM_CHECK(stream->Read(&name), "late-bound constant name");
465 constants[const_index] = NDArray(nullptr);
466 late_bound_constant_names[const_index] = std::move(name);
467 any_late_bound = true;
468 } else {
469 STREAM_CHECK(false, "constant tag");
470 }
471 }
472
473 if (!any_late_bound) {
474 late_bound_constant_names.clear();
475 }
476
477 // Load the const to device index mapping.
478 std::vector<Index> indexes;
479 indexes.reserve(size);
480 STREAM_CHECK(stream->Read(&indexes), "constant devices");
481 ICHECK_EQ(size, indexes.size());
482 const_device_indexes = std::move(indexes);
483}
484
485void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) {
486 std::vector<std::string> primitive_names;
487 for (const auto& it : this->primitive_map) {
488 auto packed_index = static_cast<size_t>(it.second);
489 if (primitive_names.size() <= packed_index) {
490 primitive_names.resize(packed_index + 1);
491 }
492 primitive_names[packed_index] = it.first;
493 }
494 strm->Write(primitive_names);
495 std::map<uint64_t, std::map<std::string, std::string>> primitive_attrs;
496 for (const auto& it : this->op_attrs) {
497 auto packed_index = static_cast<size_t>(it.first);
498 std::map<std::string, std::string> attrs;
499 for (const auto& elem : it.second) {
500 // TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer, so we just serialize
501 // strings for now
502 if (elem.second.as<StringObj>()) {
503 attrs[elem.first] = Downcast<String>(elem.second);
504 }
505 }
506 primitive_attrs[packed_index] = attrs;
507 }
508 strm->Write(primitive_attrs);
509}
510
511// Serialize a virtual machine instruction. It creates a list that contains the
512// hash, opcode, and all fields of an instruction.
513//
514// For example, the function signature used to create an `AllocTensor`
515// instruction is:
516// Instruction AllocTensor(std::vector<Index> shape, DLDataType dtype, RegName dst)
517//
518// The serialized form will be:
519// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn`
520//
521// where hash is the hash of serialized instruction that is computed internally
522// by the `VMInstructionExecutable`. It is used for sanity check before decoding.
523// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)`
524// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register`
525// is the destination register, and the rest of it together indicates the shape
526// of the tensor to be allocated.
527VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
528 std::vector<Index> fields;
529 // Save the opcode.
530 VLOG(2) << "Serializing: " << instr << std::endl;
531 switch (instr.op) {
532 case Opcode::Move: {
533 // Number of fields = 2
534 fields.assign({instr.from, instr.dst});
535 break;
536 }
537 case Opcode::Ret: {
538 // Number of fields = 1
539 fields.push_back(instr.result);
540 break;
541 }
542 case Opcode::Fatal: {
543 // Number of fields = 0
544 break;
545 }
546 case Opcode::InvokePacked: {
547 // Number of fields = 3 + instr.arity
548 // Note that arity includes both input arguments and outputs. We will
549 // put all the `arity` number of fields in the end for serialization.
550 fields.assign({instr.packed_index, instr.arity, instr.output_size});
551 // Save the args.
552 fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity);
553 break;
554 }
555 case Opcode::AllocTensor: {
556 // Number of fields = 7 + instr.alloc_tensor.ndim
557 fields.push_back(instr.alloc_tensor.storage);
558 fields.push_back(instr.alloc_tensor.offset);
559 // Save `DLDataType` and the dst register.
560 const auto& dtype = instr.alloc_tensor.dtype;
561 fields.push_back(dtype.code);
562 fields.push_back(dtype.bits);
563 fields.push_back(dtype.lanes);
564
565 // The number of dimensions is not needed for constructing an
566 // `AllocTensor` instruction as it equals to the length of the `shape`
567 // vector. However, we save it to conveniently deserialize the instruction
568 // because we will know how many fields are needed by the `shape` argument.
569 fields.push_back(instr.alloc_tensor.ndim);
570 fields.push_back(instr.dst);
571
572 // Save the shape of the tensor.
573 // Note that this field is rotated to the end of the list.
574 fields.insert(fields.end(), instr.alloc_tensor.shape,
575 instr.alloc_tensor.shape + instr.alloc_tensor.ndim);
576 break;
577 }
578 case Opcode::AllocTensorReg: {
579 // Number of fields = 7
580 fields.push_back(instr.alloc_tensor_reg.storage);
581 fields.push_back(instr.alloc_tensor_reg.offset);
582 fields.push_back(instr.alloc_tensor_reg.shape_register);
583 // Save `DLDataType` and the dst register.
584 const auto& dtype = instr.alloc_tensor_reg.dtype;
585 fields.push_back(dtype.code);
586 fields.push_back(dtype.bits);
587 fields.push_back(dtype.lanes);
588 fields.push_back(instr.dst);
589 break;
590 }
591 case Opcode::AllocStorage: {
592 fields.push_back(instr.alloc_storage.allocation_size);
593 fields.push_back(instr.alloc_storage.alignment);
594 // Save `DLDataType` and the dst register.
595 const auto& dtype = instr.alloc_storage.dtype_hint;
596 fields.push_back(dtype.code);
597 fields.push_back(dtype.bits);
598 fields.push_back(dtype.lanes);
599 fields.push_back(instr.alloc_storage.device_index);
600 fields.push_back(instr.dst);
601 break;
602 }
603 case Opcode::AllocADT: {
604 // Number of fields = 3 + instr.num_fields
605 fields.assign({instr.constructor_tag, instr.num_fields, instr.dst});
606
607 // Save the fields.
608 fields.insert(fields.end(), instr.datatype_fields, instr.datatype_fields + instr.num_fields);
609 break;
610 }
611 case Opcode::AllocClosure: {
612 // Number of fields = 3 + instr.num_freevar
613 fields.assign({instr.clo_index, instr.num_freevar, instr.dst});
614
615 // Save the free vars.
616 fields.insert(fields.end(), instr.free_vars, instr.free_vars + instr.num_freevar);
617 break;
618 }
619 case Opcode::If: {
620 // Number of fields = 4
621 fields.assign({instr.if_op.test, instr.if_op.target, instr.if_op.true_offset,
622 instr.if_op.false_offset});
623 break;
624 }
625 case Opcode::Invoke: {
626 // Number of fields = 3 + instr.num_args
627 fields.assign({instr.func_index, instr.num_args, instr.dst});
628
629 // Save the args.
630 fields.insert(fields.end(), instr.invoke_args_registers,
631 instr.invoke_args_registers + instr.num_args);
632 break;
633 }
634 case Opcode::InvokeClosure: {
635 // Number of fields = 3 + instr.num_closure_args
636 fields.assign({instr.closure, instr.num_closure_args, instr.dst});
637
638 // Save the args.
639 fields.insert(fields.end(), instr.closure_args, instr.closure_args + instr.num_closure_args);
640 break;
641 }
642 case Opcode::LoadConst: {
643 // Number of fields = 2
644 fields.assign({instr.const_index, instr.dst});
645 break;
646 }
647 case Opcode::LoadConsti: {
648 // Number of fields = 2
649 fields.assign({instr.load_consti.val, instr.dst});
650 break;
651 }
652 case Opcode::GetField: {
653 // Number of fields = 3
654 fields.assign({instr.object, instr.field_index, instr.dst});
655 break;
656 }
657 case Opcode::GetTag: {
658 // Number of fields = 2
659 fields.assign({instr.get_tag.object, instr.dst});
660 break;
661 }
662 case Opcode::Goto: {
663 // Number of fields = 1
664 fields.push_back(instr.pc_offset);
665 break;
666 }
667 case Opcode::ShapeOf: {
668 // Number of fields = 2
669 fields.assign({instr.shape_of.tensor, instr.dst});
670 break;
671 }
672 case Opcode::ReshapeTensor: {
673 // Number of fields = 3
674 fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst});
675 break;
676 }
677 case Opcode::DeviceCopy: {
678 // Number of fields = 4
679 fields.assign({instr.device_copy.src, instr.device_copy.src_device_index,
680 instr.device_copy.dst_device_index, instr.dst});
681 break;
682 }
683 case Opcode::KillRegister: {
684 fields.assign({instr.dst});
685 break;
686 }
687 default:
688 LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
689 break;
690 }
691
692 return VMInstructionSerializer(static_cast<Index>(instr.op), fields);
693}
694
695void Executable::SaveCodeSection(dmlc::Stream* strm) {
696 // Save the number of functions.
697 strm->Write(static_cast<uint64_t>(this->functions.size()));
698 for (const auto& func : this->functions) {
699 // Save the function info.
700 VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(),
701 func.params, func.param_device_indexes);
702 func_format.Save(strm);
703
704 // Serialize each instruction.
705 for (const auto& instr : func.instructions) {
706 const auto& serialized_instr = SerializeInstruction(instr);
707 serialized_instr.Save(strm);
708 }
709 }
710}
711
712void LoadHeader(dmlc::Stream* strm) {
713 // Check header.
714 uint64_t header;
715 STREAM_CHECK(strm->Read(&header), "header");
716 STREAM_CHECK(header == kTVMVMBytecodeMagic, "header");
717
718 // Check version.
719 std::string version;
720 STREAM_CHECK(strm->Read(&version), "version");
721 STREAM_CHECK(version == TVM_VERSION, "version");
722}
723
724runtime::Module Executable::GetLib() const {
725 ICHECK_LE(this->imports_.size(), 1)
726 << "The kernel library must be imported as the only module in an Executable";
727
728 if (this->imports().size() == 0) {
729 return Module(nullptr);
730 } else {
731 return this->imports_[0];
732 }
733}
734
735void Executable::SetLib(const runtime::Module& lib) {
736 ICHECK(lib.defined()) << "the provided library can not be null";
737
738 ICHECK_EQ(this->imports_.size(), 0)
739 << "A VMExecutable should never have more than one import inside an the executable, \n"
740 << "the first import should *always* be the library containing"
741 << "the platform specific kernel code";
742
743 this->Import(lib);
744}
745
746runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
747 auto exec = make_object<Executable>();
748
749 // Support null-initialization of lib, to enable initialization during
750 // deserialization before we have deserialized the imports.
751 if (lib.defined()) {
752 exec->SetLib(lib);
753 }
754
755 exec->code_ = code;
756 dmlc::MemoryStringStream strm(&exec->code_);
757
758 // Load header.
759 LoadHeader(&strm);
760
761 // Virtual devices section
762 exec->LoadVirtualDevicesSection(&strm);
763
764 // Global section.
765 exec->LoadGlobalSection(&strm);
766
767 // Constant section.
768 exec->LoadConstantSection(&strm);
769
770 // Primitive names that will be invoked by `InvokePacked` instructions.
771 exec->LoadPrimitiveOpNames(&strm);
772
773 // Code section.
774 exec->LoadCodeSection(&strm);
775
776 return runtime::Module(exec);
777}
778
779void Executable::LoadVirtualDevicesSection(dmlc::Stream* strm) {
780 STREAM_CHECK(strm->Read(&virtual_devices), "virtual_device");
781 STREAM_CHECK(strm->Read(&host_device_index), "virtual_device");
782 ICHECK(host_device_index >= 0 && host_device_index < static_cast<int>(virtual_devices.size()));
783}
784
785void Executable::LoadGlobalSection(dmlc::Stream* strm) {
786 std::vector<std::string> globals;
787 STREAM_CHECK(strm->Read(&globals), "global");
788 for (size_t i = 0; i < globals.size(); i++) {
789 this->global_map.insert({globals[i], i});
790 }
791}
792
793void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) {
794 std::vector<std::string> primitive_names;
795 STREAM_CHECK(strm->Read(&primitive_names), "primitive name");
796 for (size_t i = 0; i < primitive_names.size(); i++) {
797 this->primitive_map.insert({primitive_names[i], i});
798 }
799
800 std::map<uint64_t, std::map<std::string, std::string>> primitive_attrs;
801 STREAM_CHECK(strm->Read(&primitive_attrs), "primitive attrs");
802 for (const auto& fn : primitive_attrs) {
803 std::vector<std::pair<String, ObjectRef>> attrs;
804 for (const auto& elem : fn.second) {
805 attrs.push_back({elem.first, String(elem.second)});
806 }
807 this->op_attrs[fn.first] = Map<String, ObjectRef>(attrs.begin(), attrs.end());
808 }
809}
810
811// Extract the `cnt` number of fields started at `start` from the list
812// `instr_fields`.
813inline std::vector<Index> ExtractFields(const std::vector<Index>& instr_fields, Index start,
814 Index cnt) {
815 ICHECK_LE(static_cast<size_t>(start + cnt), instr_fields.size());
816 std::vector<Index> ret;
817 for (auto i = start; i < start + cnt; i++) {
818 ret.push_back(instr_fields[i]);
819 }
820 return ret;
821}
822
823Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
824 Opcode opcode = static_cast<Opcode>(instr.opcode);
825 switch (opcode) {
826 case Opcode::Move: {
827 // Number of fields = 2
828 DCHECK_EQ(instr.fields.size(), 2U);
829 return Instruction::Move(instr.fields[0], instr.fields[1]);
830 }
831 case Opcode::Ret: {
832 // Number of fields = 1
833 DCHECK_EQ(instr.fields.size(), 1U);
834 return Instruction::Ret(instr.fields[0]);
835 }
836 case Opcode::Fatal: {
837 // Number of fields = 0
838 DCHECK(instr.fields.empty());
839 return Instruction::Fatal();
840 }
841 case Opcode::InvokePacked: {
842 // Number of fields = 3 + instr.arity
843 DCHECK_GE(instr.fields.size(), 3U);
844 DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
845
846 Index packed_index = instr.fields[0];
847 Index arity = instr.fields[1];
848 Index output_size = instr.fields[2];
849 std::vector<RegName> args = ExtractFields(instr.fields, 3, arity);
850 return Instruction::InvokePacked(packed_index, arity, output_size, args);
851 }
852 case Opcode::AllocTensor: {
853 // Number of fields = 7 + instr.alloc_tensor.ndim
854 DCHECK_GE(instr.fields.size(), 7U);
855 DCHECK_EQ(instr.fields.size(), 7U + static_cast<size_t>(instr.fields[5]));
856
857 RegName storage_reg = instr.fields[0];
858 RegName offset = instr.fields[1];
859
860 DLDataType dtype;
861 dtype.code = instr.fields[2];
862 dtype.bits = instr.fields[3];
863 dtype.lanes = instr.fields[4];
864
865 Index ndim = instr.fields[5];
866 RegName dst = instr.fields[6];
867
868 std::vector<Index> shape = ExtractFields(instr.fields, 7, ndim);
869
870 return Instruction::AllocTensor(storage_reg, offset, shape, dtype, dst);
871 }
872 case Opcode::AllocTensorReg: {
873 // Number of fields = 7
874 DCHECK_EQ(instr.fields.size(), 7U);
875
876 RegName storage_reg = instr.fields[0];
877 RegName offset = instr.fields[1];
878 Index shape_register = instr.fields[2];
879
880 DLDataType dtype;
881 dtype.code = instr.fields[3];
882 dtype.bits = instr.fields[4];
883 dtype.lanes = instr.fields[5];
884
885 RegName dst = instr.fields[6];
886
887 return Instruction::AllocTensorReg(storage_reg, offset, shape_register, dtype, dst);
888 }
889 case Opcode::AllocADT: {
890 // Number of fields = 3 + instr.num_fields
891 DCHECK_GE(instr.fields.size(), 3U);
892 DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
893
894 Index constructor_tag = instr.fields[0];
895 Index num_fields = instr.fields[1];
896 RegName dst = instr.fields[2];
897 std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields);
898
899 return Instruction::AllocADT(constructor_tag, num_fields, fields, dst);
900 }
901 case Opcode::AllocClosure: {
902 // Number of fields = 3 + instr.num_freevar
903 DCHECK_GE(instr.fields.size(), 3U);
904 DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
905
906 Index clo_index = instr.fields[0];
907 Index num_freevar = instr.fields[1];
908 RegName dst = instr.fields[2];
909 std::vector<Index> free_vars = ExtractFields(instr.fields, 3, num_freevar);
910
911 return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
912 }
913 case Opcode::AllocStorage: {
914 // Number of fields = 7
915 DCHECK_GE(instr.fields.size(), 7U);
916 Index allocation_size = instr.fields[0];
917 Index alignment = instr.fields[1];
918
919 DLDataType dtype;
920 dtype.code = instr.fields[2];
921 dtype.bits = instr.fields[3];
922 dtype.lanes = instr.fields[4];
923
924 Index device_type = instr.fields[5];
925 RegName dst = instr.fields[6];
926
927 return Instruction::AllocStorage(allocation_size, alignment, dtype, device_type, dst);
928 }
929 case Opcode::If: {
930 // Number of fields = 4
931 DCHECK_EQ(instr.fields.size(), 4U);
932 Index test = instr.fields[0];
933 Index target = instr.fields[1];
934 Index true_offset = instr.fields[2];
935 Index false_offset = instr.fields[3];
936
937 return Instruction::If(test, target, true_offset, false_offset);
938 }
939 case Opcode::Invoke: {
940 // Number of fields = 3 + instr.num_args
941 DCHECK_GE(instr.fields.size(), 3U);
942 DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
943
944 Index func_index = instr.fields[0];
945 Index num_args = instr.fields[1];
946 RegName dst = instr.fields[2];
947 std::vector<Index> args = ExtractFields(instr.fields, 3, num_args);
948
949 return Instruction::Invoke(func_index, args, dst);
950 }
951 case Opcode::InvokeClosure: {
952 // Number of fields = 3 + instr.num_closure_args
953 DCHECK_GE(instr.fields.size(), 3U);
954 DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
955
956 Index closure = instr.fields[0];
957 Index num_closure_args = instr.fields[1];
958 RegName dst = instr.fields[2];
959 std::vector<Index> args = ExtractFields(instr.fields, 3, num_closure_args);
960
961 return Instruction::InvokeClosure(closure, args, dst);
962 }
963 case Opcode::LoadConst: {
964 // Number of fields = 2
965 DCHECK_EQ(instr.fields.size(), 2U);
966 return Instruction::LoadConst(instr.fields[0], instr.fields[1]);
967 }
968 case Opcode::LoadConsti: {
969 // Number of fields = 2
970 DCHECK_EQ(instr.fields.size(), 2U);
971 return Instruction::LoadConsti(instr.fields[0], instr.fields[1]);
972 }
973 case Opcode::GetField: {
974 // Number of fields = 3
975 DCHECK_EQ(instr.fields.size(), 3U);
976 return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]);
977 }
978 case Opcode::GetTag: {
979 // Number of fields = 2
980 DCHECK_EQ(instr.fields.size(), 2U);
981 return Instruction::GetTag(instr.fields[0], instr.fields[1]);
982 }
983 case Opcode::Goto: {
984 // Number of fields = 1
985 DCHECK_EQ(instr.fields.size(), 1U);
986 return Instruction::Goto(instr.fields[0]);
987 }
988 case Opcode::ShapeOf: {
989 // Number of fields = 2
990 DCHECK_EQ(instr.fields.size(), 2U);
991 return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
992 }
993 case Opcode::ReshapeTensor: {
994 // Number of fields = 3
995 DCHECK_EQ(instr.fields.size(), 3U);
996 return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]);
997 }
998 case Opcode::DeviceCopy: {
999 // Number of fields = 4
1000 DCHECK_EQ(instr.fields.size(), 4U);
1001 return Instruction::DeviceCopy(instr.fields[0], instr.fields[1], instr.fields[2],
1002 instr.fields[3]);
1003 }
1004 case Opcode::KillRegister: {
1005 DCHECK_EQ(instr.fields.size(), 1U);
1006 return Instruction::KillRegister(instr.fields[0]);
1007 }
1008 default:
1009 LOG(FATAL) << "Invalid opcode" << instr.opcode;
1010 }
1011}
1012
1013void Executable::LoadCodeSection(dmlc::Stream* strm) {
1014 // Load the number of functions.
1015 uint64_t sz;
1016 STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code");
1017
1018 size_t num_funcs = static_cast<size_t>(sz);
1019 this->functions.resize(num_funcs);
1020 for (size_t i = 0; i < num_funcs; i++) {
1021 // Load the function info.
1022 VMFunctionSerializer loaded_func;
1023 STREAM_CHECK(loaded_func.Load(strm), "code/function");
1024
1025 // Load the instructions.
1026 std::vector<Instruction> instructions;
1027 for (size_t j = 0; j < loaded_func.num_instructions; j++) {
1028 VMInstructionSerializer instr;
1029 std::vector<Index> instr_fields;
1030 STREAM_CHECK(instr.Load(strm), "code/instruction");
1031 instructions.push_back(DeserializeInstruction(instr));
1032 }
1033
1034 // Create the VM function.
1035 VMFunction vm_func =
1036 VMFunction(loaded_func.name, loaded_func.params, instructions,
1037 loaded_func.register_file_size, loaded_func.param_device_indexes);
1038 auto it = this->global_map.find(loaded_func.name);
1039 ICHECK(it != this->global_map.end());
1040 ICHECK_LE(it->second, this->global_map.size());
1041 this->functions[it->second] = vm_func;
1042 }
1043}
1044
1045void Executable::SaveToBinary(dmlc::Stream* stream) {
1046 auto code_bytes = this->Save();
1047 std::string code(code_bytes.data, code_bytes.size);
1048 stream->Write(code);
1049
1050 ICHECK(this->imports()[0].defined()) << "the library must be imported before serialization";
1051}
1052
1053Module ExecutableLoadBinary(void* strm) {
1054 dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
1055 std::string code;
1056 stream->Read(&code);
1057 auto exec = Executable::Load(code, Module());
1058 return exec;
1059}
1060
1061void Executable::SaveToFile(const std::string& path, const std::string& format) {
1062 tvm::runtime::SimpleBinaryFileStream stream(path, "wb");
1063 SaveToBinary(&stream);
1064}
1065
1066TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary);
1067
1068// Load module from module.
1069Module ExecutableLoadFile(const std::string& file_name, const std::string& format) {
1070 tvm::runtime::SimpleBinaryFileStream stream(file_name, "rb");
1071 auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(&stream));
1072 return exec;
1073}
1074
1075TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile);
1076
1077TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) {
1078 runtime::Module mod = args[0];
1079 const auto* exec = dynamic_cast<Executable*>(mod.operator->());
1080 ICHECK(exec);
1081 *rv = static_cast<int>(exec->global_map.size());
1082});
1083
1084TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetValue* rv) {
1085 runtime::Module mod = args[0];
1086 const auto* exec = dynamic_cast<Executable*>(mod.operator->());
1087 ICHECK(exec);
1088 int idx = args[1];
1089 std::vector<std::pair<std::string, Index>> globals(exec->global_map.begin(),
1090 exec->global_map.end());
1091 auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) {
1092 return a.second < b.second;
1093 };
1094 std::sort(globals.begin(), globals.end(), comp);
1095 ICHECK_LT(idx, globals.size());
1096 *rv = globals[idx].first;
1097});
1098
1099TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives").set_body([](TVMArgs args, TVMRetValue* rv) {
1100 runtime::Module mod = args[0];
1101 const auto* exec = dynamic_cast<Executable*>(mod.operator->());
1102 ICHECK(exec);
1103 *rv = static_cast<int>(exec->primitive_map.size());
1104});
1105
1106TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields").set_body([](TVMArgs args, TVMRetValue* rv) {
1107 runtime::Module mod = args[0];
1108 const auto* exec = dynamic_cast<Executable*>(mod.operator->());
1109 ICHECK(exec);
1110 int idx = args[1];
1111 ICHECK_GE(idx, 0);
1112 ICHECK_LT(idx, exec->primitive_map.size());
1113
1114 for (const auto& it : exec->primitive_map) {
1115 if (idx == static_cast<int>(it.second)) {
1116 *rv = it.first;
1117 break;
1118 }
1119 }
1120});
1121
1122TVM_REGISTER_GLOBAL("runtime.Load_Executable")
1123 .set_body_typed([](std::string code, runtime::Module lib) {
1124 return Executable::Load(code, lib);
1125 });
1126
1127} // namespace vm
1128} // namespace runtime
1129} // namespace tvm
1130