1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file to_basic_block_normal_form.cc |
23 | * |
24 | * \brief Turn an expression to the basic normal form. |
25 | */ |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/transform.h> |
29 | #include <tvm/runtime/logging.h> |
30 | |
31 | #include "../../support/arena.h" |
32 | #include "../analysis/dependency_graph.h" |
33 | #include "./pass_utils.h" |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | |
38 | IRModule ToBasicBlockNormalForm(const IRModule& mod) { |
39 | // Create a new module by shallow copy. |
40 | IRModule new_mod = mod->ShallowCopy(); |
41 | |
42 | tvm::Map<GlobalVar, Function> updates; |
43 | auto funcs = new_mod->functions; |
44 | for (const auto& it : funcs) { |
45 | ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables" ; |
46 | if (const auto* n = it.second.as<FunctionNode>()) { |
47 | if (n->GetAttr<String>(attr::kCompiler).defined()) continue; |
48 | Function func = GetRef<Function>(n); |
49 | Function ret = Downcast<Function>(ToBasicBlockNormalFormAux(func)); |
50 | VLOG(1) << "rewritten:" << std::endl |
51 | << PrettyPrint(func) << std::endl |
52 | << "to BasicBlockANF:" << std::endl |
53 | << PrettyPrint(ret); |
54 | updates.Set(it.first, Downcast<Function>(ret)); |
55 | } |
56 | } |
57 | |
58 | for (auto pair : updates) { |
59 | new_mod->Add(pair.first, pair.second, true); |
60 | } |
61 | |
62 | return new_mod; |
63 | } |
64 | |
65 | bool BasicBlockNormalFormCheck(const Expr& e) { |
66 | // calculate all the dependency between nodes. |
67 | support::Arena arena; |
68 | DependencyGraph dg = DependencyGraph::Create(&arena, e); |
69 | std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg); |
70 | for (auto expr : scopes.second) { |
71 | LOG(FATAL) << "The expression below violates the basic block normal form in that " |
72 | << "its scope should be lifted:\n" |
73 | << expr; |
74 | } |
75 | return scopes.second.size() == 0; |
76 | } |
77 | |
78 | TVM_REGISTER_GLOBAL("relay.analysis.check_basic_block_normal_form" ) |
79 | .set_body_typed(BasicBlockNormalFormCheck); |
80 | |
81 | namespace transform { |
82 | |
83 | Pass ToBasicBlockNormalForm() { |
84 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = |
85 | [=](IRModule m, PassContext pc) { return relay::ToBasicBlockNormalForm(m); }; |
86 | return CreateModulePass(pass_func, 1, "ToBasicBlockNormalForm" , {}); |
87 | } |
88 | |
89 | TVM_REGISTER_GLOBAL("relay._transform.ToBasicBlockNormalForm" ) |
90 | .set_body_typed(ToBasicBlockNormalForm); |
91 | |
92 | } // namespace transform |
93 | |
94 | } // namespace relay |
95 | } // namespace tvm |
96 | |