1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/transform.h |
22 | * |
23 | * This file implements a pass manager. The pass manager manages a sequence |
24 | * of IRModule -> IRModule transformation passes over a particlar unit of AST. The |
25 | * design is largely inspired from LLVM's pass manager and modern deep learning |
26 | * frameworks that perform tensor->tensor transformations. |
27 | * |
28 | * The responsibilities of a traditional compiler pass manager usually involves: |
29 | * - Organizing the execution order of optimization passes though not |
30 | * necessarily in the optimal sequence. |
31 | * - Collecting required analysis information and keep them up-to-date. |
32 | * - Reducing the effort required to implement new passes for compiler |
33 | * developers, etc. |
34 | * |
35 | * Similar to LLVM's pass manager, we designed the Relay pass manager to work |
36 | * different granularity, i.e. module level, function level, and even sequential |
37 | * passe that contains a host of passes. |
38 | * |
39 | * However, we also extend the functionality of the traditional pass manager |
40 | * with the consideration of requirements/convention from deep learning |
41 | * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass |
42 | * manager performs the IRModule -> IRModule transformation. All |
43 | * different types of passes, including the sequential-level pass object, are |
44 | * essentially pass objects. This design, therefore, effectively provides users |
45 | * a consistent and convenient interface, i.e. Pass, to play with. It offers a |
46 | * means to ease the development and testing of Relay passes. For example, with |
47 | * the pass manager, external users will be able to have custom passes correctly |
48 | * scheduled without having to modify a single handcrafted pass order. |
49 | * |
50 | * In the future we need to describe constraints between passes. For example, |
51 | * we may want to preserve dependencies between different passes and validate |
52 | * them on the completion of a certain pass. |
53 | * |
54 | * We also need to store side information and import the error reporting system. |
55 | */ |
56 | #ifndef TVM_IR_TRANSFORM_H_ |
57 | #define TVM_IR_TRANSFORM_H_ |
58 | |
59 | #include <tvm/ir/diagnostic.h> |
60 | #include <tvm/ir/instrument.h> |
61 | #include <tvm/ir/module.h> |
62 | #include <tvm/runtime/container/array.h> |
63 | #include <tvm/runtime/container/string.h> |
64 | #include <tvm/support/with.h> |
65 | |
66 | #include <string> |
67 | #include <utility> |
68 | |
69 | namespace tvm { |
70 | namespace transform { |
71 | |
72 | /*! |
73 | * \brief PassContextNode contains the information that a pass can rely on, |
74 | * such as analysis results. |
75 | * \sa PassContext |
76 | */ |
77 | class PassContextNode : public Object { |
78 | public: |
79 | /*! \brief The default optimization level. */ |
80 | int opt_level{2}; |
81 | |
82 | /*! \brief The list of required passes. */ |
83 | Array<String> required_pass; |
84 | /*! \brief The list of disabled passes. */ |
85 | Array<String> disabled_pass; |
86 | /*! \brief The diagnostic context. */ |
87 | mutable Optional<DiagnosticContext> diag_ctx; |
88 | /*! \brief Pass specific configurations. */ |
89 | Map<String, ObjectRef> config; |
90 | |
91 | /*! \brief A list of pass instrument implementations. */ |
92 | Array<instrument::PassInstrument> instruments; |
93 | |
94 | PassContextNode() = default; |
95 | |
96 | /*! |
97 | * \brief Get a config value from the pass context. |
98 | * |
99 | * \param key The config key. |
100 | * \param default_value The default value if the key does not exist, defaults to nullptr. |
101 | * |
102 | * \return The result |
103 | * |
104 | * \tparam TOBjectRef the expected object type. |
105 | * \throw Error if the key exists but the value does not match TObjectRef. |
106 | */ |
107 | template <typename TObjectRef> |
108 | Optional<TObjectRef> GetConfig(const std::string& key, Optional<TObjectRef> default_value = |
109 | Optional<TObjectRef>(nullptr)) const { |
110 | static_assert(std::is_base_of<ObjectRef, TObjectRef>::value, |
111 | "Can only call GetAttr with ObjectRef types." ); |
112 | if (!config.defined()) return default_value; |
113 | auto it = config.find(key); |
114 | if (it != config.end()) { |
115 | return Downcast<Optional<TObjectRef>>((*it).second); |
116 | } else { |
117 | return default_value; |
118 | } |
119 | } |
120 | // variant that uses TObjectRef to enable implicit conversion to default value. |
121 | template <typename TObjectRef> |
122 | Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const { |
123 | return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value)); |
124 | } |
125 | |
126 | void VisitAttrs(AttrVisitor* v) { |
127 | v->Visit("opt_level" , &opt_level); |
128 | v->Visit("required_pass" , &required_pass); |
129 | v->Visit("disabled_pass" , &disabled_pass); |
130 | v->Visit("instruments" , &instruments); |
131 | v->Visit("config" , &config); |
132 | v->Visit("diag_ctx" , &diag_ctx); |
133 | } |
134 | |
135 | static constexpr const char* _type_key = "transform.PassContext" ; |
136 | static constexpr bool _type_has_method_sequal_reduce = false; |
137 | TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); |
138 | }; |
139 | |
140 | /*! |
141 | * \brief PassContext that is used to configure the pass behavior. |
142 | * |
143 | * \code |
144 | * |
145 | * auto new_ctx = PassContext::Create(); |
146 | * ctx->opt_level = 2; |
147 | * With<PassContext> scope(ctx); |
148 | * // pass context in effect. |
149 | * |
150 | * \endcode |
151 | * \sa PassContextNode |
152 | */ |
153 | class PassContext : public ObjectRef { |
154 | public: |
155 | PassContext() {} |
156 | explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {} |
157 | /*! |
158 | * \brief const accessor. |
159 | * \return const access pointer. |
160 | */ |
161 | const PassContextNode* operator->() const { |
162 | ICHECK(get() != nullptr); |
163 | return static_cast<const PassContextNode*>(get()); |
164 | } |
165 | /*! |
166 | * \brief mutable accessor. |
167 | * \return mutable access pointer. |
168 | */ |
169 | PassContextNode* operator->() { |
170 | ICHECK(get() != nullptr); |
171 | return static_cast<PassContextNode*>(get_mutable()); |
172 | } |
173 | |
174 | /*! |
175 | * \brief Construct a PassContext containing the default configurations. |
176 | * \return The new PassContext. |
177 | */ |
178 | TVM_DLL static PassContext Create(); |
179 | /*! |
180 | * \brief Get the default pass context in the current scope. |
181 | * \return The pass context. |
182 | */ |
183 | TVM_DLL static PassContext Current(); |
184 | |
185 | /*! |
186 | * \brief Get all supported configuration names and metadata, registered within the PassContext. |
187 | * \return Map indexed by the config name, pointing to the metadata map as key-value |
188 | */ |
189 | TVM_DLL static Map<String, Map<String, String>> ListConfigs(); |
190 | |
191 | /*! |
192 | * \brief Call instrument implementations' callbacks when entering PassContext. |
193 | * The callbacks are called in order, and if one raises an exception, the rest will not be |
194 | * called. |
195 | */ |
196 | TVM_DLL void InstrumentEnterPassContext(); |
197 | |
198 | /*! |
199 | * \brief Call instrument implementations' callbacks when exiting PassContext. |
200 | * The callbacks are called in order, and if one raises an exception, the rest will not be |
201 | * called. |
202 | */ |
203 | TVM_DLL void InstrumentExitPassContext(); |
204 | |
205 | /*! |
206 | * \brief Call instrument implementations' callbacks before a pass run. |
207 | * The callbacks are called in order, and if one raises an exception, the rest will not be |
208 | * called. |
209 | * |
210 | * \param mod The module that an optimization pass runs on. |
211 | * \param info The pass information. |
212 | * |
213 | * \return false: the pass is skipped; true: the pass runs. |
214 | */ |
215 | TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; |
216 | |
217 | /*! |
218 | * \brief Call instrument implementations callbacks after a pass run. |
219 | * The callbacks are called in order, and if one raises an exception, the rest will not be |
220 | * called. |
221 | * |
222 | * \param mod The module that an optimization pass runs on. |
223 | * \param info The pass information. |
224 | */ |
225 | TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; |
226 | |
227 | /*! |
228 | * \brief Check whether a pass is enabled. |
229 | * \param info The pass information. |
230 | * \return true if the pass is enabled. Otherwise, false. |
231 | */ |
232 | TVM_DLL bool PassEnabled(const PassInfo& info) const; |
233 | |
234 | /*! |
235 | * \brief Register a valid configuration option and its ValueType for validation. |
236 | * |
237 | * \param key The configuration key. |
238 | * \tparam ValueType The value type to be registered |
239 | */ |
240 | template <typename ValueType> |
241 | static uint32_t RegisterConfigOption(const char* key) { |
242 | using ValueNodeType = typename ValueType::ContainerType; |
243 | // NOTE: we could further update the function later. |
244 | uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); |
245 | RegisterConfigOption(key, tindex); |
246 | return tindex; |
247 | } |
248 | |
249 | // accessor. |
250 | using ContainerType = PassContextNode; |
251 | class Internal; |
252 | |
253 | private: |
254 | // The entry of a pass context scope. |
255 | TVM_DLL void EnterWithScope(); |
256 | // The exit of a pass context scope. |
257 | TVM_DLL void ExitWithScope(); |
258 | // Register configuration key value type. |
259 | TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index); |
260 | |
261 | // Classes to get the Python `with` like syntax. |
262 | friend class Internal; |
263 | friend class With<PassContext>; |
264 | }; |
265 | |
266 | #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid |
267 | |
268 | /*! |
269 | * \brief Helper macro to register the object type to runtime. |
270 | * Makes sure that the runtime type table is correctly populated. |
271 | * |
272 | * Use this macro in the cc file for each terminal class. |
273 | */ |
274 | #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \ |
275 | TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \ |
276 | ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key) |
277 | |
278 | /*! |
279 | * \brief Meta data that will be used to help optimization and analysis. |
280 | * \sa PassInfo |
281 | */ |
282 | class PassInfoNode : public Object { |
283 | public: |
284 | /*! \brief The minimal optimization level that this pass will be enabled. */ |
285 | int opt_level; |
286 | |
287 | /*! \brief The name of an optimization/analysis pass. */ |
288 | String name; |
289 | |
290 | /*! \brief The passes that are required to perform the current pass. */ |
291 | Array<String> required; |
292 | |
293 | PassInfoNode() = default; |
294 | |
295 | void VisitAttrs(AttrVisitor* v) { |
296 | v->Visit("opt_level" , &opt_level); |
297 | v->Visit("name" , &name); |
298 | v->Visit("required" , &required); |
299 | } |
300 | |
301 | static constexpr const char* _type_key = "transform.PassInfo" ; |
302 | static constexpr bool _type_has_method_sequal_reduce = false; |
303 | TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); |
304 | }; |
305 | |
306 | /*! |
307 | * \brief Managed reference class for PassInfoNode |
308 | * \sa PassInfoNode |
309 | */ |
310 | class PassInfo : public ObjectRef { |
311 | public: |
312 | /*! |
313 | * \brief Constructor |
314 | * \param opt_level The optimization level |
315 | * \param name Name of the pass. |
316 | * \param required The passes that are required to perform the current pass. |
317 | */ |
318 | TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required); |
319 | |
320 | TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); |
321 | }; |
322 | |
323 | /*! |
324 | * \brief PassNode is the base type of differnt types of optimization passes. |
325 | * It is designed as a pure class and implemented by different pass subclasses |
326 | * at different granularity of Relay nodes. |
327 | */ |
328 | class PassNode : public Object { |
329 | public: |
330 | virtual ~PassNode() {} |
331 | /*! |
332 | * \brief Get the pass information/meta data. */ |
333 | virtual PassInfo Info() const = 0; |
334 | |
335 | /*! |
336 | * \brief Transform mod using the default PassContext in the current scope. |
337 | * |
338 | * \param mod The module that an optimization pass runs on. |
339 | * |
340 | * \return The transformed module. |
341 | */ |
342 | IRModule operator()(IRModule mod) const { |
343 | return this->operator()(std::move(mod), PassContext::Current()); |
344 | } |
345 | |
346 | /*! |
347 | * \brief Transform mod using a functor under a given pass context. |
348 | * |
349 | * \param mod The module that an optimization pass runs on. |
350 | * \param pass_ctx The pass context that can provide information for the optimization. |
351 | * |
352 | * \return The transformed module. |
353 | */ |
354 | virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; |
355 | |
356 | void VisitAttrs(AttrVisitor* v) {} |
357 | |
358 | static constexpr const char* _type_key = "transform.Pass" ; |
359 | TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); |
360 | }; |
361 | |
362 | class Pass : public ObjectRef { |
363 | public: |
364 | /*! |
365 | * \brief Transform mod using the default PassContext in the current scope. |
366 | * |
367 | * \code |
368 | * |
369 | * // If you do no longer need the input module |
370 | * // it is recommended to use std::move to move your input module. |
371 | * mod = pass(std::move(mod)); |
372 | * |
373 | * \endcode |
374 | * |
375 | * \param mod The module that an optimization pass runs on. |
376 | * |
377 | * \return The transformed module. |
378 | */ |
379 | IRModule operator()(IRModule mod) const; |
380 | |
381 | /*! |
382 | * \brief Transform mod using a functor under a given pass context. |
383 | * |
384 | * \param mod The module that an optimization pass runs on. |
385 | * \param pass_ctx The pass context that can provide information for the optimization. |
386 | * |
387 | * \return The transformed module. |
388 | */ |
389 | IRModule operator()(IRModule mod, const PassContext& pass_ctx) const; |
390 | |
391 | TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); |
392 | |
393 | private: |
394 | IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node, |
395 | const PassContext& pass_ctx); |
396 | }; |
397 | |
398 | /*! |
399 | * \brief The SequentialNode contains a set of passes that transform Relay |
400 | * programs from one AST to another semantically equivalent one. |
401 | * |
402 | * One example of this level of pass is that the pass manager needs to correctly |
403 | * perform a host of optimizations with a given optimization level and disabled |
404 | * passes. |
405 | */ |
406 | class SequentialNode : public PassNode { |
407 | public: |
408 | /* \brief The pass meta data.*/ |
409 | PassInfo pass_info; |
410 | |
411 | /*! \brief A list of passes that used to compose a sequential pass. */ |
412 | tvm::Array<Pass> passes; |
413 | |
414 | void VisitAttrs(tvm::AttrVisitor* v) { |
415 | v->Visit("pass_info" , &pass_info); |
416 | v->Visit("passes" , &passes); |
417 | } |
418 | |
419 | /*! |
420 | * \brief Get the pass information/meta data. |
421 | */ |
422 | PassInfo Info() const override { return pass_info; } |
423 | |
424 | /*! |
425 | * \brief Resolve the pass dependency. It globs all required passes by |
426 | * a given pass and executes them. |
427 | * |
428 | * \param mod The module that an optimization pass runs on. |
429 | * |
430 | * \return The updated module after resolving pass dependencies. |
431 | * |
432 | * TODO(zhiics) Build a dependency graph among the passes using provided |
433 | * metadata, i.e. required_passes. Likely, we can have a data structure, i.e. |
434 | * PassInfo, to store the relevant information including the parent passes. |
435 | */ |
436 | void ResolveDependency(const IRModule& mod); |
437 | |
438 | /*! |
439 | * \brief Perform optimizations on a series of passes. The aforementioned |
440 | * typical pass manager jobs could be done by it. This function could |
441 | * be overloaded to focus on different metrics, i.e. performance, |
442 | * memory footprint, etc. |
443 | * |
444 | * \param mod The module that these passes are applied on. |
445 | * \param pass_ctx The context that these passes execute on. |
446 | * |
447 | * \return Return the updated module. |
448 | */ |
449 | IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; |
450 | |
451 | static constexpr const char* _type_key = "transform.Sequential" ; |
452 | TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); |
453 | }; |
454 | |
455 | class Sequential : public Pass { |
456 | public: |
457 | /*! |
458 | * \brief The constructor of `Sequential`. |
459 | * |
460 | * \param passes The passes to apply. |
461 | * \param pass_info The pass metadata. |
462 | */ |
463 | TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info); |
464 | |
465 | /*! |
466 | * \brief The constructor of `Sequential`. |
467 | * |
468 | * \param passes The passes to apply. |
469 | * \param name The name of a sequential pass. It's defaulted to "sequential". |
470 | * This allows users to only provide a list of passes and execute them |
471 | * under a given context. |
472 | */ |
473 | TVM_DLL Sequential(Array<Pass> passes, String name = "sequential" ); |
474 | |
475 | Sequential() = default; |
476 | explicit Sequential(ObjectPtr<Object> n) : Pass(n) {} |
477 | |
478 | const SequentialNode* operator->() const; |
479 | using ContainerType = SequentialNode; |
480 | }; |
481 | |
482 | /* |
483 | * \brief Create a module pass. |
484 | * |
485 | * \param pass_func The packed function that contains the optimization. |
486 | * \param opt_level The optimization level of the module pass. |
487 | * \param name The name of the module pass. |
488 | * \param required The list of the passes that the module pass is dependent on. |
489 | * |
490 | * \return The created module pass. |
491 | */ |
492 | TVM_DLL Pass |
493 | CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, |
494 | int opt_level, String name, Array<runtime::String> required); |
495 | |
496 | /*! |
497 | * \brief A special trace pass that prints the header and IR to LOG(INFO). |
498 | * \param header The header to be attached to the output. |
499 | * \param show_meta_data Whether should we show meta data. |
500 | * \return The pass. |
501 | */ |
502 | TVM_DLL Pass PrintIR(String = "" , bool show_meta_data = false); |
503 | |
504 | } // namespace transform |
505 | } // namespace tvm |
506 | |
507 | #endif // TVM_IR_TRANSFORM_H_ |
508 | |