1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/backend/te_compiler.h
22 * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
23 *
24 *
25 * This represents the new design of the Relay compilation flow and will replace the interface
26 * contained in compile_engine.h as we migrate towards a standard pass based lowering of
27 * Relay functions.
28 *
29 * This files provides an internal API which lowers Relay programs to components which
30 * can be combined with TVM produced kernels to compile an entire program.
31 *
32 * The result of lowering contains a combination of `runtime::Module`s produced by external
33 * compilers and a set of lowered PrimFns which can be code generated for targets.
34 */
35#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_
36#define TVM_RELAY_BACKEND_TE_COMPILER_H_
37
38#include <tvm/node/structural_equal.h>
39#include <tvm/node/structural_hash.h>
40#include <tvm/relay/analysis.h>
41#include <tvm/relay/attrs/memory.h>
42#include <tvm/relay/expr.h>
43#include <tvm/relay/op_strategy.h>
44#include <tvm/relay/transform.h>
45#include <tvm/runtime/module.h>
46#include <tvm/topi/elemwise.h>
47
48#include <functional>
49#include <string>
50#include <unordered_map>
51
52#include "../transforms/infer_layout_utils.h"
53#include "../transforms/pass_utils.h"
54#include "./te_compiler_cache.h"
55#include "./utils.h"
56
57namespace tvm {
58namespace relay {
59namespace tec {
60
61using ProcessFn = std::function<void(BaseFunc)>;
62
63/*!
64 * \brief A compiler which lowers primitive Relay functions to tensor expressions
65 * and schedules them into TIR functions.
66 */
67class TECompilerNode : public Object {
68 public:
69 /*! \brief destructor */
70 virtual ~TECompilerNode() {}
71 /*!
72 * \brief Get lowered result.
73 * \param key The key to the cached function.
74 * \return The result.
75 */
76 virtual CachedFunc Lower(const CCacheKey& key) = 0;
77
78 /*!
79 * \brief Get lowered result.
80 * \param key The key to the cached function.
81 * \return The result.
82 */
83 virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0;
84
85 /* Return all functions which have been lowered by the compiler in an IRModule, annotated with
86 * their target. */
87 virtual IRModule GetLoweredFunctions() = 0;
88
89 /*!
90 * \brief Just in time compile to get a PackedFunc.
91 * \param key The key to the cached function.
92 * \return The result.
93 */
94 virtual PackedFunc JIT(const CCacheKey& key) = 0;
95 /*!
96 * \brief Lower the shape function.
97 * \param key The key to the cached function.
98 * \return The result.
99 */
100 virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
101 /*!
102 * \brief Lower the external function using external codegen tools.
103 * \return The runtime modules for each needed external codegen tool.
104 */
105 virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;
106
107 /*!
108 * \brief Update \p module to remove functions marked with the "Compiler" attribute and replace
109 * them with their 'external' representation using the "ExternalSymbol" attribute.
110 *
111 * TODO(mbs): This is a stepping stone while we migrate to a more official representation
112 * of 'external functions' in the IRModule and allow lowering to incrementally updatethe
113 * module stead of forcing everything via the cache.
114 *
115 */
116 virtual void AddExterns(IRModule module) = 0;
117
118 /*!
119 * \brief Get C Device API context mapping
120 * \return Map of GlobalVar to associated C Device API context name (either Target or kCompiler
121 * annotated)
122 */
123 virtual Map<GlobalVar, String> GetDeviceContexts() = 0;
124 virtual void SetDeviceContexts(const Map<GlobalVar, String>& device_contexts) = 0;
125
126 virtual Map<String, Integer> GetOpWeights() const = 0;
127
128 /*! \brief clear the cache. */
129 virtual void Clear() = 0;
130
131 void VisitAttrs(AttrVisitor*) {}
132
133 static constexpr const char* _type_key = "relay.TECompiler";
134 TVM_DECLARE_FINAL_OBJECT_INFO(TECompilerNode, Object);
135};
136
137/*! \brief cache entry used in compile engine */
138class TECompiler : public ObjectRef {
139 public:
140 explicit TECompiler(Optional<IRModule> opt_mod = {}, Optional<String> mod_name = {});
141 explicit TECompiler(ObjectPtr<Object> n) : ObjectRef(n) {}
142 TECompilerNode* operator->() { return static_cast<TECompilerNode*>(get_mutable()); }
143 using ContainerType = TECompilerNode;
144 TVM_DLL static TECompiler& Global();
145};
146
147/*!
148 * \brief A function to create the function metadata for an input function (ie calculate buffer
149 * input/output sizes)
150 * \param func The function to calculate function metadata for
151 * \param function_metadata The map that stores all the function metadatas
152 * \param workspace_byte_alignment Byte alignment for allocations
153 */
154void UpdateFunctionMetadata(BaseFunc relay_func,
155 Map<String, backend::FunctionInfo>& function_metadata, // NOLINT(*)
156 Integer workspace_byte_alignment = 16);
157
158/*!
159 * \brief Update the "main" control function's metadata
160 *
161 * \param mod The module
162 * \param config All the available targets.
163 * \return function_infos Function info for each function in the module
164 */
165backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const CompilationConfig& config,
166 Map<Expr, backend::StorageInfo> storage_info_map);
167
168/*! \brief Returns all the global \p PrimFunc functions in \p mod, but separated into an \p IRModule
169 * per \p Target.
170 *
171 * \param mod The IRModule to extract the per target module from
172 * \return The map from Target to IRModule
173 */
174Map<Target, IRModule> GetPerTargetModules(IRModule mod);
175
176inline void DefaultProcessFn(BaseFunc) {}
177
178/*!
179 * \brief Pass to lower an IRModule's primitive functions to TIR.
180 *
181 * This is the "back half" of the Relay compiler which lowers "primitive functions"
182 * to TE expressions, schedules them, and emits PrimFuncs.
183 *
184 * \param module_name The name of this module, used as a prefix for generated globals.
185 * \param config All available targets.
186 * \param process_fn Callback allowing one-level up code generators to process
187 * each function that we lower (default is no-op).
188 * \returns The pass which lowers primitive functions to TIR
189 */
190transform::Pass LowerTE(String module_name, CompilationConfig config,
191 ProcessFn process_fn = DefaultProcessFn);
192
193} // namespace tec
194} // namespace relay
195} // namespace tvm
196
197#endif // TVM_RELAY_BACKEND_TE_COMPILER_H_
198