1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/ir/adt.cc
22 * \brief AST nodes for Relay algebraic data types (ADTs).
23 */
24#include <tvm/relay/adt.h>
25#include <tvm/relay/type.h>
26
27namespace tvm {
28namespace relay {
29
30PatternWildcard::PatternWildcard() {
31 ObjectPtr<PatternWildcardNode> n = make_object<PatternWildcardNode>();
32 data_ = std::move(n);
33}
34
35TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
36
37TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard").set_body_typed([]() { return PatternWildcard(); });
38
39TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
40 .set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, ReprPrinter* p) {
41 p->stream << "PatternWildcardNode()";
42 });
43
44PatternVar::PatternVar(tvm::relay::Var var) {
45 ObjectPtr<PatternVarNode> n = make_object<PatternVarNode>();
46 n->var = std::move(var);
47 data_ = std::move(n);
48}
49
50TVM_REGISTER_NODE_TYPE(PatternVarNode);
51
52TVM_REGISTER_GLOBAL("relay.ir.PatternVar").set_body_typed([](tvm::relay::Var var) {
53 return PatternVar(var);
54});
55
56TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
57 .set_dispatch<PatternVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
58 auto* node = static_cast<const PatternVarNode*>(ref.get());
59 p->stream << "PatternVarNode(" << node->var << ")";
60 });
61
62PatternConstructor::PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns) {
63 ObjectPtr<PatternConstructorNode> n = make_object<PatternConstructorNode>();
64 n->constructor = std::move(constructor);
65 n->patterns = std::move(patterns);
66 data_ = std::move(n);
67}
68
69TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
70
71TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor")
72 .set_body_typed([](Constructor constructor, tvm::Array<Pattern> patterns) {
73 return PatternConstructor(constructor, patterns);
74 });
75
76TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
77 .set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
78 auto* node = static_cast<const PatternConstructorNode*>(ref.get());
79 p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->patterns << ")";
80 });
81
82PatternTuple::PatternTuple(tvm::Array<Pattern> patterns) {
83 ObjectPtr<PatternTupleNode> n = make_object<PatternTupleNode>();
84 n->patterns = std::move(patterns);
85 data_ = std::move(n);
86}
87
88TVM_REGISTER_NODE_TYPE(PatternTupleNode);
89
90TVM_REGISTER_GLOBAL("relay.ir.PatternTuple").set_body_typed([](tvm::Array<Pattern> patterns) {
91 return PatternTuple(patterns);
92});
93
94TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
95 .set_dispatch<PatternTupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
96 auto* node = static_cast<const PatternTupleNode*>(ref.get());
97 p->stream << "PatternTupleNode(" << node->patterns << ")";
98 });
99
100Clause::Clause(Pattern lhs, Expr rhs) {
101 ObjectPtr<ClauseNode> n = make_object<ClauseNode>();
102 n->lhs = std::move(lhs);
103 n->rhs = std::move(rhs);
104 data_ = std::move(n);
105}
106
107Clause WithFields(Clause clause, Optional<Pattern> opt_lhs, Optional<Expr> opt_rhs) {
108 Pattern lhs = opt_lhs.value_or(clause->lhs);
109 Expr rhs = opt_rhs.value_or(clause->rhs);
110
111 bool unchanged = lhs.same_as(clause->lhs) && rhs.same_as(clause->rhs);
112
113 if (!unchanged) {
114 ClauseNode* cow_clause_node = clause.CopyOnWrite();
115 cow_clause_node->lhs = lhs;
116 cow_clause_node->rhs = rhs;
117 }
118 return clause;
119}
120
121TVM_REGISTER_NODE_TYPE(ClauseNode);
122
123TVM_REGISTER_GLOBAL("relay.ir.Clause").set_body_typed([](Pattern lhs, Expr rhs) {
124 return Clause(lhs, rhs);
125});
126
127TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
128 .set_dispatch<ClauseNode>([](const ObjectRef& ref, ReprPrinter* p) {
129 auto* node = static_cast<const ClauseNode*>(ref.get());
130 p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")";
131 });
132
133Match::Match(Expr data, tvm::Array<Clause> clauses, bool complete, Span span) {
134 ObjectPtr<MatchNode> n = make_object<MatchNode>();
135 n->data = std::move(data);
136 n->clauses = std::move(clauses);
137 n->complete = complete;
138 n->span = std::move(span);
139 data_ = std::move(n);
140}
141
142Match WithFields(Match match, Optional<Expr> opt_data, Optional<Array<Clause>> opt_clauses,
143 Optional<Bool> opt_complete, Optional<Span> opt_span) {
144 Expr data = opt_data.value_or(match->data);
145 Array<Clause> clauses = opt_clauses.value_or(match->clauses);
146 Bool complete = opt_complete.value_or(Bool(match->complete));
147 Span span = opt_span.value_or(match->span);
148
149 bool unchanged =
150 data.same_as(match->data) && (complete == match->complete) && span.same_as(match->span);
151
152 // Check that all clauses are unchanged
153 if (unchanged) {
154 bool all_clauses_unchanged = true;
155 if (clauses.size() == match->clauses.size()) {
156 for (size_t i = 0; i < clauses.size(); i++) {
157 all_clauses_unchanged &= clauses[i].same_as(match->clauses[i]);
158 }
159 } else {
160 all_clauses_unchanged = false;
161 }
162 unchanged &= all_clauses_unchanged;
163 }
164 if (!unchanged) {
165 MatchNode* cow_match_node = match.CopyOnWrite();
166 cow_match_node->data = data;
167 cow_match_node->clauses = clauses;
168 cow_match_node->complete = complete;
169 cow_match_node->span = span;
170 }
171 return match;
172}
173
174TVM_REGISTER_NODE_TYPE(MatchNode);
175
176TVM_REGISTER_GLOBAL("relay.ir.Match")
177 .set_body_typed([](Expr data, tvm::Array<Clause> clauses, bool complete) {
178 return Match(data, clauses, complete);
179 });
180
181TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
182 .set_dispatch<MatchNode>([](const ObjectRef& ref, ReprPrinter* p) {
183 auto* node = static_cast<const MatchNode*>(ref.get());
184 p->stream << "MatchNode(" << node->data << ", " << node->clauses << ", " << node->complete
185 << ")";
186 });
187
188} // namespace relay
189} // namespace tvm
190