1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/transform.h
22 *
23 * This file implements a pass manager. The pass manager manages a sequence
24 * of IRModule -> IRModule transformation passes over a particlar unit of AST. The
25 * design is largely inspired from LLVM's pass manager and modern deep learning
26 * frameworks that perform tensor->tensor transformations.
27 *
28 * The responsibilities of a traditional compiler pass manager usually involves:
29 * - Organizing the execution order of optimization passes though not
30 * necessarily in the optimal sequence.
31 * - Collecting required analysis information and keep them up-to-date.
32 * - Reducing the effort required to implement new passes for compiler
33 * developers, etc.
34 *
35 * Similar to LLVM's pass manager, we designed the Relay pass manager to work
36 * different granularity, i.e. module level, function level, and even sequential
37 * passe that contains a host of passes.
38 *
39 * However, we also extend the functionality of the traditional pass manager
40 * with the consideration of requirements/convention from deep learning
41 * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
42 * manager performs the IRModule -> IRModule transformation. All
43 * different types of passes, including the sequential-level pass object, are
44 * essentially pass objects. This design, therefore, effectively provides users
45 * a consistent and convenient interface, i.e. Pass, to play with. It offers a
46 * means to ease the development and testing of Relay passes. For example, with
47 * the pass manager, external users will be able to have custom passes correctly
48 * scheduled without having to modify a single handcrafted pass order.
49 *
50 * In the future we need to describe constraints between passes. For example,
51 * we may want to preserve dependencies between different passes and validate
52 * them on the completion of a certain pass.
53 *
54 * We also need to store side information and import the error reporting system.
55 */
56#ifndef TVM_IR_TRANSFORM_H_
57#define TVM_IR_TRANSFORM_H_
58
59#include <tvm/ir/diagnostic.h>
60#include <tvm/ir/instrument.h>
61#include <tvm/ir/module.h>
62#include <tvm/runtime/container/array.h>
63#include <tvm/runtime/container/string.h>
64#include <tvm/support/with.h>
65
66#include <string>
67#include <utility>
68
69namespace tvm {
70namespace transform {
71
72/*!
73 * \brief PassContextNode contains the information that a pass can rely on,
74 * such as analysis results.
75 * \sa PassContext
76 */
77class PassContextNode : public Object {
78 public:
79 /*! \brief The default optimization level. */
80 int opt_level{2};
81
82 /*! \brief The list of required passes. */
83 Array<String> required_pass;
84 /*! \brief The list of disabled passes. */
85 Array<String> disabled_pass;
86 /*! \brief The diagnostic context. */
87 mutable Optional<DiagnosticContext> diag_ctx;
88 /*! \brief Pass specific configurations. */
89 Map<String, ObjectRef> config;
90
91 /*! \brief A list of pass instrument implementations. */
92 Array<instrument::PassInstrument> instruments;
93
94 PassContextNode() = default;
95
96 /*!
97 * \brief Get a config value from the pass context.
98 *
99 * \param key The config key.
100 * \param default_value The default value if the key does not exist, defaults to nullptr.
101 *
102 * \return The result
103 *
104 * \tparam TOBjectRef the expected object type.
105 * \throw Error if the key exists but the value does not match TObjectRef.
106 */
107 template <typename TObjectRef>
108 Optional<TObjectRef> GetConfig(const std::string& key, Optional<TObjectRef> default_value =
109 Optional<TObjectRef>(nullptr)) const {
110 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
111 "Can only call GetAttr with ObjectRef types.");
112 if (!config.defined()) return default_value;
113 auto it = config.find(key);
114 if (it != config.end()) {
115 return Downcast<Optional<TObjectRef>>((*it).second);
116 } else {
117 return default_value;
118 }
119 }
120 // variant that uses TObjectRef to enable implicit conversion to default value.
121 template <typename TObjectRef>
122 Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
123 return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
124 }
125
126 void VisitAttrs(AttrVisitor* v) {
127 v->Visit("opt_level", &opt_level);
128 v->Visit("required_pass", &required_pass);
129 v->Visit("disabled_pass", &disabled_pass);
130 v->Visit("instruments", &instruments);
131 v->Visit("config", &config);
132 v->Visit("diag_ctx", &diag_ctx);
133 }
134
135 static constexpr const char* _type_key = "transform.PassContext";
136 static constexpr bool _type_has_method_sequal_reduce = false;
137 TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
138};
139
140/*!
141 * \brief PassContext that is used to configure the pass behavior.
142 *
143 * \code
144 *
145 * auto new_ctx = PassContext::Create();
146 * ctx->opt_level = 2;
147 * With<PassContext> scope(ctx);
148 * // pass context in effect.
149 *
150 * \endcode
151 * \sa PassContextNode
152 */
153class PassContext : public ObjectRef {
154 public:
155 PassContext() {}
156 explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
157 /*!
158 * \brief const accessor.
159 * \return const access pointer.
160 */
161 const PassContextNode* operator->() const {
162 ICHECK(get() != nullptr);
163 return static_cast<const PassContextNode*>(get());
164 }
165 /*!
166 * \brief mutable accessor.
167 * \return mutable access pointer.
168 */
169 PassContextNode* operator->() {
170 ICHECK(get() != nullptr);
171 return static_cast<PassContextNode*>(get_mutable());
172 }
173
174 /*!
175 * \brief Construct a PassContext containing the default configurations.
176 * \return The new PassContext.
177 */
178 TVM_DLL static PassContext Create();
179 /*!
180 * \brief Get the default pass context in the current scope.
181 * \return The pass context.
182 */
183 TVM_DLL static PassContext Current();
184
185 /*!
186 * \brief Get all supported configuration names and metadata, registered within the PassContext.
187 * \return Map indexed by the config name, pointing to the metadata map as key-value
188 */
189 TVM_DLL static Map<String, Map<String, String>> ListConfigs();
190
191 /*!
192 * \brief Call instrument implementations' callbacks when entering PassContext.
193 * The callbacks are called in order, and if one raises an exception, the rest will not be
194 * called.
195 */
196 TVM_DLL void InstrumentEnterPassContext();
197
198 /*!
199 * \brief Call instrument implementations' callbacks when exiting PassContext.
200 * The callbacks are called in order, and if one raises an exception, the rest will not be
201 * called.
202 */
203 TVM_DLL void InstrumentExitPassContext();
204
205 /*!
206 * \brief Call instrument implementations' callbacks before a pass run.
207 * The callbacks are called in order, and if one raises an exception, the rest will not be
208 * called.
209 *
210 * \param mod The module that an optimization pass runs on.
211 * \param info The pass information.
212 *
213 * \return false: the pass is skipped; true: the pass runs.
214 */
215 TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
216
217 /*!
218 * \brief Call instrument implementations callbacks after a pass run.
219 * The callbacks are called in order, and if one raises an exception, the rest will not be
220 * called.
221 *
222 * \param mod The module that an optimization pass runs on.
223 * \param info The pass information.
224 */
225 TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
226
227 /*!
228 * \brief Check whether a pass is enabled.
229 * \param info The pass information.
230 * \return true if the pass is enabled. Otherwise, false.
231 */
232 TVM_DLL bool PassEnabled(const PassInfo& info) const;
233
234 /*!
235 * \brief Register a valid configuration option and its ValueType for validation.
236 *
237 * \param key The configuration key.
238 * \tparam ValueType The value type to be registered
239 */
240 template <typename ValueType>
241 static uint32_t RegisterConfigOption(const char* key) {
242 using ValueNodeType = typename ValueType::ContainerType;
243 // NOTE: we could further update the function later.
244 uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
245 RegisterConfigOption(key, tindex);
246 return tindex;
247 }
248
249 // accessor.
250 using ContainerType = PassContextNode;
251 class Internal;
252
253 private:
254 // The entry of a pass context scope.
255 TVM_DLL void EnterWithScope();
256 // The exit of a pass context scope.
257 TVM_DLL void ExitWithScope();
258 // Register configuration key value type.
259 TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
260
261 // Classes to get the Python `with` like syntax.
262 friend class Internal;
263 friend class With<PassContext>;
264};
265
266#define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
267
268/*!
269 * \brief Helper macro to register the object type to runtime.
270 * Makes sure that the runtime type table is correctly populated.
271 *
272 * Use this macro in the cc file for each terminal class.
273 */
274#define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
275 TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
276 ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
277
278/*!
279 * \brief Meta data that will be used to help optimization and analysis.
280 * \sa PassInfo
281 */
282class PassInfoNode : public Object {
283 public:
284 /*! \brief The minimal optimization level that this pass will be enabled. */
285 int opt_level;
286
287 /*! \brief The name of an optimization/analysis pass. */
288 String name;
289
290 /*! \brief The passes that are required to perform the current pass. */
291 Array<String> required;
292
293 PassInfoNode() = default;
294
295 void VisitAttrs(AttrVisitor* v) {
296 v->Visit("opt_level", &opt_level);
297 v->Visit("name", &name);
298 v->Visit("required", &required);
299 }
300
301 static constexpr const char* _type_key = "transform.PassInfo";
302 static constexpr bool _type_has_method_sequal_reduce = false;
303 TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
304};
305
306/*!
307 * \brief Managed reference class for PassInfoNode
308 * \sa PassInfoNode
309 */
310class PassInfo : public ObjectRef {
311 public:
312 /*!
313 * \brief Constructor
314 * \param opt_level The optimization level
315 * \param name Name of the pass.
316 * \param required The passes that are required to perform the current pass.
317 */
318 TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
319
320 TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
321};
322
323/*!
324 * \brief PassNode is the base type of differnt types of optimization passes.
325 * It is designed as a pure class and implemented by different pass subclasses
326 * at different granularity of Relay nodes.
327 */
328class PassNode : public Object {
329 public:
330 virtual ~PassNode() {}
331 /*!
332 * \brief Get the pass information/meta data. */
333 virtual PassInfo Info() const = 0;
334
335 /*!
336 * \brief Transform mod using the default PassContext in the current scope.
337 *
338 * \param mod The module that an optimization pass runs on.
339 *
340 * \return The transformed module.
341 */
342 IRModule operator()(IRModule mod) const {
343 return this->operator()(std::move(mod), PassContext::Current());
344 }
345
346 /*!
347 * \brief Transform mod using a functor under a given pass context.
348 *
349 * \param mod The module that an optimization pass runs on.
350 * \param pass_ctx The pass context that can provide information for the optimization.
351 *
352 * \return The transformed module.
353 */
354 virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
355
356 void VisitAttrs(AttrVisitor* v) {}
357
358 static constexpr const char* _type_key = "transform.Pass";
359 TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
360};
361
362class Pass : public ObjectRef {
363 public:
364 /*!
365 * \brief Transform mod using the default PassContext in the current scope.
366 *
367 * \code
368 *
369 * // If you do no longer need the input module
370 * // it is recommended to use std::move to move your input module.
371 * mod = pass(std::move(mod));
372 *
373 * \endcode
374 *
375 * \param mod The module that an optimization pass runs on.
376 *
377 * \return The transformed module.
378 */
379 IRModule operator()(IRModule mod) const;
380
381 /*!
382 * \brief Transform mod using a functor under a given pass context.
383 *
384 * \param mod The module that an optimization pass runs on.
385 * \param pass_ctx The pass context that can provide information for the optimization.
386 *
387 * \return The transformed module.
388 */
389 IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
390
391 TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
392
393 private:
394 IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
395 const PassContext& pass_ctx);
396};
397
398/*!
399 * \brief The SequentialNode contains a set of passes that transform Relay
400 * programs from one AST to another semantically equivalent one.
401 *
402 * One example of this level of pass is that the pass manager needs to correctly
403 * perform a host of optimizations with a given optimization level and disabled
404 * passes.
405 */
406class SequentialNode : public PassNode {
407 public:
408 /* \brief The pass meta data.*/
409 PassInfo pass_info;
410
411 /*! \brief A list of passes that used to compose a sequential pass. */
412 tvm::Array<Pass> passes;
413
414 void VisitAttrs(tvm::AttrVisitor* v) {
415 v->Visit("pass_info", &pass_info);
416 v->Visit("passes", &passes);
417 }
418
419 /*!
420 * \brief Get the pass information/meta data.
421 */
422 PassInfo Info() const override { return pass_info; }
423
424 /*!
425 * \brief Resolve the pass dependency. It globs all required passes by
426 * a given pass and executes them.
427 *
428 * \param mod The module that an optimization pass runs on.
429 *
430 * \return The updated module after resolving pass dependencies.
431 *
432 * TODO(zhiics) Build a dependency graph among the passes using provided
433 * metadata, i.e. required_passes. Likely, we can have a data structure, i.e.
434 * PassInfo, to store the relevant information including the parent passes.
435 */
436 void ResolveDependency(const IRModule& mod);
437
438 /*!
439 * \brief Perform optimizations on a series of passes. The aforementioned
440 * typical pass manager jobs could be done by it. This function could
441 * be overloaded to focus on different metrics, i.e. performance,
442 * memory footprint, etc.
443 *
444 * \param mod The module that these passes are applied on.
445 * \param pass_ctx The context that these passes execute on.
446 *
447 * \return Return the updated module.
448 */
449 IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
450
451 static constexpr const char* _type_key = "transform.Sequential";
452 TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
453};
454
455class Sequential : public Pass {
456 public:
457 /*!
458 * \brief The constructor of `Sequential`.
459 *
460 * \param passes The passes to apply.
461 * \param pass_info The pass metadata.
462 */
463 TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
464
465 /*!
466 * \brief The constructor of `Sequential`.
467 *
468 * \param passes The passes to apply.
469 * \param name The name of a sequential pass. It's defaulted to "sequential".
470 * This allows users to only provide a list of passes and execute them
471 * under a given context.
472 */
473 TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
474
475 Sequential() = default;
476 explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
477
478 const SequentialNode* operator->() const;
479 using ContainerType = SequentialNode;
480};
481
482/*
483 * \brief Create a module pass.
484 *
485 * \param pass_func The packed function that contains the optimization.
486 * \param opt_level The optimization level of the module pass.
487 * \param name The name of the module pass.
488 * \param required The list of the passes that the module pass is dependent on.
489 *
490 * \return The created module pass.
491 */
492TVM_DLL Pass
493CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
494 int opt_level, String name, Array<runtime::String> required);
495
496/*!
497 * \brief A special trace pass that prints the header and IR to LOG(INFO).
498 * \param header The header to be attached to the output.
499 * \param show_meta_data Whether should we show meta data.
500 * \return The pass.
501 */
502TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
503
504} // namespace transform
505} // namespace tvm
506
507#endif // TVM_IR_TRANSFORM_H_
508