1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/transform/capture_index_in_spans.cc
22 * \brief A pass to set spans to capture the post-dfs index of every node.
23 */
24
25#include <tvm/relay/expr_functor.h>
26#include <tvm/relay/transform.h>
27
28#include "../ir/indexed_graph.h"
29
30namespace tvm {
31namespace relay {
32namespace transform {
33
34namespace {
35
36/*! \brief Update all the spans to capture their post-dfs index. */
37class SpansRewriter : public ExprRewriter {
38 public:
39 explicit SpansRewriter(const IndexedGraph<Expr>* indexed_graph)
40 : source_name_(SourceName::Get("index")), indexed_graph_(indexed_graph) {}
41
42 private:
43 Expr Rewrite_(const VarNode* var_node, const Expr& post) final {
44 return WithFields(Downcast<Var>(post), {}, {}, {}, MakeSpan(GetRef<Var>(var_node)));
45 }
46
47 Expr Rewrite_(const GlobalVarNode* global_var_node, const Expr& post) final {
48 return WithFields(Downcast<GlobalVar>(post), {}, {}, {},
49 MakeSpan(GetRef<GlobalVar>(global_var_node)));
50 }
51
52 Expr Rewrite_(const ConstantNode* constant_node, const Expr& post) final {
53 return WithFields(Downcast<Constant>(post), {}, {}, MakeSpan(GetRef<Constant>(constant_node)));
54 }
55
56 Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
57 return WithFields(Downcast<Tuple>(post), {}, {}, MakeSpan(GetRef<Tuple>(tuple_node)));
58 }
59
60 Expr Rewrite_(const FunctionNode* function_node, const Expr& post) final {
61 return WithFields(Downcast<Function>(post), {}, {}, {}, {}, {}, {},
62 MakeSpan(GetRef<Function>(function_node)));
63 }
64
65 Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
66 return WithFields(Downcast<Call>(post), {}, {}, {}, {}, {}, MakeSpan(GetRef<Call>(call_node)));
67 }
68
69 Expr Rewrite_(const LetNode* let_node, const Expr& post) final {
70 return WithFields(Downcast<Let>(post), {}, {}, {}, {}, MakeSpan(GetRef<Let>(let_node)));
71 }
72
73 Expr Rewrite_(const IfNode* if_node, const Expr& post) final {
74 return WithFields(Downcast<If>(post), {}, {}, {}, {}, MakeSpan(GetRef<If>(if_node)));
75 }
76
77 // OpNodes are not rewritten.
78
79 Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node, const Expr& post) final {
80 return WithFields(Downcast<TupleGetItem>(post), {}, {}, {},
81 MakeSpan(GetRef<TupleGetItem>(tuple_get_item_node)));
82 }
83
84 Expr Rewrite_(const RefCreateNode* ref_create_node, const Expr& post) final {
85 return WithFields(Downcast<RefCreate>(post), {}, {},
86 MakeSpan(GetRef<RefCreate>(ref_create_node)));
87 }
88
89 Expr Rewrite_(const RefReadNode* ref_read_node, const Expr& post) final {
90 return WithFields(Downcast<RefRead>(post), {}, {}, MakeSpan(GetRef<RefRead>(ref_read_node)));
91 }
92
93 Expr Rewrite_(const RefWriteNode* ref_write_node, const Expr& post) final {
94 return WithFields(Downcast<RefWrite>(post), {}, {}, {},
95 MakeSpan(GetRef<RefWrite>(ref_write_node)));
96 }
97
98 // ConstructorNodes are not rewritten.
99
100 Expr Rewrite_(const MatchNode* match_node, const Expr& post) final {
101 return WithFields(Downcast<Match>(post), {}, {}, {}, MakeSpan(GetRef<Match>(match_node)));
102 }
103
104 Span MakeSpan(const Expr& expr) {
105 auto node = indexed_graph_->item_to_node(expr);
106 int node_index = static_cast<int>(node->index_);
107 int dominator_index =
108 node->dominator_parent_ ? static_cast<int>(node->dominator_parent_->index_) : -1;
109 Span span(source_name_, /*line=*/node_index, /*end_line=*/node_index,
110 /*column=*/dominator_index, /*end_column=*/dominator_index);
111 return span;
112 }
113
114 SourceName source_name_;
115 const IndexedGraph<Expr>* indexed_graph_;
116};
117
118} // namespace
119
120tvm::transform::Pass CapturePostDfsIndexInSpans() {
121 auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) {
122 std::unique_ptr<IndexedGraph<Expr>> indexed_graph = CreateIndexedGraph(f);
123 SpansRewriter rewriter(indexed_graph.get());
124 return Downcast<Function>(PostOrderRewrite(f, &rewriter));
125 };
126 return CreateFunctionPass(pass_func, 0, "CapturePostDfsIndexInSpans", {});
127}
128
129TVM_REGISTER_GLOBAL("relay._transform.CapturePostDfsIndexInSpans")
130 .set_body_typed(CapturePostDfsIndexInSpans);
131
132} // namespace transform
133} // namespace relay
134} // namespace tvm
135