1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/ir/transform.cc
22 * \brief Infrastructure for transformation passes.
23 */
24#include <dmlc/thread_local.h>
25#include <tvm/ir/transform.h>
26#include <tvm/node/repr_printer.h>
27#include <tvm/node/structural_hash.h>
28#include <tvm/runtime/device_api.h>
29#include <tvm/runtime/registry.h>
30
31#include <chrono>
32#include <iomanip>
33#include <stack>
34#include <unordered_set>
35
36#include "../runtime/object_internal.h"
37
38namespace tvm {
39namespace transform {
40
41using tvm::ReprPrinter;
42using tvm::runtime::TVMArgs;
43using tvm::runtime::TVMRetValue;
44
45TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool);
46
47struct PassContextThreadLocalEntry {
48 /*! \brief The default pass context. */
49 PassContext default_context;
50
51 /*! \brief The current pass context. */
52 std::stack<PassContext> context_stack;
53
54 PassContextThreadLocalEntry() { default_context = PassContext(make_object<PassContextNode>()); }
55};
56
57/*! \brief Thread local store to hold the pass context. */
58typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry> RelayPassContextThreadLocalStore;
59
60void PassContext::EnterWithScope() {
61 InstrumentEnterPassContext();
62
63 PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
64 entry->context_stack.push(*this);
65}
66
67void PassContext::ExitWithScope() {
68 PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
69 ICHECK(!entry->context_stack.empty());
70 ICHECK(entry->context_stack.top().same_as(*this));
71 entry->context_stack.pop();
72
73 InstrumentExitPassContext();
74}
75
76PassContext PassContext::Current() {
77 PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
78 if (!entry->context_stack.empty()) {
79 return entry->context_stack.top();
80 } else {
81 return entry->default_context;
82 }
83}
84
85// linearly scan the pass array to match pass_name
86bool PassArrayContains(const Array<runtime::String>& pass_array, const std::string& pass_name) {
87 for (auto x : pass_array) {
88 if (x == pass_name) return true;
89 }
90 return false;
91}
92
93bool PassContext::PassEnabled(const PassInfo& info) const {
94 if (PassArrayContains(operator->()->disabled_pass, info->name)) {
95 return false;
96 }
97
98 if (PassArrayContains(operator->()->required_pass, info->name)) {
99 return true;
100 }
101
102 return operator->()->opt_level >= info->opt_level;
103}
104
105class PassConfigManager {
106 public:
107 void Register(std::string key, uint32_t value_type_index) {
108 ICHECK_EQ(key2vtype_.count(key), 0U);
109 ValueTypeInfo info;
110 info.type_index = value_type_index;
111 info.type_key = runtime::Object::TypeIndex2Key(value_type_index);
112 key2vtype_[key] = info;
113 }
114
115 // Trying to validate and legalize a config.
116 void Legalize(Map<String, ObjectRef>* config) {
117 std::vector<std::pair<std::string, ObjectRef>> update;
118 auto* reflection = ReflectionVTable::Global();
119
120 for (auto kv : *config) {
121 auto it = key2vtype_.find(kv.first);
122 if (it == key2vtype_.end()) {
123 std::ostringstream os;
124 os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:";
125 int counter = 0;
126 for (const auto& kv : key2vtype_) {
127 os << ' ';
128 if (counter++ != 0) os << ',';
129 os << kv.first;
130 }
131 LOG(FATAL) << os.str();
132 }
133 const auto& info = it->second;
134 ICHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None";
135 if (kv.second->IsInstance<Map<String, ObjectRef>::ContainerType>()) {
136 ObjectRef converted =
137 reflection->CreateObject(info.type_key, Downcast<Map<String, ObjectRef>>(kv.second));
138 update.emplace_back(kv.first, converted);
139 } else {
140 if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) {
141 LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type "
142 << info.type_key << " but get " << kv.second->GetTypeKey();
143 }
144 }
145 }
146 for (auto&& kv : update) {
147 config->Set(kv.first, kv.second);
148 }
149 }
150
151 Map<String, Map<String, String>> ListConfigs() {
152 Map<String, Map<String, String>> configs;
153 for (const auto& kv : key2vtype_) {
154 Map<String, String> metadata;
155 metadata.Set("type", kv.second.type_key);
156 configs.Set(kv.first, metadata);
157 }
158 return configs;
159 }
160
161 static PassConfigManager* Global() {
162 static auto* inst = new PassConfigManager();
163 return inst;
164 }
165
166 private:
167 struct ValueTypeInfo {
168 std::string type_key;
169 uint32_t type_index;
170 };
171
172 std::unordered_map<std::string, ValueTypeInfo> key2vtype_;
173};
174
175void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) {
176 PassConfigManager::Global()->Register(key, value_type_index);
177}
178
179Map<String, Map<String, String>> PassContext::ListConfigs() {
180 return PassConfigManager::Global()->ListConfigs();
181}
182
183PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }
184
185void PassContext::InstrumentEnterPassContext() {
186 auto pass_ctx_node = this->operator->();
187 if (pass_ctx_node->instruments.defined()) {
188 Array<instrument::PassInstrument> enter_successes;
189 try {
190 for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
191 pi->EnterPassContext();
192 enter_successes.push_back(pi);
193 }
194 } catch (const Error& e) {
195 LOG(INFO) << "Pass instrumentation entering pass context failed.";
196 LOG(INFO) << "Disable pass instrumentation.";
197 pass_ctx_node->instruments.clear();
198
199 for (instrument::PassInstrument pi : enter_successes) {
200 LOG(INFO) << pi->name << " exiting PassContext ...";
201 pi->ExitPassContext();
202 LOG(INFO) << pi->name << " exited PassContext.";
203 }
204 enter_successes.clear();
205
206 throw e;
207 }
208 }
209}
210
211void PassContext::InstrumentExitPassContext() {
212 auto pass_ctx_node = this->operator->();
213 if (pass_ctx_node->instruments.defined()) {
214 try {
215 for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
216 pi->ExitPassContext();
217 }
218 } catch (const Error& e) {
219 LOG(INFO) << "Pass instrumentation exiting pass context failed.";
220 pass_ctx_node->instruments.clear();
221 throw e;
222 }
223 }
224}
225
226bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const {
227 auto pass_ctx_node = this->operator->();
228 if (!pass_ctx_node->instruments.defined()) {
229 return true;
230 }
231
232 const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, pass_info->name);
233 bool should_run = true;
234 if (!pass_required) {
235 for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
236 should_run &= pi->ShouldRun(ir_module, pass_info);
237 }
238 }
239
240 if (should_run) {
241 for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
242 pi->RunBeforePass(ir_module, pass_info);
243 }
244 }
245 return should_run;
246}
247
248void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const {
249 auto pass_ctx_node = this->operator->();
250 if (pass_ctx_node->instruments.defined()) {
251 for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
252 pi->RunAfterPass(ir_module, pass_info);
253 }
254 }
255}
256
257IRModule Pass::operator()(IRModule mod) const {
258 return this->operator()(std::move(mod), PassContext::Current());
259}
260
261IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
262 const PassNode* node = operator->();
263 ICHECK(node != nullptr);
264 const PassInfo& pass_info = node->Info();
265 if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) {
266 DLOG(INFO) << "Skipping pass : " << pass_info->name
267 << " with opt level: " << pass_info->opt_level;
268 return mod;
269 }
270 IRModule ret;
271 if (pass_ctx->GetConfig<Bool>("testing.immutable_module", Bool(false)).value()) {
272 ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
273 } else {
274 ret = node->operator()(std::move(mod), pass_ctx);
275 }
276 pass_ctx.InstrumentAfterPass(ret, pass_info);
277 return std::move(ret);
278}
279
280IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node,
281 const PassContext& pass_ctx) {
282 size_t before_pass_hash = tvm::StructuralHash()(mod);
283 ObjectPtr<Object> module_ptr = ObjectRef::GetDataPtr<Object>(mod);
284 IRModule copy_mod = IRModule(module_ptr);
285 IRModule ret = node->operator()(mod, pass_ctx);
286 size_t after_pass_hash = tvm::StructuralHash()(copy_mod);
287 if (before_pass_hash != after_pass_hash) {
288 // The chance of getting a hash conflict between a module and the same module but mutated
289 // must be very low.
290 LOG_FATAL << "Immutable module has been modified in pass: " << node->Info()->name;
291 }
292 return std::move(ret);
293}
294
295/*!
296 * \brief Module-level passes are designed to implement global
297 * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
298 * at this level have the full control of a given Relay program including
299 * addition and deletion of functions.
300 */
301class ModulePassNode : public PassNode {
302 public:
303 /* \brief The pass meta data.*/
304 PassInfo pass_info;
305
306 /*! \brief The pass function sketches the real optimization. For example,
307 * we may need to perform dead code elimination on the module level. We could
308 * implement the algorithm in the `pass_func` and let it run on a module. It
309 * will then remove the dead code including the unused functions in the module.
310 */
311 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func;
312
313 ModulePassNode() = default;
314
315 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
316
317 /*!
318 * \brief Run a module pass on given pass context.
319 *
320 * \param mod The module that an optimization pass is applied on.
321 * \param mod The context that an optimization pass executes on.
322 *
323 * \return Return the updated module.
324 */
325 IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
326
327 /*!
328 * \brief Get the pass information/meta data.
329 */
330 PassInfo Info() const override { return pass_info; }
331
332 static constexpr const char* _type_key = "transform.ModulePass";
333 TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
334};
335
336class ModulePass : public Pass {
337 public:
338 ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
339 PassInfo pass_info);
340
341 TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
342};
343
344PassInfo::PassInfo(int opt_level, String name, tvm::Array<runtime::String> required) {
345 auto pass_info = make_object<PassInfoNode>();
346 pass_info->opt_level = opt_level;
347 pass_info->name = std::move(name);
348 pass_info->required = std::move(required);
349 data_ = std::move(pass_info);
350}
351
352ModulePass::ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
353 PassInfo pass_info) {
354 auto n = make_object<ModulePassNode>();
355 n->pass_func = std::move(pass_func);
356 n->pass_info = std::move(pass_info);
357 data_ = std::move(n);
358}
359
360// Module -> Module optimizations.
361IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
362 DiagnosticContext previous = DiagnosticContext::Default(mod);
363
364 if (pass_ctx->diag_ctx) {
365 DiagnosticContext tmp = pass_ctx->diag_ctx.value();
366 pass_ctx->diag_ctx = previous;
367 previous = tmp;
368 } else {
369 pass_ctx->diag_ctx = previous;
370 }
371
372 ICHECK(pass_ctx->diag_ctx)
373 << "The diagnostic context was set at the top of this block this is a bug.";
374
375 const PassInfo& pass_info = Info();
376 ICHECK(mod.defined()) << "The input module must be set.";
377
378 VLOG_CONTEXT << pass_info->name;
379 VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level;
380
381 mod = pass_func(std::move(mod), pass_ctx);
382
383 ICHECK(mod.defined()) << "The return value of a module pass must be set.";
384
385 ICHECK(pass_ctx->diag_ctx)
386 << "The diagnostic context was set at the top of this block this is a bug.";
387
388 pass_ctx->diag_ctx.value().Render();
389 pass_ctx->diag_ctx = previous;
390
391 return mod;
392}
393
394Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
395 auto n = make_object<SequentialNode>();
396 n->passes = std::move(passes);
397 n->pass_info = std::move(pass_info);
398 data_ = std::move(n);
399}
400
401Sequential::Sequential(tvm::Array<Pass> passes, String name) {
402 auto n = make_object<SequentialNode>();
403 n->passes = std::move(passes);
404 PassInfo pass_info = PassInfo(0, std::move(name), {});
405 n->pass_info = std::move(pass_info);
406 data_ = std::move(n);
407}
408
409const SequentialNode* Sequential::operator->() const {
410 return static_cast<const SequentialNode*>(get());
411}
412
413void SequentialNode::ResolveDependency(const IRModule& mod) {
414 // TODO(zhiics) Implement it.
415 // 1. Consider the required passes for each pass.
416 // 2. Only resolve the enabled passes.
417 // 3. Build a dependency graph. Probably we need to update the pass list.
418 LOG(FATAL) << "Pass dependency has not been resolved yet."
419 << "\n";
420}
421
422Pass GetPass(const String& pass_name) {
423 using tvm::runtime::Registry;
424 const runtime::PackedFunc* f = nullptr;
425 if (pass_name.operator std::string().find("transform.") != std::string::npos) {
426 f = Registry::Get(pass_name);
427 } else if ((f = Registry::Get("transform." + pass_name))) {
428 // pass
429 } else if ((f = Registry::Get("relay._transform." + pass_name))) {
430 }
431 ICHECK(f != nullptr) << "Cannot use " << pass_name << " to create the pass";
432 return (*f)();
433}
434
435// TODO(zhiics): we currently only sequentially execute each pass in
436// a Sequential without the consideration of their orders. The phase
437// ordering problem needs to be handled in the future.
438IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
439 for (const Pass& pass : passes) {
440 VLOG(0) << "Running pass " << pass->Info()->name;
441 ICHECK(pass.defined()) << "Found undefined pass for optimization.";
442 const PassInfo& pass_info = pass->Info();
443 if (!pass_ctx.PassEnabled(pass_info)) {
444 VLOG(0) << "skipping disabled pass '" << pass_info->name << "'";
445 continue;
446 }
447 // resolve dependencies
448 for (const auto& it : pass_info->required) {
449 mod = GetPass(it)(std::move(mod), pass_ctx);
450 }
451 mod = pass(std::move(mod), pass_ctx);
452 }
453 return mod;
454}
455
456Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
457 int opt_level, String name, tvm::Array<String> required) {
458 PassInfo pass_info = PassInfo(opt_level, name, required);
459 return ModulePass(pass_func, pass_info);
460}
461
462TVM_REGISTER_NODE_TYPE(PassInfoNode);
463
464TVM_REGISTER_GLOBAL("transform.PassInfo")
465 .set_body_typed([](int opt_level, String name, tvm::Array<String> required) {
466 return PassInfo(opt_level, name, required);
467 });
468
469TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) {
470 Pass pass = args[0];
471 *ret = pass->Info();
472});
473
474TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
475 .set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
476 auto* node = static_cast<const PassInfoNode*>(ref.get());
477 p->stream << "The meta data of the pass - ";
478 p->stream << "pass name: " << node->name;
479 p->stream << ", opt_level: " << node->opt_level;
480 if (node->required.empty()) {
481 p->stream << ", required passes: []\n";
482 } else {
483 p->stream << ", required passes: ["
484 << "\n";
485 for (const auto& it : node->required) {
486 p->stream << it << ", ";
487 }
488 p->stream << "]\n";
489 }
490 });
491
492TVM_REGISTER_NODE_TYPE(ModulePassNode);
493
494TVM_REGISTER_GLOBAL("transform.MakeModulePass")
495 .set_body_typed([](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
496 PassInfo pass_info) { return ModulePass(pass_func, pass_info); });
497
498TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) {
499 return pass(std::move(mod));
500});
501
502TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
503 .set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
504 auto* node = static_cast<const ModulePassNode*>(ref.get());
505 const PassInfo info = node->Info();
506 p->stream << "Run Module pass: " << info->name << " at the optimization level "
507 << info->opt_level;
508 });
509
510TVM_REGISTER_NODE_TYPE(SequentialNode);
511
512TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValue* ret) {
513 tvm::Array<Pass> passes = args[0];
514 int opt_level = args[1];
515 std::string name = args[2];
516 tvm::Array<runtime::String> required = args[3];
517 PassInfo pass_info = PassInfo(opt_level, name, required);
518 *ret = Sequential(passes, pass_info);
519});
520
521TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
522 .set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
523 auto* node = static_cast<const SequentialNode*>(ref.get());
524 const PassInfo info = node->Info();
525 p->stream << "Run Sequential pass: " << info->name << " at the optimization level "
526 << info->opt_level << ". ";
527 p->stream << "The passes will be executed are: [";
528 for (const auto& it : node->passes) {
529 const PassInfo pass_info = it->Info();
530 p->stream << pass_info->name << " ";
531 }
532 p->stream << "]";
533 });
534
535TVM_REGISTER_NODE_TYPE(PassContextNode);
536
537TVM_REGISTER_GLOBAL("transform.PassContext")
538 .set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
539 Array<instrument::PassInstrument> instruments,
540 Optional<Map<String, ObjectRef>> config) {
541 auto pctx = PassContext::Create();
542 pctx->opt_level = opt_level;
543
544 pctx->required_pass = std::move(required);
545 pctx->disabled_pass = std::move(disabled);
546 pctx->instruments = std::move(instruments);
547 if (config.defined()) {
548 pctx->config = config.value();
549 }
550 PassConfigManager::Global()->Legalize(&(pctx->config));
551 return pctx;
552 });
553
554TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
555 .set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
556 auto* node = static_cast<const PassContextNode*>(ref.get());
557 p->stream << "Pass context information: "
558 << "\n";
559 p->stream << "\topt_level: " << node->opt_level << "\n";
560
561 p->stream << "\trequired passes: " << node->required_pass << "\n";
562 p->stream << "\tdisabled passes: " << node->disabled_pass << "\n";
563 p->stream << "\tinstruments: " << node->instruments << "\n";
564
565 p->stream << "\tconfig: " << node->config;
566 });
567
568class PassContext::Internal {
569 public:
570 static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); }
571
572 static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); }
573};
574
575TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current);
576
577TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope);
578
579TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope);
580
581TVM_REGISTER_GLOBAL("transform.OverrideInstruments")
582 .set_body_typed([](PassContext pass_ctx, Array<instrument::PassInstrument> instruments) {
583 pass_ctx.InstrumentExitPassContext();
584 pass_ctx->instruments = instruments;
585 pass_ctx.InstrumentEnterPassContext();
586 });
587
588Pass PrintIR(String header, bool show_meta_data) {
589 auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
590 if (const auto* f = runtime::Registry::Get("relay.ir.PrintIR")) {
591 if ((*f)(mod, header, show_meta_data)) {
592 return mod;
593 }
594 }
595 LOG(INFO) << "PrintIR(" << header << "):\n" << mod;
596 return mod;
597 };
598 return CreateModulePass(pass_func, 0, "PrintIR", {});
599}
600
601TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR);
602
603TVM_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs);
604
605} // namespace transform
606} // namespace tvm
607