1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/ir/expr.cc
22 * \brief The expression AST nodes of Relay.
23 */
24#include <tvm/ir/module.h>
25#include <tvm/relay/expr.h>
26#include <tvm/target/virtual_device.h>
27
28namespace tvm {
29
30GlobalVar WithFields(GlobalVar global_var, Optional<String> opt_name_hint, Optional<Type> opt_type,
31 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
32 String name_hint = opt_name_hint.value_or(global_var->name_hint);
33 Type type = opt_type.value_or(global_var->checked_type());
34 VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device());
35 Span span = opt_span.value_or(global_var->span);
36 bool all_fields_unchanged =
37 name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) &&
38 virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span);
39 if (!all_fields_unchanged) {
40 GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite();
41 cow_global_var_node->name_hint = name_hint;
42 cow_global_var_node->checked_type_ = type;
43 cow_global_var_node->virtual_device_ = virtual_device;
44 cow_global_var_node->span = span;
45 }
46
47 return global_var;
48}
49
50VirtualDevice RelayExprNode::virtual_device() const {
51 if (!this->virtual_device_.defined()) {
52 // virtual_device_ should always be defined, unless we imported this node from JSON using an old
53 // version of TVM, in which case we want to set it to the default, which is
54 // VirtualDevice::FullyUnconstrained().
55 return VirtualDevice::FullyUnconstrained();
56 }
57 return Downcast<VirtualDevice>(this->virtual_device_);
58}
59
60namespace relay {
61
62using tvm::ReprPrinter;
63using namespace tvm::runtime;
64
65Constant::Constant(runtime::NDArray data, Span span) {
66 ObjectPtr<ConstantNode> n = make_object<ConstantNode>();
67 n->data = std::move(data);
68 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
69 n->span = std::move(span);
70 data_ = std::move(n);
71}
72
73TVM_REGISTER_NODE_TYPE(ConstantNode);
74
75TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data, Span span) {
76 return Constant(data, span);
77});
78TVM_REGISTER_GLOBAL("relay.ir.ConstantWithFields")
79 .set_body_typed([](Constant constant, Optional<runtime::NDArray> opt_data,
80 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
81 return WithFields(constant, opt_data, opt_virtual_device, opt_span);
82 });
83
84TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
85 .set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
86 auto* node = static_cast<const ConstantNode*>(ref.get());
87 const PackedFunc* fprint = Registry::Get("relay._constant_repr");
88 ICHECK(fprint) << "unable to find printing function for constants";
89 std::string data = (*fprint)(GetRef<Constant>(node));
90 p->stream << "Constant(" << data << ")";
91 });
92
93TensorType ConstantNode::tensor_type() const {
94 auto dtype = DataType(data->dtype);
95 Array<tvm::PrimExpr> shape;
96 for (int i = 0; i < data->ndim; i++) {
97 ICHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
98 ICHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
99 shape.push_back(tvm::IntImm(DataType::Int(32), data->shape[i]));
100 }
101
102 return TensorType(shape, dtype);
103}
104
105Constant WithFields(Constant constant, Optional<runtime::NDArray> opt_data,
106 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
107 runtime::NDArray data = opt_data.value_or(constant->data);
108 VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device());
109 Span span = opt_span.value_or(constant->span);
110
111 bool all_fields_unchanged = data.same_as(constant->data) &&
112 virtual_device.same_as(constant->virtual_device()) &&
113 span.same_as(constant->span);
114
115 if (!all_fields_unchanged) {
116 ConstantNode* cow_constant_node = constant.CopyOnWrite();
117 cow_constant_node->data = data;
118 cow_constant_node->virtual_device_ = virtual_device;
119 cow_constant_node->span = span;
120 }
121 return constant;
122}
123
124Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
125 ObjectPtr<TupleNode> n = make_object<TupleNode>();
126 n->fields = std::move(fields);
127 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
128 n->span = std::move(span);
129 data_ = std::move(n);
130}
131
132TVM_REGISTER_NODE_TYPE(TupleNode);
133
134TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> fields, Span span) {
135 return Tuple(fields, span);
136});
137TVM_REGISTER_GLOBAL("relay.ir.TupleWithFields")
138 .set_body_typed([](Tuple tuple, Optional<Array<Expr>> opt_fields,
139 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
140 return WithFields(tuple, opt_fields, opt_virtual_device, opt_span);
141 });
142
143Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields,
144 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
145 Array<Expr> fields = opt_fields.value_or(tuple->fields);
146 VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device());
147 Span span = opt_span.value_or(tuple->span);
148
149 bool all_fields_unchanged = true;
150 if (fields.size() == tuple->fields.size()) {
151 for (size_t i = 0; i < fields.size(); i++) {
152 all_fields_unchanged &= fields[i].same_as(tuple->fields[i]);
153 }
154 } else {
155 all_fields_unchanged = false;
156 }
157
158 all_fields_unchanged = all_fields_unchanged && virtual_device.same_as(tuple->virtual_device()) &&
159 span.same_as(tuple->span);
160 if (!all_fields_unchanged) {
161 TupleNode* cow_tuple_node = tuple.CopyOnWrite();
162 cow_tuple_node->fields = fields;
163 cow_tuple_node->virtual_device_ = virtual_device;
164 cow_tuple_node->span = span;
165 }
166 return tuple;
167}
168
169TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
170 .set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
171 auto* node = static_cast<const TupleNode*>(ref.get());
172 p->stream << "Tuple(" << node->fields << ")";
173 });
174
175Var::Var(Id vid, Type type_annotation, Span span) {
176 ObjectPtr<VarNode> n = make_object<VarNode>();
177 n->vid = std::move(vid);
178 n->type_annotation = std::move(type_annotation);
179 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
180 n->span = std::move(span);
181 data_ = std::move(n);
182}
183
184/* static */ Var Var::GenSym(Type type_annotation, Span span) {
185 static size_t next_id = std::atomic<size_t>(0);
186 std::ostringstream os;
187 os << "x_" << next_id++;
188 return Var(os.str(), std::move(type_annotation), std::move(span));
189}
190
191Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation,
192 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
193 Id vid = opt_vid.value_or(var->vid);
194 Type type_annotation = opt_type_annotation.value_or(var->type_annotation);
195 VirtualDevice virtual_device = opt_virtual_device.value_or(var->virtual_device());
196 Span span = opt_span.value_or(var->span);
197
198 bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) &&
199 virtual_device.same_as(var->virtual_device()) && span.same_as(var->span);
200
201 if (!unchanged) {
202 VarNode* cow_var_node = var.CopyOnWrite();
203 cow_var_node->vid = vid;
204 cow_var_node->type_annotation = type_annotation;
205 cow_var_node->virtual_device_ = virtual_device;
206 cow_var_node->span = span;
207 }
208 return var;
209}
210
211TVM_REGISTER_NODE_TYPE(VarNode);
212
213TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation, Span span) {
214 return Var(str, type_annotation, span);
215});
216TVM_REGISTER_GLOBAL("relay.ir.VarWithFields")
217 .set_body_typed([](Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation,
218 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
219 return WithFields(var, opt_vid, opt_type_annotation, opt_virtual_device, opt_span);
220 });
221
222TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
223 .set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
224 auto* node = static_cast<const VarNode*>(ref.get());
225 p->stream << "Var(" << node->name_hint();
226 if (node->type_annotation.defined()) {
227 p->stream << ", ty=";
228 p->Print(node->type_annotation);
229 }
230 p->stream << ")";
231 });
232
233Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) {
234 ObjectPtr<CallNode> n = make_object<CallNode>();
235 n->op = std::move(op);
236 n->args = std::move(args);
237 n->attrs = std::move(attrs);
238 n->type_args = std::move(type_args);
239 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
240 n->span = std::move(span);
241 data_ = std::move(n);
242}
243
244Call WithFields(Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args,
245 Optional<Attrs> opt_attrs, Optional<Array<Type>> opt_type_args,
246 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
247 // Collect new values for fields.
248 Expr op = opt_op.value_or(call->op);
249 Array<Expr> args = opt_args.value_or(call->args);
250 Attrs attrs = opt_attrs.value_or(call->attrs);
251 Array<Type> type_args = opt_type_args.value_or(call->type_args);
252 VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device());
253 Span span = opt_span.value_or(call->span);
254
255 // Check if anything changed.
256 bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) &&
257 virtual_device.same_as(call->virtual_device()) && span.same_as(call->span);
258 if (unchanged) {
259 if (args.size() == call->args.size()) {
260 for (size_t i = 0; i < args.size(); i++) {
261 unchanged &= args[i].same_as(call->args[i]);
262 }
263 } else {
264 unchanged = false;
265 }
266 }
267 if (unchanged) {
268 if (type_args.size() == call->type_args.size()) {
269 for (size_t i = 0; i < type_args.size(); i++) {
270 unchanged &= type_args[i].same_as(call->type_args[i]);
271 }
272 } else {
273 unchanged = false;
274 }
275 }
276
277 if (!unchanged) {
278 // If call is only references, update it in place. Otherwise copy and update.
279 CallNode* cow_call_node = call.CopyOnWrite();
280 cow_call_node->op = op;
281 cow_call_node->args = args;
282 cow_call_node->attrs = attrs;
283 cow_call_node->type_args = type_args;
284 cow_call_node->virtual_device_ = virtual_device;
285 cow_call_node->span = span;
286 }
287 return call;
288}
289
290TVM_REGISTER_NODE_TYPE(CallNode);
291
292TVM_REGISTER_GLOBAL("relay.ir.Call")
293 .set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) {
294 return Call(op, args, attrs, type_args, span);
295 });
296TVM_REGISTER_GLOBAL("relay.ir.CallWithFields")
297 .set_body_typed([](Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args,
298 Optional<Attrs> opt_attrs, Optional<Array<Type>> opt_type_args,
299 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
300 return WithFields(call, opt_op, opt_args, opt_attrs, opt_type_args, opt_virtual_device,
301 opt_span);
302 });
303
304TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
305 .set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
306 auto* node = static_cast<const CallNode*>(ref.get());
307 p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", "
308 << node->type_args << ")";
309 });
310
311Let::Let(Var var, Expr value, Expr body, Span span) {
312 ObjectPtr<LetNode> n = make_object<LetNode>();
313 n->var = std::move(var);
314 n->value = std::move(value);
315 n->body = std::move(body);
316 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
317 n->span = std::move(span);
318 data_ = std::move(n);
319}
320
321Let WithFields(Let let, Optional<Var> opt_var, Optional<Expr> opt_value, Optional<Expr> opt_body,
322 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
323 Var var = opt_var.value_or(let->var);
324 Expr value = opt_value.value_or(let->value);
325 Expr body = opt_body.value_or(let->body);
326 VirtualDevice virtual_device = opt_virtual_device.value_or(let->virtual_device());
327 Span span = opt_span.value_or(let->span);
328
329 bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) &&
330 virtual_device.same_as(let->virtual_device()) && span.same_as(let->span);
331
332 if (!unchanged) {
333 LetNode* cow_let_node = let.CopyOnWrite();
334 cow_let_node->var = var;
335 cow_let_node->value = value;
336 cow_let_node->body = body;
337 cow_let_node->virtual_device_ = virtual_device;
338 cow_let_node->span = span;
339 }
340 return let;
341}
342
343TVM_REGISTER_NODE_TYPE(LetNode);
344
345TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body, Span span) {
346 return Let(var, value, body, span);
347});
348TVM_REGISTER_GLOBAL("relay.ir.LetWithFields")
349 .set_body_typed([](Let let, Optional<Var> opt_var, Optional<Expr> opt_value,
350 Optional<Expr> opt_body, Optional<VirtualDevice> opt_virtual_device,
351 Optional<Span> opt_span) {
352 return WithFields(let, opt_var, opt_value, opt_body, opt_virtual_device, opt_span);
353 });
354
355TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
356 .set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
357 auto* node = static_cast<const LetNode*>(ref.get());
358 p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")";
359 });
360
361If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) {
362 ObjectPtr<IfNode> n = make_object<IfNode>();
363 n->cond = std::move(cond);
364 n->true_branch = std::move(true_branch);
365 n->false_branch = std::move(false_branch);
366 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
367 n->span = std::move(span);
368 data_ = std::move(n);
369}
370
371If WithFields(If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branch,
372 Optional<Expr> opt_false_branch, Optional<VirtualDevice> opt_virtual_device,
373 Optional<Span> opt_span) {
374 Expr cond = opt_cond.value_or(if_expr->cond);
375 Expr true_branch = opt_true_branch.value_or(if_expr->true_branch);
376 Expr false_branch = opt_false_branch.value_or(if_expr->false_branch);
377 VirtualDevice virtual_device = opt_virtual_device.value_or(if_expr->virtual_device());
378 Span span = opt_span.value_or(if_expr->span);
379
380 bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) &&
381 false_branch.same_as(if_expr->false_branch) &&
382 virtual_device.same_as(if_expr->virtual_device()) && span.same_as(if_expr->span);
383
384 if (!unchanged) {
385 IfNode* cow_if_node = if_expr.CopyOnWrite();
386 cow_if_node->cond = cond;
387 cow_if_node->true_branch = true_branch;
388 cow_if_node->false_branch = false_branch;
389 cow_if_node->virtual_device_ = virtual_device;
390 cow_if_node->span = span;
391 }
392 return if_expr;
393}
394
395TVM_REGISTER_NODE_TYPE(IfNode);
396
397TVM_REGISTER_GLOBAL("relay.ir.If")
398 .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) {
399 return If(cond, true_branch, false_branch, span);
400 });
401TVM_REGISTER_GLOBAL("relay.ir.IfWithFields")
402 .set_body_typed([](If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branch,
403 Optional<Expr> opt_false_branch, Optional<VirtualDevice> opt_virtual_device,
404 Optional<Span> opt_span) {
405 return WithFields(if_expr, opt_cond, opt_true_branch, opt_false_branch, opt_virtual_device,
406 opt_span);
407 });
408
409TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
410 .set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
411 auto* node = static_cast<const IfNode*>(ref.get());
412 p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", "
413 << node->false_branch << ")";
414 });
415
416TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) {
417 ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
418 n->tuple = std::move(tuple);
419 n->index = index;
420 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
421 n->span = std::move(span);
422 data_ = std::move(n);
423}
424
425TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
426 Optional<Integer> opt_index, Optional<VirtualDevice> opt_virtual_device,
427 Optional<Span> opt_span) {
428 Expr tuple = opt_tuple.value_or(tuple_get_item->tuple);
429 Integer index = opt_index.value_or(tuple_get_item->index);
430 VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device());
431 Span span = opt_span.value_or(tuple_get_item->span);
432
433 bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) &&
434 virtual_device.same_as(tuple_get_item->virtual_device()) &&
435 span.same_as(tuple_get_item->span);
436 if (!unchanged) {
437 TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite();
438 cow_tuple_get_item_node->tuple = tuple;
439 cow_tuple_get_item_node->index = index.IntValue();
440 cow_tuple_get_item_node->span = span;
441 cow_tuple_get_item_node->virtual_device_ = virtual_device;
442 }
443 return tuple_get_item;
444}
445
446TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
447
448TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) {
449 return TupleGetItem(tuple, index, span);
450});
451TVM_REGISTER_GLOBAL("relay.ir.TupleGetItemWithFields")
452 .set_body_typed([](TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
453 Optional<Integer> opt_index, Optional<VirtualDevice> opt_virtual_device,
454 Optional<Span> opt_span) {
455 return WithFields(tuple_get_item, opt_tuple, opt_index, opt_virtual_device, opt_span);
456 });
457
458TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
459 .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
460 auto* node = static_cast<const TupleGetItemNode*>(ref.get());
461 p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
462 });
463
464RefCreate::RefCreate(Expr value, Span span) {
465 ObjectPtr<RefCreateNode> n = make_object<RefCreateNode>();
466 n->value = std::move(value);
467 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
468 n->span = std::move(span);
469 data_ = std::move(n);
470}
471
472RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value,
473 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
474 Expr value = opt_value.value_or(ref_create->value);
475 VirtualDevice virtual_device = opt_virtual_device.value_or(ref_create->virtual_device());
476 Span span = opt_span.value_or(ref_create->span);
477
478 bool unchanged = value.same_as(ref_create->value) &&
479 virtual_device.same_as(ref_create->virtual_device()) &&
480 span.same_as(ref_create->span);
481 if (!unchanged) {
482 RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite();
483 cow_ref_create_node->value = value;
484 cow_ref_create_node->virtual_device_ = virtual_device;
485 cow_ref_create_node->span = span;
486 }
487 return ref_create;
488}
489
490TVM_REGISTER_NODE_TYPE(RefCreateNode);
491
492TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value, Span span) {
493 return RefCreate(value, span);
494});
495TVM_REGISTER_GLOBAL("relay.ir.RefCreateWithFields")
496 .set_body_typed([](RefCreate ref_create, Optional<Expr> opt_value,
497 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
498 return WithFields(ref_create, opt_value, opt_virtual_device, opt_span);
499 });
500
501TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
502 .set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
503 auto* node = static_cast<const RefCreateNode*>(ref.get());
504 p->stream << "RefCreateNode(" << node->value << ")";
505 });
506
507RefRead::RefRead(Expr ref, Span span) {
508 ObjectPtr<RefReadNode> n = make_object<RefReadNode>();
509 n->ref = std::move(ref);
510 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
511 n->span = std::move(span);
512 data_ = std::move(n);
513}
514
515RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref,
516 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
517 Expr ref = opt_ref.value_or(ref_read->ref);
518 VirtualDevice virtual_device = opt_virtual_device.value_or(ref_read->virtual_device());
519 Span span = opt_span.value_or(ref_read->span);
520
521 bool unchanged = ref.same_as(ref_read->ref) &&
522 virtual_device.same_as(ref_read->virtual_device()) &&
523 span.same_as(ref_read->span);
524 if (!unchanged) {
525 RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite();
526 cow_ref_read_node->ref = ref;
527 cow_ref_read_node->virtual_device_ = virtual_device;
528 cow_ref_read_node->span = span;
529 }
530 return ref_read;
531}
532
533TVM_REGISTER_NODE_TYPE(RefReadNode);
534
535TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref, Span span) {
536 return RefRead(ref, span);
537});
538TVM_REGISTER_GLOBAL("relay.ir.RefReadWithFields")
539 .set_body_typed([](RefRead ref_read, Optional<Expr> opt_ref,
540 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
541 return WithFields(ref_read, opt_ref, opt_virtual_device, opt_span);
542 });
543
544TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
545 .set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
546 auto* node = static_cast<const RefReadNode*>(ref.get());
547 p->stream << "RefReadNode(" << node->ref << ")";
548 });
549
550RefWrite::RefWrite(Expr ref, Expr value, Span span) {
551 ObjectPtr<RefWriteNode> n = make_object<RefWriteNode>();
552 n->ref = std::move(ref);
553 n->value = std::move(value);
554 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
555 n->span = std::move(span);
556 data_ = std::move(n);
557}
558
559RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> opt_value,
560 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
561 Expr ref = opt_ref.value_or(ref_write->ref);
562 Expr value = opt_value.value_or(ref_write->value);
563 VirtualDevice virtual_device = opt_virtual_device.value_or(ref_write->virtual_device());
564 Span span = opt_span.value_or(ref_write->span);
565
566 bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) &&
567 virtual_device.same_as(ref_write->virtual_device()) &&
568 span.same_as(ref_write->span);
569 if (!unchanged) {
570 RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite();
571 cow_ref_write_node->ref = ref;
572 cow_ref_write_node->value = value;
573 cow_ref_write_node->virtual_device_ = virtual_device;
574 cow_ref_write_node->span = span;
575 }
576 return ref_write;
577}
578
579TVM_REGISTER_NODE_TYPE(RefWriteNode);
580
581TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value, Span span) {
582 return RefWrite(ref, value, span);
583});
584TVM_REGISTER_GLOBAL("relay.ir.RefWriteWithFields")
585 .set_body_typed([](RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> opt_value,
586 Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
587 return WithFields(ref_write, opt_ref, opt_value, opt_virtual_device, opt_span);
588 });
589
590TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
591 .set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
592 auto* node = static_cast<const RefWriteNode*>(ref.get());
593 p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
594 });
595
596TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize").set_body_typed([](TempExpr temp) {
597 return temp->Realize();
598});
599
600TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any(); });
601
602/*
603 * Non-recursive traversal with dismantling unused call nodes,
604 * a derivative from ExpandDataflow method
605 */
606inline void Dismantle(const Expr& expr) {
607 std::stack<std::pair<Expr, bool>> stack;
608 auto fpush_to_stack = [&stack](const Expr& expr) {
609 // do not visit nodes with more than 2 refs (one can be in stack)
610 if (expr.use_count() < 3) {
611 stack.push({expr, false});
612 }
613 };
614 fpush_to_stack(expr);
615 while (stack.size() > 0) {
616 const auto& node = stack.top().first;
617 if (stack.top().second) {
618 // dismantle node
619 // +1 ref in stack/deque;
620 if (node.use_count() < 3) {
621 if (auto* op = const_cast<CallNode*>(node.as<CallNode>())) {
622 op->args = Array<Expr>();
623 }
624 if (auto* op = const_cast<LetNode*>(node.as<LetNode>())) {
625 op->body = Expr();
626 }
627 }
628 // eject
629 stack.pop();
630 } else {
631 stack.top().second = true;
632
633 // special handling
634 if (const auto* call_node = node.as<CallNode>()) {
635 // do not process args if used elsewhere
636 if (call_node->args.use_count() < 2) {
637 for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) {
638 fpush_to_stack(*it);
639 }
640 }
641 } else if (const auto* tuple_node = node.as<TupleNode>()) {
642 // do not process fields if used elsewhere
643 if (tuple_node->fields.use_count() < 2) {
644 for (auto it = tuple_node->fields.rbegin(); it != tuple_node->fields.rend(); ++it) {
645 fpush_to_stack(*it);
646 }
647 }
648 } else if (const auto* tuple_get_item_node = node.as<TupleGetItemNode>()) {
649 // do not process tuple if used elsewhere
650 if (tuple_get_item_node->tuple.use_count() < 2) {
651 fpush_to_stack(tuple_get_item_node->tuple);
652 }
653 } else if (const auto* let_node = node.as<LetNode>()) {
654 // do not process let if used elsewhere
655 if (let_node->body.use_count() < 2) {
656 fpush_to_stack(let_node->body);
657 }
658 }
659 }
660 }
661}
662
663/*
664 * Non-recursive destructor
665 */
666Call::~Call() {
667 // attempt to dismantle if referenced one or zero times
668 if (this->use_count() < 2) {
669 if (this->as<CallNode>() && this->as<CallNode>()->args.size()) {
670 Dismantle(*this);
671 }
672 }
673}
674
675/*
676 * CallNode's deleter
677 */
678void CallNode::Deleter_(Object* ptr) {
679 auto p = reinterpret_cast<CallNode*>(ptr);
680 // resore original deleter
681 p->deleter_ = p->saved_deleter_;
682 // create Call reference in order to invoke ~Call
683 auto c = GetRef<Call>(p);
684}
685
686/*
687 * Non-recursive destructor
688 */
689Let::~Let() {
690 // attempt to dismantle if referenced one or zero times
691 if (this->use_count() < 2) {
692 if (this->as<LetNode>() && this->as<LetNode>()->body.defined()) {
693 Dismantle(*this);
694 }
695 }
696}
697
698/*
699 * LetNode's deleter
700 */
701void LetNode::Deleter_(Object* ptr) {
702 auto p = reinterpret_cast<LetNode*>(ptr);
703 // resore original deleter
704 p->deleter_ = p->saved_deleter_;
705 // create Let reference in order to invoke ~Let
706 auto c = GetRef<Let>(p);
707}
708
709} // namespace relay
710} // namespace tvm
711