1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/vm/compiler.h |
22 | * \brief A compiler from relay::Module to the VM byte code. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_ |
26 | #define TVM_RELAY_BACKEND_VM_COMPILER_H_ |
27 | |
28 | #include <tvm/relay/error.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/interpreter.h> |
31 | #include <tvm/relay/transform.h> |
32 | #include <tvm/runtime/logging.h> |
33 | #include <tvm/runtime/vm/vm.h> |
34 | #include <tvm/tir/function.h> |
35 | |
36 | #include <iostream> |
37 | #include <memory> |
38 | #include <string> |
39 | #include <unordered_map> |
40 | #include <unordered_set> |
41 | #include <utility> |
42 | #include <vector> |
43 | |
44 | #include "../../../runtime/vm/naive_allocator.h" |
45 | #include "../../../runtime/vm/profiler/vm.h" |
46 | #include "../../transforms/pass_utils.h" |
47 | #include "../te_compiler.h" |
48 | #include "../te_compiler_cache.h" |
49 | |
50 | namespace tvm { |
51 | namespace relay { |
52 | namespace vm { |
53 | |
54 | using namespace tvm::runtime; |
55 | using namespace tvm::runtime::vm; |
56 | using namespace relay::transform; |
57 | |
58 | template <typename T, typename U> |
59 | using NodeMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>; |
60 | using TagMap = NodeMap<tvm::relay::Constructor, Index>; |
61 | using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>; |
62 | using GlobalMap = NodeMap<GlobalVar, Index>; |
63 | using ConstMap = NodeMap<Constant, Index>; |
64 | using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>; |
65 | |
66 | struct VMCompilerContext { |
67 | // The module context for the compilation |
68 | IRModule module; |
69 | // Error reporter |
70 | ErrorReporter err_reporter; |
71 | // Map from a unique integer to ADT constructor tag |
72 | TagNameMap tag_index_map; |
73 | // Map from ADT constructor tag to a unique integer |
74 | TagMap tag_map; |
75 | // Map from global var to a unique integer |
76 | GlobalMap global_map; |
77 | // List of constants |
78 | std::vector<NDArray> constants; |
79 | // Device indexes for constants |
80 | std::vector<Index> const_device_indexes; |
81 | // Map from names of primitive functions already allocated to their primitive function index. |
82 | std::unordered_map<std::string, Index> primitive_map; |
83 | // The virtual devices corresponding to each device index. |
84 | std::vector<VirtualDevice> virtual_devices_; |
85 | }; |
86 | |
87 | class VMCompiler : public runtime::ModuleNode { |
88 | public: |
89 | VMCompiler() = default; |
90 | virtual ~VMCompiler() = default; |
91 | |
92 | virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self); |
93 | |
94 | const char* type_key() const final { return "VMCompiler" ; } |
95 | |
96 | /*! |
97 | * \brief Set the parameters |
98 | * |
99 | * \param name name of parameter |
100 | * \param data_in input DLTensor |
101 | */ |
102 | void SetParam(const std::string& name, runtime::NDArray data_in); |
103 | |
104 | /*! |
105 | * \brief Lower the functions in a Module. |
106 | * |
107 | * ---------------------------------------------------------------------------------- |
108 | * | This is the main entry point for the VM compilation flow. | |
109 | * | - Preceded by \p SetParam for the global params. | |
110 | * | - Followed by \p Codegen() to finalize the executable. | |
111 | * | - Then the result runtime::Module can be constructed by GetExecutable. | |
112 | * ---------------------------------------------------------------------------------- |
113 | * |
114 | * \param mod Relay Module |
115 | * \param raw_targets List of available targets for running kernels. Any host target should |
116 | * be conveyed by the 'host' target field. |
117 | */ |
118 | void Lower(IRModule mod, const Array<Target>& raw_targets); |
119 | |
120 | /* |
121 | * \brief Perform a series of optimizations on the input IR module. Can be used instead |
122 | * of Lower if wish to stop and observe optimized IRModule. Otherwise not needed on |
123 | * regular compilation flow. |
124 | * |
125 | * \param mod The input IRModule. |
126 | * \param raw_targets List of available target for running kernels. |
127 | * |
128 | * \return The optimized IRModule. |
129 | */ |
130 | IRModule OptimizeModule(IRModule mod, const Array<Target>& raw_targets); |
131 | |
132 | /*! \brief Generate the machine code for lowered functions. */ |
133 | void Codegen(); |
134 | |
135 | /*! \brief Returns the runtime::Module containing the compiled VM code. */ |
136 | runtime::Module GetExecutable() const; |
137 | |
138 | protected: |
139 | /*! \brief Builds the executor and compilation config to match \p raw_targets. */ |
140 | void Setup(const Array<Target>& raw_targets); |
141 | |
142 | /*! \brief Internal implementation of \p Lower. */ |
143 | void LowerImpl(IRModule mod); |
144 | |
145 | /*! \brief Internal implementation of \p OptimizeModule. */ |
146 | IRModule OptimizeModuleImpl(IRModule mod); |
147 | |
148 | /*! \brief Returns the passes which layout memory. */ |
149 | transform::Sequential MemoryOpt(const CompilationConfig& config); |
150 | |
151 | /*! \brief Returns the passes which fuse then lower Relay primitive operators. */ |
152 | transform::Sequential FuseAndLowerOperators(const CompilationConfig& config); |
153 | |
154 | /*! |
155 | * \brief Populate the global function names in a map where the value is used |
156 | * as the index by the VMFunctions. Returns the number of functions. |
157 | */ |
158 | size_t PopulateGlobalMap(); |
159 | |
160 | protected: |
161 | /*! \brief Targets and scopes needed for compilation. */ |
162 | CompilationConfig config_; |
163 | /*! \brief Global shared meta data */ |
164 | VMCompilerContext context_; |
165 | /*! \brief Compiled executable. */ |
166 | ObjectPtr<Executable> exec_; |
167 | /*! \brief parameters */ |
168 | std::unordered_map<std::string, runtime::NDArray> params_; |
169 | }; |
170 | |
171 | } // namespace vm |
172 | } // namespace relay |
173 | } // namespace tvm |
174 | |
175 | #endif // TVM_RELAY_BACKEND_VM_COMPILER_H_ |
176 | |