1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/op/op.cc
22 *
23 * Common operator definitions for ops in tir/op.h
24 */
25
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30#include <tvm/tir/op_attr_types.h>
31
32#include <cmath>
33// Centralized header for constant folders.
34#include "../../arith/const_fold.h"
35#include "../../target/datatype/registry.h"
36
37namespace tvm {
38
39using namespace tir;
40
41// macro to register an unary op
42#define TVM_TIR_REGISTER_PURE_UNARY_OP(OpName) \
43 TVM_TIR_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \
44 "TCallEffectKind", Integer(CallEffectKind::kPure))
45
46// macro to register an binary op
47#define TVM_TIR_REGISTER_PURE_BINARY_OP(OpName) \
48 TVM_TIR_REGISTER_OP(OpName).set_num_inputs(2).set_attr<TCallEffectKind>( \
49 "TCallEffectKind", Integer(CallEffectKind::kPure))
50
51runtime::DataType GetRuntimeDataType(const Type& type) {
52 if (auto* n = type.as<PrimTypeNode>()) {
53 return n->dtype;
54 } else if (type.as<PointerTypeNode>()) {
55 return DataType::Handle();
56 } else if (IsVoidType(type)) {
57 return DataType::Void();
58 } else {
59 LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType";
60 }
61}
62
63Type GetType(const PrimExpr& expr) {
64 // TODO(tqchen): add recursive type inference for Call here
65 // once we introduced the corresponding fields to the IR.
66 if (auto* ptr = expr.as<tir::VarNode>()) {
67 // If Var has a more refined type annotation,
68 // return the type anotation
69 if (ptr->type_annotation.defined()) {
70 return ptr->type_annotation;
71 }
72 }
73 // Default: return the type indicated by the dtype.
74 runtime::DataType dtype = expr.dtype();
75 return GetTypeFromRuntimeDataType(dtype);
76}
77
78Type GetTypeFromRuntimeDataType(const DataType& dtype) {
79 if (dtype.is_void()) {
80 return VoidType();
81 }
82 return PrimType(dtype);
83}
84
85// LargeUIntImm
86PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) {
87 return tir::Call(
88 t, tir::builtin::large_uint_imm(),
89 {make_const(DataType::UInt(32), low, span), make_const(DataType::UInt(32), high, span)},
90 span);
91}
92
93// Q-multiplication
94PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span) {
95 return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(),
96 {x, y, q, s}, span);
97}
98
99// The public function with a quick checking path.
100void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
101 CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
102 CHECK(rhs.defined()) << "ValueError: `rhs` is null in the binary operator";
103 if (lhs.dtype() == rhs.dtype()) return;
104 DataType ltype = lhs.dtype();
105 DataType rtype = rhs.dtype();
106 if (ltype.lanes() == 1 && rtype.lanes() != 1) {
107 lhs = tir::Broadcast(lhs, rtype.lanes());
108 } else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
109 rhs = tir::Broadcast(rhs, ltype.lanes());
110 } else {
111 ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype;
112 }
113 if (lhs.dtype() == rhs.dtype()) return;
114
115 ltype = lhs.dtype();
116 rtype = rhs.dtype();
117 // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by
118 // operators. This can be helpful for users to find potential type conversion problems. The
119 // following are exceptions:
120 if (ltype.is_float() && rtype.is_float()) {
121 // Given two dissimilar floats, cast the lower bit version to the higher bit version.
122 // E.g. fp16 + fp32 --> fp32 + fp32
123 if (ltype.bits() < rtype.bits()) {
124 lhs = cast(rtype, lhs);
125 } else {
126 rhs = cast(ltype, rhs);
127 }
128 } else if (!ltype.is_float() &&
129 (rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
130 // Cast int->float when the other operand is a float
131 lhs = cast(rtype, lhs);
132 } else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
133 !rtype.is_float()) {
134 // Cast int->float when the other operand is a float
135 rhs = cast(ltype, rhs);
136 } else if (!ltype.is_bfloat16() &&
137 (rtype.is_bfloat16() ||
138 datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
139 // Cast int->bfloat16 when the other operand is a bfloat16
140 lhs = cast(rtype, lhs);
141 } else if ((ltype.is_bfloat16() ||
142 datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
143 !rtype.is_bfloat16()) {
144 // Cast int->bfloat16 when the other operand is a bfloat16
145 rhs = cast(ltype, rhs);
146 } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
147 // Promote int to higher bits e.g. int8 + int16 --> int16 + int16
148 if (ltype.bits() < rtype.bits()) {
149 lhs = cast(rtype, lhs);
150 } else {
151 rhs = cast(ltype, rhs);
152 }
153 } else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) {
154 // Handle mixing signed and unsigned integers
155 if (ltype.bits() < rtype.bits()) {
156 lhs = cast(rtype, lhs);
157 } else if (ltype.bits() > rtype.bits()) {
158 rhs = cast(ltype, rhs);
159 } else {
160 // The width of signed and unsigned integers is same.
161 if (ltype.is_uint()) {
162 rhs = cast(ltype, rhs);
163 } else {
164 lhs = cast(rtype, lhs);
165 }
166 }
167 } else {
168 LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
169 }
170}
171
172PrimExpr ret(PrimExpr value, Span span) {
173 return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span);
174}
175
176// maximum and min limits
177PrimExpr max_value(const DataType& dtype, Span span) {
178 using namespace tir;
179 ICHECK_EQ(dtype.lanes(), 1);
180 if (dtype.is_int()) {
181 if (dtype.bits() == 64) {
182 return IntImm(dtype, std::numeric_limits<int64_t>::max(), span);
183 } else if (dtype.bits() < 64) {
184 int64_t val = 1;
185 val = (val << (dtype.bits() - 1)) - 1;
186 return IntImm(dtype, val, span);
187 }
188 } else if (dtype.is_uint()) {
189 if (dtype.bits() == 64) {
190 return make_const(dtype, std::numeric_limits<uint64_t>::max(), span);
191 } else if (dtype.bits() < 64) {
192 uint64_t val = 1;
193 val = (val << static_cast<uint64_t>(dtype.bits())) - 1;
194 return IntImm(dtype, static_cast<int64_t>(val), span);
195 }
196 } else if (dtype.is_float()) {
197 if (dtype.bits() == 64) {
198 return FloatImm(dtype, std::numeric_limits<double>::max(), span);
199 } else if (dtype.bits() == 32) {
200 return FloatImm(dtype, std::numeric_limits<float>::max(), span);
201 } else if (dtype.bits() == 16) {
202 return FloatImm(dtype, 65504.0, span);
203 }
204 } else if (dtype.is_bfloat16()) {
205 return FloatImm(dtype, std::numeric_limits<float>::max(), span);
206 }
207 LOG(FATAL) << "Cannot decide max_value for type" << dtype;
208}
209
210PrimExpr min_value(const DataType& dtype, Span span) {
211 using namespace tir;
212 ICHECK_EQ(dtype.lanes(), 1);
213 if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) {
214 // TODO(tkonolige): need to convert all registered min functions to use the span.
215 auto f = datatype::GetMinFunc(dtype.code());
216 ICHECK(f) << "No minimum function registered for custom dtype " << (unsigned int)dtype.code();
217 // TODO(@hypercubestart) Document this change (and others associated with the overflowing
218 // floatimm min bug)
219 return (*f)(dtype.bits());
220 } else if (dtype.is_int()) {
221 if (dtype.bits() == 64) {
222 return IntImm(dtype, std::numeric_limits<int64_t>::lowest(), span);
223 } else if (dtype.bits() < 64) {
224 int64_t val = 1;
225 val = -(val << (dtype.bits() - 1));
226 return IntImm(dtype, val, span);
227 }
228 } else if (dtype.is_uint()) {
229 return IntImm(dtype, 0, span);
230 } else if (dtype.is_float()) {
231 if (dtype.bits() == 64) {
232 return FloatImm(dtype, std::numeric_limits<double>::lowest(), span);
233 } else if (dtype.bits() == 32) {
234 return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
235 } else if (dtype.bits() == 16) {
236 return FloatImm(dtype, -65504.0, span);
237 }
238 } else if (dtype.is_bfloat16()) {
239 return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
240 }
241 LOG(FATAL) << "Cannot decide min_value for type" << dtype;
242}
243
244// infinity
245PrimExpr infinity(const DataType& dtype, Span span) {
246 using namespace tir;
247 ICHECK_EQ(dtype.lanes(), 1);
248 if (dtype.is_float()) {
249 if (dtype.bits() == 64) {
250 return FloatImm(dtype, std::numeric_limits<double>::infinity(), span);
251 } else if (dtype.bits() == 32 || dtype.bits() == 16) {
252 return FloatImm(dtype, std::numeric_limits<float>::infinity(), span);
253 }
254 }
255 LOG(FATAL) << "Cannot decide infinity for type " << dtype;
256}
257
258namespace tir {
259template <typename ValueType>
260inline bool ConstPowerHelper(ValueType val, int* shift) {
261 if (val <= 0) return false;
262 shift[0] = 0;
263 while (val != 0) {
264 if (val & 1) {
265 return (val == 1);
266 }
267 ++shift[0];
268 val = val >> 1;
269 }
270 return true;
271}
272
273bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) {
274 if (const auto* op = x.as<tir::IntImmNode>()) {
275 return ConstPowerHelper(op->value, shift);
276 } else {
277 return false;
278 }
279}
280} // namespace tir
281
282PrimExpr cast(const DataType& t, PrimExpr value, Span span) {
283 using tir::FloatImmNode;
284 if (value.dtype() == t) return value;
285 // const fold IntImm as they are used in index computations
286 if (t.lanes() == 1) {
287 if (const IntImmNode* op = value.as<IntImmNode>()) {
288 return make_const(t, op->value, op->span);
289 } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
290 return make_const(t, op->value, op->span);
291 }
292 ICHECK(!value.dtype().is_handle()) << "Can't cast a handle to other types.";
293 return tir::Cast(t, value, span);
294 } else {
295 DataType vtype = t.element_of();
296 if (value.dtype().lanes() == 1) {
297 // manually unroll cast
298 if (value.dtype() != vtype) {
299 if (const IntImmNode* op = value.as<IntImmNode>()) {
300 value = make_const(vtype, op->value, op->span);
301 } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
302 value = make_const(vtype, op->value, op->span);
303 } else {
304 value = tir::Cast(vtype, value, span);
305 }
306 }
307 return tir::Broadcast(value, t.lanes(), span);
308 } else {
309 ICHECK(value.dtype().lanes() == t.lanes());
310 if (const auto* broadcast = value.as<tir::BroadcastNode>()) {
311 return tir::Broadcast(cast(vtype, broadcast->value, span), t.lanes(), span);
312 } else if (const auto* ramp = value.as<tir::RampNode>()) {
313 if (t.is_int() || t.is_uint()) {
314 // only cast to index data type can be folded to ramp
315 return tir::Ramp(cast(vtype, ramp->base, span), cast(vtype, ramp->stride, span),
316 ramp->lanes, span);
317 }
318 }
319 return tir::Cast(t, value, span);
320 }
321 }
322}
323
324// reinterpret
325PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
326 if (value.dtype() == t) return value;
327 return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
328}
329
330// operator+
331PrimExpr operator+(PrimExpr a, PrimExpr b) { return add(a, b); }
332
333PrimExpr add(PrimExpr a, PrimExpr b, Span span) {
334 BinaryOpMatchTypes(a, b, span);
335 if (auto ret = arith::TryConstFold<tir::Add>(a, b)) return ret.value();
336 return tir::Add(a, b, span);
337}
338
339// negation
340PrimExpr operator-(PrimExpr a) { return neg(a); }
341
342PrimExpr neg(PrimExpr a, Span span) {
343 using tir::FloatImmNode;
344 using tir::IntImmNode;
345 const IntImmNode* pa = a.as<IntImmNode>();
346 const FloatImmNode* fa = a.as<FloatImmNode>();
347 if (pa) return IntImm(a.dtype(), -pa->value, span);
348 if (fa) return FloatImm(a.dtype(), -fa->value, span);
349 return make_zero(a.dtype(), span) - a;
350}
351
352PrimExpr operator-(PrimExpr a, PrimExpr b) { return sub(a, b); }
353
354PrimExpr sub(PrimExpr a, PrimExpr b, Span span) {
355 BinaryOpMatchTypes(a, b, span);
356 if (auto ret = arith::TryConstFold<tir::Sub>(a, b)) return ret.value();
357 return tir::Sub(a, b, span);
358}
359
360PrimExpr operator*(PrimExpr a, PrimExpr b) { return mul(a, b); }
361PrimExpr mul(PrimExpr a, PrimExpr b, Span span) {
362 BinaryOpMatchTypes(a, b, span);
363 if (auto ret = arith::TryConstFold<tir::Mul>(a, b)) return ret.value();
364 return tir::Mul(a, b, span);
365}
366
367PrimExpr div(PrimExpr a, PrimExpr b, Span span) {
368 BinaryOpMatchTypes(a, b, span);
369 if (auto ret = arith::TryConstFold<tir::Div>(a, b)) return ret.value();
370 return tir::Div(a, b, span);
371}
372
373PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span) {
374 ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
375 ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
376 return div(a, b, span);
377}
378
379PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span) {
380 BinaryOpMatchTypes(a, b, span);
381 if (auto ret = arith::TryConstFold<tir::Mod>(a, b)) return ret.value();
382 return tir::Mod(a, b, span);
383}
384
385PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); }
386
387PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); }
388
389// TODO(tqchen): switch to floordiv
390PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span) { return floordiv(a, b, span); }
391
392PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span) { return ceildiv(a, b, span); }
393
394PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span) { return floormod(a, b, span); }
395
396PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
397 ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
398 ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
399 BinaryOpMatchTypes(a, b, span);
400 if (auto ret = arith::TryConstFold<tir::FloorDiv>(a, b)) return ret.value();
401 return tir::FloorDiv(a, b, span);
402}
403
404PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) {
405 ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
406 ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
407 BinaryOpMatchTypes(a, b, span);
408 if (auto ret = arith::TryConstFold<tir::FloorDiv>(a + b - 1, b)) return ret.value();
409 return tir::FloorDiv(a + b - 1, b, span);
410}
411
412PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) {
413 ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
414 ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
415 BinaryOpMatchTypes(a, b, span);
416 if (auto ret = arith::TryConstFold<tir::FloorMod>(a, b)) return ret.value();
417 return tir::FloorMod(a, b, span);
418}
419
420PrimExpr min(PrimExpr a, PrimExpr b, Span span) {
421 // inf-aware simplificaiton
422 using arith::is_neg_inf;
423 using arith::is_pos_inf;
424 if (is_pos_inf(a)) return b;
425 if (is_neg_inf(a)) return a;
426 if (is_pos_inf(b)) return a;
427 if (is_neg_inf(b)) return b;
428 BinaryOpMatchTypes(a, b, span);
429 if (auto ret = arith::TryConstFold<tir::Min>(a, b)) return ret.value();
430 return tir::Min(a, b, span);
431}
432
433PrimExpr max(PrimExpr a, PrimExpr b, Span span) {
434 // inf-aware simplificaiton
435 using arith::is_neg_inf;
436 using arith::is_pos_inf;
437 if (is_pos_inf(a)) return a;
438 if (is_neg_inf(a)) return b;
439 if (is_pos_inf(b)) return b;
440 if (is_neg_inf(b)) return a;
441 BinaryOpMatchTypes(a, b, span);
442 if (auto ret = arith::TryConstFold<tir::Max>(a, b)) return ret.value();
443 return tir::Max(a, b, span);
444}
445
446// if_then_else
447PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
448 ICHECK(cond.dtype() == DataType::Bool(1))
449 << "if_then_else only accept the condition to be boolean type.";
450 BinaryOpMatchTypes(true_value, false_value, span);
451 if (const IntImmNode* op = cond.as<IntImmNode>()) {
452 if (op->value != 0) {
453 return true_value;
454 } else {
455 return false_value;
456 }
457 }
458
459 return tir::Call(true_value.dtype(), tir::builtin::if_then_else(),
460 {cond, true_value, false_value}, span);
461}
462
463// likely
464PrimExpr likely(PrimExpr cond, Span span) {
465 if (is_const_int(cond)) return cond;
466 return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, span);
467}
468
469// operator>
470PrimExpr operator>(PrimExpr a, PrimExpr b) { return greater(a, b); }
471PrimExpr greater(PrimExpr a, PrimExpr b, Span span) {
472 BinaryOpMatchTypes(a, b, span);
473 if (auto ret = arith::TryConstFold<tir::GT>(a, b)) return ret.value();
474 return tir::GT(a, b, span);
475}
476
477PrimExpr operator>=(PrimExpr a, PrimExpr b) { return greater_equal(a, b); }
478PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span) {
479 BinaryOpMatchTypes(a, b, span);
480 if (auto ret = arith::TryConstFold<tir::GE>(a, b)) return ret.value();
481 return tir::GE(a, b, span);
482}
483
484PrimExpr operator<(PrimExpr a, PrimExpr b) { return less(a, b); }
485PrimExpr less(PrimExpr a, PrimExpr b, Span span) {
486 BinaryOpMatchTypes(a, b, span);
487 if (auto ret = arith::TryConstFold<tir::LT>(a, b)) return ret.value();
488 return tir::LT(a, b, span);
489}
490
491PrimExpr operator<=(PrimExpr a, PrimExpr b) { return less_equal(a, b); }
492PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span) {
493 BinaryOpMatchTypes(a, b, span);
494 if (auto ret = arith::TryConstFold<tir::LE>(a, b)) return ret.value();
495 return tir::LE(a, b, span);
496}
497
498PrimExpr operator==(PrimExpr a, PrimExpr b) { return equal(a, b); }
499PrimExpr equal(PrimExpr a, PrimExpr b, Span span) {
500 BinaryOpMatchTypes(a, b, span);
501 if (auto ret = arith::TryConstFold<tir::EQ>(a, b)) return ret.value();
502 return tir::EQ(a, b, span);
503}
504
505PrimExpr operator!=(PrimExpr a, PrimExpr b) { return not_equal(a, b); }
506PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) {
507 BinaryOpMatchTypes(a, b, span);
508 if (auto ret = arith::TryConstFold<tir::NE>(a, b)) return ret.value();
509 return tir::NE(a, b, span);
510}
511
512namespace {
513void type_check_boolean_args(const PrimExpr& arg, const char* op) {
514 ICHECK(arg.dtype().is_bool()) << "Expected boolean argument for " << op << ", but received "
515 << arg << " of type " << arg.dtype();
516}
517void type_check_boolean_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
518 ICHECK(lhs.dtype().is_bool()) << "Expected boolean argument as LHS of " << op << ", but received "
519 << lhs << " of type " << lhs.dtype();
520 ICHECK(rhs.dtype().is_bool()) << "Expected boolean argument as RHS of " << op << ", but received "
521 << rhs << " of type " << rhs.dtype();
522}
523
524void type_check_integer_args(const PrimExpr& arg, const char* op) {
525 ICHECK(arg.dtype().is_int() || arg.dtype().is_uint())
526 << "Expected integer argument for " << op << ", but received " << arg << " of type "
527 << arg.dtype();
528}
529
530void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
531 ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint())
532 << "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type "
533 << lhs.dtype();
534 ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint())
535 << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type "
536 << rhs.dtype();
537}
538} // namespace
539
540PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
541PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) {
542 type_check_boolean_args(a, b, "&& operator (logical AND)");
543 if (auto ret = arith::TryConstFold<tir::And>(a, b)) return ret.value();
544 return tir::And(a, b, span);
545}
546
547PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); }
548PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) {
549 type_check_boolean_args(a, b, "|| operator (logical OR)");
550 if (auto ret = arith::TryConstFold<tir::Or>(a, b)) return ret.value();
551 return tir::Or(a, b, span);
552}
553
554PrimExpr operator!(PrimExpr a) { return logical_not(a); }
555PrimExpr logical_not(PrimExpr a, Span span) {
556 type_check_boolean_args(a, "! operator (logical NOT)");
557 if (auto ret = arith::TryConstFold<tir::Not>(a)) return ret.value();
558 return tir::Not(a, span);
559}
560
561// shift right
562PrimExpr operator>>(PrimExpr a, PrimExpr b) { return right_shift(a, b); }
563
564PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
565 type_check_integer_args(a, b, ">> operator (right shift)");
566
567 BinaryOpMatchTypes(a, b, span);
568 TVM_INDEX_CONST_PROPAGATION({
569 const DataType& rtype = a.dtype();
570 if (pb)
571 ICHECK(pb->value >= 0 && pb->value < rtype.bits())
572 << "Shift amount must be non-negative and less than " << rtype.bits() << " for type "
573 << rtype;
574 if (pa && pb) {
575 return IntImm(rtype, (pa->value >> pb->value), span);
576 }
577 if (pb) {
578 if (pb->value == 0) return a;
579 }
580 });
581
582 return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, span);
583}
584
585// shift left
586PrimExpr operator<<(PrimExpr a, PrimExpr b) { return left_shift(a, b); }
587PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
588 type_check_integer_args(a, b, "<< operator (left shift)");
589 BinaryOpMatchTypes(a, b, span);
590 TVM_INDEX_CONST_PROPAGATION({
591 const DataType& rtype = a.dtype();
592 if (pb)
593 ICHECK(pb->value >= 0 && pb->value < rtype.bits())
594 << "Shift amount must be non-negative and less than " << rtype.bits() << " for type "
595 << rtype;
596 if (pa && pb) return IntImm(rtype, (pa->value << pb->value), span);
597 if (pb) {
598 if (pb->value == 0) return a;
599 }
600 });
601 return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, span);
602}
603
604// bitwise and
605PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); }
606PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
607 type_check_integer_args(a, b, "& operator (bitwise AND)");
608 BinaryOpMatchTypes(a, b, span);
609 TVM_INDEX_CONST_PROPAGATION({
610 const DataType& rtype = a.dtype();
611 if (pa && pb) return IntImm(rtype, (pa->value & pb->value), span);
612 });
613 return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, span);
614}
615
616// bitwise_or
617PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); }
618PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
619 type_check_integer_args(a, b, "| operator (bitwise OR)");
620 BinaryOpMatchTypes(a, b, span);
621 TVM_INDEX_CONST_PROPAGATION({
622 const DataType& rtype = a.dtype();
623 if (pa && pb) return IntImm(rtype, (pa->value | pb->value), span);
624 });
625 return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, span);
626}
627
628// bitwise_xor
629PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
630PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
631 type_check_integer_args(a, b, "^ operator (bitwise XOR)");
632 BinaryOpMatchTypes(a, b, span);
633 TVM_INDEX_CONST_PROPAGATION({
634 const DataType& rtype = a.dtype();
635 if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value), span);
636 });
637 return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, span);
638}
639
640// bitwise_not
641PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }
642
643PrimExpr bitwise_neg(PrimExpr a, Span span) {
644 type_check_integer_args(a, "~ operator (bitwise NOT)");
645 return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span);
646}
647
648TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) {
649 return bitwise_neg(a, span);
650});
651
652// pow
653PrimExpr pow(PrimExpr x, PrimExpr y, Span span) {
654 BinaryOpMatchTypes(x, y, span);
655 ICHECK(x.dtype().is_float()) << "power only applies to float";
656 static auto op = Op::Get("tir.pow");
657 return tir::Call(x.dtype(), op, {x, y}, span);
658}
659
660TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr<TVectorizable>("TVectorizable", true);
661
662// abs
663PrimExpr abs(PrimExpr x, Span span) {
664 if (x.dtype().is_int()) {
665 using tir::IntImmNode;
666 const IntImmNode* px = x.as<IntImmNode>();
667 if (px) {
668 return IntImm(x.dtype(), std::abs(px->value), px->span);
669 }
670 return tir::Select(x >= make_zero(x.dtype()), x, -x, span);
671 } else if (x.dtype().is_float()) {
672 using tir::FloatImmNode;
673 const FloatImmNode* fx = x.as<FloatImmNode>();
674 if (fx) {
675 return FloatImm(x.dtype(), std::fabs(fx->value), fx->span);
676 }
677 static auto op = Op::Get("tir.fabs");
678 return tir::Call(x.dtype(), op, {x}, span);
679 } else if (x.dtype().is_uint()) {
680 return x;
681 } else {
682 LOG(FATAL) << "Data type " << x.dtype()
683 << " not supported for absolute op. Skipping absolute op...";
684 return x;
685 }
686}
687
688TVM_TIR_REGISTER_PURE_UNARY_OP("fabs").set_attr<TVectorizable>("TVectorizable", true);
689
690// isnan
691PrimExpr isnan(PrimExpr x, Span span) {
692 DataType t = DataType::Bool(x.dtype().lanes());
693 if (x.dtype().is_int() || x.dtype().is_uint()) {
694 return make_const(t, false);
695 } else if (x.dtype().is_float()) {
696 using tir::FloatImmNode;
697 const FloatImmNode* fx = x.as<FloatImmNode>();
698 if (fx) {
699 return make_const(t, std::isnan(fx->value), fx->span);
700 }
701 static auto op = Op::Get("tir.isnan");
702 if (x.dtype().bits() == 16) {
703 return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, span);
704 } else {
705 return tir::Call(t, op, {x}, span);
706 }
707 } else {
708 LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op...";
709 }
710}
711
712// isinf
713PrimExpr isinf(PrimExpr x, Span span) {
714 DataType t = DataType::Bool(x.dtype().lanes());
715 if (x.dtype().is_int() || x.dtype().is_uint()) {
716 return make_const(t, false, span);
717 } else if (x.dtype().is_float()) {
718 PrimExpr infX = infinity(x.dtype(), span);
719 return abs(x, span) == infX && !isnan(x, span);
720 } else {
721 LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness ops. Skipping it...";
722 }
723}
724
725// isfinite
726PrimExpr isfinite(PrimExpr x, Span span) { return !isinf(x, span) && !isnan(x, span); }
727
728PrimExpr sum(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
729 Var x("x", source.dtype(), span), y("y", source.dtype(), span);
730 PrimExpr result = tir::Add(x, y, span);
731 PrimExpr identity_element = make_zero(source.dtype(), span);
732 tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
733 return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
734}
735
736PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
737 type_check_boolean_args(source, "tvm::all");
738 Var x("x", source.dtype(), span), y("y", source.dtype());
739 PrimExpr result = tir::And(x, y, span);
740 PrimExpr identity_element = make_const(source.dtype(), true, span);
741 tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
742 return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
743}
744
745PrimExpr any(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
746 type_check_boolean_args(source, "tvm::any");
747 Var x("x", source.dtype(), span), y("y", source.dtype(), span);
748 PrimExpr result = tir::Or(x, y, span);
749 PrimExpr identity_element = make_const(source.dtype(), false, span);
750 tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
751 return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
752}
753
754PrimExpr max(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
755 Var x("x", source.dtype(), span), y("y", source.dtype(), span);
756 PrimExpr result = tir::Max(x, y, span);
757 PrimExpr identity_element = min_value(source.dtype(), span);
758 tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
759 return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
760}
761
762PrimExpr min(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
763 Var x("x", source.dtype(), span), y("y", source.dtype(), span);
764 PrimExpr result = tir::Min(x, y, span);
765 PrimExpr identity_element = max_value(source.dtype(), span);
766 tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
767 return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
768}
769
770PrimExpr prod(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
771 Var x("x", source.dtype(), span), y("y", source.dtype(), span);
772 PrimExpr result = tir::Mul(x, y, span);
773 PrimExpr identity_element = make_const(source.dtype(), 1, span);
774 tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span);
775 return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span);
776}
777
778// fmod
779PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) {
780 BinaryOpMatchTypes(x, y, span);
781 ICHECK(x.dtype().is_float()) << "fmod only applies to float";
782 static auto op = Op::Get("tir.fmod");
783 return tir::Call(x.dtype(), op, {x, y}, span);
784}
785
786TVM_TIR_REGISTER_PURE_UNARY_OP("fmod");
787
788// floor
789PrimExpr floor(PrimExpr x, Span span) {
790 if (x.dtype().is_int() || x.dtype().is_uint()) {
791 return x;
792 }
793 using tir::FloatImmNode;
794 const FloatImmNode* fx = x.as<FloatImmNode>();
795 if (fx) return FloatImm(x.dtype(), std::floor(fx->value), fx->span);
796 static auto op = Op::Get("tir.floor");
797 return tir::Call(x.dtype(), op, {x}, span);
798}
799
800TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr<TVectorizable>("TVectorizable", true);
801
802// ceil
803PrimExpr ceil(PrimExpr x, Span span) {
804 if (x.dtype().is_int() || x.dtype().is_uint()) {
805 return x;
806 }
807 using tir::FloatImmNode;
808 const FloatImmNode* fx = x.as<FloatImmNode>();
809 if (fx) return FloatImm(x.dtype(), std::ceil(fx->value), fx->span);
810 static auto op = Op::Get("tir.ceil");
811 return tir::Call(x.dtype(), op, {x}, span);
812}
813
814TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr<TVectorizable>("TVectorizable", true);
815
816// round
817PrimExpr round(PrimExpr x, Span span) {
818 if (x.dtype().is_int() || x.dtype().is_uint()) {
819 return x;
820 }
821 using tir::FloatImmNode;
822 const FloatImmNode* fx = x.as<FloatImmNode>();
823 if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span);
824 static auto op = Op::Get("tir.round");
825 return tir::Call(x.dtype(), op, {x}, span);
826}
827
828TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr<TVectorizable>("TVectorizable", true);
829
830// nearbyint
831PrimExpr nearbyint(PrimExpr x, Span span) {
832 if (x.dtype().is_int() || x.dtype().is_uint()) {
833 return x;
834 }
835 using tir::FloatImmNode;
836 const FloatImmNode* fx = x.as<FloatImmNode>();
837 if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span);
838 static auto op = Op::Get("tir.nearbyint");
839 return tir::Call(x.dtype(), op, {x}, span);
840}
841
842TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint");
843
844// trunc
845PrimExpr trunc(PrimExpr x, Span span) {
846 if (x.dtype().is_int() || x.dtype().is_uint()) {
847 return x;
848 }
849 using tir::FloatImmNode;
850 const FloatImmNode* fx = x.as<FloatImmNode>();
851 if (fx) {
852 return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)),
853 fx->span);
854 }
855 static auto op = Op::Get("tir.trunc");
856 return tir::Call(x.dtype(), op, {x}, span);
857}
858
859TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr<TVectorizable>("TVectorizable", true);
860
861// unary op registration.
862TVM_TIR_REGISTER_PURE_UNARY_OP("exp").set_attr<TVectorizable>("TVectorizable", true);
863
864TVM_TIR_REGISTER_PURE_UNARY_OP("exp2").set_attr<TVectorizable>("TVectorizable", true);
865
866TVM_TIR_REGISTER_PURE_UNARY_OP("exp10").set_attr<TVectorizable>("TVectorizable", true);
867
868TVM_TIR_REGISTER_PURE_UNARY_OP("erf");
869
870TVM_TIR_REGISTER_PURE_UNARY_OP("tanh").set_attr<TVectorizable>("TVectorizable", true);
871
872TVM_TIR_REGISTER_PURE_UNARY_OP("sigmoid").set_attr<TVectorizable>("TVectorizable", true);
873
874TVM_TIR_REGISTER_PURE_UNARY_OP("sqrt").set_attr<TVectorizable>("TVectorizable", true);
875
876TVM_TIR_REGISTER_PURE_UNARY_OP("rsqrt");
877
878TVM_TIR_REGISTER_PURE_UNARY_OP("log").set_attr<TVectorizable>("TVectorizable", true);
879
880TVM_TIR_REGISTER_PURE_UNARY_OP("log2").set_attr<TVectorizable>("TVectorizable", true);
881
882TVM_TIR_REGISTER_PURE_UNARY_OP("log1p");
883
884TVM_TIR_REGISTER_PURE_UNARY_OP("log10").set_attr<TVectorizable>("TVectorizable", true);
885
886TVM_TIR_REGISTER_PURE_UNARY_OP("tan").set_attr<TVectorizable>("TVectorizable", true);
887
888TVM_TIR_REGISTER_PURE_UNARY_OP("cos").set_attr<TVectorizable>("TVectorizable", true);
889
890TVM_TIR_REGISTER_PURE_UNARY_OP("cosh").set_attr<TVectorizable>("TVectorizable", true);
891
892TVM_TIR_REGISTER_PURE_UNARY_OP("sin").set_attr<TVectorizable>("TVectorizable", true);
893
894TVM_TIR_REGISTER_PURE_UNARY_OP("sinh").set_attr<TVectorizable>("TVectorizable", true);
895
896TVM_TIR_REGISTER_PURE_UNARY_OP("asin");
897
898TVM_TIR_REGISTER_PURE_UNARY_OP("acos");
899
900TVM_TIR_REGISTER_PURE_UNARY_OP("atan");
901
902TVM_TIR_REGISTER_PURE_UNARY_OP("acosh");
903
904TVM_TIR_REGISTER_PURE_UNARY_OP("asinh");
905
906TVM_TIR_REGISTER_PURE_UNARY_OP("atanh");
907
908TVM_TIR_REGISTER_PURE_UNARY_OP("clz");
909
910// binary intrinsics
911TVM_TIR_REGISTER_PURE_BINARY_OP("atan2");
912
913TVM_TIR_REGISTER_PURE_BINARY_OP("nextafter");
914
915TVM_TIR_REGISTER_PURE_BINARY_OP("hypot");
916
917TVM_TIR_REGISTER_PURE_BINARY_OP("copysign");
918
919TVM_TIR_REGISTER_PURE_BINARY_OP("ldexp");
920
921TVM_TIR_REGISTER_OP("TVMBackendAllocWorkspace")
922 .set_num_inputs(5)
923 .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAllocWorkspace")
924 .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
925
926TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace")
927 .set_num_inputs(3)
928 .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendFreeWorkspace")
929 .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
930
931// expose basic functions to node namespace
932TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) {
933 if (args[0].type_code() == kDLInt) {
934 *ret = tir::make_const(args[1], args[0].operator int64_t(), args[2]);
935 } else if (args[0].type_code() == kDLFloat) {
936 *ret = tir::make_const(args[1], args[0].operator double(), args[2]);
937 } else {
938 LOG(FATAL) << "only accept int or float"; // FIXME
939 }
940});
941
942TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm);
943
944TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value);
945
946TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
947
948TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(infinity);
949
950TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
951
952TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);
953
954TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);
955
956TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
957
958TVM_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf);
959
960TVM_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor);
961
962TVM_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil);
963
964TVM_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round);
965
966TVM_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint);
967
968TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
969
970TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
971
972// operator overloading, smarter than make
973#define REGISTER_MAKE_BINARY_OP(Node, Func) \
974 TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
975 return (Func(a, b, span)); \
976 })
977
978#define REGISTER_MAKE_BIT_OP(Node, Func) \
979 TVM_REGISTER_GLOBAL("tir." #Node).set_body([](TVMArgs args, TVMRetValue* ret) { \
980 bool lhs_is_int = args[0].type_code() == kDLInt; \
981 bool rhs_is_int = args[1].type_code() == kDLInt; \
982 if (lhs_is_int) { \
983 *ret = (Func(args[0].operator int(), args[1].operator PrimExpr(), args[2])); \
984 } else if (rhs_is_int) { \
985 *ret = (Func(args[0].operator PrimExpr(), args[1].operator int(), args[2])); \
986 } else { \
987 *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr(), args[2])); \
988 } \
989 })
990
991REGISTER_MAKE_BINARY_OP(_OpAdd, add);
992REGISTER_MAKE_BINARY_OP(_OpSub, sub);
993REGISTER_MAKE_BINARY_OP(_OpMul, mul);
994REGISTER_MAKE_BINARY_OP(_OpDiv, div);
995REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
996REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
997REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
998REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
999REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
1000REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
1001REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
1002REGISTER_MAKE_BINARY_OP(_OpCeilDiv, ceildiv);
1003REGISTER_MAKE_BINARY_OP(_OpPow, pow);
1004REGISTER_MAKE_BINARY_OP(_OpMin, min);
1005REGISTER_MAKE_BINARY_OP(_OpMax, max);
1006REGISTER_MAKE_BINARY_OP(_OpEQ, equal);
1007REGISTER_MAKE_BINARY_OP(_OpNE, not_equal);
1008REGISTER_MAKE_BINARY_OP(_OpLT, less); // NOLINT(*)
1009REGISTER_MAKE_BINARY_OP(_OpLE, less_equal); // NOLINT(*)
1010REGISTER_MAKE_BINARY_OP(_OpGT, greater); // NOLINT(*)
1011REGISTER_MAKE_BINARY_OP(_OpGE, greater_equal);
1012REGISTER_MAKE_BINARY_OP(_OpAnd, logical_and);
1013REGISTER_MAKE_BINARY_OP(_OpOr, logical_or);
1014REGISTER_MAKE_BIT_OP(bitwise_and, bitwise_and);
1015REGISTER_MAKE_BIT_OP(bitwise_or, bitwise_or);
1016REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
1017REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*)
1018REGISTER_MAKE_BIT_OP(right_shift, right_shift);
1019
1020TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
1021 .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
1022 return if_then_else(cond, true_value, false_value, span);
1023 });
1024
1025TVM_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t, Span span) {
1026 return const_true(t.lanes(), span);
1027});
1028
1029} // namespace tvm
1030