1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/expr.h
22 * \brief Relay expression language.
23 */
24#ifndef TVM_RELAY_EXPR_H_
25#define TVM_RELAY_EXPR_H_
26
27#include <tvm/ir/attrs.h>
28#include <tvm/ir/expr.h>
29#include <tvm/ir/module.h>
30#include <tvm/ir/op.h>
31#include <tvm/target/virtual_device.h>
32
33#include <functional>
34#include <stack>
35#include <string>
36#include <utility>
37
38#include "./base.h"
39#include "./type.h"
40
41namespace tvm {
42
43/*!
44 * \brief Returns \p global_var with the given properties. A null property denotes 'no change'.
45 * Returns \p global_var if all properties are unchanged. Otherwise, returns a copy with the new
46 * fields.
47 */
48GlobalVar WithFields(GlobalVar global_var, Optional<String> opt_name_hint = {},
49 Optional<Type> opt_type = {}, Optional<VirtualDevice> opt_virtual_device = {},
50 Optional<Span> opt_span = {});
51
52namespace relay {
53
54using Expr = tvm::RelayExpr;
55using ExprNode = tvm::RelayExprNode;
56using BaseFunc = tvm::BaseFunc;
57using BaseFuncNode = tvm::BaseFuncNode;
58using GlobalVar = tvm::GlobalVar;
59using GlobalVarNode = tvm::GlobalVarNode;
60
61/*!
62 * \brief Constant tensor, backed by an NDArray on the cpu(0) device.
63 *
64 * \note Scalar constants are represented by rank-0 const tensor.
65 * Constant folding are handled uniformly via Tensor types.
66 */
67class Constant;
68/*!
69 * \brief Constant tensor type.
70 */
71class ConstantNode : public ExprNode {
72 public:
73 /*! \brief The data of the tensor */
74 runtime::NDArray data;
75
76 /*! \return The corresponding tensor type of the data */
77 TensorType tensor_type() const;
78
79 /*! \return Whether it is scalar(rank-0 tensor) */
80 bool is_scalar() const { return data->ndim == 0; }
81
82 void VisitAttrs(tvm::AttrVisitor* v) {
83 v->Visit("data", &data);
84 v->Visit("virtual_device_", &virtual_device_);
85 v->Visit("span", &span);
86 v->Visit("_checked_type_", &checked_type_);
87 }
88
89 bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
90 return equal(data, other->data);
91 }
92
93 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }
94
95 static constexpr const char* _type_key = "relay.Constant";
96 TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
97};
98
99class Constant : public Expr {
100 public:
101 /*!
102 * \brief The constructor
103 * \param data The data of the constant tensor.
104 * \param span The source span of the expression.
105 */
106 TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span());
107
108 TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
109 TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode);
110};
111
112/*!
113 * \brief Returns \p constant with the given properties. A null property denotes 'no change'.
114 * Returns \p constant if all properties are unchanged. Otherwise, returns a copy with the new
115 * fields.
116 */
117Constant WithFields(Constant constant, Optional<runtime::NDArray> opt_data = {},
118 Optional<VirtualDevice> opt_virtual_device = {}, Optional<Span> opt_span = {});
119
120/*! \brief Tuple of multiple Exprs */
121class Tuple;
122/*! \brief Tuple container */
123class TupleNode : public ExprNode {
124 public:
125 /*! \brief the fields of the tuple */
126 tvm::Array<relay::Expr> fields;
127
128 void VisitAttrs(tvm::AttrVisitor* v) {
129 v->Visit("fields", &fields);
130 v->Visit("virtual_device_", &virtual_device_);
131 v->Visit("span", &span);
132 v->Visit("_checked_type_", &checked_type_);
133 }
134
135 bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
136 // specially handle empty tuple as a constant is not a graph node.
137 if (fields.size() == other->fields.size() && fields.size() == 0) {
138 return true;
139 } else {
140 equal->MarkGraphNode();
141 return equal(fields, other->fields);
142 }
143 }
144
145 void SHashReduce(SHashReducer hash_reduce) const {
146 if (fields.size() != 0) {
147 hash_reduce->MarkGraphNode();
148 hash_reduce(fields);
149 }
150 }
151
152 static constexpr const char* _type_key = "relay.Tuple";
153 TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
154};
155
156class Tuple : public Expr {
157 public:
158 /*!
159 * \brief The constructor
160 * \param fields The fields of a tuple.
161 * \param span The source span of the expression.
162 */
163 TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());
164
165 TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
166 TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
167};
168
169/*!
170 * \brief Returns \p tuple with the given properties. A null property denotes 'no change'.
171 * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new
172 * fields.
173 */
174Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
175 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
176 Optional<Span> opt_span = Optional<Span>());
177
178/*!
179 * \brief Local variables used in the let expression.
180 *
181 * Its semantics are similar to tvm.Var node used in TVM's low level
182 * tensor expression language.
183 *
184 * \note Each Var is bind only once and is immutable.
185 */
186class Var;
187/*! \brief Container for Var */
188class VarNode : public ExprNode {
189 public:
190 /*!
191 * \brief The unique identifier of the Var.
192 *
193 * vid will be preserved for the same Var during type inference
194 * and other rewritings, while the VarNode might be recreated
195 * to attach additional information.
196 * This property can be used to keep track of parameter Var
197 * information across passes.
198 */
199 Id vid;
200 /*!
201 * \brief type annotaion of the variable.
202 * This field records user provided type annotation of the Var.
203 * This field is optional and can be None.
204 */
205 Type type_annotation;
206
207 /*! \return The name hint of the variable */
208 const String& name_hint() const { return vid->name_hint; }
209
210 void VisitAttrs(tvm::AttrVisitor* v) {
211 v->Visit("vid", &vid);
212 v->Visit("type_annotation", &type_annotation);
213 v->Visit("virtual_device_", &virtual_device_);
214 v->Visit("span", &span);
215 v->Visit("_checked_type_", &checked_type_);
216 }
217
218 bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
219 equal->MarkGraphNode();
220 return equal(type_annotation, other->type_annotation) && equal(vid, other->vid) &&
221 equal(virtual_device_, other->virtual_device_);
222 }
223
224 void SHashReduce(SHashReducer hash_reduce) const {
225 hash_reduce->MarkGraphNode();
226 hash_reduce(type_annotation);
227 hash_reduce(vid);
228 }
229
230 static constexpr const char* _type_key = "relay.Var";
231 TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
232};
233
234class Var : public Expr {
235 public:
236 /*!
237 * \brief The constructor
238 * \param name_hint The name hint of a variable.
239 * \param type_annotation The type annotation of a variable.
240 * \param span The source span of the expression.
241 */
242 TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span())
243 : Var(Id(name_hint), type_annotation, span) {}
244
245 /*!
246 * \brief The constructor
247 * \param vid The unique id of a variable.
248 * \param type_annotation The type annotation of a variable.
249 * \param span The source span of the expression.
250 */
251 TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());
252
253 /*!
254 * \brief Return a globally fresh name. Helps with debugging to follow the same
255 * variable between passes and sub-expressions.
256 *
257 * TODO(mbs): Replace with name creation w.r.t. scopes once available as part of
258 * name gen overhaul.
259 */
260 static Var GenSym(Type type_annotation = {}, Span span = {});
261
262 TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
263 TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
264};
265
266/*!
267 * \brief Returns \p var with the given properties. A null property denotes 'no change'.
268 * Returns \p var if all properties are unchanged. Otherwise, returns a copy with the new
269 * fields.
270 */
271Var WithFields(Var var, Optional<Id> opt_vid = Optional<Id>(),
272 Optional<Type> opt_type_annotation = Optional<Type>(),
273 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
274 Optional<Span> opt_span = Optional<Span>());
275
276/*!
277 * \brief Call corresponds to operator invocation.
278 * Corresponds to the operator in computational graph terminology.
279 */
280class Call;
281/*! \brief Call container. */
282class CallNode : public ExprNode {
283 protected:
284 // CallNode uses own deleter to indirectly call non-recursive destructor
285 Object::FDeleter saved_deleter_;
286 static void Deleter_(Object* ptr);
287
288 public:
289 /*!
290 * \brief The operator(function) being invoked
291 *
292 * - It can be tvm::Op which corresponds to the primitive operators.
293 * - It can also be user defined functions (Function, GlobalVar, Var).
294 */
295 Expr op;
296
297 /*! \brief The arguments(inputs) of the call */
298 tvm::Array<relay::Expr> args;
299
300 /*! \brief The additional attributes */
301 Attrs attrs;
302
303 /*!
304 * \brief The type arguments passed to polymorphic(template) function.
305 *
306 * This is the advance feature that is only used when the function is
307 * polymorphic. It is safe to be ignored in most cases. For example, in the
308 * following code, the type_args of addone call is [int].
309 *
310 * \code
311 *
312 * template<typename T>
313 * T addone(T a) { return a + 1; }
314 *
315 * void main() {
316 * int x = addone<int>(10);
317 * }
318 *
319 * \endcode
320 */
321 tvm::Array<Type> type_args;
322
323 void VisitAttrs(tvm::AttrVisitor* v) {
324 v->Visit("op", &op);
325 v->Visit("args", &args);
326 v->Visit("attrs", &attrs);
327 v->Visit("type_args", &type_args);
328 v->Visit("virtual_device_", &virtual_device_);
329 v->Visit("span", &span);
330 v->Visit("_checked_type_", &checked_type_);
331 }
332
333 bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
334 // skip type_args check for primitive ops.
335 equal->MarkGraphNode();
336 return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) &&
337 (IsPrimitiveOp(op) || equal(type_args, other->type_args));
338 }
339
340 void SHashReduce(SHashReducer hash_reduce) const {
341 hash_reduce->MarkGraphNode();
342 hash_reduce(op);
343 hash_reduce(args);
344 hash_reduce(attrs);
345 if (!IsPrimitiveOp(op)) {
346 hash_reduce(type_args);
347 }
348 }
349
350 static constexpr const char* _type_key = "relay.Call";
351 TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
352 template <typename>
353 friend class runtime::ObjAllocatorBase;
354 friend class Call;
355};
356
357class Call : public Expr {
358 public:
359 /*!
360 * \brief The destructor
361 */
362 ~Call();
363
364 /*!
365 * \brief The constructor
366 * \param op The operator will be invoked.
367 * \param args The arguments of the call.
368 * \param attrs The attributes of the call node.
369 * \param type_args The type arguments passed to a polymorphic function.
370 * \param span The source span of the expression.
371 */
372 TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
373 Array<Type> type_args = Array<Type>(), Span span = Span());
374
375 TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
376 TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
377};
378
379/*!
380 * \brief Returns \p call with the given properties. A null property denotes 'no change'.
381 * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new
382 * fields.
383 */
384Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(),
385 Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
386 Optional<Attrs> opt_attrs = Optional<Attrs>(),
387 Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(),
388 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
389 Optional<Span> opt_span = Optional<Span>());
390
391/*!
392 * \brief Let binding that binds a local var and optionally a type annotation.
393 *
394 * \note Let is useful to transform the program to be A-normal form.
395 * where each of the expression corresponds to a let binding.
396 *
397 * For developers who are familar with the computational graph.
398 * Each of the let can be viewed as a operator node in the computational graph.
399 * Traversing the list of let bindings is similar to running
400 * PostDFS-order(topo-order) traversal on the computational graph.
401 */
402class Let;
403/*! \brief A binding of a sub-network. */
404class LetNode : public ExprNode {
405 protected:
406 // LetNode uses own deleter to indirectly call non-recursive destructor
407 Object::FDeleter saved_deleter_;
408 static void Deleter_(Object* ptr);
409
410 public:
411 /*! \brief The variable we bind to */
412 Var var;
413 /*! \brief The value we bind var to */
414 Expr value;
415 /*! \brief The body of the let binding */
416 Expr body;
417
418 void VisitAttrs(tvm::AttrVisitor* v) {
419 v->Visit("var", &var);
420 v->Visit("value", &value);
421 v->Visit("body", &body);
422 v->Visit("virtual_device_", &virtual_device_);
423 v->Visit("span", &span);
424 v->Visit("_checked_type_", &checked_type_);
425 }
426
427 bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
428 equal->MarkGraphNode();
429 return equal.DefEqual(var, other->var) && equal(value, other->value) &&
430 equal(body, other->body);
431 }
432
433 void SHashReduce(SHashReducer hash_reduce) const {
434 hash_reduce->MarkGraphNode();
435 hash_reduce.DefHash(var);
436 hash_reduce(value);
437 hash_reduce(body);
438 }
439
440 static constexpr const char* _type_key = "relay.Let";
441 TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
442 template <typename>
443 friend class runtime::ObjAllocatorBase;
444 friend class Let;
445};
446
447class Let : public Expr {
448 public:
449 /*!
450 * \brief The destructor
451 */
452 ~Let();
453
454 /*!
455 * \brief The constructor
456 * \param var The variable that is bound to.
457 * \param value The value used to bind to the variable.
458 * \param body The body of the let binding.
459 * \param span The source span of the expression.
460 */
461 TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span());
462
463 TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode);
464 TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode);
465};
466
467/*!
468 * \brief Returns \p let with the given properties. A null property denotes 'no change'.
469 * Returns \p let if all properties are unchanged. Otherwise, returns a copy with the new
470 * fields.
471 */
472Let WithFields(Let let, Optional<Var> opt_var = Optional<Var>(),
473 Optional<Expr> opt_value = Optional<Expr>(),
474 Optional<Expr> opt_body = Optional<Expr>(),
475 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
476 Optional<Span> opt_span = Optional<Span>());
477
478/*!
479 * \brief Condition expression
480 *
481 * Unlike traditional statement `if`s, the if evalutes
482 * to the result of the branch taken.
483 *
484 * let x = if (true) { 1 } else { 0 }; // x is 1
485 * let y = if (false) { 1 } else { 0 }; // y is 0
486 *
487 * \note This is similar to C's ternary operator.
488 */
489class If;
490/*! \brief container of If */
491class IfNode : public ExprNode {
492 public:
493 /*! \brief The condition */
494 Expr cond;
495 /*! \brief The expression evaluated when condition is true. */
496 Expr true_branch;
497 /*! \brief The expression evaluated when condition is false */
498 Expr false_branch;
499
500 void VisitAttrs(tvm::AttrVisitor* v) {
501 v->Visit("cond", &cond);
502 v->Visit("true_branch", &true_branch);
503 v->Visit("false_branch", &false_branch);
504 v->Visit("virtual_device_", &virtual_device_);
505 v->Visit("span", &span);
506 v->Visit("_checked_type_", &checked_type_);
507 }
508
509 bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
510 equal->MarkGraphNode();
511 return equal(cond, other->cond) && equal(true_branch, other->true_branch) &&
512 equal(false_branch, other->false_branch);
513 }
514
515 void SHashReduce(SHashReducer hash_reduce) const {
516 hash_reduce->MarkGraphNode();
517 hash_reduce(cond);
518 hash_reduce(true_branch);
519 hash_reduce(false_branch);
520 }
521
522 static constexpr const char* _type_key = "relay.If";
523 TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
524};
525
526class If : public Expr {
527 public:
528 /*!
529 * \brief The constructor
530 * \param cond The condition of a if node.
531 * \param true_branch The fall through branch
532 * \param false_branch The branch for execution when condition is false.
533 * \param span The source span of the expression.
534 */
535 TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
536
537 TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode);
538 TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode);
539};
540
541/*!
542 * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'.
543 * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new
544 * fields.
545 */
546If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
547 Optional<Expr> opt_true_branch = Optional<Expr>(),
548 Optional<Expr> opt_false_branch = Optional<Expr>(),
549 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
550 Optional<Span> opt_span = Optional<Span>());
551
552/*! \brief Get index-th field out of a tuple. */
553class TupleGetItem;
554class TupleGetItemNode : public ExprNode {
555 public:
556 /*! \brief The tuple Expression */
557 Expr tuple;
558 /*! \brief which value to get */
559 int index;
560
561 void VisitAttrs(tvm::AttrVisitor* v) {
562 v->Visit("tuple_value", &tuple);
563 v->Visit("index", &index);
564 v->Visit("virtual_device_", &virtual_device_);
565 v->Visit("span", &span);
566 v->Visit("_checked_type_", &checked_type_);
567 }
568
569 bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
570 return equal(tuple, other->tuple) && equal(index, other->index);
571 }
572
573 void SHashReduce(SHashReducer hash_reduce) const {
574 hash_reduce(tuple);
575 hash_reduce(index);
576 }
577
578 static constexpr const char* _type_key = "relay.TupleGetItem";
579 TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
580};
581
582class TupleGetItem : public Expr {
583 public:
584 /*!
585 * \brief The constructor
586 * \param tuple The tuple to get an element from.
587 * \param index The index for extracting a value in the tuple.
588 * \param span The source span of the expression.
589 */
590 TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
591
592 TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode);
593 TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode);
594};
595
596/*!
597 * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'.
598 * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new
599 * fields.
600 */
601TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple = Optional<Expr>(),
602 Optional<Integer> opt_index = Optional<Integer>(),
603 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
604 Optional<Span> opt_span = Optional<Span>());
605
606/*! \brief Create a new Reference out of initial value. */
607class RefCreate;
608class RefCreateNode : public ExprNode {
609 public:
610 /*! \brief The initial value of the Reference. */
611 Expr value;
612
613 void VisitAttrs(tvm::AttrVisitor* v) {
614 v->Visit("value", &value);
615 v->Visit("virtual_device_", &virtual_device_);
616 v->Visit("span", &span);
617 v->Visit("_checked_type_", &checked_type_);
618 }
619
620 bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const {
621 equal->MarkGraphNode();
622 return equal(value, other->value);
623 }
624
625 void SHashReduce(SHashReducer hash_reduce) const {
626 hash_reduce->MarkGraphNode();
627 hash_reduce(value);
628 }
629
630 static constexpr const char* _type_key = "relay.RefCreate";
631 TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
632};
633
634class RefCreate : public Expr {
635 public:
636 /*!
637 * \brief The constructor
638 * \param value The initial value of the reference.
639 * \param span The source span of the expression.
640 */
641 TVM_DLL explicit RefCreate(Expr value, Span span = Span());
642
643 TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode);
644 TVM_DEFINE_OBJECT_REF_COW_METHOD(RefCreateNode);
645};
646
647/*!
648 * \brief Returns \p ref_create with the given properties. A null property denotes 'no change'.
649 * Returns \p ref_crete if all properties are unchanged. Otherwise, returns a copy with the new
650 * fields.
651 */
652RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value = Optional<Expr>(),
653 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
654 Optional<Span> opt_span = Optional<Span>());
655
656/*! \brief Get value out of Reference. */
657class RefRead;
658class RefReadNode : public ExprNode {
659 public:
660 /*! \brief The Reference Expression. */
661 Expr ref;
662
663 void VisitAttrs(tvm::AttrVisitor* v) {
664 v->Visit("ref", &ref);
665 v->Visit("virtual_device_", &virtual_device_);
666 v->Visit("span", &span);
667 v->Visit("_checked_type_", &checked_type_);
668 }
669
670 bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const {
671 equal->MarkGraphNode();
672 return equal(ref, other->ref);
673 }
674
675 void SHashReduce(SHashReducer hash_reduce) const {
676 hash_reduce->MarkGraphNode();
677 hash_reduce(ref);
678 }
679
680 static constexpr const char* _type_key = "relay.RefRead";
681 TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
682};
683
684class RefRead : public Expr {
685 public:
686 /*!
687 * \brief The constructor
688 * \param ref The reference where to read data.
689 * \param span The source span of the expression.
690 */
691 TVM_DLL explicit RefRead(Expr ref, Span span = Span());
692
693 TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode);
694 TVM_DEFINE_OBJECT_REF_COW_METHOD(RefReadNode);
695};
696
697/*!
698 * \brief Returns \p ref_read with the given properties. A null property denotes 'no change'.
699 * Returns \p ref_read if all properties are unchanged. Otherwise, returns a copy with the new
700 * fields.
701 */
702RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref = Optional<Expr>(),
703 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
704 Optional<Span> opt_span = Optional<Span>());
705
706/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
707class RefWrite;
708class RefWriteNode : public ExprNode {
709 public:
710 /*! \brief The Reference Expression. */
711 Expr ref;
712 /*! \brief The value to write into. */
713 Expr value;
714
715 void VisitAttrs(tvm::AttrVisitor* v) {
716 v->Visit("ref", &ref);
717 v->Visit("value", &value);
718 v->Visit("virtual_device_", &virtual_device_);
719 v->Visit("span", &span);
720 v->Visit("_checked_type_", &checked_type_);
721 }
722
723 bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
724 equal->MarkGraphNode();
725 return equal(ref, other->ref) && equal(value, other->value);
726 }
727
728 void SHashReduce(SHashReducer hash_reduce) const {
729 hash_reduce->MarkGraphNode();
730 hash_reduce(ref);
731 hash_reduce(value);
732 }
733
734 static constexpr const char* _type_key = "relay.RefWrite";
735 TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode);
736};
737
738class RefWrite : public Expr {
739 public:
740 /*!
741 * \brief The constructor
742 * \param ref The reference where data is write to.
743 * \param value The value to write.
744 * \param span The source span of the expression.
745 */
746 TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span());
747
748 TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode);
749 TVM_DEFINE_OBJECT_REF_COW_METHOD(RefWriteNode);
750};
751
752/*!
753 * \brief Returns \p ref_write with the given properties. A null property denotes 'no change'.
754 * Returns \p ref_write if all properties are unchanged. Otherwise, returns a copy with the new
755 * fields.
756 */
757RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref = Optional<Expr>(),
758 Optional<Expr> opt_value = Optional<Expr>(),
759 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
760 Optional<Span> opt_span = Optional<Span>());
761
762/*!
763 * \brief Base class of the temporary expression.
764 *
765 * TempExprs are pass specific expression that can be
766 * useful to define intermediate result in the
767 * rewriting pass such as layout or type transformation.
768 *
769 * Subclass TempExprNode allows us to pattern match on
770 * specific kind of TempExpr and use them for expression rewriting.
771 *
772 * TempExpr should only be used within a pass,
773 */
774class TempExprNode : public ExprNode {
775 public:
776 /*! \brief virtual destructor */
777 virtual ~TempExprNode() {}
778 /*!
779 * \brief Convert the expression to a normal(non-temp) Expr.
780 * \return The corresponding normal(non-temp) expression.
781 */
782 virtual Expr Realize() const = 0;
783
784 static constexpr const char* _type_key = "relay.TempExpr";
785 static constexpr const bool _type_has_method_sequal_reduce = false;
786 static constexpr const bool _type_has_method_shash_reduce = false;
787 static constexpr const uint32_t _type_child_slots = 0;
788 TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
789};
790
791class TempExpr : public Expr {
792 public:
793 TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
794};
795
796} // namespace relay
797
798namespace runtime {
799
800template <>
801template <>
802inline ObjectPtr<relay::LetNode>
803ObjAllocatorBase<SimpleObjAllocator>::make_object<relay::LetNode>() {
804 using Derived = SimpleObjAllocator;
805 using T = relay::LetNode;
806 using Handler = typename Derived::template Handler<T>;
807 static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
808 T* ptr = Handler::New(static_cast<Derived*>(this));
809 ptr->type_index_ = T::RuntimeTypeIndex();
810 ptr->saved_deleter_ = Handler::Deleter();
811 ptr->deleter_ = relay::LetNode::Deleter_;
812 return ObjectPtr<T>(ptr);
813}
814
815template <>
816template <>
817inline ObjectPtr<relay::CallNode>
818ObjAllocatorBase<SimpleObjAllocator>::make_object<relay::CallNode>() {
819 using Derived = SimpleObjAllocator;
820 using T = relay::CallNode;
821 using Handler = typename Derived::template Handler<T>;
822 static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
823 T* ptr = Handler::New(static_cast<Derived*>(this));
824 ptr->type_index_ = T::RuntimeTypeIndex();
825 ptr->saved_deleter_ = Handler::Deleter();
826 ptr->deleter_ = relay::CallNode::Deleter_;
827 return ObjectPtr<T>(ptr);
828}
829
830} // namespace runtime
831
832} // namespace tvm
833#endif // TVM_RELAY_EXPR_H_
834