1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/tvm/relay/dataflow_pattern.cc
22 * \brief The dataflow pattern language for Relay.
23 */
24#include <tvm/relay/dataflow_pattern.h>
25#include <tvm/runtime/data_type.h>
26
27namespace tvm {
28namespace relay {
29
30ExprPattern::ExprPattern(Expr expr) {
31 ObjectPtr<ExprPatternNode> n = make_object<ExprPatternNode>();
32 n->expr = std::move(expr);
33 data_ = std::move(n);
34}
35
36TVM_REGISTER_NODE_TYPE(ExprPatternNode);
37
38TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern").set_body_typed([](Expr e) {
39 return ExprPattern(e);
40});
41
42TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
43 .set_dispatch<ExprPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
44 auto* node = static_cast<const ExprPatternNode*>(ref.get());
45 p->Print(node->expr);
46 });
47
48VarPattern::VarPattern(String name_hint) {
49 ObjectPtr<VarPatternNode> n = make_object<VarPatternNode>();
50 n->name = std::move(name_hint);
51 data_ = std::move(n);
52}
53
54TVM_REGISTER_NODE_TYPE(VarPatternNode);
55
56TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern").set_body_typed([](String name_hint) {
57 return VarPattern(name_hint);
58});
59
60TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
61 .set_dispatch<VarPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
62 auto* node = static_cast<const VarPatternNode*>(ref.get());
63 p->stream << "VarPattern(" << node->name_hint() << ")";
64 });
65
66TVM_REGISTER_NODE_TYPE(ConstantPatternNode);
67
68TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern").set_body_typed([]() {
69 auto c = ConstantPattern(make_object<ConstantPatternNode>());
70 return c;
71});
72
73TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
74 .set_dispatch<ConstantPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
75 p->stream << "ConstantPattern()";
76 });
77
78CallPattern::CallPattern(DFPattern op, Array<DFPattern> args) {
79 ObjectPtr<CallPatternNode> n = make_object<CallPatternNode>();
80 n->op = std::move(op);
81 n->args = std::move(args);
82 data_ = std::move(n);
83}
84
85TVM_REGISTER_NODE_TYPE(CallPatternNode);
86
87TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern")
88 .set_body_typed([](DFPattern op, Array<DFPattern> args) { return CallPattern(op, args); });
89
90TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
91 .set_dispatch<CallPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
92 auto* node = static_cast<const CallPatternNode*>(ref.get());
93 p->stream << "CallPatternNode(" << node->op << ", " << node->args << ")";
94 });
95
96FunctionPattern::FunctionPattern(Array<DFPattern> params, DFPattern body) {
97 ObjectPtr<FunctionPatternNode> n = make_object<FunctionPatternNode>();
98 n->params = std::move(params);
99 n->body = std::move(body);
100 data_ = std::move(n);
101}
102TVM_REGISTER_NODE_TYPE(FunctionPatternNode);
103
104TVM_REGISTER_GLOBAL("relay.dataflow_pattern.FunctionPattern")
105 .set_body_typed([](Array<DFPattern> params, DFPattern body) {
106 return FunctionPattern(params, body);
107 });
108
109TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
110 .set_dispatch<FunctionPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
111 auto* node = static_cast<const FunctionPatternNode*>(ref.get());
112 p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")";
113 });
114
115LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) {
116 ObjectPtr<LetPatternNode> n = make_object<LetPatternNode>();
117 n->var = std::move(var);
118 n->value = std::move(value);
119 n->body = std::move(body);
120 data_ = std::move(n);
121}
122
123TVM_REGISTER_NODE_TYPE(LetPatternNode);
124
125TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern")
126 .set_body_typed([](DFPattern var, DFPattern value, DFPattern body) {
127 return LetPattern(var, value, body);
128 });
129
130TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
131 .set_dispatch<LetPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
132 auto* node = static_cast<const LetPatternNode*>(ref.get());
133 p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", " << node->body
134 << ")";
135 });
136
137IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
138 ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>();
139 n->cond = std::move(cond);
140 n->true_branch = std::move(true_branch);
141 n->false_branch = std::move(false_branch);
142 data_ = std::move(n);
143}
144
145TVM_REGISTER_NODE_TYPE(IfPatternNode);
146
147TVM_REGISTER_GLOBAL("relay.dataflow_pattern.IfPattern")
148 .set_body_typed([](DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
149 return IfPattern(cond, true_branch, false_branch);
150 });
151
152TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
153 .set_dispatch<IfPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
154 auto* node = static_cast<const IfPatternNode*>(ref.get());
155 p->stream << "IfPattern(" << node->cond << ", " << node->true_branch << ", "
156 << node->false_branch << ")";
157 });
158
159TuplePattern::TuplePattern(tvm::Array<DFPattern> fields) {
160 ObjectPtr<TuplePatternNode> n = make_object<TuplePatternNode>();
161 n->fields = std::move(fields);
162 data_ = std::move(n);
163}
164
165TVM_REGISTER_NODE_TYPE(TuplePatternNode);
166
167TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TuplePattern")
168 .set_body_typed([](tvm::Array<DFPattern> fields) { return TuplePattern(fields); });
169
170TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
171 .set_dispatch<TuplePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
172 auto* node = static_cast<const TuplePatternNode*>(ref.get());
173 p->stream << "TuplePattern(" << node->fields << ")";
174 });
175
176TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) {
177 ObjectPtr<TupleGetItemPatternNode> n = make_object<TupleGetItemPatternNode>();
178 n->tuple = std::move(tuple);
179 n->index = index;
180 data_ = std::move(n);
181}
182
183TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode);
184
185TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TupleGetItemPattern")
186 .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); });
187
188TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
189 .set_dispatch<TupleGetItemPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
190 auto* node = static_cast<const TupleGetItemPatternNode*>(ref.get());
191 p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << node->index << ")";
192 });
193
194AltPattern::AltPattern(DFPattern left, DFPattern right) {
195 ObjectPtr<AltPatternNode> n = make_object<AltPatternNode>();
196 n->left = std::move(left);
197 n->right = std::move(right);
198 data_ = std::move(n);
199}
200
201TVM_REGISTER_NODE_TYPE(AltPatternNode);
202
203TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AltPattern")
204 .set_body_typed([](DFPattern left, DFPattern right) { return AltPattern(left, right); });
205
206TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
207 .set_dispatch<AltPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
208 auto* node = static_cast<const AltPatternNode*>(ref.get());
209 p->stream << "AltPattern(" << node->left << " | " << node->right << ")";
210 });
211
212TVM_REGISTER_NODE_TYPE(WildcardPatternNode);
213
214TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]() {
215 auto w = WildcardPattern(make_object<WildcardPatternNode>());
216 return w;
217});
218
219TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
220 .set_dispatch<WildcardPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
221 p->stream << "*";
222 });
223
224TypePattern::TypePattern(DFPattern pattern, Type type) {
225 ObjectPtr<TypePatternNode> n = make_object<TypePatternNode>();
226 n->pattern = std::move(pattern);
227 n->type = std::move(type);
228 data_ = std::move(n);
229}
230
231TVM_REGISTER_NODE_TYPE(TypePatternNode);
232
233TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TypePattern")
234 .set_body_typed([](DFPattern pattern, Type type) { return TypePattern(pattern, type); });
235
236TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
237 .set_dispatch<TypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
238 auto* node = static_cast<const TypePatternNode*>(ref.get());
239 p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")";
240 });
241
242ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) {
243 ObjectPtr<ShapePatternNode> n = make_object<ShapePatternNode>();
244 n->pattern = std::move(pattern);
245 n->shape = std::move(shape);
246 data_ = std::move(n);
247}
248
249TVM_REGISTER_NODE_TYPE(ShapePatternNode);
250
251TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern")
252 .set_body_typed([](DFPattern pattern, Array<PrimExpr> shape) {
253 return ShapePattern(pattern, shape);
254 });
255
256TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
257 .set_dispatch<ShapePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
258 auto* node = static_cast<const ShapePatternNode*>(ref.get());
259 p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")";
260 });
261
262DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) {
263 ObjectPtr<DataTypePatternNode> n = make_object<DataTypePatternNode>();
264 n->pattern = std::move(pattern);
265 n->dtype = std::move(dtype);
266 data_ = std::move(n);
267}
268
269TVM_REGISTER_NODE_TYPE(DataTypePatternNode);
270
271TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern")
272 .set_body_typed([](DFPattern pattern, DataType dtype) {
273 return DataTypePattern(pattern, dtype);
274 });
275
276TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
277 .set_dispatch<DataTypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
278 auto* node = static_cast<const DataTypePatternNode*>(ref.get());
279 p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")";
280 });
281
282AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) {
283 ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>();
284 n->pattern = std::move(pattern);
285 n->attrs = std::move(attrs);
286 data_ = std::move(n);
287}
288
289TVM_REGISTER_NODE_TYPE(AttrPatternNode);
290
291TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern")
292 .set_body_typed([](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); });
293
294TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
295 .set_dispatch<AttrPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
296 auto* node = static_cast<const AttrPatternNode*>(ref.get());
297 p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")";
298 });
299
300DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern child) {
301 ObjectPtr<DominatorPatternNode> n = make_object<DominatorPatternNode>();
302 n->parent = std::move(parent);
303 n->path = std::move(path);
304
305 n->child = std::move(child);
306 data_ = std::move(n);
307}
308
309TVM_REGISTER_NODE_TYPE(DominatorPatternNode);
310
311TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DominatorPattern")
312 .set_body_typed([](DFPattern parent, DFPattern path, DFPattern child) {
313 return DominatorPattern(parent, path, child);
314 });
315
316TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
317 .set_dispatch<DominatorPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
318 auto* node = static_cast<const DominatorPatternNode*>(ref.get());
319 p->stream << "DominatorPattern(" << node->parent << ", " << node->path << ", " << node->child
320 << ")";
321 });
322
323// Syntatic Sugar
324DFPattern DFPattern::operator()(const std::vector<DFPattern>& args) const {
325 return CallPattern(GetRef<DFPattern>(this->get()), Array<DFPattern>(args));
326}
327DFPattern DFPattern::operator+(const DFPattern& other) const {
328 return IsOp("add")({GetRef<DFPattern>(this->get()), other});
329}
330DFPattern DFPattern::operator-(const DFPattern& other) const {
331 return IsOp("subtract")({GetRef<DFPattern>(this->get()), other});
332}
333DFPattern DFPattern::operator*(const DFPattern& other) const {
334 return IsOp("multiply")({GetRef<DFPattern>(this->get()), other});
335}
336DFPattern DFPattern::operator/(const DFPattern& other) const {
337 return IsOp("divide")({GetRef<DFPattern>(this->get()), other});
338}
339DFPattern DFPattern::operator||(const DFPattern& other) const {
340 return AltPattern(GetRef<DFPattern>(this->get()), other);
341}
342
343DFPattern DFPattern::Optional(const std::function<DFPattern(const DFPattern&)>& func) const {
344 DFPattern current = GetRef<DFPattern>(this->get());
345 return current || func(current);
346}
347
348DFPattern DFPattern::HasAttr(const Map<String, ObjectRef>& attrs) const {
349 return AttrPattern(GetRef<DFPattern>(this->get()), DictAttrs(attrs));
350}
351DFPattern DFPattern::HasType(const Type& type) const {
352 return TypePattern(GetRef<DFPattern>(this->get()), type);
353}
354DFPattern DFPattern::HasDtype(const DataType& dtype) const {
355 return DataTypePattern(GetRef<DFPattern>(this->get()), dtype);
356}
357DFPattern DFPattern::HasDtype(const std::string& dtype) const {
358 return HasDtype(DataType(runtime::String2DLDataType(dtype)));
359}
360DFPattern DFPattern::HasShape(const Array<PrimExpr> shape) const {
361 return ShapePattern(GetRef<DFPattern>(this->get()), shape);
362}
363DFPattern IsVar(const String& name) { return VarPattern(name); }
364DFPattern IsConstant() { return ConstantPattern(make_object<ConstantPatternNode>()); }
365DFPattern IsWildcard() { return WildcardPattern(make_object<WildcardPatternNode>()); }
366DFPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
367DFPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
368DFPattern IsTuple(const Array<DFPattern>& fields) { return TuplePattern(fields); }
369DFPattern IsTupleGetItem(const DFPattern tuple, int index) {
370 return TupleGetItemPattern(tuple, index);
371}
372
373} // namespace relay
374} // namespace tvm
375