1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/ir/adt.cc |
22 | * \brief AST nodes for Relay algebraic data types (ADTs). |
23 | */ |
24 | #include <tvm/relay/adt.h> |
25 | #include <tvm/relay/type.h> |
26 | |
27 | namespace tvm { |
28 | namespace relay { |
29 | |
30 | PatternWildcard::PatternWildcard() { |
31 | ObjectPtr<PatternWildcardNode> n = make_object<PatternWildcardNode>(); |
32 | data_ = std::move(n); |
33 | } |
34 | |
35 | TVM_REGISTER_NODE_TYPE(PatternWildcardNode); |
36 | |
37 | TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard" ).set_body_typed([]() { return PatternWildcard(); }); |
38 | |
39 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
40 | .set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, ReprPrinter* p) { |
41 | p->stream << "PatternWildcardNode()" ; |
42 | }); |
43 | |
44 | PatternVar::PatternVar(tvm::relay::Var var) { |
45 | ObjectPtr<PatternVarNode> n = make_object<PatternVarNode>(); |
46 | n->var = std::move(var); |
47 | data_ = std::move(n); |
48 | } |
49 | |
50 | TVM_REGISTER_NODE_TYPE(PatternVarNode); |
51 | |
52 | TVM_REGISTER_GLOBAL("relay.ir.PatternVar" ).set_body_typed([](tvm::relay::Var var) { |
53 | return PatternVar(var); |
54 | }); |
55 | |
56 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
57 | .set_dispatch<PatternVarNode>([](const ObjectRef& ref, ReprPrinter* p) { |
58 | auto* node = static_cast<const PatternVarNode*>(ref.get()); |
59 | p->stream << "PatternVarNode(" << node->var << ")" ; |
60 | }); |
61 | |
62 | PatternConstructor::PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns) { |
63 | ObjectPtr<PatternConstructorNode> n = make_object<PatternConstructorNode>(); |
64 | n->constructor = std::move(constructor); |
65 | n->patterns = std::move(patterns); |
66 | data_ = std::move(n); |
67 | } |
68 | |
69 | TVM_REGISTER_NODE_TYPE(PatternConstructorNode); |
70 | |
71 | TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor" ) |
72 | .set_body_typed([](Constructor constructor, tvm::Array<Pattern> patterns) { |
73 | return PatternConstructor(constructor, patterns); |
74 | }); |
75 | |
76 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
77 | .set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) { |
78 | auto* node = static_cast<const PatternConstructorNode*>(ref.get()); |
79 | p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->patterns << ")" ; |
80 | }); |
81 | |
82 | PatternTuple::PatternTuple(tvm::Array<Pattern> patterns) { |
83 | ObjectPtr<PatternTupleNode> n = make_object<PatternTupleNode>(); |
84 | n->patterns = std::move(patterns); |
85 | data_ = std::move(n); |
86 | } |
87 | |
88 | TVM_REGISTER_NODE_TYPE(PatternTupleNode); |
89 | |
90 | TVM_REGISTER_GLOBAL("relay.ir.PatternTuple" ).set_body_typed([](tvm::Array<Pattern> patterns) { |
91 | return PatternTuple(patterns); |
92 | }); |
93 | |
94 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
95 | .set_dispatch<PatternTupleNode>([](const ObjectRef& ref, ReprPrinter* p) { |
96 | auto* node = static_cast<const PatternTupleNode*>(ref.get()); |
97 | p->stream << "PatternTupleNode(" << node->patterns << ")" ; |
98 | }); |
99 | |
100 | Clause::Clause(Pattern lhs, Expr rhs) { |
101 | ObjectPtr<ClauseNode> n = make_object<ClauseNode>(); |
102 | n->lhs = std::move(lhs); |
103 | n->rhs = std::move(rhs); |
104 | data_ = std::move(n); |
105 | } |
106 | |
107 | Clause WithFields(Clause clause, Optional<Pattern> opt_lhs, Optional<Expr> opt_rhs) { |
108 | Pattern lhs = opt_lhs.value_or(clause->lhs); |
109 | Expr rhs = opt_rhs.value_or(clause->rhs); |
110 | |
111 | bool unchanged = lhs.same_as(clause->lhs) && rhs.same_as(clause->rhs); |
112 | |
113 | if (!unchanged) { |
114 | ClauseNode* cow_clause_node = clause.CopyOnWrite(); |
115 | cow_clause_node->lhs = lhs; |
116 | cow_clause_node->rhs = rhs; |
117 | } |
118 | return clause; |
119 | } |
120 | |
121 | TVM_REGISTER_NODE_TYPE(ClauseNode); |
122 | |
123 | TVM_REGISTER_GLOBAL("relay.ir.Clause" ).set_body_typed([](Pattern lhs, Expr rhs) { |
124 | return Clause(lhs, rhs); |
125 | }); |
126 | |
127 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
128 | .set_dispatch<ClauseNode>([](const ObjectRef& ref, ReprPrinter* p) { |
129 | auto* node = static_cast<const ClauseNode*>(ref.get()); |
130 | p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")" ; |
131 | }); |
132 | |
133 | Match::Match(Expr data, tvm::Array<Clause> clauses, bool complete, Span span) { |
134 | ObjectPtr<MatchNode> n = make_object<MatchNode>(); |
135 | n->data = std::move(data); |
136 | n->clauses = std::move(clauses); |
137 | n->complete = complete; |
138 | n->span = std::move(span); |
139 | data_ = std::move(n); |
140 | } |
141 | |
142 | Match WithFields(Match match, Optional<Expr> opt_data, Optional<Array<Clause>> opt_clauses, |
143 | Optional<Bool> opt_complete, Optional<Span> opt_span) { |
144 | Expr data = opt_data.value_or(match->data); |
145 | Array<Clause> clauses = opt_clauses.value_or(match->clauses); |
146 | Bool complete = opt_complete.value_or(Bool(match->complete)); |
147 | Span span = opt_span.value_or(match->span); |
148 | |
149 | bool unchanged = |
150 | data.same_as(match->data) && (complete == match->complete) && span.same_as(match->span); |
151 | |
152 | // Check that all clauses are unchanged |
153 | if (unchanged) { |
154 | bool all_clauses_unchanged = true; |
155 | if (clauses.size() == match->clauses.size()) { |
156 | for (size_t i = 0; i < clauses.size(); i++) { |
157 | all_clauses_unchanged &= clauses[i].same_as(match->clauses[i]); |
158 | } |
159 | } else { |
160 | all_clauses_unchanged = false; |
161 | } |
162 | unchanged &= all_clauses_unchanged; |
163 | } |
164 | if (!unchanged) { |
165 | MatchNode* cow_match_node = match.CopyOnWrite(); |
166 | cow_match_node->data = data; |
167 | cow_match_node->clauses = clauses; |
168 | cow_match_node->complete = complete; |
169 | cow_match_node->span = span; |
170 | } |
171 | return match; |
172 | } |
173 | |
174 | TVM_REGISTER_NODE_TYPE(MatchNode); |
175 | |
176 | TVM_REGISTER_GLOBAL("relay.ir.Match" ) |
177 | .set_body_typed([](Expr data, tvm::Array<Clause> clauses, bool complete) { |
178 | return Match(data, clauses, complete); |
179 | }); |
180 | |
181 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
182 | .set_dispatch<MatchNode>([](const ObjectRef& ref, ReprPrinter* p) { |
183 | auto* node = static_cast<const MatchNode*>(ref.get()); |
184 | p->stream << "MatchNode(" << node->data << ", " << node->clauses << ", " << node->complete |
185 | << ")" ; |
186 | }); |
187 | |
188 | } // namespace relay |
189 | } // namespace tvm |
190 | |