1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/runtime/vm/serialize_utils.h |
22 | * \brief Definitions of helpers for serializing and deserializing a Relay VM. |
23 | */ |
24 | #ifndef TVM_RUNTIME_VM_SERIALIZE_UTILS_H_ |
25 | #define TVM_RUNTIME_VM_SERIALIZE_UTILS_H_ |
26 | |
27 | #include <dmlc/memory_io.h> |
28 | #include <tvm/runtime/vm/executable.h> |
29 | |
30 | #include <functional> |
31 | #include <string> |
32 | #include <vector> |
33 | |
34 | #include "../../support/utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace runtime { |
38 | namespace vm { |
39 | |
40 | /*! \brief The magic number for the serialized VM bytecode file */ |
41 | constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; |
42 | |
43 | template <typename T> |
44 | static inline uint64_t VectorHash(uint64_t key, const std::vector<T>& values) { |
45 | for (const auto& it : values) { |
46 | key = support::HashCombine(key, it); |
47 | } |
48 | return key; |
49 | } |
50 | |
51 | // A struct to hold the funciton info in the code section. |
52 | struct VMFunctionSerializer { |
53 | /*! \brief The name of the VMFunction. */ |
54 | std::string name; |
55 | /*! \brief The number of registers used by the VMFunction. */ |
56 | Index register_file_size; |
57 | /*! \brief The number of instructions in the VMFunction. */ |
58 | size_t num_instructions; |
59 | /*! \brief The parameters of the VMFunction. */ |
60 | std::vector<std::string> params; |
61 | /*! \brief The index for the devices holding each parameter of the VMFunction. */ |
62 | std::vector<Index> param_device_indexes; |
63 | |
64 | VMFunctionSerializer() = default; |
65 | |
66 | VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, |
67 | const std::vector<std::string>& params, |
68 | const std::vector<Index>& param_device_indexes) |
69 | : name(name), |
70 | register_file_size(register_file_size), |
71 | num_instructions(num_instructions), |
72 | params(params), |
73 | param_device_indexes(param_device_indexes) {} |
74 | |
75 | /*! |
76 | * \brief Load the serialized function header. |
77 | * \param strm The stream used to load data. |
78 | * \return True if successful. Otherwise, false. |
79 | */ |
80 | bool Load(dmlc::Stream* strm) { |
81 | std::vector<std::string> func_info; |
82 | if (!strm->Read(&func_info)) return false; |
83 | ICHECK_EQ(func_info.size(), 3U) << "Failed to decode the vm function." |
84 | << "\n" ; |
85 | name = func_info[0]; |
86 | register_file_size = std::stoll(func_info[1]); |
87 | // Get the number of instructions. |
88 | num_instructions = static_cast<size_t>(std::stoll(func_info[2])); |
89 | if (!strm->Read(¶ms)) return false; |
90 | if (!strm->Read(¶m_device_indexes)) return false; |
91 | return true; |
92 | } |
93 | |
94 | /*! |
95 | * \brief Save the VM function header into the serialized form. |
96 | * \param strm The stream used to save data. |
97 | */ |
98 | void Save(dmlc::Stream* strm) const { |
99 | std::vector<std::string> func_info; |
100 | func_info.push_back(name); |
101 | func_info.push_back(std::to_string(register_file_size)); |
102 | func_info.push_back(std::to_string(num_instructions)); |
103 | strm->Write(func_info); |
104 | strm->Write(params); |
105 | strm->Write(param_device_indexes); |
106 | } |
107 | }; |
108 | |
109 | struct VMInstructionSerializer { |
110 | /*! \brief The opcode of the instruction. */ |
111 | Index opcode; |
112 | /*! \brief The fields of the instruction. */ |
113 | std::vector<Index> fields; |
114 | |
115 | VMInstructionSerializer() = default; |
116 | |
117 | VMInstructionSerializer(Index opcode, const std::vector<Index>& fields) |
118 | : opcode(opcode), fields(fields) {} |
119 | |
120 | /*! |
121 | * \brief Compute the hash of the serialized instruction. |
122 | * \return The hash that combines the opcode and all fields of the VM |
123 | * instruction. |
124 | */ |
125 | Index Hash() const { |
126 | uint64_t key = static_cast<uint64_t>(opcode); |
127 | key = VectorHash(key, fields); |
128 | return key; |
129 | } |
130 | |
131 | /*! |
132 | * \brief Load the serialized instruction. |
133 | * \param strm The stream used to load data. |
134 | * \return True if successful. Otherwise, false. |
135 | */ |
136 | bool Load(dmlc::Stream* strm) { |
137 | std::vector<Index> instr; |
138 | if (!strm->Read(&instr)) return false; |
139 | ICHECK_GE(instr.size(), 2U); |
140 | Index loaded_hash = instr[0]; |
141 | opcode = instr[1]; |
142 | |
143 | for (size_t i = 2; i < instr.size(); i++) { |
144 | fields.push_back(instr[i]); |
145 | } |
146 | |
147 | Index hash = Hash(); |
148 | ICHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " << opcode << "\n" ; |
149 | return true; |
150 | } |
151 | |
152 | /*! |
153 | * \brief Save the instruction into the serialized form. |
154 | * \param strm The stream used to save data. |
155 | */ |
156 | void Save(dmlc::Stream* strm) const { |
157 | Index hash = Hash(); |
158 | std::vector<Index> serialized({hash, opcode}); |
159 | serialized.insert(serialized.end(), fields.begin(), fields.end()); |
160 | strm->Write(serialized); |
161 | } |
162 | }; |
163 | |
164 | } // namespace vm |
165 | } // namespace runtime |
166 | } // namespace tvm |
167 | |
168 | #endif // TVM_RUNTIME_VM_SERIALIZE_UTILS_H_ |
169 | |