1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/backend/te_compiler.h |
22 | * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. |
23 | * |
24 | * |
25 | * This represents the new design of the Relay compilation flow and will replace the interface |
26 | * contained in compile_engine.h as we migrate towards a standard pass based lowering of |
27 | * Relay functions. |
28 | * |
29 | * This files provides an internal API which lowers Relay programs to components which |
30 | * can be combined with TVM produced kernels to compile an entire program. |
31 | * |
32 | * The result of lowering contains a combination of `runtime::Module`s produced by external |
33 | * compilers and a set of lowered PrimFns which can be code generated for targets. |
34 | */ |
35 | #ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_ |
36 | #define TVM_RELAY_BACKEND_TE_COMPILER_H_ |
37 | |
38 | #include <tvm/node/structural_equal.h> |
39 | #include <tvm/node/structural_hash.h> |
40 | #include <tvm/relay/analysis.h> |
41 | #include <tvm/relay/attrs/memory.h> |
42 | #include <tvm/relay/expr.h> |
43 | #include <tvm/relay/op_strategy.h> |
44 | #include <tvm/relay/transform.h> |
45 | #include <tvm/runtime/module.h> |
46 | #include <tvm/topi/elemwise.h> |
47 | |
48 | #include <functional> |
49 | #include <string> |
50 | #include <unordered_map> |
51 | |
52 | #include "../transforms/infer_layout_utils.h" |
53 | #include "../transforms/pass_utils.h" |
54 | #include "./te_compiler_cache.h" |
55 | #include "./utils.h" |
56 | |
57 | namespace tvm { |
58 | namespace relay { |
59 | namespace tec { |
60 | |
61 | using ProcessFn = std::function<void(BaseFunc)>; |
62 | |
63 | /*! |
64 | * \brief A compiler which lowers primitive Relay functions to tensor expressions |
65 | * and schedules them into TIR functions. |
66 | */ |
67 | class TECompilerNode : public Object { |
68 | public: |
69 | /*! \brief destructor */ |
70 | virtual ~TECompilerNode() {} |
71 | /*! |
72 | * \brief Get lowered result. |
73 | * \param key The key to the cached function. |
74 | * \return The result. |
75 | */ |
76 | virtual CachedFunc Lower(const CCacheKey& key) = 0; |
77 | |
78 | /*! |
79 | * \brief Get lowered result. |
80 | * \param key The key to the cached function. |
81 | * \return The result. |
82 | */ |
83 | virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; |
84 | |
85 | /* Return all functions which have been lowered by the compiler in an IRModule, annotated with |
86 | * their target. */ |
87 | virtual IRModule GetLoweredFunctions() = 0; |
88 | |
89 | /*! |
90 | * \brief Just in time compile to get a PackedFunc. |
91 | * \param key The key to the cached function. |
92 | * \return The result. |
93 | */ |
94 | virtual PackedFunc JIT(const CCacheKey& key) = 0; |
95 | /*! |
96 | * \brief Lower the shape function. |
97 | * \param key The key to the cached function. |
98 | * \return The result. |
99 | */ |
100 | virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; |
101 | /*! |
102 | * \brief Lower the external function using external codegen tools. |
103 | * \return The runtime modules for each needed external codegen tool. |
104 | */ |
105 | virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0; |
106 | |
107 | /*! |
108 | * \brief Update \p module to remove functions marked with the "Compiler" attribute and replace |
109 | * them with their 'external' representation using the "ExternalSymbol" attribute. |
110 | * |
111 | * TODO(mbs): This is a stepping stone while we migrate to a more official representation |
112 | * of 'external functions' in the IRModule and allow lowering to incrementally updatethe |
113 | * module stead of forcing everything via the cache. |
114 | * |
115 | */ |
116 | virtual void AddExterns(IRModule module) = 0; |
117 | |
118 | /*! |
119 | * \brief Get C Device API context mapping |
120 | * \return Map of GlobalVar to associated C Device API context name (either Target or kCompiler |
121 | * annotated) |
122 | */ |
123 | virtual Map<GlobalVar, String> GetDeviceContexts() = 0; |
124 | virtual void SetDeviceContexts(const Map<GlobalVar, String>& device_contexts) = 0; |
125 | |
126 | virtual Map<String, Integer> GetOpWeights() const = 0; |
127 | |
128 | /*! \brief clear the cache. */ |
129 | virtual void Clear() = 0; |
130 | |
131 | void VisitAttrs(AttrVisitor*) {} |
132 | |
133 | static constexpr const char* _type_key = "relay.TECompiler" ; |
134 | TVM_DECLARE_FINAL_OBJECT_INFO(TECompilerNode, Object); |
135 | }; |
136 | |
137 | /*! \brief cache entry used in compile engine */ |
138 | class TECompiler : public ObjectRef { |
139 | public: |
140 | explicit TECompiler(Optional<IRModule> opt_mod = {}, Optional<String> mod_name = {}); |
141 | explicit TECompiler(ObjectPtr<Object> n) : ObjectRef(n) {} |
142 | TECompilerNode* operator->() { return static_cast<TECompilerNode*>(get_mutable()); } |
143 | using ContainerType = TECompilerNode; |
144 | TVM_DLL static TECompiler& Global(); |
145 | }; |
146 | |
147 | /*! |
148 | * \brief A function to create the function metadata for an input function (ie calculate buffer |
149 | * input/output sizes) |
150 | * \param func The function to calculate function metadata for |
151 | * \param function_metadata The map that stores all the function metadatas |
152 | * \param workspace_byte_alignment Byte alignment for allocations |
153 | */ |
154 | void UpdateFunctionMetadata(BaseFunc relay_func, |
155 | Map<String, backend::FunctionInfo>& function_metadata, // NOLINT(*) |
156 | Integer workspace_byte_alignment = 16); |
157 | |
158 | /*! |
159 | * \brief Update the "main" control function's metadata |
160 | * |
161 | * \param mod The module |
162 | * \param config All the available targets. |
163 | * \return function_infos Function info for each function in the module |
164 | */ |
165 | backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const CompilationConfig& config, |
166 | Map<Expr, backend::StorageInfo> storage_info_map); |
167 | |
168 | /*! \brief Returns all the global \p PrimFunc functions in \p mod, but separated into an \p IRModule |
169 | * per \p Target. |
170 | * |
171 | * \param mod The IRModule to extract the per target module from |
172 | * \return The map from Target to IRModule |
173 | */ |
174 | Map<Target, IRModule> GetPerTargetModules(IRModule mod); |
175 | |
176 | inline void DefaultProcessFn(BaseFunc) {} |
177 | |
178 | /*! |
179 | * \brief Pass to lower an IRModule's primitive functions to TIR. |
180 | * |
181 | * This is the "back half" of the Relay compiler which lowers "primitive functions" |
182 | * to TE expressions, schedules them, and emits PrimFuncs. |
183 | * |
184 | * \param module_name The name of this module, used as a prefix for generated globals. |
185 | * \param config All available targets. |
186 | * \param process_fn Callback allowing one-level up code generators to process |
187 | * each function that we lower (default is no-op). |
188 | * \returns The pass which lowers primitive functions to TIR |
189 | */ |
190 | transform::Pass LowerTE(String module_name, CompilationConfig config, |
191 | ProcessFn process_fn = DefaultProcessFn); |
192 | |
193 | } // namespace tec |
194 | } // namespace relay |
195 | } // namespace tvm |
196 | |
197 | #endif // TVM_RELAY_BACKEND_TE_COMPILER_H_ |
198 | |