1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/transform/capture_index_in_spans.cc |
22 | * \brief A pass to set spans to capture the post-dfs index of every node. |
23 | */ |
24 | |
25 | #include <tvm/relay/expr_functor.h> |
26 | #include <tvm/relay/transform.h> |
27 | |
28 | #include "../ir/indexed_graph.h" |
29 | |
30 | namespace tvm { |
31 | namespace relay { |
32 | namespace transform { |
33 | |
34 | namespace { |
35 | |
36 | /*! \brief Update all the spans to capture their post-dfs index. */ |
37 | class SpansRewriter : public ExprRewriter { |
38 | public: |
39 | explicit SpansRewriter(const IndexedGraph<Expr>* indexed_graph) |
40 | : source_name_(SourceName::Get("index" )), indexed_graph_(indexed_graph) {} |
41 | |
42 | private: |
43 | Expr Rewrite_(const VarNode* var_node, const Expr& post) final { |
44 | return WithFields(Downcast<Var>(post), {}, {}, {}, MakeSpan(GetRef<Var>(var_node))); |
45 | } |
46 | |
47 | Expr Rewrite_(const GlobalVarNode* global_var_node, const Expr& post) final { |
48 | return WithFields(Downcast<GlobalVar>(post), {}, {}, {}, |
49 | MakeSpan(GetRef<GlobalVar>(global_var_node))); |
50 | } |
51 | |
52 | Expr Rewrite_(const ConstantNode* constant_node, const Expr& post) final { |
53 | return WithFields(Downcast<Constant>(post), {}, {}, MakeSpan(GetRef<Constant>(constant_node))); |
54 | } |
55 | |
56 | Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final { |
57 | return WithFields(Downcast<Tuple>(post), {}, {}, MakeSpan(GetRef<Tuple>(tuple_node))); |
58 | } |
59 | |
60 | Expr Rewrite_(const FunctionNode* function_node, const Expr& post) final { |
61 | return WithFields(Downcast<Function>(post), {}, {}, {}, {}, {}, {}, |
62 | MakeSpan(GetRef<Function>(function_node))); |
63 | } |
64 | |
65 | Expr Rewrite_(const CallNode* call_node, const Expr& post) final { |
66 | return WithFields(Downcast<Call>(post), {}, {}, {}, {}, {}, MakeSpan(GetRef<Call>(call_node))); |
67 | } |
68 | |
69 | Expr Rewrite_(const LetNode* let_node, const Expr& post) final { |
70 | return WithFields(Downcast<Let>(post), {}, {}, {}, {}, MakeSpan(GetRef<Let>(let_node))); |
71 | } |
72 | |
73 | Expr Rewrite_(const IfNode* if_node, const Expr& post) final { |
74 | return WithFields(Downcast<If>(post), {}, {}, {}, {}, MakeSpan(GetRef<If>(if_node))); |
75 | } |
76 | |
77 | // OpNodes are not rewritten. |
78 | |
79 | Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node, const Expr& post) final { |
80 | return WithFields(Downcast<TupleGetItem>(post), {}, {}, {}, |
81 | MakeSpan(GetRef<TupleGetItem>(tuple_get_item_node))); |
82 | } |
83 | |
84 | Expr Rewrite_(const RefCreateNode* ref_create_node, const Expr& post) final { |
85 | return WithFields(Downcast<RefCreate>(post), {}, {}, |
86 | MakeSpan(GetRef<RefCreate>(ref_create_node))); |
87 | } |
88 | |
89 | Expr Rewrite_(const RefReadNode* ref_read_node, const Expr& post) final { |
90 | return WithFields(Downcast<RefRead>(post), {}, {}, MakeSpan(GetRef<RefRead>(ref_read_node))); |
91 | } |
92 | |
93 | Expr Rewrite_(const RefWriteNode* ref_write_node, const Expr& post) final { |
94 | return WithFields(Downcast<RefWrite>(post), {}, {}, {}, |
95 | MakeSpan(GetRef<RefWrite>(ref_write_node))); |
96 | } |
97 | |
98 | // ConstructorNodes are not rewritten. |
99 | |
100 | Expr Rewrite_(const MatchNode* match_node, const Expr& post) final { |
101 | return WithFields(Downcast<Match>(post), {}, {}, {}, MakeSpan(GetRef<Match>(match_node))); |
102 | } |
103 | |
104 | Span MakeSpan(const Expr& expr) { |
105 | auto node = indexed_graph_->item_to_node(expr); |
106 | int node_index = static_cast<int>(node->index_); |
107 | int dominator_index = |
108 | node->dominator_parent_ ? static_cast<int>(node->dominator_parent_->index_) : -1; |
109 | Span span(source_name_, /*line=*/node_index, /*end_line=*/node_index, |
110 | /*column=*/dominator_index, /*end_column=*/dominator_index); |
111 | return span; |
112 | } |
113 | |
114 | SourceName source_name_; |
115 | const IndexedGraph<Expr>* indexed_graph_; |
116 | }; |
117 | |
118 | } // namespace |
119 | |
120 | tvm::transform::Pass CapturePostDfsIndexInSpans() { |
121 | auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { |
122 | std::unique_ptr<IndexedGraph<Expr>> indexed_graph = CreateIndexedGraph(f); |
123 | SpansRewriter rewriter(indexed_graph.get()); |
124 | return Downcast<Function>(PostOrderRewrite(f, &rewriter)); |
125 | }; |
126 | return CreateFunctionPass(pass_func, 0, "CapturePostDfsIndexInSpans" , {}); |
127 | } |
128 | |
129 | TVM_REGISTER_GLOBAL("relay._transform.CapturePostDfsIndexInSpans" ) |
130 | .set_body_typed(CapturePostDfsIndexInSpans); |
131 | |
132 | } // namespace transform |
133 | } // namespace relay |
134 | } // namespace tvm |
135 | |