1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file to_basic_block_normal_form.cc
23 *
24 * \brief Turn an expression to the basic normal form.
25 */
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/transform.h>
29#include <tvm/runtime/logging.h>
30
31#include "../../support/arena.h"
32#include "../analysis/dependency_graph.h"
33#include "./pass_utils.h"
34
35namespace tvm {
36namespace relay {
37
38IRModule ToBasicBlockNormalForm(const IRModule& mod) {
39 // Create a new module by shallow copy.
40 IRModule new_mod = mod->ShallowCopy();
41
42 tvm::Map<GlobalVar, Function> updates;
43 auto funcs = new_mod->functions;
44 for (const auto& it : funcs) {
45 ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables";
46 if (const auto* n = it.second.as<FunctionNode>()) {
47 if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
48 Function func = GetRef<Function>(n);
49 Function ret = Downcast<Function>(ToBasicBlockNormalFormAux(func));
50 VLOG(1) << "rewritten:" << std::endl
51 << PrettyPrint(func) << std::endl
52 << "to BasicBlockANF:" << std::endl
53 << PrettyPrint(ret);
54 updates.Set(it.first, Downcast<Function>(ret));
55 }
56 }
57
58 for (auto pair : updates) {
59 new_mod->Add(pair.first, pair.second, true);
60 }
61
62 return new_mod;
63}
64
65bool BasicBlockNormalFormCheck(const Expr& e) {
66 // calculate all the dependency between nodes.
67 support::Arena arena;
68 DependencyGraph dg = DependencyGraph::Create(&arena, e);
69 std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
70 for (auto expr : scopes.second) {
71 LOG(FATAL) << "The expression below violates the basic block normal form in that "
72 << "its scope should be lifted:\n"
73 << expr;
74 }
75 return scopes.second.size() == 0;
76}
77
78TVM_REGISTER_GLOBAL("relay.analysis.check_basic_block_normal_form")
79 .set_body_typed(BasicBlockNormalFormCheck);
80
81namespace transform {
82
83Pass ToBasicBlockNormalForm() {
84 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
85 [=](IRModule m, PassContext pc) { return relay::ToBasicBlockNormalForm(m); };
86 return CreateModulePass(pass_func, 1, "ToBasicBlockNormalForm", {});
87}
88
89TVM_REGISTER_GLOBAL("relay._transform.ToBasicBlockNormalForm")
90 .set_body_typed(ToBasicBlockNormalForm);
91
92} // namespace transform
93
94} // namespace relay
95} // namespace tvm
96