1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file expr.cc
22 */
23#include <tvm/runtime/registry.h>
24#include <tvm/tir/expr.h>
25#include <tvm/tir/op.h>
26#include <tvm/tir/stmt_functor.h>
27
28#include "../../support/str_escape.h"
29#include "buffer_common.h"
30
31namespace tvm {
32namespace tir {
33
34#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
35 Name::Name(PrimExpr a, PrimExpr b, Span span) { \
36 using T = Name::ContainerType; \
37 ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
38 ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
39 CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
40 << b.dtype() << "\n"; \
41 ObjectPtr<T> node = make_object<T>(); \
42 node->dtype = a.dtype(); \
43 node->a = std::move(a); \
44 node->b = std::move(b); \
45 node->span = std::move(span); \
46 data_ = std::move(node); \
47 }
48
49#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \
50 Name::Name(PrimExpr a, PrimExpr b, Span span) { \
51 using T = Name::ContainerType; \
52 ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
53 ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
54 CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
55 << b.dtype() << "\n"; \
56 ObjectPtr<T> node = make_object<T>(); \
57 node->dtype = DataType::Bool(a.dtype().lanes()); \
58 node->a = std::move(a); \
59 node->b = std::move(b); \
60 node->span = std::move(span); \
61 data_ = std::move(node); \
62 }
63
64// Var
65Var::Var(String name_hint, DataType dtype, Span span) {
66 auto n = make_object<VarNode>();
67 n->name_hint = std::move(name_hint);
68 n->type_annotation = GetTypeFromRuntimeDataType(dtype);
69 n->dtype = std::move(dtype);
70 n->span = std::move(span);
71 data_ = std::move(n);
72}
73
74Var::Var(String name_hint, Type type_annotation, Span span) {
75 auto n = make_object<VarNode>();
76 n->name_hint = std::move(name_hint);
77 n->dtype = GetRuntimeDataType(type_annotation);
78 n->type_annotation = std::move(type_annotation);
79 n->span = std::move(span);
80 data_ = std::move(n);
81}
82
83Var Var::copy_with_suffix(const String& suffix) const {
84 const VarNode* node = get();
85 ObjectPtr<VarNode> new_ptr;
86 if (auto* ptr = this->as<SizeVarNode>()) {
87 new_ptr = make_object<SizeVarNode>(*ptr);
88 } else {
89 new_ptr = make_object<VarNode>(*node);
90 }
91 new_ptr->name_hint = new_ptr->name_hint + suffix;
92 return Var(new_ptr);
93}
94
95Var Var::copy_with_dtype(DataType dtype) const {
96 const VarNode* node = get();
97 ObjectPtr<VarNode> new_ptr;
98 if (auto* ptr = this->as<SizeVarNode>()) {
99 new_ptr = make_object<SizeVarNode>(*ptr);
100 } else {
101 new_ptr = make_object<VarNode>(*node);
102 }
103 new_ptr->type_annotation = GetTypeFromRuntimeDataType(dtype);
104 new_ptr->dtype = std::move(dtype);
105 return Var(new_ptr);
106}
107
108TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type,
109 Span span) {
110 if (type.IsObjectRef<Type>()) {
111 return Var(name_hint, type.operator Type(), span);
112 } else {
113 return Var(name_hint, type.operator DataType(), span);
114 }
115});
116
117TVM_REGISTER_NODE_TYPE(VarNode);
118
119// SizeVar
120SizeVar::SizeVar(String name_hint, DataType dtype, Span span) {
121 auto n = make_object<SizeVarNode>();
122 n->name_hint = std::move(name_hint);
123 n->type_annotation = GetTypeFromRuntimeDataType(dtype);
124 n->dtype = std::move(dtype);
125 n->span = std::move(span);
126 data_ = std::move(n);
127}
128
129TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) {
130 return SizeVar(s, t, span);
131});
132
133TVM_REGISTER_NODE_TYPE(SizeVarNode);
134
135// IterVar
136IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span) {
137 ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
138 if (dom.defined() && dom->extent.defined()) {
139 CHECK(dom->extent.dtype().is_int())
140 << "The dtype of the domain of an IterVar must be an integer type. However, the domain's "
141 "dtype is "
142 << dom->extent.dtype();
143 CHECK_EQ(dom->extent.dtype(), var.dtype())
144 << "The dtype of the extent of an IterVar (" << dom->extent.dtype()
145 << ") must match its associated Var's dtype (" << var.dtype() << ")";
146 }
147 n->dom = dom;
148 n->var = var;
149 n->iter_type = t;
150 n->thread_tag = thread_tag;
151 n->span = std::move(span);
152 data_ = std::move(n);
153}
154
155TVM_REGISTER_GLOBAL("tir.IterVar")
156 .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag, Span span) {
157 return IterVar(dom, var, static_cast<IterVarType>(iter_type), thread_tag, span);
158 });
159
160TVM_REGISTER_NODE_TYPE(IterVarNode);
161
162// StringImm
163StringImm::StringImm(String value, Span span) {
164 ObjectPtr<StringImmNode> node = make_object<StringImmNode>();
165 node->dtype = DataType::Handle();
166 node->value = std::move(value);
167 node->span = std::move(span);
168 data_ = std::move(node);
169}
170
171TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span) {
172 return StringImm(value, span);
173});
174
175TVM_REGISTER_NODE_TYPE(StringImmNode);
176
177// Cast
178Cast::Cast(DataType t, PrimExpr value, Span span) {
179 ICHECK(value.defined());
180 ICHECK_EQ(t.lanes(), value.dtype().lanes());
181 ObjectPtr<CastNode> node = make_object<CastNode>();
182 node->dtype = t;
183 node->value = std::move(value);
184 node->span = std::move(span);
185 data_ = std::move(node);
186}
187
188TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value, Span span) {
189 return Cast(dtype, value, span);
190});
191
192TVM_REGISTER_NODE_TYPE(CastNode);
193
194// Add
195TVM_DEFINE_BINOP_CONSTRUCTOR(Add);
196
197TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
198 return Add(a, b, span);
199});
200
201TVM_REGISTER_NODE_TYPE(AddNode);
202
203// Sub
204TVM_DEFINE_BINOP_CONSTRUCTOR(Sub);
205
206TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
207 return Sub(a, b, span);
208});
209
210TVM_REGISTER_NODE_TYPE(SubNode);
211
212// Mul
213TVM_DEFINE_BINOP_CONSTRUCTOR(Mul);
214
215TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
216 return Mul(a, b, span);
217});
218
219TVM_REGISTER_NODE_TYPE(MulNode);
220
221// Div
222TVM_DEFINE_BINOP_CONSTRUCTOR(Div);
223
224TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
225 return Div(a, b, span);
226});
227
228TVM_REGISTER_NODE_TYPE(DivNode);
229
230// Mod
231TVM_DEFINE_BINOP_CONSTRUCTOR(Mod);
232
233TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
234 return Mod(a, b, span);
235});
236
237TVM_REGISTER_NODE_TYPE(ModNode);
238
239// FloorDiv
240TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv);
241
242TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
243 return FloorDiv(a, b, span);
244});
245
246TVM_REGISTER_NODE_TYPE(FloorDivNode);
247
248// FloorMod
249TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod);
250
251TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
252 return FloorMod(a, b, span);
253});
254
255TVM_REGISTER_NODE_TYPE(FloorModNode);
256
257// Min
258TVM_DEFINE_BINOP_CONSTRUCTOR(Min);
259
260TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
261 return Min(a, b, span);
262});
263
264TVM_REGISTER_NODE_TYPE(MinNode);
265
266// Max
267TVM_DEFINE_BINOP_CONSTRUCTOR(Max);
268
269TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
270 return Max(a, b, span);
271});
272
273TVM_REGISTER_NODE_TYPE(MaxNode);
274
275// EQ
276TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ);
277
278TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
279 return EQ(a, b, span);
280});
281
282TVM_REGISTER_NODE_TYPE(EQNode);
283
284// NE
285TVM_DEFINE_CMPOP_CONSTRUCTOR(NE);
286
287TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
288 return NE(a, b, span);
289});
290
291TVM_REGISTER_NODE_TYPE(NENode);
292
293// LT
294TVM_DEFINE_CMPOP_CONSTRUCTOR(LT);
295
296TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
297 return LT(a, b, span);
298});
299
300TVM_REGISTER_NODE_TYPE(LTNode);
301
302// LE
303TVM_DEFINE_CMPOP_CONSTRUCTOR(LE);
304
305TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
306 return LE(a, b, span);
307});
308
309TVM_REGISTER_NODE_TYPE(LENode);
310
311// GT
312TVM_DEFINE_CMPOP_CONSTRUCTOR(GT);
313
314TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
315 return GT(a, b, span);
316});
317
318TVM_REGISTER_NODE_TYPE(GTNode);
319
320// GE
321TVM_DEFINE_CMPOP_CONSTRUCTOR(GE);
322
323TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
324 return GE(a, b, span);
325});
326
327TVM_REGISTER_NODE_TYPE(GENode);
328
329// And
330And::And(PrimExpr a, PrimExpr b, Span span) {
331 ICHECK(a.defined()) << "ValueError: a is undefined";
332 ICHECK(b.defined()) << "ValueError: b is undefined";
333 ICHECK(a.dtype().is_bool());
334 ICHECK(b.dtype().is_bool());
335 ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
336
337 ObjectPtr<AndNode> node = make_object<AndNode>();
338 node->dtype = DataType::Bool(a.dtype().lanes());
339 node->a = std::move(a);
340 node->b = std::move(b);
341 node->span = std::move(span);
342 data_ = std::move(node);
343}
344
345TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
346 return And(a, b, span);
347});
348
349TVM_REGISTER_NODE_TYPE(AndNode);
350
351// Or
352Or::Or(PrimExpr a, PrimExpr b, Span span) {
353 ICHECK(a.defined()) << "ValueError: a is undefined";
354 ICHECK(b.defined()) << "ValueError: b is undefined";
355 ICHECK(a.dtype().is_bool());
356 ICHECK(b.dtype().is_bool());
357 ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
358
359 ObjectPtr<OrNode> node = make_object<OrNode>();
360 node->dtype = DataType::Bool(a.dtype().lanes());
361 node->a = std::move(a);
362 node->b = std::move(b);
363 node->span = std::move(span);
364 data_ = std::move(node);
365}
366
367TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
368 return Or(a, b, span);
369});
370
371TVM_REGISTER_NODE_TYPE(OrNode);
372
373// Not
374Not::Not(PrimExpr a, Span span) {
375 ICHECK(a.defined()) << "ValueError: a is undefined";
376 ICHECK(a.dtype().is_bool());
377
378 ObjectPtr<NotNode> node = make_object<NotNode>();
379 node->dtype = DataType::Bool(a.dtype().lanes());
380 node->a = std::move(a);
381 node->span = std::move(span);
382 data_ = std::move(node);
383}
384
385TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a, Span span) { return Not(a, span); });
386
387TVM_REGISTER_NODE_TYPE(NotNode);
388
389// Select
390Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) {
391 ICHECK(condition.defined()) << "ValueError: condition is undefined";
392 ICHECK(true_value.defined()) << "ValueError: true_value is undefined";
393 ICHECK(false_value.defined()) << "ValueError: true_value is undefined";
394 ICHECK(condition.dtype().is_bool());
395 ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1);
396 ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types";
397
398 ObjectPtr<SelectNode> node = make_object<SelectNode>();
399 node->dtype = true_value.dtype();
400 node->condition = std::move(condition);
401 node->true_value = std::move(true_value);
402 node->false_value = std::move(false_value);
403 node->span = std::move(span);
404 data_ = std::move(node);
405}
406
407TVM_REGISTER_GLOBAL("tir.Select")
408 .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) {
409 return Select(condition, true_value, false_value, span);
410 });
411
412TVM_REGISTER_NODE_TYPE(SelectNode);
413
414// Load
415Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, Span span) {
416 LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint
417 << ". Use BufferStore instead.";
418 ICHECK(buffer_var.defined());
419 ICHECK(predicate.defined());
420 ICHECK(index.defined());
421
422 // Assume that the array elements have 1 lane, unless a type
423 // annotation tells us otherwise.
424 int element_lanes = 1;
425 auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
426 if (pointer_type.has_value()) {
427 // Cannot check element type of array, as it may be different than
428 // the loaded type in some cases.
429 //
430 // 1. Booleans use DataType::Int(8) while stored, and the codegens
431 // handle cast to boolean.
432 //
433 // 2. The StorageRewrite pass can merge multiple allocations at
434 // the same scope, regardless of element type. The codegen is
435 // then responsible for casting to the output type.
436
437 // TODO(Lunderberg): Uncomment this check once it can be applied.
438 // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
439 // for discussion.
440
441 // ICHECK(dtype.element_of() == pointer_type->element_of())
442 // << "Type mismatch, cannot load type " << dtype << " from buffer " <<
443 // buffer_var->name_hint
444 // << " of type " << pointer_type.value();
445 element_lanes = pointer_type->lanes();
446 }
447
448 // The C-based codegens assume that all loads occur on a array with
449 // non-vectorized elements, and cast between
450 // vectorized/non-vectorized arrays as needed. Ideally, these
451 // should be changed to explicit casts in the TIR graph, rather than
452 // being handled at the code-gen level.
453 ICHECK((dtype.lanes() == element_lanes * index.dtype().lanes()) ||
454 (dtype.lanes() == index.dtype().lanes()));
455 ICHECK((dtype.lanes() == element_lanes * predicate.dtype().lanes()) ||
456 (dtype.lanes() == index.dtype().lanes()));
457
458 ObjectPtr<LoadNode> node = make_object<LoadNode>();
459 node->dtype = dtype;
460 node->buffer_var = std::move(buffer_var);
461 node->index = std::move(index);
462 node->predicate = std::move(predicate);
463 node->span = std::move(span);
464
465 data_ = std::move(node);
466}
467
468TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) {
469 DataType t = args[0];
470 if (args.size() == 3) {
471 *ret = Load(t, args[1], args[2], const_true(t.lanes()), Span());
472 } else if (args.size() == 4) {
473 *ret = Load(t, args[1], args[2], args[3], Span());
474 } else {
475 *ret = Load(t, args[1], args[2], args[3], args[4]);
476 }
477});
478
479TVM_REGISTER_NODE_TYPE(LoadNode);
480
481// Ramp
482Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
483 ICHECK(base.defined());
484 ICHECK(stride.defined());
485 ICHECK(base.dtype().is_scalar());
486 ICHECK(stride.dtype().is_scalar());
487 ICHECK_GT(lanes, 1);
488 ICHECK_EQ(stride.dtype(), base.dtype());
489
490 ObjectPtr<RampNode> node = make_object<RampNode>();
491 node->dtype = base.dtype().with_lanes(lanes);
492 node->base = base;
493 node->stride = stride;
494 node->lanes = lanes;
495 node->span = std::move(span);
496 data_ = std::move(node);
497}
498
499TVM_REGISTER_GLOBAL("tir.Ramp")
500 .set_body_typed([](PrimExpr base, PrimExpr stride, int lanes, Span span) {
501 return Ramp(base, stride, lanes, span);
502 });
503
504TVM_REGISTER_NODE_TYPE(RampNode);
505
506// Broadcast
507Broadcast::Broadcast(PrimExpr value, int lanes, Span span) {
508 ICHECK(value.defined());
509 ICHECK(value.dtype().is_scalar());
510 ICHECK_GT(lanes, 1);
511
512 ObjectPtr<BroadcastNode> node = make_object<BroadcastNode>();
513 node->dtype = value.dtype().with_lanes(lanes);
514 node->value = std::move(value);
515 node->lanes = lanes;
516 node->span = std::move(span);
517 data_ = node;
518}
519
520TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes, Span span) {
521 return Broadcast(value, lanes, span);
522});
523
524TVM_REGISTER_NODE_TYPE(BroadcastNode);
525
526// Let
527Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) {
528 ICHECK(value.defined());
529 ICHECK(body.defined());
530 ICHECK_EQ(value.dtype(), var.dtype());
531
532 ObjectPtr<LetNode> node = make_object<LetNode>();
533 node->dtype = body.dtype();
534 node->var = std::move(var);
535 node->value = std::move(value);
536 node->body = std::move(body);
537 node->span = std::move(span);
538 data_ = std::move(node);
539}
540
541TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body,
542 Span span) {
543 return Let(var, value, body, span);
544});
545
546TVM_REGISTER_NODE_TYPE(LetNode);
547
548// Call
549Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span) {
550 for (size_t i = 0; i < args.size(); ++i) {
551 ICHECK(args[i].defined()) << "arg " << i << " is not defined()";
552 }
553
554 ObjectPtr<CallNode> node = make_object<CallNode>();
555 node->dtype = dtype;
556 node->op = std::move(op);
557 node->args = std::move(args);
558 node->span = std::move(span);
559 data_ = std::move(node);
560}
561
562TVM_REGISTER_GLOBAL("tir.Call")
563 .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) {
564 Array<PrimExpr> prim_expr_args;
565 for (const auto& it : args) {
566 ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() ||
567 it->IsInstance<IterVarNode>() || it->IsInstance<BufferRegionNode>())
568 << "Argument " << it << " is not a string or primexpr";
569 if (const auto* str = it.as<runtime::StringObj>()) {
570 prim_expr_args.push_back(StringImm(str->data));
571 } else if (const auto* iter_var = it.as<IterVarNode>()) {
572 prim_expr_args.push_back(GetRef<IterVar>(iter_var)->var);
573 } else if (const auto* br = it.as<BufferRegionNode>()) {
574 Array<PrimExpr> indices;
575 for (Range r : br->region) {
576 if (is_one(r->extent)) {
577 indices.push_back(r->min);
578 } else if (const auto* extent = r->extent.as<IntImmNode>()) {
579 indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), extent->value));
580 } else {
581 LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: "
582 << GetRef<BufferRegion>(br);
583 }
584 }
585 prim_expr_args.push_back(BufferLoad(br->buffer, indices));
586 } else {
587 prim_expr_args.push_back(Downcast<PrimExpr>(it));
588 }
589 }
590 return Call(type, op, prim_expr_args, span);
591 });
592
593TVM_REGISTER_NODE_TYPE(CallNode);
594
595// Shuffle
596Shuffle::Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span) {
597 ICHECK_NE(vectors.size(), 0U);
598 ICHECK_NE(indices.size(), 0U);
599
600 DataType base_type = vectors[0].dtype().element_of();
601 int total_lanes = 0;
602
603 for (PrimExpr val : vectors) {
604 ICHECK(val.dtype().element_of() == base_type);
605 total_lanes += val.dtype().lanes();
606 }
607 ICHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
608
609 ObjectPtr<ShuffleNode> node = make_object<ShuffleNode>();
610 node->dtype = base_type.with_lanes(static_cast<int>(indices.size()));
611 node->vectors = std::move(vectors);
612 node->indices = std::move(indices);
613 node->span = std::move(span);
614 data_ = node;
615}
616
617PrimExpr Shuffle::Concat(Array<PrimExpr> vectors, Span span) {
618 ICHECK_NE(vectors.size(), 0);
619 if (vectors.size() == 1) {
620 return vectors[0];
621 }
622 Array<PrimExpr> indices;
623 int index = 0;
624 for (const PrimExpr& e : vectors) {
625 for (int i = 0; i < e.dtype().lanes(); ++i) {
626 indices.push_back(IntImm(DataType::Int(32), index++));
627 }
628 }
629 return Shuffle(vectors, indices, span);
630}
631
632PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) {
633 return Shuffle({vector}, {Integer(index)}, span);
634}
635
636TVM_REGISTER_GLOBAL("tir.Shuffle")
637 .set_body_typed([](Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span) {
638 return Shuffle(vectors, indices, span);
639 });
640
641TVM_REGISTER_NODE_TYPE(ShuffleNode);
642
643// CommReducer
644CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
645 Array<PrimExpr> identity_element, Span span) {
646 size_t n_group = result.size();
647 CHECK_EQ(lhs.size(), n_group) << "ValueError: The number of vars in `lhs` must equal to the "
648 "number of elements in `results`";
649 CHECK_EQ(rhs.size(), n_group) << "ValueError: The number of vars in `rhs` must equal to the "
650 "number of elements in `results`";
651 CHECK_EQ(identity_element.size(), n_group)
652 << "ValueError: The number of identities must equal to the number of elements in `results`";
653
654 // Change the dtype of input vars to adapt to the dtype of identities
655 ArrayNode* p_lhs = lhs.CopyOnWrite();
656 ArrayNode* p_rhs = rhs.CopyOnWrite();
657 std::unordered_map<const VarNode*, PrimExpr> var_map;
658 var_map.reserve(n_group * 2);
659 for (int i = 0; i < static_cast<int>(n_group); ++i) {
660 DataType dtype = identity_element[i].dtype();
661 Var l = lhs[i].copy_with_dtype(dtype);
662 Var r = rhs[i].copy_with_dtype(dtype);
663 var_map[lhs[i].get()] = l;
664 var_map[rhs[i].get()] = r;
665
666 p_lhs->SetItem(i, l);
667 p_rhs->SetItem(i, r);
668 }
669
670 ArrayNode* p_result = result.CopyOnWrite();
671 for (int i = 0; i < static_cast<int>(n_group); ++i) {
672 p_result->SetItem(i, Substitute(result[i], var_map));
673 }
674
675 auto node = make_object<CommReducerNode>();
676 node->lhs = lhs;
677 node->rhs = rhs;
678 node->result = result;
679 node->identity_element = identity_element;
680 node->span = std::move(span);
681 data_ = std::move(node);
682}
683
684Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b) const {
685 ICHECK_EQ(a.size(), b.size());
686 ICHECK_EQ(lhs.size(), a.size());
687 ICHECK_EQ(rhs.size(), b.size());
688 Map<Var, PrimExpr> value_map;
689 for (size_t i = 0; i < a.size(); ++i) {
690 value_map.Set(lhs[i], a[i]);
691 value_map.Set(rhs[i], b[i]);
692 }
693 auto ret = this->result.Map([&value_map](const PrimExpr& e) { return Substitute(e, value_map); });
694 return ret;
695}
696
697TVM_REGISTER_GLOBAL("tir.CommReducer")
698 .set_body_typed([](Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
699 Array<PrimExpr> identity_element, Span span) {
700 return CommReducer(lhs, rhs, result, identity_element, span);
701 });
702
703TVM_REGISTER_GLOBAL("tir.CommReducerCombine")
704 .set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
705
706TVM_REGISTER_NODE_TYPE(CommReducerNode);
707
708// Reduce
709Reduce::Reduce(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
710 PrimExpr condition, int value_index, Array<PrimExpr> init, Span span) {
711 for (size_t i = 0; i < axis.size(); ++i) {
712 ICHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis";
713 }
714 if (!condition.defined()) {
715 condition = const_true();
716 }
717 auto n = make_object<ReduceNode>();
718 ICHECK(source.defined());
719 for (size_t i = 0; i < axis.size(); ++i) {
720 ICHECK(axis[i].defined());
721 }
722 if (!init.empty()) {
723 ICHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs";
724 for (size_t i = 0; i < init.size(); i++) {
725 ICHECK(init[i]->IsInstance<ProducerLoadNode>() || init[i]->IsInstance<IntImmNode>() ||
726 init[i]->IsInstance<FloatImmNode>())
727 << "init can only be a IntImm, FloatImm or ProducerLoad";
728 }
729 }
730 n->dtype = source[value_index].dtype();
731 n->combiner = std::move(combiner);
732 n->source = std::move(source);
733 n->init = std::move(init);
734 n->axis = std::move(axis);
735 n->condition = condition;
736 n->value_index = value_index;
737 n->span = std::move(span);
738 data_ = std::move(n);
739}
740
741TVM_REGISTER_GLOBAL("tir.Reduce")
742 .set_body_typed([](CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
743 PrimExpr condition, int value_index, Array<PrimExpr> init, Span span) {
744 return Reduce(combiner, source, axis, condition, value_index, init, span);
745 });
746
747TVM_REGISTER_NODE_TYPE(ReduceNode);
748
749// Any
750Any::Any(Span span) {
751 auto n = make_object<AnyNode>();
752 n->dtype = DataType::Int(32);
753 n->span = std::move(span);
754 data_ = std::move(n);
755}
756
757TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([](Span span) { return Any(span); });
758
759TVM_REGISTER_NODE_TYPE(AnyNode);
760
761// BufferLoad
762void BufferLoadNode::LegalizeDType() {
763 for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
764 ICHECK(indices[i].dtype().is_scalar())
765 << "Only the last index of a buffer access may be a vector type.";
766 }
767
768 int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
769 int buffer_lanes = buffer->dtype.lanes();
770
771 this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes);
772}
773
774BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span) {
775 ICHECK_EQ(buffer->shape.size(), indices.size())
776 << "Buffer " << buffer->name << " is " << buffer->shape.size()
777 << "-dimensional, cannot be indexed with the " << indices.size()
778 << "-dimensional indices provided.";
779
780 ObjectPtr<BufferLoadNode> node = make_object<BufferLoadNode>();
781 node->buffer = std::move(buffer);
782 node->indices = std::move(indices);
783 node->span = std::move(span);
784 node->LegalizeDType();
785 data_ = std::move(node);
786}
787
788TVM_REGISTER_GLOBAL("tir.BufferLoad")
789 .set_body_typed([](Buffer buffer, Array<PrimExpr> indices, Span span) {
790 return BufferLoad(buffer, indices, span);
791 });
792
793TVM_REGISTER_NODE_TYPE(BufferLoadNode);
794
795// ProducerLoad
796ProducerLoad::ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span) {
797 ObjectPtr<ProducerLoadNode> node = make_object<ProducerLoadNode>();
798 node->dtype = producer->GetDataType();
799 node->producer = std::move(producer);
800 node->indices = std::move(indices);
801 node->span = std::move(span);
802 data_ = std::move(node);
803}
804
805TVM_REGISTER_GLOBAL("tir.ProducerLoad")
806 .set_body_typed([](DataProducer producer, Array<PrimExpr> indices, Span span) {
807 return ProducerLoad(producer, indices, span);
808 });
809
810TVM_REGISTER_NODE_TYPE(ProducerLoadNode);
811
812} // namespace tir
813} // namespace tvm
814