1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/vm/compiler.h
22 * \brief A compiler from relay::Module to the VM byte code.
23 */
24
25#ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_
26#define TVM_RELAY_BACKEND_VM_COMPILER_H_
27
28#include <tvm/relay/error.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/interpreter.h>
31#include <tvm/relay/transform.h>
32#include <tvm/runtime/logging.h>
33#include <tvm/runtime/vm/vm.h>
34#include <tvm/tir/function.h>
35
36#include <iostream>
37#include <memory>
38#include <string>
39#include <unordered_map>
40#include <unordered_set>
41#include <utility>
42#include <vector>
43
44#include "../../../runtime/vm/naive_allocator.h"
45#include "../../../runtime/vm/profiler/vm.h"
46#include "../../transforms/pass_utils.h"
47#include "../te_compiler.h"
48#include "../te_compiler_cache.h"
49
50namespace tvm {
51namespace relay {
52namespace vm {
53
54using namespace tvm::runtime;
55using namespace tvm::runtime::vm;
56using namespace relay::transform;
57
58template <typename T, typename U>
59using NodeMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
60using TagMap = NodeMap<tvm::relay::Constructor, Index>;
61using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
62using GlobalMap = NodeMap<GlobalVar, Index>;
63using ConstMap = NodeMap<Constant, Index>;
64using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
65
66struct VMCompilerContext {
67 // The module context for the compilation
68 IRModule module;
69 // Error reporter
70 ErrorReporter err_reporter;
71 // Map from a unique integer to ADT constructor tag
72 TagNameMap tag_index_map;
73 // Map from ADT constructor tag to a unique integer
74 TagMap tag_map;
75 // Map from global var to a unique integer
76 GlobalMap global_map;
77 // List of constants
78 std::vector<NDArray> constants;
79 // Device indexes for constants
80 std::vector<Index> const_device_indexes;
81 // Map from names of primitive functions already allocated to their primitive function index.
82 std::unordered_map<std::string, Index> primitive_map;
83 // The virtual devices corresponding to each device index.
84 std::vector<VirtualDevice> virtual_devices_;
85};
86
87class VMCompiler : public runtime::ModuleNode {
88 public:
89 VMCompiler() = default;
90 virtual ~VMCompiler() = default;
91
92 virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
93
94 const char* type_key() const final { return "VMCompiler"; }
95
96 /*!
97 * \brief Set the parameters
98 *
99 * \param name name of parameter
100 * \param data_in input DLTensor
101 */
102 void SetParam(const std::string& name, runtime::NDArray data_in);
103
104 /*!
105 * \brief Lower the functions in a Module.
106 *
107 * ----------------------------------------------------------------------------------
108 * | This is the main entry point for the VM compilation flow. |
109 * | - Preceded by \p SetParam for the global params. |
110 * | - Followed by \p Codegen() to finalize the executable. |
111 * | - Then the result runtime::Module can be constructed by GetExecutable. |
112 * ----------------------------------------------------------------------------------
113 *
114 * \param mod Relay Module
115 * \param raw_targets List of available targets for running kernels. Any host target should
116 * be conveyed by the 'host' target field.
117 */
118 void Lower(IRModule mod, const Array<Target>& raw_targets);
119
120 /*
121 * \brief Perform a series of optimizations on the input IR module. Can be used instead
122 * of Lower if wish to stop and observe optimized IRModule. Otherwise not needed on
123 * regular compilation flow.
124 *
125 * \param mod The input IRModule.
126 * \param raw_targets List of available target for running kernels.
127 *
128 * \return The optimized IRModule.
129 */
130 IRModule OptimizeModule(IRModule mod, const Array<Target>& raw_targets);
131
132 /*! \brief Generate the machine code for lowered functions. */
133 void Codegen();
134
135 /*! \brief Returns the runtime::Module containing the compiled VM code. */
136 runtime::Module GetExecutable() const;
137
138 protected:
139 /*! \brief Builds the executor and compilation config to match \p raw_targets. */
140 void Setup(const Array<Target>& raw_targets);
141
142 /*! \brief Internal implementation of \p Lower. */
143 void LowerImpl(IRModule mod);
144
145 /*! \brief Internal implementation of \p OptimizeModule. */
146 IRModule OptimizeModuleImpl(IRModule mod);
147
148 /*! \brief Returns the passes which layout memory. */
149 transform::Sequential MemoryOpt(const CompilationConfig& config);
150
151 /*! \brief Returns the passes which fuse then lower Relay primitive operators. */
152 transform::Sequential FuseAndLowerOperators(const CompilationConfig& config);
153
154 /*!
155 * \brief Populate the global function names in a map where the value is used
156 * as the index by the VMFunctions. Returns the number of functions.
157 */
158 size_t PopulateGlobalMap();
159
160 protected:
161 /*! \brief Targets and scopes needed for compilation. */
162 CompilationConfig config_;
163 /*! \brief Global shared meta data */
164 VMCompilerContext context_;
165 /*! \brief Compiled executable. */
166 ObjectPtr<Executable> exec_;
167 /*! \brief parameters */
168 std::unordered_map<std::string, runtime::NDArray> params_;
169};
170
171} // namespace vm
172} // namespace relay
173} // namespace tvm
174
175#endif // TVM_RELAY_BACKEND_VM_COMPILER_H_
176