1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/ir/expr.cc |
22 | * \brief The expression AST nodes of Relay. |
23 | */ |
24 | #include <tvm/ir/module.h> |
25 | #include <tvm/relay/expr.h> |
26 | #include <tvm/target/virtual_device.h> |
27 | |
28 | namespace tvm { |
29 | |
30 | GlobalVar WithFields(GlobalVar global_var, Optional<String> opt_name_hint, Optional<Type> opt_type, |
31 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
32 | String name_hint = opt_name_hint.value_or(global_var->name_hint); |
33 | Type type = opt_type.value_or(global_var->checked_type()); |
34 | VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device()); |
35 | Span span = opt_span.value_or(global_var->span); |
36 | bool all_fields_unchanged = |
37 | name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) && |
38 | virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span); |
39 | if (!all_fields_unchanged) { |
40 | GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite(); |
41 | cow_global_var_node->name_hint = name_hint; |
42 | cow_global_var_node->checked_type_ = type; |
43 | cow_global_var_node->virtual_device_ = virtual_device; |
44 | cow_global_var_node->span = span; |
45 | } |
46 | |
47 | return global_var; |
48 | } |
49 | |
50 | VirtualDevice RelayExprNode::virtual_device() const { |
51 | if (!this->virtual_device_.defined()) { |
52 | // virtual_device_ should always be defined, unless we imported this node from JSON using an old |
53 | // version of TVM, in which case we want to set it to the default, which is |
54 | // VirtualDevice::FullyUnconstrained(). |
55 | return VirtualDevice::FullyUnconstrained(); |
56 | } |
57 | return Downcast<VirtualDevice>(this->virtual_device_); |
58 | } |
59 | |
60 | namespace relay { |
61 | |
62 | using tvm::ReprPrinter; |
63 | using namespace tvm::runtime; |
64 | |
65 | Constant::Constant(runtime::NDArray data, Span span) { |
66 | ObjectPtr<ConstantNode> n = make_object<ConstantNode>(); |
67 | n->data = std::move(data); |
68 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
69 | n->span = std::move(span); |
70 | data_ = std::move(n); |
71 | } |
72 | |
73 | TVM_REGISTER_NODE_TYPE(ConstantNode); |
74 | |
75 | TVM_REGISTER_GLOBAL("relay.ir.Constant" ).set_body_typed([](runtime::NDArray data, Span span) { |
76 | return Constant(data, span); |
77 | }); |
78 | TVM_REGISTER_GLOBAL("relay.ir.ConstantWithFields" ) |
79 | .set_body_typed([](Constant constant, Optional<runtime::NDArray> opt_data, |
80 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
81 | return WithFields(constant, opt_data, opt_virtual_device, opt_span); |
82 | }); |
83 | |
84 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
85 | .set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) { |
86 | auto* node = static_cast<const ConstantNode*>(ref.get()); |
87 | const PackedFunc* fprint = Registry::Get("relay._constant_repr" ); |
88 | ICHECK(fprint) << "unable to find printing function for constants" ; |
89 | std::string data = (*fprint)(GetRef<Constant>(node)); |
90 | p->stream << "Constant(" << data << ")" ; |
91 | }); |
92 | |
93 | TensorType ConstantNode::tensor_type() const { |
94 | auto dtype = DataType(data->dtype); |
95 | Array<tvm::PrimExpr> shape; |
96 | for (int i = 0; i < data->ndim; i++) { |
97 | ICHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max()); |
98 | ICHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min()); |
99 | shape.push_back(tvm::IntImm(DataType::Int(32), data->shape[i])); |
100 | } |
101 | |
102 | return TensorType(shape, dtype); |
103 | } |
104 | |
105 | Constant WithFields(Constant constant, Optional<runtime::NDArray> opt_data, |
106 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
107 | runtime::NDArray data = opt_data.value_or(constant->data); |
108 | VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device()); |
109 | Span span = opt_span.value_or(constant->span); |
110 | |
111 | bool all_fields_unchanged = data.same_as(constant->data) && |
112 | virtual_device.same_as(constant->virtual_device()) && |
113 | span.same_as(constant->span); |
114 | |
115 | if (!all_fields_unchanged) { |
116 | ConstantNode* cow_constant_node = constant.CopyOnWrite(); |
117 | cow_constant_node->data = data; |
118 | cow_constant_node->virtual_device_ = virtual_device; |
119 | cow_constant_node->span = span; |
120 | } |
121 | return constant; |
122 | } |
123 | |
124 | Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) { |
125 | ObjectPtr<TupleNode> n = make_object<TupleNode>(); |
126 | n->fields = std::move(fields); |
127 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
128 | n->span = std::move(span); |
129 | data_ = std::move(n); |
130 | } |
131 | |
132 | TVM_REGISTER_NODE_TYPE(TupleNode); |
133 | |
134 | TVM_REGISTER_GLOBAL("relay.ir.Tuple" ).set_body_typed([](tvm::Array<relay::Expr> fields, Span span) { |
135 | return Tuple(fields, span); |
136 | }); |
137 | TVM_REGISTER_GLOBAL("relay.ir.TupleWithFields" ) |
138 | .set_body_typed([](Tuple tuple, Optional<Array<Expr>> opt_fields, |
139 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
140 | return WithFields(tuple, opt_fields, opt_virtual_device, opt_span); |
141 | }); |
142 | |
143 | Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields, |
144 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
145 | Array<Expr> fields = opt_fields.value_or(tuple->fields); |
146 | VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); |
147 | Span span = opt_span.value_or(tuple->span); |
148 | |
149 | bool all_fields_unchanged = true; |
150 | if (fields.size() == tuple->fields.size()) { |
151 | for (size_t i = 0; i < fields.size(); i++) { |
152 | all_fields_unchanged &= fields[i].same_as(tuple->fields[i]); |
153 | } |
154 | } else { |
155 | all_fields_unchanged = false; |
156 | } |
157 | |
158 | all_fields_unchanged = all_fields_unchanged && virtual_device.same_as(tuple->virtual_device()) && |
159 | span.same_as(tuple->span); |
160 | if (!all_fields_unchanged) { |
161 | TupleNode* cow_tuple_node = tuple.CopyOnWrite(); |
162 | cow_tuple_node->fields = fields; |
163 | cow_tuple_node->virtual_device_ = virtual_device; |
164 | cow_tuple_node->span = span; |
165 | } |
166 | return tuple; |
167 | } |
168 | |
169 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
170 | .set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) { |
171 | auto* node = static_cast<const TupleNode*>(ref.get()); |
172 | p->stream << "Tuple(" << node->fields << ")" ; |
173 | }); |
174 | |
175 | Var::Var(Id vid, Type type_annotation, Span span) { |
176 | ObjectPtr<VarNode> n = make_object<VarNode>(); |
177 | n->vid = std::move(vid); |
178 | n->type_annotation = std::move(type_annotation); |
179 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
180 | n->span = std::move(span); |
181 | data_ = std::move(n); |
182 | } |
183 | |
184 | /* static */ Var Var::GenSym(Type type_annotation, Span span) { |
185 | static size_t next_id = std::atomic<size_t>(0); |
186 | std::ostringstream os; |
187 | os << "x_" << next_id++; |
188 | return Var(os.str(), std::move(type_annotation), std::move(span)); |
189 | } |
190 | |
191 | Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation, |
192 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
193 | Id vid = opt_vid.value_or(var->vid); |
194 | Type type_annotation = opt_type_annotation.value_or(var->type_annotation); |
195 | VirtualDevice virtual_device = opt_virtual_device.value_or(var->virtual_device()); |
196 | Span span = opt_span.value_or(var->span); |
197 | |
198 | bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && |
199 | virtual_device.same_as(var->virtual_device()) && span.same_as(var->span); |
200 | |
201 | if (!unchanged) { |
202 | VarNode* cow_var_node = var.CopyOnWrite(); |
203 | cow_var_node->vid = vid; |
204 | cow_var_node->type_annotation = type_annotation; |
205 | cow_var_node->virtual_device_ = virtual_device; |
206 | cow_var_node->span = span; |
207 | } |
208 | return var; |
209 | } |
210 | |
211 | TVM_REGISTER_NODE_TYPE(VarNode); |
212 | |
213 | TVM_REGISTER_GLOBAL("relay.ir.Var" ).set_body_typed([](String str, Type type_annotation, Span span) { |
214 | return Var(str, type_annotation, span); |
215 | }); |
216 | TVM_REGISTER_GLOBAL("relay.ir.VarWithFields" ) |
217 | .set_body_typed([](Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation, |
218 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
219 | return WithFields(var, opt_vid, opt_type_annotation, opt_virtual_device, opt_span); |
220 | }); |
221 | |
222 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
223 | .set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) { |
224 | auto* node = static_cast<const VarNode*>(ref.get()); |
225 | p->stream << "Var(" << node->name_hint(); |
226 | if (node->type_annotation.defined()) { |
227 | p->stream << ", ty=" ; |
228 | p->Print(node->type_annotation); |
229 | } |
230 | p->stream << ")" ; |
231 | }); |
232 | |
233 | Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) { |
234 | ObjectPtr<CallNode> n = make_object<CallNode>(); |
235 | n->op = std::move(op); |
236 | n->args = std::move(args); |
237 | n->attrs = std::move(attrs); |
238 | n->type_args = std::move(type_args); |
239 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
240 | n->span = std::move(span); |
241 | data_ = std::move(n); |
242 | } |
243 | |
244 | Call WithFields(Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args, |
245 | Optional<Attrs> opt_attrs, Optional<Array<Type>> opt_type_args, |
246 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
247 | // Collect new values for fields. |
248 | Expr op = opt_op.value_or(call->op); |
249 | Array<Expr> args = opt_args.value_or(call->args); |
250 | Attrs attrs = opt_attrs.value_or(call->attrs); |
251 | Array<Type> type_args = opt_type_args.value_or(call->type_args); |
252 | VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device()); |
253 | Span span = opt_span.value_or(call->span); |
254 | |
255 | // Check if anything changed. |
256 | bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && |
257 | virtual_device.same_as(call->virtual_device()) && span.same_as(call->span); |
258 | if (unchanged) { |
259 | if (args.size() == call->args.size()) { |
260 | for (size_t i = 0; i < args.size(); i++) { |
261 | unchanged &= args[i].same_as(call->args[i]); |
262 | } |
263 | } else { |
264 | unchanged = false; |
265 | } |
266 | } |
267 | if (unchanged) { |
268 | if (type_args.size() == call->type_args.size()) { |
269 | for (size_t i = 0; i < type_args.size(); i++) { |
270 | unchanged &= type_args[i].same_as(call->type_args[i]); |
271 | } |
272 | } else { |
273 | unchanged = false; |
274 | } |
275 | } |
276 | |
277 | if (!unchanged) { |
278 | // If call is only references, update it in place. Otherwise copy and update. |
279 | CallNode* cow_call_node = call.CopyOnWrite(); |
280 | cow_call_node->op = op; |
281 | cow_call_node->args = args; |
282 | cow_call_node->attrs = attrs; |
283 | cow_call_node->type_args = type_args; |
284 | cow_call_node->virtual_device_ = virtual_device; |
285 | cow_call_node->span = span; |
286 | } |
287 | return call; |
288 | } |
289 | |
290 | TVM_REGISTER_NODE_TYPE(CallNode); |
291 | |
292 | TVM_REGISTER_GLOBAL("relay.ir.Call" ) |
293 | .set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) { |
294 | return Call(op, args, attrs, type_args, span); |
295 | }); |
296 | TVM_REGISTER_GLOBAL("relay.ir.CallWithFields" ) |
297 | .set_body_typed([](Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args, |
298 | Optional<Attrs> opt_attrs, Optional<Array<Type>> opt_type_args, |
299 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
300 | return WithFields(call, opt_op, opt_args, opt_attrs, opt_type_args, opt_virtual_device, |
301 | opt_span); |
302 | }); |
303 | |
304 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
305 | .set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) { |
306 | auto* node = static_cast<const CallNode*>(ref.get()); |
307 | p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " |
308 | << node->type_args << ")" ; |
309 | }); |
310 | |
311 | Let::Let(Var var, Expr value, Expr body, Span span) { |
312 | ObjectPtr<LetNode> n = make_object<LetNode>(); |
313 | n->var = std::move(var); |
314 | n->value = std::move(value); |
315 | n->body = std::move(body); |
316 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
317 | n->span = std::move(span); |
318 | data_ = std::move(n); |
319 | } |
320 | |
321 | Let WithFields(Let let, Optional<Var> opt_var, Optional<Expr> opt_value, Optional<Expr> opt_body, |
322 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
323 | Var var = opt_var.value_or(let->var); |
324 | Expr value = opt_value.value_or(let->value); |
325 | Expr body = opt_body.value_or(let->body); |
326 | VirtualDevice virtual_device = opt_virtual_device.value_or(let->virtual_device()); |
327 | Span span = opt_span.value_or(let->span); |
328 | |
329 | bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && |
330 | virtual_device.same_as(let->virtual_device()) && span.same_as(let->span); |
331 | |
332 | if (!unchanged) { |
333 | LetNode* cow_let_node = let.CopyOnWrite(); |
334 | cow_let_node->var = var; |
335 | cow_let_node->value = value; |
336 | cow_let_node->body = body; |
337 | cow_let_node->virtual_device_ = virtual_device; |
338 | cow_let_node->span = span; |
339 | } |
340 | return let; |
341 | } |
342 | |
343 | TVM_REGISTER_NODE_TYPE(LetNode); |
344 | |
345 | TVM_REGISTER_GLOBAL("relay.ir.Let" ).set_body_typed([](Var var, Expr value, Expr body, Span span) { |
346 | return Let(var, value, body, span); |
347 | }); |
348 | TVM_REGISTER_GLOBAL("relay.ir.LetWithFields" ) |
349 | .set_body_typed([](Let let, Optional<Var> opt_var, Optional<Expr> opt_value, |
350 | Optional<Expr> opt_body, Optional<VirtualDevice> opt_virtual_device, |
351 | Optional<Span> opt_span) { |
352 | return WithFields(let, opt_var, opt_value, opt_body, opt_virtual_device, opt_span); |
353 | }); |
354 | |
355 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
356 | .set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) { |
357 | auto* node = static_cast<const LetNode*>(ref.get()); |
358 | p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")" ; |
359 | }); |
360 | |
361 | If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { |
362 | ObjectPtr<IfNode> n = make_object<IfNode>(); |
363 | n->cond = std::move(cond); |
364 | n->true_branch = std::move(true_branch); |
365 | n->false_branch = std::move(false_branch); |
366 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
367 | n->span = std::move(span); |
368 | data_ = std::move(n); |
369 | } |
370 | |
371 | If WithFields(If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branch, |
372 | Optional<Expr> opt_false_branch, Optional<VirtualDevice> opt_virtual_device, |
373 | Optional<Span> opt_span) { |
374 | Expr cond = opt_cond.value_or(if_expr->cond); |
375 | Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); |
376 | Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); |
377 | VirtualDevice virtual_device = opt_virtual_device.value_or(if_expr->virtual_device()); |
378 | Span span = opt_span.value_or(if_expr->span); |
379 | |
380 | bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && |
381 | false_branch.same_as(if_expr->false_branch) && |
382 | virtual_device.same_as(if_expr->virtual_device()) && span.same_as(if_expr->span); |
383 | |
384 | if (!unchanged) { |
385 | IfNode* cow_if_node = if_expr.CopyOnWrite(); |
386 | cow_if_node->cond = cond; |
387 | cow_if_node->true_branch = true_branch; |
388 | cow_if_node->false_branch = false_branch; |
389 | cow_if_node->virtual_device_ = virtual_device; |
390 | cow_if_node->span = span; |
391 | } |
392 | return if_expr; |
393 | } |
394 | |
395 | TVM_REGISTER_NODE_TYPE(IfNode); |
396 | |
397 | TVM_REGISTER_GLOBAL("relay.ir.If" ) |
398 | .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { |
399 | return If(cond, true_branch, false_branch, span); |
400 | }); |
401 | TVM_REGISTER_GLOBAL("relay.ir.IfWithFields" ) |
402 | .set_body_typed([](If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branch, |
403 | Optional<Expr> opt_false_branch, Optional<VirtualDevice> opt_virtual_device, |
404 | Optional<Span> opt_span) { |
405 | return WithFields(if_expr, opt_cond, opt_true_branch, opt_false_branch, opt_virtual_device, |
406 | opt_span); |
407 | }); |
408 | |
409 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
410 | .set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) { |
411 | auto* node = static_cast<const IfNode*>(ref.get()); |
412 | p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " |
413 | << node->false_branch << ")" ; |
414 | }); |
415 | |
416 | TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { |
417 | ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>(); |
418 | n->tuple = std::move(tuple); |
419 | n->index = index; |
420 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
421 | n->span = std::move(span); |
422 | data_ = std::move(n); |
423 | } |
424 | |
425 | TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple, |
426 | Optional<Integer> opt_index, Optional<VirtualDevice> opt_virtual_device, |
427 | Optional<Span> opt_span) { |
428 | Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); |
429 | Integer index = opt_index.value_or(tuple_get_item->index); |
430 | VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); |
431 | Span span = opt_span.value_or(tuple_get_item->span); |
432 | |
433 | bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && |
434 | virtual_device.same_as(tuple_get_item->virtual_device()) && |
435 | span.same_as(tuple_get_item->span); |
436 | if (!unchanged) { |
437 | TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); |
438 | cow_tuple_get_item_node->tuple = tuple; |
439 | cow_tuple_get_item_node->index = index.IntValue(); |
440 | cow_tuple_get_item_node->span = span; |
441 | cow_tuple_get_item_node->virtual_device_ = virtual_device; |
442 | } |
443 | return tuple_get_item; |
444 | } |
445 | |
446 | TVM_REGISTER_NODE_TYPE(TupleGetItemNode); |
447 | |
448 | TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem" ).set_body_typed([](Expr tuple, int index, Span span) { |
449 | return TupleGetItem(tuple, index, span); |
450 | }); |
451 | TVM_REGISTER_GLOBAL("relay.ir.TupleGetItemWithFields" ) |
452 | .set_body_typed([](TupleGetItem tuple_get_item, Optional<Expr> opt_tuple, |
453 | Optional<Integer> opt_index, Optional<VirtualDevice> opt_virtual_device, |
454 | Optional<Span> opt_span) { |
455 | return WithFields(tuple_get_item, opt_tuple, opt_index, opt_virtual_device, opt_span); |
456 | }); |
457 | |
458 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
459 | .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) { |
460 | auto* node = static_cast<const TupleGetItemNode*>(ref.get()); |
461 | p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")" ; |
462 | }); |
463 | |
464 | RefCreate::RefCreate(Expr value, Span span) { |
465 | ObjectPtr<RefCreateNode> n = make_object<RefCreateNode>(); |
466 | n->value = std::move(value); |
467 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
468 | n->span = std::move(span); |
469 | data_ = std::move(n); |
470 | } |
471 | |
472 | RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value, |
473 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
474 | Expr value = opt_value.value_or(ref_create->value); |
475 | VirtualDevice virtual_device = opt_virtual_device.value_or(ref_create->virtual_device()); |
476 | Span span = opt_span.value_or(ref_create->span); |
477 | |
478 | bool unchanged = value.same_as(ref_create->value) && |
479 | virtual_device.same_as(ref_create->virtual_device()) && |
480 | span.same_as(ref_create->span); |
481 | if (!unchanged) { |
482 | RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite(); |
483 | cow_ref_create_node->value = value; |
484 | cow_ref_create_node->virtual_device_ = virtual_device; |
485 | cow_ref_create_node->span = span; |
486 | } |
487 | return ref_create; |
488 | } |
489 | |
490 | TVM_REGISTER_NODE_TYPE(RefCreateNode); |
491 | |
492 | TVM_REGISTER_GLOBAL("relay.ir.RefCreate" ).set_body_typed([](Expr value, Span span) { |
493 | return RefCreate(value, span); |
494 | }); |
495 | TVM_REGISTER_GLOBAL("relay.ir.RefCreateWithFields" ) |
496 | .set_body_typed([](RefCreate ref_create, Optional<Expr> opt_value, |
497 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
498 | return WithFields(ref_create, opt_value, opt_virtual_device, opt_span); |
499 | }); |
500 | |
501 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
502 | .set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) { |
503 | auto* node = static_cast<const RefCreateNode*>(ref.get()); |
504 | p->stream << "RefCreateNode(" << node->value << ")" ; |
505 | }); |
506 | |
507 | RefRead::RefRead(Expr ref, Span span) { |
508 | ObjectPtr<RefReadNode> n = make_object<RefReadNode>(); |
509 | n->ref = std::move(ref); |
510 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
511 | n->span = std::move(span); |
512 | data_ = std::move(n); |
513 | } |
514 | |
515 | RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref, |
516 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
517 | Expr ref = opt_ref.value_or(ref_read->ref); |
518 | VirtualDevice virtual_device = opt_virtual_device.value_or(ref_read->virtual_device()); |
519 | Span span = opt_span.value_or(ref_read->span); |
520 | |
521 | bool unchanged = ref.same_as(ref_read->ref) && |
522 | virtual_device.same_as(ref_read->virtual_device()) && |
523 | span.same_as(ref_read->span); |
524 | if (!unchanged) { |
525 | RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite(); |
526 | cow_ref_read_node->ref = ref; |
527 | cow_ref_read_node->virtual_device_ = virtual_device; |
528 | cow_ref_read_node->span = span; |
529 | } |
530 | return ref_read; |
531 | } |
532 | |
533 | TVM_REGISTER_NODE_TYPE(RefReadNode); |
534 | |
535 | TVM_REGISTER_GLOBAL("relay.ir.RefRead" ).set_body_typed([](Expr ref, Span span) { |
536 | return RefRead(ref, span); |
537 | }); |
538 | TVM_REGISTER_GLOBAL("relay.ir.RefReadWithFields" ) |
539 | .set_body_typed([](RefRead ref_read, Optional<Expr> opt_ref, |
540 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
541 | return WithFields(ref_read, opt_ref, opt_virtual_device, opt_span); |
542 | }); |
543 | |
544 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
545 | .set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) { |
546 | auto* node = static_cast<const RefReadNode*>(ref.get()); |
547 | p->stream << "RefReadNode(" << node->ref << ")" ; |
548 | }); |
549 | |
550 | RefWrite::RefWrite(Expr ref, Expr value, Span span) { |
551 | ObjectPtr<RefWriteNode> n = make_object<RefWriteNode>(); |
552 | n->ref = std::move(ref); |
553 | n->value = std::move(value); |
554 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
555 | n->span = std::move(span); |
556 | data_ = std::move(n); |
557 | } |
558 | |
559 | RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> opt_value, |
560 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
561 | Expr ref = opt_ref.value_or(ref_write->ref); |
562 | Expr value = opt_value.value_or(ref_write->value); |
563 | VirtualDevice virtual_device = opt_virtual_device.value_or(ref_write->virtual_device()); |
564 | Span span = opt_span.value_or(ref_write->span); |
565 | |
566 | bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && |
567 | virtual_device.same_as(ref_write->virtual_device()) && |
568 | span.same_as(ref_write->span); |
569 | if (!unchanged) { |
570 | RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite(); |
571 | cow_ref_write_node->ref = ref; |
572 | cow_ref_write_node->value = value; |
573 | cow_ref_write_node->virtual_device_ = virtual_device; |
574 | cow_ref_write_node->span = span; |
575 | } |
576 | return ref_write; |
577 | } |
578 | |
579 | TVM_REGISTER_NODE_TYPE(RefWriteNode); |
580 | |
581 | TVM_REGISTER_GLOBAL("relay.ir.RefWrite" ).set_body_typed([](Expr ref, Expr value, Span span) { |
582 | return RefWrite(ref, value, span); |
583 | }); |
584 | TVM_REGISTER_GLOBAL("relay.ir.RefWriteWithFields" ) |
585 | .set_body_typed([](RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> opt_value, |
586 | Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
587 | return WithFields(ref_write, opt_ref, opt_value, opt_virtual_device, opt_span); |
588 | }); |
589 | |
590 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
591 | .set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) { |
592 | auto* node = static_cast<const RefWriteNode*>(ref.get()); |
593 | p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")" ; |
594 | }); |
595 | |
596 | TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize" ).set_body_typed([](TempExpr temp) { |
597 | return temp->Realize(); |
598 | }); |
599 | |
600 | TVM_REGISTER_GLOBAL("relay.ir.Any" ).set_body_typed([]() { return Any(); }); |
601 | |
602 | /* |
603 | * Non-recursive traversal with dismantling unused call nodes, |
604 | * a derivative from ExpandDataflow method |
605 | */ |
606 | inline void Dismantle(const Expr& expr) { |
607 | std::stack<std::pair<Expr, bool>> stack; |
608 | auto fpush_to_stack = [&stack](const Expr& expr) { |
609 | // do not visit nodes with more than 2 refs (one can be in stack) |
610 | if (expr.use_count() < 3) { |
611 | stack.push({expr, false}); |
612 | } |
613 | }; |
614 | fpush_to_stack(expr); |
615 | while (stack.size() > 0) { |
616 | const auto& node = stack.top().first; |
617 | if (stack.top().second) { |
618 | // dismantle node |
619 | // +1 ref in stack/deque; |
620 | if (node.use_count() < 3) { |
621 | if (auto* op = const_cast<CallNode*>(node.as<CallNode>())) { |
622 | op->args = Array<Expr>(); |
623 | } |
624 | if (auto* op = const_cast<LetNode*>(node.as<LetNode>())) { |
625 | op->body = Expr(); |
626 | } |
627 | } |
628 | // eject |
629 | stack.pop(); |
630 | } else { |
631 | stack.top().second = true; |
632 | |
633 | // special handling |
634 | if (const auto* call_node = node.as<CallNode>()) { |
635 | // do not process args if used elsewhere |
636 | if (call_node->args.use_count() < 2) { |
637 | for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) { |
638 | fpush_to_stack(*it); |
639 | } |
640 | } |
641 | } else if (const auto* tuple_node = node.as<TupleNode>()) { |
642 | // do not process fields if used elsewhere |
643 | if (tuple_node->fields.use_count() < 2) { |
644 | for (auto it = tuple_node->fields.rbegin(); it != tuple_node->fields.rend(); ++it) { |
645 | fpush_to_stack(*it); |
646 | } |
647 | } |
648 | } else if (const auto* tuple_get_item_node = node.as<TupleGetItemNode>()) { |
649 | // do not process tuple if used elsewhere |
650 | if (tuple_get_item_node->tuple.use_count() < 2) { |
651 | fpush_to_stack(tuple_get_item_node->tuple); |
652 | } |
653 | } else if (const auto* let_node = node.as<LetNode>()) { |
654 | // do not process let if used elsewhere |
655 | if (let_node->body.use_count() < 2) { |
656 | fpush_to_stack(let_node->body); |
657 | } |
658 | } |
659 | } |
660 | } |
661 | } |
662 | |
663 | /* |
664 | * Non-recursive destructor |
665 | */ |
666 | Call::~Call() { |
667 | // attempt to dismantle if referenced one or zero times |
668 | if (this->use_count() < 2) { |
669 | if (this->as<CallNode>() && this->as<CallNode>()->args.size()) { |
670 | Dismantle(*this); |
671 | } |
672 | } |
673 | } |
674 | |
675 | /* |
676 | * CallNode's deleter |
677 | */ |
678 | void CallNode::Deleter_(Object* ptr) { |
679 | auto p = reinterpret_cast<CallNode*>(ptr); |
680 | // resore original deleter |
681 | p->deleter_ = p->saved_deleter_; |
682 | // create Call reference in order to invoke ~Call |
683 | auto c = GetRef<Call>(p); |
684 | } |
685 | |
686 | /* |
687 | * Non-recursive destructor |
688 | */ |
689 | Let::~Let() { |
690 | // attempt to dismantle if referenced one or zero times |
691 | if (this->use_count() < 2) { |
692 | if (this->as<LetNode>() && this->as<LetNode>()->body.defined()) { |
693 | Dismantle(*this); |
694 | } |
695 | } |
696 | } |
697 | |
698 | /* |
699 | * LetNode's deleter |
700 | */ |
701 | void LetNode::Deleter_(Object* ptr) { |
702 | auto p = reinterpret_cast<LetNode*>(ptr); |
703 | // resore original deleter |
704 | p->deleter_ = p->saved_deleter_; |
705 | // create Let reference in order to invoke ~Let |
706 | auto c = GetRef<Let>(p); |
707 | } |
708 | |
709 | } // namespace relay |
710 | } // namespace tvm |
711 | |