1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/ir/transform.cc |
22 | * \brief Infrastructure for transformation passes. |
23 | */ |
24 | #include <dmlc/thread_local.h> |
25 | #include <tvm/ir/transform.h> |
26 | #include <tvm/node/repr_printer.h> |
27 | #include <tvm/node/structural_hash.h> |
28 | #include <tvm/runtime/device_api.h> |
29 | #include <tvm/runtime/registry.h> |
30 | |
31 | #include <chrono> |
32 | #include <iomanip> |
33 | #include <stack> |
34 | #include <unordered_set> |
35 | |
36 | #include "../runtime/object_internal.h" |
37 | |
38 | namespace tvm { |
39 | namespace transform { |
40 | |
41 | using tvm::ReprPrinter; |
42 | using tvm::runtime::TVMArgs; |
43 | using tvm::runtime::TVMRetValue; |
44 | |
45 | TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module" , Bool); |
46 | |
47 | struct PassContextThreadLocalEntry { |
48 | /*! \brief The default pass context. */ |
49 | PassContext default_context; |
50 | |
51 | /*! \brief The current pass context. */ |
52 | std::stack<PassContext> context_stack; |
53 | |
54 | PassContextThreadLocalEntry() { default_context = PassContext(make_object<PassContextNode>()); } |
55 | }; |
56 | |
57 | /*! \brief Thread local store to hold the pass context. */ |
58 | typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry> RelayPassContextThreadLocalStore; |
59 | |
60 | void PassContext::EnterWithScope() { |
61 | InstrumentEnterPassContext(); |
62 | |
63 | PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); |
64 | entry->context_stack.push(*this); |
65 | } |
66 | |
67 | void PassContext::ExitWithScope() { |
68 | PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); |
69 | ICHECK(!entry->context_stack.empty()); |
70 | ICHECK(entry->context_stack.top().same_as(*this)); |
71 | entry->context_stack.pop(); |
72 | |
73 | InstrumentExitPassContext(); |
74 | } |
75 | |
76 | PassContext PassContext::Current() { |
77 | PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); |
78 | if (!entry->context_stack.empty()) { |
79 | return entry->context_stack.top(); |
80 | } else { |
81 | return entry->default_context; |
82 | } |
83 | } |
84 | |
85 | // linearly scan the pass array to match pass_name |
86 | bool PassArrayContains(const Array<runtime::String>& pass_array, const std::string& pass_name) { |
87 | for (auto x : pass_array) { |
88 | if (x == pass_name) return true; |
89 | } |
90 | return false; |
91 | } |
92 | |
93 | bool PassContext::PassEnabled(const PassInfo& info) const { |
94 | if (PassArrayContains(operator->()->disabled_pass, info->name)) { |
95 | return false; |
96 | } |
97 | |
98 | if (PassArrayContains(operator->()->required_pass, info->name)) { |
99 | return true; |
100 | } |
101 | |
102 | return operator->()->opt_level >= info->opt_level; |
103 | } |
104 | |
105 | class PassConfigManager { |
106 | public: |
107 | void Register(std::string key, uint32_t value_type_index) { |
108 | ICHECK_EQ(key2vtype_.count(key), 0U); |
109 | ValueTypeInfo info; |
110 | info.type_index = value_type_index; |
111 | info.type_key = runtime::Object::TypeIndex2Key(value_type_index); |
112 | key2vtype_[key] = info; |
113 | } |
114 | |
115 | // Trying to validate and legalize a config. |
116 | void Legalize(Map<String, ObjectRef>* config) { |
117 | std::vector<std::pair<std::string, ObjectRef>> update; |
118 | auto* reflection = ReflectionVTable::Global(); |
119 | |
120 | for (auto kv : *config) { |
121 | auto it = key2vtype_.find(kv.first); |
122 | if (it == key2vtype_.end()) { |
123 | std::ostringstream os; |
124 | os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:" ; |
125 | int counter = 0; |
126 | for (const auto& kv : key2vtype_) { |
127 | os << ' '; |
128 | if (counter++ != 0) os << ','; |
129 | os << kv.first; |
130 | } |
131 | LOG(FATAL) << os.str(); |
132 | } |
133 | const auto& info = it->second; |
134 | ICHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None" ; |
135 | if (kv.second->IsInstance<Map<String, ObjectRef>::ContainerType>()) { |
136 | ObjectRef converted = |
137 | reflection->CreateObject(info.type_key, Downcast<Map<String, ObjectRef>>(kv.second)); |
138 | update.emplace_back(kv.first, converted); |
139 | } else { |
140 | if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { |
141 | LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " |
142 | << info.type_key << " but get " << kv.second->GetTypeKey(); |
143 | } |
144 | } |
145 | } |
146 | for (auto&& kv : update) { |
147 | config->Set(kv.first, kv.second); |
148 | } |
149 | } |
150 | |
151 | Map<String, Map<String, String>> ListConfigs() { |
152 | Map<String, Map<String, String>> configs; |
153 | for (const auto& kv : key2vtype_) { |
154 | Map<String, String> metadata; |
155 | metadata.Set("type" , kv.second.type_key); |
156 | configs.Set(kv.first, metadata); |
157 | } |
158 | return configs; |
159 | } |
160 | |
161 | static PassConfigManager* Global() { |
162 | static auto* inst = new PassConfigManager(); |
163 | return inst; |
164 | } |
165 | |
166 | private: |
167 | struct ValueTypeInfo { |
168 | std::string type_key; |
169 | uint32_t type_index; |
170 | }; |
171 | |
172 | std::unordered_map<std::string, ValueTypeInfo> key2vtype_; |
173 | }; |
174 | |
175 | void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { |
176 | PassConfigManager::Global()->Register(key, value_type_index); |
177 | } |
178 | |
179 | Map<String, Map<String, String>> PassContext::ListConfigs() { |
180 | return PassConfigManager::Global()->ListConfigs(); |
181 | } |
182 | |
183 | PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); } |
184 | |
185 | void PassContext::InstrumentEnterPassContext() { |
186 | auto pass_ctx_node = this->operator->(); |
187 | if (pass_ctx_node->instruments.defined()) { |
188 | Array<instrument::PassInstrument> enter_successes; |
189 | try { |
190 | for (instrument::PassInstrument pi : pass_ctx_node->instruments) { |
191 | pi->EnterPassContext(); |
192 | enter_successes.push_back(pi); |
193 | } |
194 | } catch (const Error& e) { |
195 | LOG(INFO) << "Pass instrumentation entering pass context failed." ; |
196 | LOG(INFO) << "Disable pass instrumentation." ; |
197 | pass_ctx_node->instruments.clear(); |
198 | |
199 | for (instrument::PassInstrument pi : enter_successes) { |
200 | LOG(INFO) << pi->name << " exiting PassContext ..." ; |
201 | pi->ExitPassContext(); |
202 | LOG(INFO) << pi->name << " exited PassContext." ; |
203 | } |
204 | enter_successes.clear(); |
205 | |
206 | throw e; |
207 | } |
208 | } |
209 | } |
210 | |
211 | void PassContext::InstrumentExitPassContext() { |
212 | auto pass_ctx_node = this->operator->(); |
213 | if (pass_ctx_node->instruments.defined()) { |
214 | try { |
215 | for (instrument::PassInstrument pi : pass_ctx_node->instruments) { |
216 | pi->ExitPassContext(); |
217 | } |
218 | } catch (const Error& e) { |
219 | LOG(INFO) << "Pass instrumentation exiting pass context failed." ; |
220 | pass_ctx_node->instruments.clear(); |
221 | throw e; |
222 | } |
223 | } |
224 | } |
225 | |
226 | bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { |
227 | auto pass_ctx_node = this->operator->(); |
228 | if (!pass_ctx_node->instruments.defined()) { |
229 | return true; |
230 | } |
231 | |
232 | const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, pass_info->name); |
233 | bool should_run = true; |
234 | if (!pass_required) { |
235 | for (instrument::PassInstrument pi : pass_ctx_node->instruments) { |
236 | should_run &= pi->ShouldRun(ir_module, pass_info); |
237 | } |
238 | } |
239 | |
240 | if (should_run) { |
241 | for (instrument::PassInstrument pi : pass_ctx_node->instruments) { |
242 | pi->RunBeforePass(ir_module, pass_info); |
243 | } |
244 | } |
245 | return should_run; |
246 | } |
247 | |
248 | void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { |
249 | auto pass_ctx_node = this->operator->(); |
250 | if (pass_ctx_node->instruments.defined()) { |
251 | for (instrument::PassInstrument pi : pass_ctx_node->instruments) { |
252 | pi->RunAfterPass(ir_module, pass_info); |
253 | } |
254 | } |
255 | } |
256 | |
257 | IRModule Pass::operator()(IRModule mod) const { |
258 | return this->operator()(std::move(mod), PassContext::Current()); |
259 | } |
260 | |
261 | IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { |
262 | const PassNode* node = operator->(); |
263 | ICHECK(node != nullptr); |
264 | const PassInfo& pass_info = node->Info(); |
265 | if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { |
266 | DLOG(INFO) << "Skipping pass : " << pass_info->name |
267 | << " with opt level: " << pass_info->opt_level; |
268 | return mod; |
269 | } |
270 | IRModule ret; |
271 | if (pass_ctx->GetConfig<Bool>("testing.immutable_module" , Bool(false)).value()) { |
272 | ret = Pass::AssertImmutableModule(mod, node, pass_ctx); |
273 | } else { |
274 | ret = node->operator()(std::move(mod), pass_ctx); |
275 | } |
276 | pass_ctx.InstrumentAfterPass(ret, pass_info); |
277 | return std::move(ret); |
278 | } |
279 | |
280 | IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node, |
281 | const PassContext& pass_ctx) { |
282 | size_t before_pass_hash = tvm::StructuralHash()(mod); |
283 | ObjectPtr<Object> module_ptr = ObjectRef::GetDataPtr<Object>(mod); |
284 | IRModule copy_mod = IRModule(module_ptr); |
285 | IRModule ret = node->operator()(mod, pass_ctx); |
286 | size_t after_pass_hash = tvm::StructuralHash()(copy_mod); |
287 | if (before_pass_hash != after_pass_hash) { |
288 | // The chance of getting a hash conflict between a module and the same module but mutated |
289 | // must be very low. |
290 | LOG_FATAL << "Immutable module has been modified in pass: " << node->Info()->name; |
291 | } |
292 | return std::move(ret); |
293 | } |
294 | |
295 | /*! |
296 | * \brief Module-level passes are designed to implement global |
297 | * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes |
298 | * at this level have the full control of a given Relay program including |
299 | * addition and deletion of functions. |
300 | */ |
301 | class ModulePassNode : public PassNode { |
302 | public: |
303 | /* \brief The pass meta data.*/ |
304 | PassInfo pass_info; |
305 | |
306 | /*! \brief The pass function sketches the real optimization. For example, |
307 | * we may need to perform dead code elimination on the module level. We could |
308 | * implement the algorithm in the `pass_func` and let it run on a module. It |
309 | * will then remove the dead code including the unused functions in the module. |
310 | */ |
311 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func; |
312 | |
313 | ModulePassNode() = default; |
314 | |
315 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info" , &pass_info); } |
316 | |
317 | /*! |
318 | * \brief Run a module pass on given pass context. |
319 | * |
320 | * \param mod The module that an optimization pass is applied on. |
321 | * \param mod The context that an optimization pass executes on. |
322 | * |
323 | * \return Return the updated module. |
324 | */ |
325 | IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; |
326 | |
327 | /*! |
328 | * \brief Get the pass information/meta data. |
329 | */ |
330 | PassInfo Info() const override { return pass_info; } |
331 | |
332 | static constexpr const char* _type_key = "transform.ModulePass" ; |
333 | TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); |
334 | }; |
335 | |
336 | class ModulePass : public Pass { |
337 | public: |
338 | ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func, |
339 | PassInfo pass_info); |
340 | |
341 | TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); |
342 | }; |
343 | |
344 | PassInfo::PassInfo(int opt_level, String name, tvm::Array<runtime::String> required) { |
345 | auto pass_info = make_object<PassInfoNode>(); |
346 | pass_info->opt_level = opt_level; |
347 | pass_info->name = std::move(name); |
348 | pass_info->required = std::move(required); |
349 | data_ = std::move(pass_info); |
350 | } |
351 | |
352 | ModulePass::ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func, |
353 | PassInfo pass_info) { |
354 | auto n = make_object<ModulePassNode>(); |
355 | n->pass_func = std::move(pass_func); |
356 | n->pass_info = std::move(pass_info); |
357 | data_ = std::move(n); |
358 | } |
359 | |
360 | // Module -> Module optimizations. |
361 | IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { |
362 | DiagnosticContext previous = DiagnosticContext::Default(mod); |
363 | |
364 | if (pass_ctx->diag_ctx) { |
365 | DiagnosticContext tmp = pass_ctx->diag_ctx.value(); |
366 | pass_ctx->diag_ctx = previous; |
367 | previous = tmp; |
368 | } else { |
369 | pass_ctx->diag_ctx = previous; |
370 | } |
371 | |
372 | ICHECK(pass_ctx->diag_ctx) |
373 | << "The diagnostic context was set at the top of this block this is a bug." ; |
374 | |
375 | const PassInfo& pass_info = Info(); |
376 | ICHECK(mod.defined()) << "The input module must be set." ; |
377 | |
378 | VLOG_CONTEXT << pass_info->name; |
379 | VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level; |
380 | |
381 | mod = pass_func(std::move(mod), pass_ctx); |
382 | |
383 | ICHECK(mod.defined()) << "The return value of a module pass must be set." ; |
384 | |
385 | ICHECK(pass_ctx->diag_ctx) |
386 | << "The diagnostic context was set at the top of this block this is a bug." ; |
387 | |
388 | pass_ctx->diag_ctx.value().Render(); |
389 | pass_ctx->diag_ctx = previous; |
390 | |
391 | return mod; |
392 | } |
393 | |
394 | Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) { |
395 | auto n = make_object<SequentialNode>(); |
396 | n->passes = std::move(passes); |
397 | n->pass_info = std::move(pass_info); |
398 | data_ = std::move(n); |
399 | } |
400 | |
401 | Sequential::Sequential(tvm::Array<Pass> passes, String name) { |
402 | auto n = make_object<SequentialNode>(); |
403 | n->passes = std::move(passes); |
404 | PassInfo pass_info = PassInfo(0, std::move(name), {}); |
405 | n->pass_info = std::move(pass_info); |
406 | data_ = std::move(n); |
407 | } |
408 | |
409 | const SequentialNode* Sequential::operator->() const { |
410 | return static_cast<const SequentialNode*>(get()); |
411 | } |
412 | |
413 | void SequentialNode::ResolveDependency(const IRModule& mod) { |
414 | // TODO(zhiics) Implement it. |
415 | // 1. Consider the required passes for each pass. |
416 | // 2. Only resolve the enabled passes. |
417 | // 3. Build a dependency graph. Probably we need to update the pass list. |
418 | LOG(FATAL) << "Pass dependency has not been resolved yet." |
419 | << "\n" ; |
420 | } |
421 | |
422 | Pass GetPass(const String& pass_name) { |
423 | using tvm::runtime::Registry; |
424 | const runtime::PackedFunc* f = nullptr; |
425 | if (pass_name.operator std::string().find("transform." ) != std::string::npos) { |
426 | f = Registry::Get(pass_name); |
427 | } else if ((f = Registry::Get("transform." + pass_name))) { |
428 | // pass |
429 | } else if ((f = Registry::Get("relay._transform." + pass_name))) { |
430 | } |
431 | ICHECK(f != nullptr) << "Cannot use " << pass_name << " to create the pass" ; |
432 | return (*f)(); |
433 | } |
434 | |
435 | // TODO(zhiics): we currently only sequentially execute each pass in |
436 | // a Sequential without the consideration of their orders. The phase |
437 | // ordering problem needs to be handled in the future. |
438 | IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { |
439 | for (const Pass& pass : passes) { |
440 | VLOG(0) << "Running pass " << pass->Info()->name; |
441 | ICHECK(pass.defined()) << "Found undefined pass for optimization." ; |
442 | const PassInfo& pass_info = pass->Info(); |
443 | if (!pass_ctx.PassEnabled(pass_info)) { |
444 | VLOG(0) << "skipping disabled pass '" << pass_info->name << "'" ; |
445 | continue; |
446 | } |
447 | // resolve dependencies |
448 | for (const auto& it : pass_info->required) { |
449 | mod = GetPass(it)(std::move(mod), pass_ctx); |
450 | } |
451 | mod = pass(std::move(mod), pass_ctx); |
452 | } |
453 | return mod; |
454 | } |
455 | |
456 | Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, |
457 | int opt_level, String name, tvm::Array<String> required) { |
458 | PassInfo pass_info = PassInfo(opt_level, name, required); |
459 | return ModulePass(pass_func, pass_info); |
460 | } |
461 | |
462 | TVM_REGISTER_NODE_TYPE(PassInfoNode); |
463 | |
464 | TVM_REGISTER_GLOBAL("transform.PassInfo" ) |
465 | .set_body_typed([](int opt_level, String name, tvm::Array<String> required) { |
466 | return PassInfo(opt_level, name, required); |
467 | }); |
468 | |
469 | TVM_REGISTER_GLOBAL("transform.Info" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
470 | Pass pass = args[0]; |
471 | *ret = pass->Info(); |
472 | }); |
473 | |
474 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
475 | .set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) { |
476 | auto* node = static_cast<const PassInfoNode*>(ref.get()); |
477 | p->stream << "The meta data of the pass - " ; |
478 | p->stream << "pass name: " << node->name; |
479 | p->stream << ", opt_level: " << node->opt_level; |
480 | if (node->required.empty()) { |
481 | p->stream << ", required passes: []\n" ; |
482 | } else { |
483 | p->stream << ", required passes: [" |
484 | << "\n" ; |
485 | for (const auto& it : node->required) { |
486 | p->stream << it << ", " ; |
487 | } |
488 | p->stream << "]\n" ; |
489 | } |
490 | }); |
491 | |
492 | TVM_REGISTER_NODE_TYPE(ModulePassNode); |
493 | |
494 | TVM_REGISTER_GLOBAL("transform.MakeModulePass" ) |
495 | .set_body_typed([](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func, |
496 | PassInfo pass_info) { return ModulePass(pass_func, pass_info); }); |
497 | |
498 | TVM_REGISTER_GLOBAL("transform.RunPass" ).set_body_typed([](Pass pass, IRModule mod) { |
499 | return pass(std::move(mod)); |
500 | }); |
501 | |
502 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
503 | .set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) { |
504 | auto* node = static_cast<const ModulePassNode*>(ref.get()); |
505 | const PassInfo info = node->Info(); |
506 | p->stream << "Run Module pass: " << info->name << " at the optimization level " |
507 | << info->opt_level; |
508 | }); |
509 | |
510 | TVM_REGISTER_NODE_TYPE(SequentialNode); |
511 | |
512 | TVM_REGISTER_GLOBAL("transform.Sequential" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
513 | tvm::Array<Pass> passes = args[0]; |
514 | int opt_level = args[1]; |
515 | std::string name = args[2]; |
516 | tvm::Array<runtime::String> required = args[3]; |
517 | PassInfo pass_info = PassInfo(opt_level, name, required); |
518 | *ret = Sequential(passes, pass_info); |
519 | }); |
520 | |
521 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
522 | .set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) { |
523 | auto* node = static_cast<const SequentialNode*>(ref.get()); |
524 | const PassInfo info = node->Info(); |
525 | p->stream << "Run Sequential pass: " << info->name << " at the optimization level " |
526 | << info->opt_level << ". " ; |
527 | p->stream << "The passes will be executed are: [" ; |
528 | for (const auto& it : node->passes) { |
529 | const PassInfo pass_info = it->Info(); |
530 | p->stream << pass_info->name << " " ; |
531 | } |
532 | p->stream << "]" ; |
533 | }); |
534 | |
535 | TVM_REGISTER_NODE_TYPE(PassContextNode); |
536 | |
537 | TVM_REGISTER_GLOBAL("transform.PassContext" ) |
538 | .set_body_typed([](int opt_level, Array<String> required, Array<String> disabled, |
539 | Array<instrument::PassInstrument> instruments, |
540 | Optional<Map<String, ObjectRef>> config) { |
541 | auto pctx = PassContext::Create(); |
542 | pctx->opt_level = opt_level; |
543 | |
544 | pctx->required_pass = std::move(required); |
545 | pctx->disabled_pass = std::move(disabled); |
546 | pctx->instruments = std::move(instruments); |
547 | if (config.defined()) { |
548 | pctx->config = config.value(); |
549 | } |
550 | PassConfigManager::Global()->Legalize(&(pctx->config)); |
551 | return pctx; |
552 | }); |
553 | |
554 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
555 | .set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) { |
556 | auto* node = static_cast<const PassContextNode*>(ref.get()); |
557 | p->stream << "Pass context information: " |
558 | << "\n" ; |
559 | p->stream << "\topt_level: " << node->opt_level << "\n" ; |
560 | |
561 | p->stream << "\trequired passes: " << node->required_pass << "\n" ; |
562 | p->stream << "\tdisabled passes: " << node->disabled_pass << "\n" ; |
563 | p->stream << "\tinstruments: " << node->instruments << "\n" ; |
564 | |
565 | p->stream << "\tconfig: " << node->config; |
566 | }); |
567 | |
568 | class PassContext::Internal { |
569 | public: |
570 | static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); } |
571 | |
572 | static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } |
573 | }; |
574 | |
575 | TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext" ).set_body_typed(PassContext::Current); |
576 | |
577 | TVM_REGISTER_GLOBAL("transform.EnterPassContext" ).set_body_typed(PassContext::Internal::EnterScope); |
578 | |
579 | TVM_REGISTER_GLOBAL("transform.ExitPassContext" ).set_body_typed(PassContext::Internal::ExitScope); |
580 | |
581 | TVM_REGISTER_GLOBAL("transform.OverrideInstruments" ) |
582 | .set_body_typed([](PassContext pass_ctx, Array<instrument::PassInstrument> instruments) { |
583 | pass_ctx.InstrumentExitPassContext(); |
584 | pass_ctx->instruments = instruments; |
585 | pass_ctx.InstrumentEnterPassContext(); |
586 | }); |
587 | |
588 | Pass PrintIR(String , bool show_meta_data) { |
589 | auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { |
590 | if (const auto* f = runtime::Registry::Get("relay.ir.PrintIR" )) { |
591 | if ((*f)(mod, header, show_meta_data)) { |
592 | return mod; |
593 | } |
594 | } |
595 | LOG(INFO) << "PrintIR(" << header << "):\n" << mod; |
596 | return mod; |
597 | }; |
598 | return CreateModulePass(pass_func, 0, "PrintIR" , {}); |
599 | } |
600 | |
601 | TVM_REGISTER_GLOBAL("transform.PrintIR" ).set_body_typed(PrintIR); |
602 | |
603 | TVM_REGISTER_GLOBAL("transform.ListConfigs" ).set_body_typed(PassContext::ListConfigs); |
604 | |
605 | } // namespace transform |
606 | } // namespace tvm |
607 | |