1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file expr.cc |
22 | */ |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/tir/expr.h> |
25 | #include <tvm/tir/op.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | |
28 | #include "../../support/str_escape.h" |
29 | #include "buffer_common.h" |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ |
35 | Name::Name(PrimExpr a, PrimExpr b, Span span) { \ |
36 | using T = Name::ContainerType; \ |
37 | ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ |
38 | ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ |
39 | CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ |
40 | << b.dtype() << "\n"; \ |
41 | ObjectPtr<T> node = make_object<T>(); \ |
42 | node->dtype = a.dtype(); \ |
43 | node->a = std::move(a); \ |
44 | node->b = std::move(b); \ |
45 | node->span = std::move(span); \ |
46 | data_ = std::move(node); \ |
47 | } |
48 | |
49 | #define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ |
50 | Name::Name(PrimExpr a, PrimExpr b, Span span) { \ |
51 | using T = Name::ContainerType; \ |
52 | ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ |
53 | ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ |
54 | CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ |
55 | << b.dtype() << "\n"; \ |
56 | ObjectPtr<T> node = make_object<T>(); \ |
57 | node->dtype = DataType::Bool(a.dtype().lanes()); \ |
58 | node->a = std::move(a); \ |
59 | node->b = std::move(b); \ |
60 | node->span = std::move(span); \ |
61 | data_ = std::move(node); \ |
62 | } |
63 | |
64 | // Var |
65 | Var::Var(String name_hint, DataType dtype, Span span) { |
66 | auto n = make_object<VarNode>(); |
67 | n->name_hint = std::move(name_hint); |
68 | n->type_annotation = GetTypeFromRuntimeDataType(dtype); |
69 | n->dtype = std::move(dtype); |
70 | n->span = std::move(span); |
71 | data_ = std::move(n); |
72 | } |
73 | |
74 | Var::Var(String name_hint, Type type_annotation, Span span) { |
75 | auto n = make_object<VarNode>(); |
76 | n->name_hint = std::move(name_hint); |
77 | n->dtype = GetRuntimeDataType(type_annotation); |
78 | n->type_annotation = std::move(type_annotation); |
79 | n->span = std::move(span); |
80 | data_ = std::move(n); |
81 | } |
82 | |
83 | Var Var::copy_with_suffix(const String& suffix) const { |
84 | const VarNode* node = get(); |
85 | ObjectPtr<VarNode> new_ptr; |
86 | if (auto* ptr = this->as<SizeVarNode>()) { |
87 | new_ptr = make_object<SizeVarNode>(*ptr); |
88 | } else { |
89 | new_ptr = make_object<VarNode>(*node); |
90 | } |
91 | new_ptr->name_hint = new_ptr->name_hint + suffix; |
92 | return Var(new_ptr); |
93 | } |
94 | |
95 | Var Var::copy_with_dtype(DataType dtype) const { |
96 | const VarNode* node = get(); |
97 | ObjectPtr<VarNode> new_ptr; |
98 | if (auto* ptr = this->as<SizeVarNode>()) { |
99 | new_ptr = make_object<SizeVarNode>(*ptr); |
100 | } else { |
101 | new_ptr = make_object<VarNode>(*node); |
102 | } |
103 | new_ptr->type_annotation = GetTypeFromRuntimeDataType(dtype); |
104 | new_ptr->dtype = std::move(dtype); |
105 | return Var(new_ptr); |
106 | } |
107 | |
108 | TVM_REGISTER_GLOBAL("tir.Var" ).set_body_typed([](String name_hint, runtime::TVMArgValue type, |
109 | Span span) { |
110 | if (type.IsObjectRef<Type>()) { |
111 | return Var(name_hint, type.operator Type(), span); |
112 | } else { |
113 | return Var(name_hint, type.operator DataType(), span); |
114 | } |
115 | }); |
116 | |
117 | TVM_REGISTER_NODE_TYPE(VarNode); |
118 | |
119 | // SizeVar |
120 | SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { |
121 | auto n = make_object<SizeVarNode>(); |
122 | n->name_hint = std::move(name_hint); |
123 | n->type_annotation = GetTypeFromRuntimeDataType(dtype); |
124 | n->dtype = std::move(dtype); |
125 | n->span = std::move(span); |
126 | data_ = std::move(n); |
127 | } |
128 | |
129 | TVM_REGISTER_GLOBAL("tir.SizeVar" ).set_body_typed([](String s, DataType t, Span span) { |
130 | return SizeVar(s, t, span); |
131 | }); |
132 | |
133 | TVM_REGISTER_NODE_TYPE(SizeVarNode); |
134 | |
135 | // IterVar |
136 | IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span) { |
137 | ObjectPtr<IterVarNode> n = make_object<IterVarNode>(); |
138 | if (dom.defined() && dom->extent.defined()) { |
139 | CHECK(dom->extent.dtype().is_int()) |
140 | << "The dtype of the domain of an IterVar must be an integer type. However, the domain's " |
141 | "dtype is " |
142 | << dom->extent.dtype(); |
143 | CHECK_EQ(dom->extent.dtype(), var.dtype()) |
144 | << "The dtype of the extent of an IterVar (" << dom->extent.dtype() |
145 | << ") must match its associated Var's dtype (" << var.dtype() << ")" ; |
146 | } |
147 | n->dom = dom; |
148 | n->var = var; |
149 | n->iter_type = t; |
150 | n->thread_tag = thread_tag; |
151 | n->span = std::move(span); |
152 | data_ = std::move(n); |
153 | } |
154 | |
155 | TVM_REGISTER_GLOBAL("tir.IterVar" ) |
156 | .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag, Span span) { |
157 | return IterVar(dom, var, static_cast<IterVarType>(iter_type), thread_tag, span); |
158 | }); |
159 | |
160 | TVM_REGISTER_NODE_TYPE(IterVarNode); |
161 | |
162 | // StringImm |
163 | StringImm::StringImm(String value, Span span) { |
164 | ObjectPtr<StringImmNode> node = make_object<StringImmNode>(); |
165 | node->dtype = DataType::Handle(); |
166 | node->value = std::move(value); |
167 | node->span = std::move(span); |
168 | data_ = std::move(node); |
169 | } |
170 | |
171 | TVM_REGISTER_GLOBAL("tir.StringImm" ).set_body_typed([](String value, Span span) { |
172 | return StringImm(value, span); |
173 | }); |
174 | |
175 | TVM_REGISTER_NODE_TYPE(StringImmNode); |
176 | |
177 | // Cast |
178 | Cast::Cast(DataType t, PrimExpr value, Span span) { |
179 | ICHECK(value.defined()); |
180 | ICHECK_EQ(t.lanes(), value.dtype().lanes()); |
181 | ObjectPtr<CastNode> node = make_object<CastNode>(); |
182 | node->dtype = t; |
183 | node->value = std::move(value); |
184 | node->span = std::move(span); |
185 | data_ = std::move(node); |
186 | } |
187 | |
188 | TVM_REGISTER_GLOBAL("tir.Cast" ).set_body_typed([](DataType dtype, PrimExpr value, Span span) { |
189 | return Cast(dtype, value, span); |
190 | }); |
191 | |
192 | TVM_REGISTER_NODE_TYPE(CastNode); |
193 | |
194 | // Add |
195 | TVM_DEFINE_BINOP_CONSTRUCTOR(Add); |
196 | |
197 | TVM_REGISTER_GLOBAL("tir.Add" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
198 | return Add(a, b, span); |
199 | }); |
200 | |
201 | TVM_REGISTER_NODE_TYPE(AddNode); |
202 | |
203 | // Sub |
204 | TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); |
205 | |
206 | TVM_REGISTER_GLOBAL("tir.Sub" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
207 | return Sub(a, b, span); |
208 | }); |
209 | |
210 | TVM_REGISTER_NODE_TYPE(SubNode); |
211 | |
212 | // Mul |
213 | TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); |
214 | |
215 | TVM_REGISTER_GLOBAL("tir.Mul" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
216 | return Mul(a, b, span); |
217 | }); |
218 | |
219 | TVM_REGISTER_NODE_TYPE(MulNode); |
220 | |
221 | // Div |
222 | TVM_DEFINE_BINOP_CONSTRUCTOR(Div); |
223 | |
224 | TVM_REGISTER_GLOBAL("tir.Div" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
225 | return Div(a, b, span); |
226 | }); |
227 | |
228 | TVM_REGISTER_NODE_TYPE(DivNode); |
229 | |
230 | // Mod |
231 | TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); |
232 | |
233 | TVM_REGISTER_GLOBAL("tir.Mod" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
234 | return Mod(a, b, span); |
235 | }); |
236 | |
237 | TVM_REGISTER_NODE_TYPE(ModNode); |
238 | |
239 | // FloorDiv |
240 | TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); |
241 | |
242 | TVM_REGISTER_GLOBAL("tir.FloorDiv" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
243 | return FloorDiv(a, b, span); |
244 | }); |
245 | |
246 | TVM_REGISTER_NODE_TYPE(FloorDivNode); |
247 | |
248 | // FloorMod |
249 | TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); |
250 | |
251 | TVM_REGISTER_GLOBAL("tir.FloorMod" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
252 | return FloorMod(a, b, span); |
253 | }); |
254 | |
255 | TVM_REGISTER_NODE_TYPE(FloorModNode); |
256 | |
257 | // Min |
258 | TVM_DEFINE_BINOP_CONSTRUCTOR(Min); |
259 | |
260 | TVM_REGISTER_GLOBAL("tir.Min" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
261 | return Min(a, b, span); |
262 | }); |
263 | |
264 | TVM_REGISTER_NODE_TYPE(MinNode); |
265 | |
266 | // Max |
267 | TVM_DEFINE_BINOP_CONSTRUCTOR(Max); |
268 | |
269 | TVM_REGISTER_GLOBAL("tir.Max" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
270 | return Max(a, b, span); |
271 | }); |
272 | |
273 | TVM_REGISTER_NODE_TYPE(MaxNode); |
274 | |
275 | // EQ |
276 | TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); |
277 | |
278 | TVM_REGISTER_GLOBAL("tir.EQ" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
279 | return EQ(a, b, span); |
280 | }); |
281 | |
282 | TVM_REGISTER_NODE_TYPE(EQNode); |
283 | |
284 | // NE |
285 | TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); |
286 | |
287 | TVM_REGISTER_GLOBAL("tir.NE" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
288 | return NE(a, b, span); |
289 | }); |
290 | |
291 | TVM_REGISTER_NODE_TYPE(NENode); |
292 | |
293 | // LT |
294 | TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); |
295 | |
296 | TVM_REGISTER_GLOBAL("tir.LT" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
297 | return LT(a, b, span); |
298 | }); |
299 | |
300 | TVM_REGISTER_NODE_TYPE(LTNode); |
301 | |
302 | // LE |
303 | TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); |
304 | |
305 | TVM_REGISTER_GLOBAL("tir.LE" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
306 | return LE(a, b, span); |
307 | }); |
308 | |
309 | TVM_REGISTER_NODE_TYPE(LENode); |
310 | |
311 | // GT |
312 | TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); |
313 | |
314 | TVM_REGISTER_GLOBAL("tir.GT" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
315 | return GT(a, b, span); |
316 | }); |
317 | |
318 | TVM_REGISTER_NODE_TYPE(GTNode); |
319 | |
320 | // GE |
321 | TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); |
322 | |
323 | TVM_REGISTER_GLOBAL("tir.GE" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
324 | return GE(a, b, span); |
325 | }); |
326 | |
327 | TVM_REGISTER_NODE_TYPE(GENode); |
328 | |
329 | // And |
330 | And::And(PrimExpr a, PrimExpr b, Span span) { |
331 | ICHECK(a.defined()) << "ValueError: a is undefined" ; |
332 | ICHECK(b.defined()) << "ValueError: b is undefined" ; |
333 | ICHECK(a.dtype().is_bool()); |
334 | ICHECK(b.dtype().is_bool()); |
335 | ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types" ; |
336 | |
337 | ObjectPtr<AndNode> node = make_object<AndNode>(); |
338 | node->dtype = DataType::Bool(a.dtype().lanes()); |
339 | node->a = std::move(a); |
340 | node->b = std::move(b); |
341 | node->span = std::move(span); |
342 | data_ = std::move(node); |
343 | } |
344 | |
345 | TVM_REGISTER_GLOBAL("tir.And" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
346 | return And(a, b, span); |
347 | }); |
348 | |
349 | TVM_REGISTER_NODE_TYPE(AndNode); |
350 | |
351 | // Or |
352 | Or::Or(PrimExpr a, PrimExpr b, Span span) { |
353 | ICHECK(a.defined()) << "ValueError: a is undefined" ; |
354 | ICHECK(b.defined()) << "ValueError: b is undefined" ; |
355 | ICHECK(a.dtype().is_bool()); |
356 | ICHECK(b.dtype().is_bool()); |
357 | ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types" ; |
358 | |
359 | ObjectPtr<OrNode> node = make_object<OrNode>(); |
360 | node->dtype = DataType::Bool(a.dtype().lanes()); |
361 | node->a = std::move(a); |
362 | node->b = std::move(b); |
363 | node->span = std::move(span); |
364 | data_ = std::move(node); |
365 | } |
366 | |
367 | TVM_REGISTER_GLOBAL("tir.Or" ).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { |
368 | return Or(a, b, span); |
369 | }); |
370 | |
371 | TVM_REGISTER_NODE_TYPE(OrNode); |
372 | |
373 | // Not |
374 | Not::Not(PrimExpr a, Span span) { |
375 | ICHECK(a.defined()) << "ValueError: a is undefined" ; |
376 | ICHECK(a.dtype().is_bool()); |
377 | |
378 | ObjectPtr<NotNode> node = make_object<NotNode>(); |
379 | node->dtype = DataType::Bool(a.dtype().lanes()); |
380 | node->a = std::move(a); |
381 | node->span = std::move(span); |
382 | data_ = std::move(node); |
383 | } |
384 | |
385 | TVM_REGISTER_GLOBAL("tir.Not" ).set_body_typed([](PrimExpr a, Span span) { return Not(a, span); }); |
386 | |
387 | TVM_REGISTER_NODE_TYPE(NotNode); |
388 | |
389 | // Select |
390 | Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { |
391 | ICHECK(condition.defined()) << "ValueError: condition is undefined" ; |
392 | ICHECK(true_value.defined()) << "ValueError: true_value is undefined" ; |
393 | ICHECK(false_value.defined()) << "ValueError: true_value is undefined" ; |
394 | ICHECK(condition.dtype().is_bool()); |
395 | ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); |
396 | ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types" ; |
397 | |
398 | ObjectPtr<SelectNode> node = make_object<SelectNode>(); |
399 | node->dtype = true_value.dtype(); |
400 | node->condition = std::move(condition); |
401 | node->true_value = std::move(true_value); |
402 | node->false_value = std::move(false_value); |
403 | node->span = std::move(span); |
404 | data_ = std::move(node); |
405 | } |
406 | |
407 | TVM_REGISTER_GLOBAL("tir.Select" ) |
408 | .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { |
409 | return Select(condition, true_value, false_value, span); |
410 | }); |
411 | |
412 | TVM_REGISTER_NODE_TYPE(SelectNode); |
413 | |
414 | // Load |
415 | Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, Span span) { |
416 | LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint |
417 | << ". Use BufferStore instead." ; |
418 | ICHECK(buffer_var.defined()); |
419 | ICHECK(predicate.defined()); |
420 | ICHECK(index.defined()); |
421 | |
422 | // Assume that the array elements have 1 lane, unless a type |
423 | // annotation tells us otherwise. |
424 | int element_lanes = 1; |
425 | auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); |
426 | if (pointer_type.has_value()) { |
427 | // Cannot check element type of array, as it may be different than |
428 | // the loaded type in some cases. |
429 | // |
430 | // 1. Booleans use DataType::Int(8) while stored, and the codegens |
431 | // handle cast to boolean. |
432 | // |
433 | // 2. The StorageRewrite pass can merge multiple allocations at |
434 | // the same scope, regardless of element type. The codegen is |
435 | // then responsible for casting to the output type. |
436 | |
437 | // TODO(Lunderberg): Uncomment this check once it can be applied. |
438 | // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 |
439 | // for discussion. |
440 | |
441 | // ICHECK(dtype.element_of() == pointer_type->element_of()) |
442 | // << "Type mismatch, cannot load type " << dtype << " from buffer " << |
443 | // buffer_var->name_hint |
444 | // << " of type " << pointer_type.value(); |
445 | element_lanes = pointer_type->lanes(); |
446 | } |
447 | |
448 | // The C-based codegens assume that all loads occur on a array with |
449 | // non-vectorized elements, and cast between |
450 | // vectorized/non-vectorized arrays as needed. Ideally, these |
451 | // should be changed to explicit casts in the TIR graph, rather than |
452 | // being handled at the code-gen level. |
453 | ICHECK((dtype.lanes() == element_lanes * index.dtype().lanes()) || |
454 | (dtype.lanes() == index.dtype().lanes())); |
455 | ICHECK((dtype.lanes() == element_lanes * predicate.dtype().lanes()) || |
456 | (dtype.lanes() == index.dtype().lanes())); |
457 | |
458 | ObjectPtr<LoadNode> node = make_object<LoadNode>(); |
459 | node->dtype = dtype; |
460 | node->buffer_var = std::move(buffer_var); |
461 | node->index = std::move(index); |
462 | node->predicate = std::move(predicate); |
463 | node->span = std::move(span); |
464 | |
465 | data_ = std::move(node); |
466 | } |
467 | |
468 | TVM_REGISTER_GLOBAL("tir.Load" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
469 | DataType t = args[0]; |
470 | if (args.size() == 3) { |
471 | *ret = Load(t, args[1], args[2], const_true(t.lanes()), Span()); |
472 | } else if (args.size() == 4) { |
473 | *ret = Load(t, args[1], args[2], args[3], Span()); |
474 | } else { |
475 | *ret = Load(t, args[1], args[2], args[3], args[4]); |
476 | } |
477 | }); |
478 | |
479 | TVM_REGISTER_NODE_TYPE(LoadNode); |
480 | |
481 | // Ramp |
482 | Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { |
483 | ICHECK(base.defined()); |
484 | ICHECK(stride.defined()); |
485 | ICHECK(base.dtype().is_scalar()); |
486 | ICHECK(stride.dtype().is_scalar()); |
487 | ICHECK_GT(lanes, 1); |
488 | ICHECK_EQ(stride.dtype(), base.dtype()); |
489 | |
490 | ObjectPtr<RampNode> node = make_object<RampNode>(); |
491 | node->dtype = base.dtype().with_lanes(lanes); |
492 | node->base = base; |
493 | node->stride = stride; |
494 | node->lanes = lanes; |
495 | node->span = std::move(span); |
496 | data_ = std::move(node); |
497 | } |
498 | |
499 | TVM_REGISTER_GLOBAL("tir.Ramp" ) |
500 | .set_body_typed([](PrimExpr base, PrimExpr stride, int lanes, Span span) { |
501 | return Ramp(base, stride, lanes, span); |
502 | }); |
503 | |
504 | TVM_REGISTER_NODE_TYPE(RampNode); |
505 | |
506 | // Broadcast |
507 | Broadcast::Broadcast(PrimExpr value, int lanes, Span span) { |
508 | ICHECK(value.defined()); |
509 | ICHECK(value.dtype().is_scalar()); |
510 | ICHECK_GT(lanes, 1); |
511 | |
512 | ObjectPtr<BroadcastNode> node = make_object<BroadcastNode>(); |
513 | node->dtype = value.dtype().with_lanes(lanes); |
514 | node->value = std::move(value); |
515 | node->lanes = lanes; |
516 | node->span = std::move(span); |
517 | data_ = node; |
518 | } |
519 | |
520 | TVM_REGISTER_GLOBAL("tir.Broadcast" ).set_body_typed([](PrimExpr value, int lanes, Span span) { |
521 | return Broadcast(value, lanes, span); |
522 | }); |
523 | |
524 | TVM_REGISTER_NODE_TYPE(BroadcastNode); |
525 | |
526 | // Let |
527 | Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { |
528 | ICHECK(value.defined()); |
529 | ICHECK(body.defined()); |
530 | ICHECK_EQ(value.dtype(), var.dtype()); |
531 | |
532 | ObjectPtr<LetNode> node = make_object<LetNode>(); |
533 | node->dtype = body.dtype(); |
534 | node->var = std::move(var); |
535 | node->value = std::move(value); |
536 | node->body = std::move(body); |
537 | node->span = std::move(span); |
538 | data_ = std::move(node); |
539 | } |
540 | |
541 | TVM_REGISTER_GLOBAL("tir.Let" ).set_body_typed([](Var var, PrimExpr value, PrimExpr body, |
542 | Span span) { |
543 | return Let(var, value, body, span); |
544 | }); |
545 | |
546 | TVM_REGISTER_NODE_TYPE(LetNode); |
547 | |
548 | // Call |
549 | Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span) { |
550 | for (size_t i = 0; i < args.size(); ++i) { |
551 | ICHECK(args[i].defined()) << "arg " << i << " is not defined()" ; |
552 | } |
553 | |
554 | ObjectPtr<CallNode> node = make_object<CallNode>(); |
555 | node->dtype = dtype; |
556 | node->op = std::move(op); |
557 | node->args = std::move(args); |
558 | node->span = std::move(span); |
559 | data_ = std::move(node); |
560 | } |
561 | |
562 | TVM_REGISTER_GLOBAL("tir.Call" ) |
563 | .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) { |
564 | Array<PrimExpr> prim_expr_args; |
565 | for (const auto& it : args) { |
566 | ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() || |
567 | it->IsInstance<IterVarNode>() || it->IsInstance<BufferRegionNode>()) |
568 | << "Argument " << it << " is not a string or primexpr" ; |
569 | if (const auto* str = it.as<runtime::StringObj>()) { |
570 | prim_expr_args.push_back(StringImm(str->data)); |
571 | } else if (const auto* iter_var = it.as<IterVarNode>()) { |
572 | prim_expr_args.push_back(GetRef<IterVar>(iter_var)->var); |
573 | } else if (const auto* br = it.as<BufferRegionNode>()) { |
574 | Array<PrimExpr> indices; |
575 | for (Range r : br->region) { |
576 | if (is_one(r->extent)) { |
577 | indices.push_back(r->min); |
578 | } else if (const auto* extent = r->extent.as<IntImmNode>()) { |
579 | indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), extent->value)); |
580 | } else { |
581 | LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " |
582 | << GetRef<BufferRegion>(br); |
583 | } |
584 | } |
585 | prim_expr_args.push_back(BufferLoad(br->buffer, indices)); |
586 | } else { |
587 | prim_expr_args.push_back(Downcast<PrimExpr>(it)); |
588 | } |
589 | } |
590 | return Call(type, op, prim_expr_args, span); |
591 | }); |
592 | |
593 | TVM_REGISTER_NODE_TYPE(CallNode); |
594 | |
595 | // Shuffle |
596 | Shuffle::Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span) { |
597 | ICHECK_NE(vectors.size(), 0U); |
598 | ICHECK_NE(indices.size(), 0U); |
599 | |
600 | DataType base_type = vectors[0].dtype().element_of(); |
601 | int total_lanes = 0; |
602 | |
603 | for (PrimExpr val : vectors) { |
604 | ICHECK(val.dtype().element_of() == base_type); |
605 | total_lanes += val.dtype().lanes(); |
606 | } |
607 | ICHECK_LE(indices.size(), static_cast<size_t>(total_lanes)); |
608 | |
609 | ObjectPtr<ShuffleNode> node = make_object<ShuffleNode>(); |
610 | node->dtype = base_type.with_lanes(static_cast<int>(indices.size())); |
611 | node->vectors = std::move(vectors); |
612 | node->indices = std::move(indices); |
613 | node->span = std::move(span); |
614 | data_ = node; |
615 | } |
616 | |
617 | PrimExpr Shuffle::Concat(Array<PrimExpr> vectors, Span span) { |
618 | ICHECK_NE(vectors.size(), 0); |
619 | if (vectors.size() == 1) { |
620 | return vectors[0]; |
621 | } |
622 | Array<PrimExpr> indices; |
623 | int index = 0; |
624 | for (const PrimExpr& e : vectors) { |
625 | for (int i = 0; i < e.dtype().lanes(); ++i) { |
626 | indices.push_back(IntImm(DataType::Int(32), index++)); |
627 | } |
628 | } |
629 | return Shuffle(vectors, indices, span); |
630 | } |
631 | |
632 | PrimExpr Shuffle::(PrimExpr vector, int index, Span span) { |
633 | return Shuffle({vector}, {Integer(index)}, span); |
634 | } |
635 | |
636 | TVM_REGISTER_GLOBAL("tir.Shuffle" ) |
637 | .set_body_typed([](Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span) { |
638 | return Shuffle(vectors, indices, span); |
639 | }); |
640 | |
641 | TVM_REGISTER_NODE_TYPE(ShuffleNode); |
642 | |
643 | // CommReducer |
644 | CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result, |
645 | Array<PrimExpr> identity_element, Span span) { |
646 | size_t n_group = result.size(); |
647 | CHECK_EQ(lhs.size(), n_group) << "ValueError: The number of vars in `lhs` must equal to the " |
648 | "number of elements in `results`" ; |
649 | CHECK_EQ(rhs.size(), n_group) << "ValueError: The number of vars in `rhs` must equal to the " |
650 | "number of elements in `results`" ; |
651 | CHECK_EQ(identity_element.size(), n_group) |
652 | << "ValueError: The number of identities must equal to the number of elements in `results`" ; |
653 | |
654 | // Change the dtype of input vars to adapt to the dtype of identities |
655 | ArrayNode* p_lhs = lhs.CopyOnWrite(); |
656 | ArrayNode* p_rhs = rhs.CopyOnWrite(); |
657 | std::unordered_map<const VarNode*, PrimExpr> var_map; |
658 | var_map.reserve(n_group * 2); |
659 | for (int i = 0; i < static_cast<int>(n_group); ++i) { |
660 | DataType dtype = identity_element[i].dtype(); |
661 | Var l = lhs[i].copy_with_dtype(dtype); |
662 | Var r = rhs[i].copy_with_dtype(dtype); |
663 | var_map[lhs[i].get()] = l; |
664 | var_map[rhs[i].get()] = r; |
665 | |
666 | p_lhs->SetItem(i, l); |
667 | p_rhs->SetItem(i, r); |
668 | } |
669 | |
670 | ArrayNode* p_result = result.CopyOnWrite(); |
671 | for (int i = 0; i < static_cast<int>(n_group); ++i) { |
672 | p_result->SetItem(i, Substitute(result[i], var_map)); |
673 | } |
674 | |
675 | auto node = make_object<CommReducerNode>(); |
676 | node->lhs = lhs; |
677 | node->rhs = rhs; |
678 | node->result = result; |
679 | node->identity_element = identity_element; |
680 | node->span = std::move(span); |
681 | data_ = std::move(node); |
682 | } |
683 | |
684 | Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b) const { |
685 | ICHECK_EQ(a.size(), b.size()); |
686 | ICHECK_EQ(lhs.size(), a.size()); |
687 | ICHECK_EQ(rhs.size(), b.size()); |
688 | Map<Var, PrimExpr> value_map; |
689 | for (size_t i = 0; i < a.size(); ++i) { |
690 | value_map.Set(lhs[i], a[i]); |
691 | value_map.Set(rhs[i], b[i]); |
692 | } |
693 | auto ret = this->result.Map([&value_map](const PrimExpr& e) { return Substitute(e, value_map); }); |
694 | return ret; |
695 | } |
696 | |
697 | TVM_REGISTER_GLOBAL("tir.CommReducer" ) |
698 | .set_body_typed([](Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result, |
699 | Array<PrimExpr> identity_element, Span span) { |
700 | return CommReducer(lhs, rhs, result, identity_element, span); |
701 | }); |
702 | |
703 | TVM_REGISTER_GLOBAL("tir.CommReducerCombine" ) |
704 | .set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator()); |
705 | |
706 | TVM_REGISTER_NODE_TYPE(CommReducerNode); |
707 | |
708 | // Reduce |
709 | Reduce::Reduce(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis, |
710 | PrimExpr condition, int value_index, Array<PrimExpr> init, Span span) { |
711 | for (size_t i = 0; i < axis.size(); ++i) { |
712 | ICHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis" ; |
713 | } |
714 | if (!condition.defined()) { |
715 | condition = const_true(); |
716 | } |
717 | auto n = make_object<ReduceNode>(); |
718 | ICHECK(source.defined()); |
719 | for (size_t i = 0; i < axis.size(); ++i) { |
720 | ICHECK(axis[i].defined()); |
721 | } |
722 | if (!init.empty()) { |
723 | ICHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs" ; |
724 | for (size_t i = 0; i < init.size(); i++) { |
725 | ICHECK(init[i]->IsInstance<ProducerLoadNode>() || init[i]->IsInstance<IntImmNode>() || |
726 | init[i]->IsInstance<FloatImmNode>()) |
727 | << "init can only be a IntImm, FloatImm or ProducerLoad" ; |
728 | } |
729 | } |
730 | n->dtype = source[value_index].dtype(); |
731 | n->combiner = std::move(combiner); |
732 | n->source = std::move(source); |
733 | n->init = std::move(init); |
734 | n->axis = std::move(axis); |
735 | n->condition = condition; |
736 | n->value_index = value_index; |
737 | n->span = std::move(span); |
738 | data_ = std::move(n); |
739 | } |
740 | |
741 | TVM_REGISTER_GLOBAL("tir.Reduce" ) |
742 | .set_body_typed([](CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis, |
743 | PrimExpr condition, int value_index, Array<PrimExpr> init, Span span) { |
744 | return Reduce(combiner, source, axis, condition, value_index, init, span); |
745 | }); |
746 | |
747 | TVM_REGISTER_NODE_TYPE(ReduceNode); |
748 | |
749 | // Any |
750 | Any::Any(Span span) { |
751 | auto n = make_object<AnyNode>(); |
752 | n->dtype = DataType::Int(32); |
753 | n->span = std::move(span); |
754 | data_ = std::move(n); |
755 | } |
756 | |
757 | TVM_REGISTER_GLOBAL("tir.Any" ).set_body_typed([](Span span) { return Any(span); }); |
758 | |
759 | TVM_REGISTER_NODE_TYPE(AnyNode); |
760 | |
761 | // BufferLoad |
762 | void BufferLoadNode::LegalizeDType() { |
763 | for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) { |
764 | ICHECK(indices[i].dtype().is_scalar()) |
765 | << "Only the last index of a buffer access may be a vector type." ; |
766 | } |
767 | |
768 | int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; |
769 | int buffer_lanes = buffer->dtype.lanes(); |
770 | |
771 | this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes); |
772 | } |
773 | |
774 | BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span) { |
775 | ICHECK_EQ(buffer->shape.size(), indices.size()) |
776 | << "Buffer " << buffer->name << " is " << buffer->shape.size() |
777 | << "-dimensional, cannot be indexed with the " << indices.size() |
778 | << "-dimensional indices provided." ; |
779 | |
780 | ObjectPtr<BufferLoadNode> node = make_object<BufferLoadNode>(); |
781 | node->buffer = std::move(buffer); |
782 | node->indices = std::move(indices); |
783 | node->span = std::move(span); |
784 | node->LegalizeDType(); |
785 | data_ = std::move(node); |
786 | } |
787 | |
788 | TVM_REGISTER_GLOBAL("tir.BufferLoad" ) |
789 | .set_body_typed([](Buffer buffer, Array<PrimExpr> indices, Span span) { |
790 | return BufferLoad(buffer, indices, span); |
791 | }); |
792 | |
793 | TVM_REGISTER_NODE_TYPE(BufferLoadNode); |
794 | |
795 | // ProducerLoad |
796 | ProducerLoad::ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span) { |
797 | ObjectPtr<ProducerLoadNode> node = make_object<ProducerLoadNode>(); |
798 | node->dtype = producer->GetDataType(); |
799 | node->producer = std::move(producer); |
800 | node->indices = std::move(indices); |
801 | node->span = std::move(span); |
802 | data_ = std::move(node); |
803 | } |
804 | |
805 | TVM_REGISTER_GLOBAL("tir.ProducerLoad" ) |
806 | .set_body_typed([](DataProducer producer, Array<PrimExpr> indices, Span span) { |
807 | return ProducerLoad(producer, indices, span); |
808 | }); |
809 | |
810 | TVM_REGISTER_NODE_TYPE(ProducerLoadNode); |
811 | |
812 | } // namespace tir |
813 | } // namespace tvm |
814 | |