1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/expr.h |
22 | * \brief Relay expression language. |
23 | */ |
24 | #ifndef TVM_RELAY_EXPR_H_ |
25 | #define TVM_RELAY_EXPR_H_ |
26 | |
27 | #include <tvm/ir/attrs.h> |
28 | #include <tvm/ir/expr.h> |
29 | #include <tvm/ir/module.h> |
30 | #include <tvm/ir/op.h> |
31 | #include <tvm/target/virtual_device.h> |
32 | |
33 | #include <functional> |
34 | #include <stack> |
35 | #include <string> |
36 | #include <utility> |
37 | |
38 | #include "./base.h" |
39 | #include "./type.h" |
40 | |
41 | namespace tvm { |
42 | |
43 | /*! |
44 | * \brief Returns \p global_var with the given properties. A null property denotes 'no change'. |
45 | * Returns \p global_var if all properties are unchanged. Otherwise, returns a copy with the new |
46 | * fields. |
47 | */ |
48 | GlobalVar WithFields(GlobalVar global_var, Optional<String> opt_name_hint = {}, |
49 | Optional<Type> opt_type = {}, Optional<VirtualDevice> opt_virtual_device = {}, |
50 | Optional<Span> opt_span = {}); |
51 | |
52 | namespace relay { |
53 | |
54 | using Expr = tvm::RelayExpr; |
55 | using ExprNode = tvm::RelayExprNode; |
56 | using BaseFunc = tvm::BaseFunc; |
57 | using BaseFuncNode = tvm::BaseFuncNode; |
58 | using GlobalVar = tvm::GlobalVar; |
59 | using GlobalVarNode = tvm::GlobalVarNode; |
60 | |
61 | /*! |
62 | * \brief Constant tensor, backed by an NDArray on the cpu(0) device. |
63 | * |
64 | * \note Scalar constants are represented by rank-0 const tensor. |
65 | * Constant folding are handled uniformly via Tensor types. |
66 | */ |
67 | class Constant; |
68 | /*! |
69 | * \brief Constant tensor type. |
70 | */ |
71 | class ConstantNode : public ExprNode { |
72 | public: |
73 | /*! \brief The data of the tensor */ |
74 | runtime::NDArray data; |
75 | |
76 | /*! \return The corresponding tensor type of the data */ |
77 | TensorType tensor_type() const; |
78 | |
79 | /*! \return Whether it is scalar(rank-0 tensor) */ |
80 | bool is_scalar() const { return data->ndim == 0; } |
81 | |
82 | void VisitAttrs(tvm::AttrVisitor* v) { |
83 | v->Visit("data" , &data); |
84 | v->Visit("virtual_device_" , &virtual_device_); |
85 | v->Visit("span" , &span); |
86 | v->Visit("_checked_type_" , &checked_type_); |
87 | } |
88 | |
89 | bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { |
90 | return equal(data, other->data); |
91 | } |
92 | |
93 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); } |
94 | |
95 | static constexpr const char* _type_key = "relay.Constant" ; |
96 | TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); |
97 | }; |
98 | |
99 | class Constant : public Expr { |
100 | public: |
101 | /*! |
102 | * \brief The constructor |
103 | * \param data The data of the constant tensor. |
104 | * \param span The source span of the expression. |
105 | */ |
106 | TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); |
107 | |
108 | TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); |
109 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); |
110 | }; |
111 | |
112 | /*! |
113 | * \brief Returns \p constant with the given properties. A null property denotes 'no change'. |
114 | * Returns \p constant if all properties are unchanged. Otherwise, returns a copy with the new |
115 | * fields. |
116 | */ |
117 | Constant WithFields(Constant constant, Optional<runtime::NDArray> opt_data = {}, |
118 | Optional<VirtualDevice> opt_virtual_device = {}, Optional<Span> opt_span = {}); |
119 | |
120 | /*! \brief Tuple of multiple Exprs */ |
121 | class Tuple; |
122 | /*! \brief Tuple container */ |
123 | class TupleNode : public ExprNode { |
124 | public: |
125 | /*! \brief the fields of the tuple */ |
126 | tvm::Array<relay::Expr> fields; |
127 | |
128 | void VisitAttrs(tvm::AttrVisitor* v) { |
129 | v->Visit("fields" , &fields); |
130 | v->Visit("virtual_device_" , &virtual_device_); |
131 | v->Visit("span" , &span); |
132 | v->Visit("_checked_type_" , &checked_type_); |
133 | } |
134 | |
135 | bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { |
136 | // specially handle empty tuple as a constant is not a graph node. |
137 | if (fields.size() == other->fields.size() && fields.size() == 0) { |
138 | return true; |
139 | } else { |
140 | equal->MarkGraphNode(); |
141 | return equal(fields, other->fields); |
142 | } |
143 | } |
144 | |
145 | void SHashReduce(SHashReducer hash_reduce) const { |
146 | if (fields.size() != 0) { |
147 | hash_reduce->MarkGraphNode(); |
148 | hash_reduce(fields); |
149 | } |
150 | } |
151 | |
152 | static constexpr const char* _type_key = "relay.Tuple" ; |
153 | TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); |
154 | }; |
155 | |
156 | class Tuple : public Expr { |
157 | public: |
158 | /*! |
159 | * \brief The constructor |
160 | * \param fields The fields of a tuple. |
161 | * \param span The source span of the expression. |
162 | */ |
163 | TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span()); |
164 | |
165 | TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode); |
166 | TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); |
167 | }; |
168 | |
169 | /*! |
170 | * \brief Returns \p tuple with the given properties. A null property denotes 'no change'. |
171 | * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new |
172 | * fields. |
173 | */ |
174 | Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(), |
175 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
176 | Optional<Span> opt_span = Optional<Span>()); |
177 | |
178 | /*! |
179 | * \brief Local variables used in the let expression. |
180 | * |
181 | * Its semantics are similar to tvm.Var node used in TVM's low level |
182 | * tensor expression language. |
183 | * |
184 | * \note Each Var is bind only once and is immutable. |
185 | */ |
186 | class Var; |
187 | /*! \brief Container for Var */ |
188 | class VarNode : public ExprNode { |
189 | public: |
190 | /*! |
191 | * \brief The unique identifier of the Var. |
192 | * |
193 | * vid will be preserved for the same Var during type inference |
194 | * and other rewritings, while the VarNode might be recreated |
195 | * to attach additional information. |
196 | * This property can be used to keep track of parameter Var |
197 | * information across passes. |
198 | */ |
199 | Id vid; |
200 | /*! |
201 | * \brief type annotaion of the variable. |
202 | * This field records user provided type annotation of the Var. |
203 | * This field is optional and can be None. |
204 | */ |
205 | Type type_annotation; |
206 | |
207 | /*! \return The name hint of the variable */ |
208 | const String& name_hint() const { return vid->name_hint; } |
209 | |
210 | void VisitAttrs(tvm::AttrVisitor* v) { |
211 | v->Visit("vid" , &vid); |
212 | v->Visit("type_annotation" , &type_annotation); |
213 | v->Visit("virtual_device_" , &virtual_device_); |
214 | v->Visit("span" , &span); |
215 | v->Visit("_checked_type_" , &checked_type_); |
216 | } |
217 | |
218 | bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { |
219 | equal->MarkGraphNode(); |
220 | return equal(type_annotation, other->type_annotation) && equal(vid, other->vid) && |
221 | equal(virtual_device_, other->virtual_device_); |
222 | } |
223 | |
224 | void SHashReduce(SHashReducer hash_reduce) const { |
225 | hash_reduce->MarkGraphNode(); |
226 | hash_reduce(type_annotation); |
227 | hash_reduce(vid); |
228 | } |
229 | |
230 | static constexpr const char* _type_key = "relay.Var" ; |
231 | TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); |
232 | }; |
233 | |
234 | class Var : public Expr { |
235 | public: |
236 | /*! |
237 | * \brief The constructor |
238 | * \param name_hint The name hint of a variable. |
239 | * \param type_annotation The type annotation of a variable. |
240 | * \param span The source span of the expression. |
241 | */ |
242 | TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span()) |
243 | : Var(Id(name_hint), type_annotation, span) {} |
244 | |
245 | /*! |
246 | * \brief The constructor |
247 | * \param vid The unique id of a variable. |
248 | * \param type_annotation The type annotation of a variable. |
249 | * \param span The source span of the expression. |
250 | */ |
251 | TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); |
252 | |
253 | /*! |
254 | * \brief Return a globally fresh name. Helps with debugging to follow the same |
255 | * variable between passes and sub-expressions. |
256 | * |
257 | * TODO(mbs): Replace with name creation w.r.t. scopes once available as part of |
258 | * name gen overhaul. |
259 | */ |
260 | static Var GenSym(Type type_annotation = {}, Span span = {}); |
261 | |
262 | TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); |
263 | TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); |
264 | }; |
265 | |
266 | /*! |
267 | * \brief Returns \p var with the given properties. A null property denotes 'no change'. |
268 | * Returns \p var if all properties are unchanged. Otherwise, returns a copy with the new |
269 | * fields. |
270 | */ |
271 | Var WithFields(Var var, Optional<Id> opt_vid = Optional<Id>(), |
272 | Optional<Type> opt_type_annotation = Optional<Type>(), |
273 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
274 | Optional<Span> opt_span = Optional<Span>()); |
275 | |
276 | /*! |
277 | * \brief Call corresponds to operator invocation. |
278 | * Corresponds to the operator in computational graph terminology. |
279 | */ |
280 | class Call; |
281 | /*! \brief Call container. */ |
282 | class CallNode : public ExprNode { |
283 | protected: |
284 | // CallNode uses own deleter to indirectly call non-recursive destructor |
285 | Object::FDeleter saved_deleter_; |
286 | static void Deleter_(Object* ptr); |
287 | |
288 | public: |
289 | /*! |
290 | * \brief The operator(function) being invoked |
291 | * |
292 | * - It can be tvm::Op which corresponds to the primitive operators. |
293 | * - It can also be user defined functions (Function, GlobalVar, Var). |
294 | */ |
295 | Expr op; |
296 | |
297 | /*! \brief The arguments(inputs) of the call */ |
298 | tvm::Array<relay::Expr> args; |
299 | |
300 | /*! \brief The additional attributes */ |
301 | Attrs attrs; |
302 | |
303 | /*! |
304 | * \brief The type arguments passed to polymorphic(template) function. |
305 | * |
306 | * This is the advance feature that is only used when the function is |
307 | * polymorphic. It is safe to be ignored in most cases. For example, in the |
308 | * following code, the type_args of addone call is [int]. |
309 | * |
310 | * \code |
311 | * |
312 | * template<typename T> |
313 | * T addone(T a) { return a + 1; } |
314 | * |
315 | * void main() { |
316 | * int x = addone<int>(10); |
317 | * } |
318 | * |
319 | * \endcode |
320 | */ |
321 | tvm::Array<Type> type_args; |
322 | |
323 | void VisitAttrs(tvm::AttrVisitor* v) { |
324 | v->Visit("op" , &op); |
325 | v->Visit("args" , &args); |
326 | v->Visit("attrs" , &attrs); |
327 | v->Visit("type_args" , &type_args); |
328 | v->Visit("virtual_device_" , &virtual_device_); |
329 | v->Visit("span" , &span); |
330 | v->Visit("_checked_type_" , &checked_type_); |
331 | } |
332 | |
333 | bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { |
334 | // skip type_args check for primitive ops. |
335 | equal->MarkGraphNode(); |
336 | return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && |
337 | (IsPrimitiveOp(op) || equal(type_args, other->type_args)); |
338 | } |
339 | |
340 | void SHashReduce(SHashReducer hash_reduce) const { |
341 | hash_reduce->MarkGraphNode(); |
342 | hash_reduce(op); |
343 | hash_reduce(args); |
344 | hash_reduce(attrs); |
345 | if (!IsPrimitiveOp(op)) { |
346 | hash_reduce(type_args); |
347 | } |
348 | } |
349 | |
350 | static constexpr const char* _type_key = "relay.Call" ; |
351 | TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); |
352 | template <typename> |
353 | friend class runtime::ObjAllocatorBase; |
354 | friend class Call; |
355 | }; |
356 | |
357 | class Call : public Expr { |
358 | public: |
359 | /*! |
360 | * \brief The destructor |
361 | */ |
362 | ~Call(); |
363 | |
364 | /*! |
365 | * \brief The constructor |
366 | * \param op The operator will be invoked. |
367 | * \param args The arguments of the call. |
368 | * \param attrs The attributes of the call node. |
369 | * \param type_args The type arguments passed to a polymorphic function. |
370 | * \param span The source span of the expression. |
371 | */ |
372 | TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(), |
373 | Array<Type> type_args = Array<Type>(), Span span = Span()); |
374 | |
375 | TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); |
376 | TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); |
377 | }; |
378 | |
379 | /*! |
380 | * \brief Returns \p call with the given properties. A null property denotes 'no change'. |
381 | * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new |
382 | * fields. |
383 | */ |
384 | Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(), |
385 | Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(), |
386 | Optional<Attrs> opt_attrs = Optional<Attrs>(), |
387 | Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(), |
388 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
389 | Optional<Span> opt_span = Optional<Span>()); |
390 | |
391 | /*! |
392 | * \brief Let binding that binds a local var and optionally a type annotation. |
393 | * |
394 | * \note Let is useful to transform the program to be A-normal form. |
395 | * where each of the expression corresponds to a let binding. |
396 | * |
397 | * For developers who are familar with the computational graph. |
398 | * Each of the let can be viewed as a operator node in the computational graph. |
399 | * Traversing the list of let bindings is similar to running |
400 | * PostDFS-order(topo-order) traversal on the computational graph. |
401 | */ |
402 | class Let; |
403 | /*! \brief A binding of a sub-network. */ |
404 | class LetNode : public ExprNode { |
405 | protected: |
406 | // LetNode uses own deleter to indirectly call non-recursive destructor |
407 | Object::FDeleter saved_deleter_; |
408 | static void Deleter_(Object* ptr); |
409 | |
410 | public: |
411 | /*! \brief The variable we bind to */ |
412 | Var var; |
413 | /*! \brief The value we bind var to */ |
414 | Expr value; |
415 | /*! \brief The body of the let binding */ |
416 | Expr body; |
417 | |
418 | void VisitAttrs(tvm::AttrVisitor* v) { |
419 | v->Visit("var" , &var); |
420 | v->Visit("value" , &value); |
421 | v->Visit("body" , &body); |
422 | v->Visit("virtual_device_" , &virtual_device_); |
423 | v->Visit("span" , &span); |
424 | v->Visit("_checked_type_" , &checked_type_); |
425 | } |
426 | |
427 | bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { |
428 | equal->MarkGraphNode(); |
429 | return equal.DefEqual(var, other->var) && equal(value, other->value) && |
430 | equal(body, other->body); |
431 | } |
432 | |
433 | void SHashReduce(SHashReducer hash_reduce) const { |
434 | hash_reduce->MarkGraphNode(); |
435 | hash_reduce.DefHash(var); |
436 | hash_reduce(value); |
437 | hash_reduce(body); |
438 | } |
439 | |
440 | static constexpr const char* _type_key = "relay.Let" ; |
441 | TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); |
442 | template <typename> |
443 | friend class runtime::ObjAllocatorBase; |
444 | friend class Let; |
445 | }; |
446 | |
447 | class Let : public Expr { |
448 | public: |
449 | /*! |
450 | * \brief The destructor |
451 | */ |
452 | ~Let(); |
453 | |
454 | /*! |
455 | * \brief The constructor |
456 | * \param var The variable that is bound to. |
457 | * \param value The value used to bind to the variable. |
458 | * \param body The body of the let binding. |
459 | * \param span The source span of the expression. |
460 | */ |
461 | TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span()); |
462 | |
463 | TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode); |
464 | TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode); |
465 | }; |
466 | |
467 | /*! |
468 | * \brief Returns \p let with the given properties. A null property denotes 'no change'. |
469 | * Returns \p let if all properties are unchanged. Otherwise, returns a copy with the new |
470 | * fields. |
471 | */ |
472 | Let WithFields(Let let, Optional<Var> opt_var = Optional<Var>(), |
473 | Optional<Expr> opt_value = Optional<Expr>(), |
474 | Optional<Expr> opt_body = Optional<Expr>(), |
475 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
476 | Optional<Span> opt_span = Optional<Span>()); |
477 | |
478 | /*! |
479 | * \brief Condition expression |
480 | * |
481 | * Unlike traditional statement `if`s, the if evalutes |
482 | * to the result of the branch taken. |
483 | * |
484 | * let x = if (true) { 1 } else { 0 }; // x is 1 |
485 | * let y = if (false) { 1 } else { 0 }; // y is 0 |
486 | * |
487 | * \note This is similar to C's ternary operator. |
488 | */ |
489 | class If; |
490 | /*! \brief container of If */ |
491 | class IfNode : public ExprNode { |
492 | public: |
493 | /*! \brief The condition */ |
494 | Expr cond; |
495 | /*! \brief The expression evaluated when condition is true. */ |
496 | Expr true_branch; |
497 | /*! \brief The expression evaluated when condition is false */ |
498 | Expr false_branch; |
499 | |
500 | void VisitAttrs(tvm::AttrVisitor* v) { |
501 | v->Visit("cond" , &cond); |
502 | v->Visit("true_branch" , &true_branch); |
503 | v->Visit("false_branch" , &false_branch); |
504 | v->Visit("virtual_device_" , &virtual_device_); |
505 | v->Visit("span" , &span); |
506 | v->Visit("_checked_type_" , &checked_type_); |
507 | } |
508 | |
509 | bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { |
510 | equal->MarkGraphNode(); |
511 | return equal(cond, other->cond) && equal(true_branch, other->true_branch) && |
512 | equal(false_branch, other->false_branch); |
513 | } |
514 | |
515 | void SHashReduce(SHashReducer hash_reduce) const { |
516 | hash_reduce->MarkGraphNode(); |
517 | hash_reduce(cond); |
518 | hash_reduce(true_branch); |
519 | hash_reduce(false_branch); |
520 | } |
521 | |
522 | static constexpr const char* _type_key = "relay.If" ; |
523 | TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); |
524 | }; |
525 | |
526 | class If : public Expr { |
527 | public: |
528 | /*! |
529 | * \brief The constructor |
530 | * \param cond The condition of a if node. |
531 | * \param true_branch The fall through branch |
532 | * \param false_branch The branch for execution when condition is false. |
533 | * \param span The source span of the expression. |
534 | */ |
535 | TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); |
536 | |
537 | TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode); |
538 | TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); |
539 | }; |
540 | |
541 | /*! |
542 | * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. |
543 | * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new |
544 | * fields. |
545 | */ |
546 | If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(), |
547 | Optional<Expr> opt_true_branch = Optional<Expr>(), |
548 | Optional<Expr> opt_false_branch = Optional<Expr>(), |
549 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
550 | Optional<Span> opt_span = Optional<Span>()); |
551 | |
552 | /*! \brief Get index-th field out of a tuple. */ |
553 | class TupleGetItem; |
554 | class TupleGetItemNode : public ExprNode { |
555 | public: |
556 | /*! \brief The tuple Expression */ |
557 | Expr tuple; |
558 | /*! \brief which value to get */ |
559 | int index; |
560 | |
561 | void VisitAttrs(tvm::AttrVisitor* v) { |
562 | v->Visit("tuple_value" , &tuple); |
563 | v->Visit("index" , &index); |
564 | v->Visit("virtual_device_" , &virtual_device_); |
565 | v->Visit("span" , &span); |
566 | v->Visit("_checked_type_" , &checked_type_); |
567 | } |
568 | |
569 | bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { |
570 | return equal(tuple, other->tuple) && equal(index, other->index); |
571 | } |
572 | |
573 | void SHashReduce(SHashReducer hash_reduce) const { |
574 | hash_reduce(tuple); |
575 | hash_reduce(index); |
576 | } |
577 | |
578 | static constexpr const char* _type_key = "relay.TupleGetItem" ; |
579 | TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); |
580 | }; |
581 | |
582 | class TupleGetItem : public Expr { |
583 | public: |
584 | /*! |
585 | * \brief The constructor |
586 | * \param tuple The tuple to get an element from. |
587 | * \param index The index for extracting a value in the tuple. |
588 | * \param span The source span of the expression. |
589 | */ |
590 | TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); |
591 | |
592 | TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode); |
593 | TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode); |
594 | }; |
595 | |
596 | /*! |
597 | * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'. |
598 | * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new |
599 | * fields. |
600 | */ |
601 | TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple = Optional<Expr>(), |
602 | Optional<Integer> opt_index = Optional<Integer>(), |
603 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
604 | Optional<Span> opt_span = Optional<Span>()); |
605 | |
606 | /*! \brief Create a new Reference out of initial value. */ |
607 | class RefCreate; |
608 | class RefCreateNode : public ExprNode { |
609 | public: |
610 | /*! \brief The initial value of the Reference. */ |
611 | Expr value; |
612 | |
613 | void VisitAttrs(tvm::AttrVisitor* v) { |
614 | v->Visit("value" , &value); |
615 | v->Visit("virtual_device_" , &virtual_device_); |
616 | v->Visit("span" , &span); |
617 | v->Visit("_checked_type_" , &checked_type_); |
618 | } |
619 | |
620 | bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const { |
621 | equal->MarkGraphNode(); |
622 | return equal(value, other->value); |
623 | } |
624 | |
625 | void SHashReduce(SHashReducer hash_reduce) const { |
626 | hash_reduce->MarkGraphNode(); |
627 | hash_reduce(value); |
628 | } |
629 | |
630 | static constexpr const char* _type_key = "relay.RefCreate" ; |
631 | TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode); |
632 | }; |
633 | |
634 | class RefCreate : public Expr { |
635 | public: |
636 | /*! |
637 | * \brief The constructor |
638 | * \param value The initial value of the reference. |
639 | * \param span The source span of the expression. |
640 | */ |
641 | TVM_DLL explicit RefCreate(Expr value, Span span = Span()); |
642 | |
643 | TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode); |
644 | TVM_DEFINE_OBJECT_REF_COW_METHOD(RefCreateNode); |
645 | }; |
646 | |
647 | /*! |
648 | * \brief Returns \p ref_create with the given properties. A null property denotes 'no change'. |
649 | * Returns \p ref_crete if all properties are unchanged. Otherwise, returns a copy with the new |
650 | * fields. |
651 | */ |
652 | RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value = Optional<Expr>(), |
653 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
654 | Optional<Span> opt_span = Optional<Span>()); |
655 | |
656 | /*! \brief Get value out of Reference. */ |
657 | class RefRead; |
658 | class RefReadNode : public ExprNode { |
659 | public: |
660 | /*! \brief The Reference Expression. */ |
661 | Expr ref; |
662 | |
663 | void VisitAttrs(tvm::AttrVisitor* v) { |
664 | v->Visit("ref" , &ref); |
665 | v->Visit("virtual_device_" , &virtual_device_); |
666 | v->Visit("span" , &span); |
667 | v->Visit("_checked_type_" , &checked_type_); |
668 | } |
669 | |
670 | bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const { |
671 | equal->MarkGraphNode(); |
672 | return equal(ref, other->ref); |
673 | } |
674 | |
675 | void SHashReduce(SHashReducer hash_reduce) const { |
676 | hash_reduce->MarkGraphNode(); |
677 | hash_reduce(ref); |
678 | } |
679 | |
680 | static constexpr const char* _type_key = "relay.RefRead" ; |
681 | TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode); |
682 | }; |
683 | |
684 | class RefRead : public Expr { |
685 | public: |
686 | /*! |
687 | * \brief The constructor |
688 | * \param ref The reference where to read data. |
689 | * \param span The source span of the expression. |
690 | */ |
691 | TVM_DLL explicit RefRead(Expr ref, Span span = Span()); |
692 | |
693 | TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode); |
694 | TVM_DEFINE_OBJECT_REF_COW_METHOD(RefReadNode); |
695 | }; |
696 | |
697 | /*! |
698 | * \brief Returns \p ref_read with the given properties. A null property denotes 'no change'. |
699 | * Returns \p ref_read if all properties are unchanged. Otherwise, returns a copy with the new |
700 | * fields. |
701 | */ |
702 | RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref = Optional<Expr>(), |
703 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
704 | Optional<Span> opt_span = Optional<Span>()); |
705 | |
706 | /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ |
707 | class RefWrite; |
708 | class RefWriteNode : public ExprNode { |
709 | public: |
710 | /*! \brief The Reference Expression. */ |
711 | Expr ref; |
712 | /*! \brief The value to write into. */ |
713 | Expr value; |
714 | |
715 | void VisitAttrs(tvm::AttrVisitor* v) { |
716 | v->Visit("ref" , &ref); |
717 | v->Visit("value" , &value); |
718 | v->Visit("virtual_device_" , &virtual_device_); |
719 | v->Visit("span" , &span); |
720 | v->Visit("_checked_type_" , &checked_type_); |
721 | } |
722 | |
723 | bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const { |
724 | equal->MarkGraphNode(); |
725 | return equal(ref, other->ref) && equal(value, other->value); |
726 | } |
727 | |
728 | void SHashReduce(SHashReducer hash_reduce) const { |
729 | hash_reduce->MarkGraphNode(); |
730 | hash_reduce(ref); |
731 | hash_reduce(value); |
732 | } |
733 | |
734 | static constexpr const char* _type_key = "relay.RefWrite" ; |
735 | TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode); |
736 | }; |
737 | |
738 | class RefWrite : public Expr { |
739 | public: |
740 | /*! |
741 | * \brief The constructor |
742 | * \param ref The reference where data is write to. |
743 | * \param value The value to write. |
744 | * \param span The source span of the expression. |
745 | */ |
746 | TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span()); |
747 | |
748 | TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode); |
749 | TVM_DEFINE_OBJECT_REF_COW_METHOD(RefWriteNode); |
750 | }; |
751 | |
752 | /*! |
753 | * \brief Returns \p ref_write with the given properties. A null property denotes 'no change'. |
754 | * Returns \p ref_write if all properties are unchanged. Otherwise, returns a copy with the new |
755 | * fields. |
756 | */ |
757 | RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref = Optional<Expr>(), |
758 | Optional<Expr> opt_value = Optional<Expr>(), |
759 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
760 | Optional<Span> opt_span = Optional<Span>()); |
761 | |
762 | /*! |
763 | * \brief Base class of the temporary expression. |
764 | * |
765 | * TempExprs are pass specific expression that can be |
766 | * useful to define intermediate result in the |
767 | * rewriting pass such as layout or type transformation. |
768 | * |
769 | * Subclass TempExprNode allows us to pattern match on |
770 | * specific kind of TempExpr and use them for expression rewriting. |
771 | * |
772 | * TempExpr should only be used within a pass, |
773 | */ |
774 | class TempExprNode : public ExprNode { |
775 | public: |
776 | /*! \brief virtual destructor */ |
777 | virtual ~TempExprNode() {} |
778 | /*! |
779 | * \brief Convert the expression to a normal(non-temp) Expr. |
780 | * \return The corresponding normal(non-temp) expression. |
781 | */ |
782 | virtual Expr Realize() const = 0; |
783 | |
784 | static constexpr const char* _type_key = "relay.TempExpr" ; |
785 | static constexpr const bool _type_has_method_sequal_reduce = false; |
786 | static constexpr const bool _type_has_method_shash_reduce = false; |
787 | static constexpr const uint32_t _type_child_slots = 0; |
788 | TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode); |
789 | }; |
790 | |
791 | class TempExpr : public Expr { |
792 | public: |
793 | TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode); |
794 | }; |
795 | |
796 | } // namespace relay |
797 | |
798 | namespace runtime { |
799 | |
800 | template <> |
801 | template <> |
802 | inline ObjectPtr<relay::LetNode> |
803 | ObjAllocatorBase<SimpleObjAllocator>::make_object<relay::LetNode>() { |
804 | using Derived = SimpleObjAllocator; |
805 | using T = relay::LetNode; |
806 | using Handler = typename Derived::template Handler<T>; |
807 | static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object" ); |
808 | T* ptr = Handler::New(static_cast<Derived*>(this)); |
809 | ptr->type_index_ = T::RuntimeTypeIndex(); |
810 | ptr->saved_deleter_ = Handler::Deleter(); |
811 | ptr->deleter_ = relay::LetNode::Deleter_; |
812 | return ObjectPtr<T>(ptr); |
813 | } |
814 | |
815 | template <> |
816 | template <> |
817 | inline ObjectPtr<relay::CallNode> |
818 | ObjAllocatorBase<SimpleObjAllocator>::make_object<relay::CallNode>() { |
819 | using Derived = SimpleObjAllocator; |
820 | using T = relay::CallNode; |
821 | using Handler = typename Derived::template Handler<T>; |
822 | static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object" ); |
823 | T* ptr = Handler::New(static_cast<Derived*>(this)); |
824 | ptr->type_index_ = T::RuntimeTypeIndex(); |
825 | ptr->saved_deleter_ = Handler::Deleter(); |
826 | ptr->deleter_ = relay::CallNode::Deleter_; |
827 | return ObjectPtr<T>(ptr); |
828 | } |
829 | |
830 | } // namespace runtime |
831 | |
832 | } // namespace tvm |
833 | #endif // TVM_RELAY_EXPR_H_ |
834 | |