1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/tvm/relay/dataflow_pattern.cc |
22 | * \brief The dataflow pattern language for Relay. |
23 | */ |
24 | #include <tvm/relay/dataflow_pattern.h> |
25 | #include <tvm/runtime/data_type.h> |
26 | |
27 | namespace tvm { |
28 | namespace relay { |
29 | |
30 | ExprPattern::ExprPattern(Expr expr) { |
31 | ObjectPtr<ExprPatternNode> n = make_object<ExprPatternNode>(); |
32 | n->expr = std::move(expr); |
33 | data_ = std::move(n); |
34 | } |
35 | |
36 | TVM_REGISTER_NODE_TYPE(ExprPatternNode); |
37 | |
38 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern" ).set_body_typed([](Expr e) { |
39 | return ExprPattern(e); |
40 | }); |
41 | |
42 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
43 | .set_dispatch<ExprPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
44 | auto* node = static_cast<const ExprPatternNode*>(ref.get()); |
45 | p->Print(node->expr); |
46 | }); |
47 | |
48 | VarPattern::VarPattern(String name_hint) { |
49 | ObjectPtr<VarPatternNode> n = make_object<VarPatternNode>(); |
50 | n->name = std::move(name_hint); |
51 | data_ = std::move(n); |
52 | } |
53 | |
54 | TVM_REGISTER_NODE_TYPE(VarPatternNode); |
55 | |
56 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern" ).set_body_typed([](String name_hint) { |
57 | return VarPattern(name_hint); |
58 | }); |
59 | |
60 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
61 | .set_dispatch<VarPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
62 | auto* node = static_cast<const VarPatternNode*>(ref.get()); |
63 | p->stream << "VarPattern(" << node->name_hint() << ")" ; |
64 | }); |
65 | |
66 | TVM_REGISTER_NODE_TYPE(ConstantPatternNode); |
67 | |
68 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern" ).set_body_typed([]() { |
69 | auto c = ConstantPattern(make_object<ConstantPatternNode>()); |
70 | return c; |
71 | }); |
72 | |
73 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
74 | .set_dispatch<ConstantPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
75 | p->stream << "ConstantPattern()" ; |
76 | }); |
77 | |
78 | CallPattern::CallPattern(DFPattern op, Array<DFPattern> args) { |
79 | ObjectPtr<CallPatternNode> n = make_object<CallPatternNode>(); |
80 | n->op = std::move(op); |
81 | n->args = std::move(args); |
82 | data_ = std::move(n); |
83 | } |
84 | |
85 | TVM_REGISTER_NODE_TYPE(CallPatternNode); |
86 | |
87 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern" ) |
88 | .set_body_typed([](DFPattern op, Array<DFPattern> args) { return CallPattern(op, args); }); |
89 | |
90 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
91 | .set_dispatch<CallPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
92 | auto* node = static_cast<const CallPatternNode*>(ref.get()); |
93 | p->stream << "CallPatternNode(" << node->op << ", " << node->args << ")" ; |
94 | }); |
95 | |
96 | FunctionPattern::FunctionPattern(Array<DFPattern> params, DFPattern body) { |
97 | ObjectPtr<FunctionPatternNode> n = make_object<FunctionPatternNode>(); |
98 | n->params = std::move(params); |
99 | n->body = std::move(body); |
100 | data_ = std::move(n); |
101 | } |
102 | TVM_REGISTER_NODE_TYPE(FunctionPatternNode); |
103 | |
104 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.FunctionPattern" ) |
105 | .set_body_typed([](Array<DFPattern> params, DFPattern body) { |
106 | return FunctionPattern(params, body); |
107 | }); |
108 | |
109 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
110 | .set_dispatch<FunctionPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
111 | auto* node = static_cast<const FunctionPatternNode*>(ref.get()); |
112 | p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")" ; |
113 | }); |
114 | |
115 | LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) { |
116 | ObjectPtr<LetPatternNode> n = make_object<LetPatternNode>(); |
117 | n->var = std::move(var); |
118 | n->value = std::move(value); |
119 | n->body = std::move(body); |
120 | data_ = std::move(n); |
121 | } |
122 | |
123 | TVM_REGISTER_NODE_TYPE(LetPatternNode); |
124 | |
125 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern" ) |
126 | .set_body_typed([](DFPattern var, DFPattern value, DFPattern body) { |
127 | return LetPattern(var, value, body); |
128 | }); |
129 | |
130 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
131 | .set_dispatch<LetPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
132 | auto* node = static_cast<const LetPatternNode*>(ref.get()); |
133 | p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", " << node->body |
134 | << ")" ; |
135 | }); |
136 | |
137 | IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) { |
138 | ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>(); |
139 | n->cond = std::move(cond); |
140 | n->true_branch = std::move(true_branch); |
141 | n->false_branch = std::move(false_branch); |
142 | data_ = std::move(n); |
143 | } |
144 | |
145 | TVM_REGISTER_NODE_TYPE(IfPatternNode); |
146 | |
147 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.IfPattern" ) |
148 | .set_body_typed([](DFPattern cond, DFPattern true_branch, DFPattern false_branch) { |
149 | return IfPattern(cond, true_branch, false_branch); |
150 | }); |
151 | |
152 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
153 | .set_dispatch<IfPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
154 | auto* node = static_cast<const IfPatternNode*>(ref.get()); |
155 | p->stream << "IfPattern(" << node->cond << ", " << node->true_branch << ", " |
156 | << node->false_branch << ")" ; |
157 | }); |
158 | |
159 | TuplePattern::TuplePattern(tvm::Array<DFPattern> fields) { |
160 | ObjectPtr<TuplePatternNode> n = make_object<TuplePatternNode>(); |
161 | n->fields = std::move(fields); |
162 | data_ = std::move(n); |
163 | } |
164 | |
165 | TVM_REGISTER_NODE_TYPE(TuplePatternNode); |
166 | |
167 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TuplePattern" ) |
168 | .set_body_typed([](tvm::Array<DFPattern> fields) { return TuplePattern(fields); }); |
169 | |
170 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
171 | .set_dispatch<TuplePatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
172 | auto* node = static_cast<const TuplePatternNode*>(ref.get()); |
173 | p->stream << "TuplePattern(" << node->fields << ")" ; |
174 | }); |
175 | |
176 | TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { |
177 | ObjectPtr<TupleGetItemPatternNode> n = make_object<TupleGetItemPatternNode>(); |
178 | n->tuple = std::move(tuple); |
179 | n->index = index; |
180 | data_ = std::move(n); |
181 | } |
182 | |
183 | TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); |
184 | |
185 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TupleGetItemPattern" ) |
186 | .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); |
187 | |
188 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
189 | .set_dispatch<TupleGetItemPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
190 | auto* node = static_cast<const TupleGetItemPatternNode*>(ref.get()); |
191 | p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << node->index << ")" ; |
192 | }); |
193 | |
194 | AltPattern::AltPattern(DFPattern left, DFPattern right) { |
195 | ObjectPtr<AltPatternNode> n = make_object<AltPatternNode>(); |
196 | n->left = std::move(left); |
197 | n->right = std::move(right); |
198 | data_ = std::move(n); |
199 | } |
200 | |
201 | TVM_REGISTER_NODE_TYPE(AltPatternNode); |
202 | |
203 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AltPattern" ) |
204 | .set_body_typed([](DFPattern left, DFPattern right) { return AltPattern(left, right); }); |
205 | |
206 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
207 | .set_dispatch<AltPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
208 | auto* node = static_cast<const AltPatternNode*>(ref.get()); |
209 | p->stream << "AltPattern(" << node->left << " | " << node->right << ")" ; |
210 | }); |
211 | |
212 | TVM_REGISTER_NODE_TYPE(WildcardPatternNode); |
213 | |
214 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern" ).set_body_typed([]() { |
215 | auto w = WildcardPattern(make_object<WildcardPatternNode>()); |
216 | return w; |
217 | }); |
218 | |
219 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
220 | .set_dispatch<WildcardPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
221 | p->stream << "*" ; |
222 | }); |
223 | |
224 | TypePattern::TypePattern(DFPattern pattern, Type type) { |
225 | ObjectPtr<TypePatternNode> n = make_object<TypePatternNode>(); |
226 | n->pattern = std::move(pattern); |
227 | n->type = std::move(type); |
228 | data_ = std::move(n); |
229 | } |
230 | |
231 | TVM_REGISTER_NODE_TYPE(TypePatternNode); |
232 | |
233 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TypePattern" ) |
234 | .set_body_typed([](DFPattern pattern, Type type) { return TypePattern(pattern, type); }); |
235 | |
236 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
237 | .set_dispatch<TypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
238 | auto* node = static_cast<const TypePatternNode*>(ref.get()); |
239 | p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")" ; |
240 | }); |
241 | |
242 | ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) { |
243 | ObjectPtr<ShapePatternNode> n = make_object<ShapePatternNode>(); |
244 | n->pattern = std::move(pattern); |
245 | n->shape = std::move(shape); |
246 | data_ = std::move(n); |
247 | } |
248 | |
249 | TVM_REGISTER_NODE_TYPE(ShapePatternNode); |
250 | |
251 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern" ) |
252 | .set_body_typed([](DFPattern pattern, Array<PrimExpr> shape) { |
253 | return ShapePattern(pattern, shape); |
254 | }); |
255 | |
256 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
257 | .set_dispatch<ShapePatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
258 | auto* node = static_cast<const ShapePatternNode*>(ref.get()); |
259 | p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")" ; |
260 | }); |
261 | |
262 | DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { |
263 | ObjectPtr<DataTypePatternNode> n = make_object<DataTypePatternNode>(); |
264 | n->pattern = std::move(pattern); |
265 | n->dtype = std::move(dtype); |
266 | data_ = std::move(n); |
267 | } |
268 | |
269 | TVM_REGISTER_NODE_TYPE(DataTypePatternNode); |
270 | |
271 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern" ) |
272 | .set_body_typed([](DFPattern pattern, DataType dtype) { |
273 | return DataTypePattern(pattern, dtype); |
274 | }); |
275 | |
276 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
277 | .set_dispatch<DataTypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
278 | auto* node = static_cast<const DataTypePatternNode*>(ref.get()); |
279 | p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")" ; |
280 | }); |
281 | |
282 | AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { |
283 | ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>(); |
284 | n->pattern = std::move(pattern); |
285 | n->attrs = std::move(attrs); |
286 | data_ = std::move(n); |
287 | } |
288 | |
289 | TVM_REGISTER_NODE_TYPE(AttrPatternNode); |
290 | |
291 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern" ) |
292 | .set_body_typed([](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); }); |
293 | |
294 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
295 | .set_dispatch<AttrPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
296 | auto* node = static_cast<const AttrPatternNode*>(ref.get()); |
297 | p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")" ; |
298 | }); |
299 | |
300 | DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern child) { |
301 | ObjectPtr<DominatorPatternNode> n = make_object<DominatorPatternNode>(); |
302 | n->parent = std::move(parent); |
303 | n->path = std::move(path); |
304 | |
305 | n->child = std::move(child); |
306 | data_ = std::move(n); |
307 | } |
308 | |
309 | TVM_REGISTER_NODE_TYPE(DominatorPatternNode); |
310 | |
311 | TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DominatorPattern" ) |
312 | .set_body_typed([](DFPattern parent, DFPattern path, DFPattern child) { |
313 | return DominatorPattern(parent, path, child); |
314 | }); |
315 | |
316 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
317 | .set_dispatch<DominatorPatternNode>([](const ObjectRef& ref, ReprPrinter* p) { |
318 | auto* node = static_cast<const DominatorPatternNode*>(ref.get()); |
319 | p->stream << "DominatorPattern(" << node->parent << ", " << node->path << ", " << node->child |
320 | << ")" ; |
321 | }); |
322 | |
323 | // Syntatic Sugar |
324 | DFPattern DFPattern::operator()(const std::vector<DFPattern>& args) const { |
325 | return CallPattern(GetRef<DFPattern>(this->get()), Array<DFPattern>(args)); |
326 | } |
327 | DFPattern DFPattern::operator+(const DFPattern& other) const { |
328 | return IsOp("add" )({GetRef<DFPattern>(this->get()), other}); |
329 | } |
330 | DFPattern DFPattern::operator-(const DFPattern& other) const { |
331 | return IsOp("subtract" )({GetRef<DFPattern>(this->get()), other}); |
332 | } |
333 | DFPattern DFPattern::operator*(const DFPattern& other) const { |
334 | return IsOp("multiply" )({GetRef<DFPattern>(this->get()), other}); |
335 | } |
336 | DFPattern DFPattern::operator/(const DFPattern& other) const { |
337 | return IsOp("divide" )({GetRef<DFPattern>(this->get()), other}); |
338 | } |
339 | DFPattern DFPattern::operator||(const DFPattern& other) const { |
340 | return AltPattern(GetRef<DFPattern>(this->get()), other); |
341 | } |
342 | |
343 | DFPattern DFPattern::Optional(const std::function<DFPattern(const DFPattern&)>& func) const { |
344 | DFPattern current = GetRef<DFPattern>(this->get()); |
345 | return current || func(current); |
346 | } |
347 | |
348 | DFPattern DFPattern::HasAttr(const Map<String, ObjectRef>& attrs) const { |
349 | return AttrPattern(GetRef<DFPattern>(this->get()), DictAttrs(attrs)); |
350 | } |
351 | DFPattern DFPattern::HasType(const Type& type) const { |
352 | return TypePattern(GetRef<DFPattern>(this->get()), type); |
353 | } |
354 | DFPattern DFPattern::HasDtype(const DataType& dtype) const { |
355 | return DataTypePattern(GetRef<DFPattern>(this->get()), dtype); |
356 | } |
357 | DFPattern DFPattern::HasDtype(const std::string& dtype) const { |
358 | return HasDtype(DataType(runtime::String2DLDataType(dtype))); |
359 | } |
360 | DFPattern DFPattern::HasShape(const Array<PrimExpr> shape) const { |
361 | return ShapePattern(GetRef<DFPattern>(this->get()), shape); |
362 | } |
363 | DFPattern IsVar(const String& name) { return VarPattern(name); } |
364 | DFPattern IsConstant() { return ConstantPattern(make_object<ConstantPatternNode>()); } |
365 | DFPattern IsWildcard() { return WildcardPattern(make_object<WildcardPatternNode>()); } |
366 | DFPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } |
367 | DFPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } |
368 | DFPattern IsTuple(const Array<DFPattern>& fields) { return TuplePattern(fields); } |
369 | DFPattern IsTupleGetItem(const DFPattern tuple, int index) { |
370 | return TupleGetItemPattern(tuple, index); |
371 | } |
372 | |
373 | } // namespace relay |
374 | } // namespace tvm |
375 | |