1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/runtime/vm/serialize_utils.h
22 * \brief Definitions of helpers for serializing and deserializing a Relay VM.
23 */
24#ifndef TVM_RUNTIME_VM_SERIALIZE_UTILS_H_
25#define TVM_RUNTIME_VM_SERIALIZE_UTILS_H_
26
27#include <dmlc/memory_io.h>
28#include <tvm/runtime/vm/executable.h>
29
30#include <functional>
31#include <string>
32#include <vector>
33
34#include "../../support/utils.h"
35
36namespace tvm {
37namespace runtime {
38namespace vm {
39
40/*! \brief The magic number for the serialized VM bytecode file */
41constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D;
42
43template <typename T>
44static inline uint64_t VectorHash(uint64_t key, const std::vector<T>& values) {
45 for (const auto& it : values) {
46 key = support::HashCombine(key, it);
47 }
48 return key;
49}
50
51// A struct to hold the funciton info in the code section.
52struct VMFunctionSerializer {
53 /*! \brief The name of the VMFunction. */
54 std::string name;
55 /*! \brief The number of registers used by the VMFunction. */
56 Index register_file_size;
57 /*! \brief The number of instructions in the VMFunction. */
58 size_t num_instructions;
59 /*! \brief The parameters of the VMFunction. */
60 std::vector<std::string> params;
61 /*! \brief The index for the devices holding each parameter of the VMFunction. */
62 std::vector<Index> param_device_indexes;
63
64 VMFunctionSerializer() = default;
65
66 VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions,
67 const std::vector<std::string>& params,
68 const std::vector<Index>& param_device_indexes)
69 : name(name),
70 register_file_size(register_file_size),
71 num_instructions(num_instructions),
72 params(params),
73 param_device_indexes(param_device_indexes) {}
74
75 /*!
76 * \brief Load the serialized function header.
77 * \param strm The stream used to load data.
78 * \return True if successful. Otherwise, false.
79 */
80 bool Load(dmlc::Stream* strm) {
81 std::vector<std::string> func_info;
82 if (!strm->Read(&func_info)) return false;
83 ICHECK_EQ(func_info.size(), 3U) << "Failed to decode the vm function."
84 << "\n";
85 name = func_info[0];
86 register_file_size = std::stoll(func_info[1]);
87 // Get the number of instructions.
88 num_instructions = static_cast<size_t>(std::stoll(func_info[2]));
89 if (!strm->Read(&params)) return false;
90 if (!strm->Read(&param_device_indexes)) return false;
91 return true;
92 }
93
94 /*!
95 * \brief Save the VM function header into the serialized form.
96 * \param strm The stream used to save data.
97 */
98 void Save(dmlc::Stream* strm) const {
99 std::vector<std::string> func_info;
100 func_info.push_back(name);
101 func_info.push_back(std::to_string(register_file_size));
102 func_info.push_back(std::to_string(num_instructions));
103 strm->Write(func_info);
104 strm->Write(params);
105 strm->Write(param_device_indexes);
106 }
107};
108
109struct VMInstructionSerializer {
110 /*! \brief The opcode of the instruction. */
111 Index opcode;
112 /*! \brief The fields of the instruction. */
113 std::vector<Index> fields;
114
115 VMInstructionSerializer() = default;
116
117 VMInstructionSerializer(Index opcode, const std::vector<Index>& fields)
118 : opcode(opcode), fields(fields) {}
119
120 /*!
121 * \brief Compute the hash of the serialized instruction.
122 * \return The hash that combines the opcode and all fields of the VM
123 * instruction.
124 */
125 Index Hash() const {
126 uint64_t key = static_cast<uint64_t>(opcode);
127 key = VectorHash(key, fields);
128 return key;
129 }
130
131 /*!
132 * \brief Load the serialized instruction.
133 * \param strm The stream used to load data.
134 * \return True if successful. Otherwise, false.
135 */
136 bool Load(dmlc::Stream* strm) {
137 std::vector<Index> instr;
138 if (!strm->Read(&instr)) return false;
139 ICHECK_GE(instr.size(), 2U);
140 Index loaded_hash = instr[0];
141 opcode = instr[1];
142
143 for (size_t i = 2; i < instr.size(); i++) {
144 fields.push_back(instr[i]);
145 }
146
147 Index hash = Hash();
148 ICHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " << opcode << "\n";
149 return true;
150 }
151
152 /*!
153 * \brief Save the instruction into the serialized form.
154 * \param strm The stream used to save data.
155 */
156 void Save(dmlc::Stream* strm) const {
157 Index hash = Hash();
158 std::vector<Index> serialized({hash, opcode});
159 serialized.insert(serialized.end(), fields.begin(), fields.end());
160 strm->Write(serialized);
161 }
162};
163
164} // namespace vm
165} // namespace runtime
166} // namespace tvm
167
168#endif // TVM_RUNTIME_VM_SERIALIZE_UTILS_H_
169