1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/op/op.cc |
22 | * |
23 | * Common operator definitions for ops in tir/op.h |
24 | */ |
25 | |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | #include <tvm/tir/op_attr_types.h> |
31 | |
32 | #include <cmath> |
33 | // Centralized header for constant folders. |
34 | #include "../../arith/const_fold.h" |
35 | #include "../../target/datatype/registry.h" |
36 | |
37 | namespace tvm { |
38 | |
39 | using namespace tir; |
40 | |
41 | // macro to register an unary op |
42 | #define TVM_TIR_REGISTER_PURE_UNARY_OP(OpName) \ |
43 | TVM_TIR_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \ |
44 | "TCallEffectKind", Integer(CallEffectKind::kPure)) |
45 | |
46 | // macro to register an binary op |
47 | #define TVM_TIR_REGISTER_PURE_BINARY_OP(OpName) \ |
48 | TVM_TIR_REGISTER_OP(OpName).set_num_inputs(2).set_attr<TCallEffectKind>( \ |
49 | "TCallEffectKind", Integer(CallEffectKind::kPure)) |
50 | |
51 | runtime::DataType GetRuntimeDataType(const Type& type) { |
52 | if (auto* n = type.as<PrimTypeNode>()) { |
53 | return n->dtype; |
54 | } else if (type.as<PointerTypeNode>()) { |
55 | return DataType::Handle(); |
56 | } else if (IsVoidType(type)) { |
57 | return DataType::Void(); |
58 | } else { |
59 | LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType" ; |
60 | } |
61 | } |
62 | |
63 | Type GetType(const PrimExpr& expr) { |
64 | // TODO(tqchen): add recursive type inference for Call here |
65 | // once we introduced the corresponding fields to the IR. |
66 | if (auto* ptr = expr.as<tir::VarNode>()) { |
67 | // If Var has a more refined type annotation, |
68 | // return the type anotation |
69 | if (ptr->type_annotation.defined()) { |
70 | return ptr->type_annotation; |
71 | } |
72 | } |
73 | // Default: return the type indicated by the dtype. |
74 | runtime::DataType dtype = expr.dtype(); |
75 | return GetTypeFromRuntimeDataType(dtype); |
76 | } |
77 | |
78 | Type GetTypeFromRuntimeDataType(const DataType& dtype) { |
79 | if (dtype.is_void()) { |
80 | return VoidType(); |
81 | } |
82 | return PrimType(dtype); |
83 | } |
84 | |
85 | // LargeUIntImm |
86 | PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) { |
87 | return tir::Call( |
88 | t, tir::builtin::large_uint_imm(), |
89 | {make_const(DataType::UInt(32), low, span), make_const(DataType::UInt(32), high, span)}, |
90 | span); |
91 | } |
92 | |
93 | // Q-multiplication |
94 | PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span) { |
95 | return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(), |
96 | {x, y, q, s}, span); |
97 | } |
98 | |
99 | // The public function with a quick checking path. |
100 | void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) |
101 | CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator" ; |
102 | CHECK(rhs.defined()) << "ValueError: `rhs` is null in the binary operator" ; |
103 | if (lhs.dtype() == rhs.dtype()) return; |
104 | DataType ltype = lhs.dtype(); |
105 | DataType rtype = rhs.dtype(); |
106 | if (ltype.lanes() == 1 && rtype.lanes() != 1) { |
107 | lhs = tir::Broadcast(lhs, rtype.lanes()); |
108 | } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { |
109 | rhs = tir::Broadcast(rhs, ltype.lanes()); |
110 | } else { |
111 | ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; |
112 | } |
113 | if (lhs.dtype() == rhs.dtype()) return; |
114 | |
115 | ltype = lhs.dtype(); |
116 | rtype = rhs.dtype(); |
117 | // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by |
118 | // operators. This can be helpful for users to find potential type conversion problems. The |
119 | // following are exceptions: |
120 | if (ltype.is_float() && rtype.is_float()) { |
121 | // Given two dissimilar floats, cast the lower bit version to the higher bit version. |
122 | // E.g. fp16 + fp32 --> fp32 + fp32 |
123 | if (ltype.bits() < rtype.bits()) { |
124 | lhs = cast(rtype, lhs); |
125 | } else { |
126 | rhs = cast(ltype, rhs); |
127 | } |
128 | } else if (!ltype.is_float() && |
129 | (rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { |
130 | // Cast int->float when the other operand is a float |
131 | lhs = cast(rtype, lhs); |
132 | } else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && |
133 | !rtype.is_float()) { |
134 | // Cast int->float when the other operand is a float |
135 | rhs = cast(ltype, rhs); |
136 | } else if (!ltype.is_bfloat16() && |
137 | (rtype.is_bfloat16() || |
138 | datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { |
139 | // Cast int->bfloat16 when the other operand is a bfloat16 |
140 | lhs = cast(rtype, lhs); |
141 | } else if ((ltype.is_bfloat16() || |
142 | datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && |
143 | !rtype.is_bfloat16()) { |
144 | // Cast int->bfloat16 when the other operand is a bfloat16 |
145 | rhs = cast(ltype, rhs); |
146 | } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { |
147 | // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 |
148 | if (ltype.bits() < rtype.bits()) { |
149 | lhs = cast(rtype, lhs); |
150 | } else { |
151 | rhs = cast(ltype, rhs); |
152 | } |
153 | } else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) { |
154 | // Handle mixing signed and unsigned integers |
155 | if (ltype.bits() < rtype.bits()) { |
156 | lhs = cast(rtype, lhs); |
157 | } else if (ltype.bits() > rtype.bits()) { |
158 | rhs = cast(ltype, rhs); |
159 | } else { |
160 | // The width of signed and unsigned integers is same. |
161 | if (ltype.is_uint()) { |
162 | rhs = cast(ltype, rhs); |
163 | } else { |
164 | lhs = cast(rtype, lhs); |
165 | } |
166 | } |
167 | } else { |
168 | LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; |
169 | } |
170 | } |
171 | |
172 | PrimExpr ret(PrimExpr value, Span span) { |
173 | return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); |
174 | } |
175 | |
176 | // maximum and min limits |
177 | PrimExpr max_value(const DataType& dtype, Span span) { |
178 | using namespace tir; |
179 | ICHECK_EQ(dtype.lanes(), 1); |
180 | if (dtype.is_int()) { |
181 | if (dtype.bits() == 64) { |
182 | return IntImm(dtype, std::numeric_limits<int64_t>::max(), span); |
183 | } else if (dtype.bits() < 64) { |
184 | int64_t val = 1; |
185 | val = (val << (dtype.bits() - 1)) - 1; |
186 | return IntImm(dtype, val, span); |
187 | } |
188 | } else if (dtype.is_uint()) { |
189 | if (dtype.bits() == 64) { |
190 | return make_const(dtype, std::numeric_limits<uint64_t>::max(), span); |
191 | } else if (dtype.bits() < 64) { |
192 | uint64_t val = 1; |
193 | val = (val << static_cast<uint64_t>(dtype.bits())) - 1; |
194 | return IntImm(dtype, static_cast<int64_t>(val), span); |
195 | } |
196 | } else if (dtype.is_float()) { |
197 | if (dtype.bits() == 64) { |
198 | return FloatImm(dtype, std::numeric_limits<double>::max(), span); |
199 | } else if (dtype.bits() == 32) { |
200 | return FloatImm(dtype, std::numeric_limits<float>::max(), span); |
201 | } else if (dtype.bits() == 16) { |
202 | return FloatImm(dtype, 65504.0, span); |
203 | } |
204 | } else if (dtype.is_bfloat16()) { |
205 | return FloatImm(dtype, std::numeric_limits<float>::max(), span); |
206 | } |
207 | LOG(FATAL) << "Cannot decide max_value for type" << dtype; |
208 | } |
209 | |
210 | PrimExpr min_value(const DataType& dtype, Span span) { |
211 | using namespace tir; |
212 | ICHECK_EQ(dtype.lanes(), 1); |
213 | if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) { |
214 | // TODO(tkonolige): need to convert all registered min functions to use the span. |
215 | auto f = datatype::GetMinFunc(dtype.code()); |
216 | ICHECK(f) << "No minimum function registered for custom dtype " << (unsigned int)dtype.code(); |
217 | // TODO(@hypercubestart) Document this change (and others associated with the overflowing |
218 | // floatimm min bug) |
219 | return (*f)(dtype.bits()); |
220 | } else if (dtype.is_int()) { |
221 | if (dtype.bits() == 64) { |
222 | return IntImm(dtype, std::numeric_limits<int64_t>::lowest(), span); |
223 | } else if (dtype.bits() < 64) { |
224 | int64_t val = 1; |
225 | val = -(val << (dtype.bits() - 1)); |
226 | return IntImm(dtype, val, span); |
227 | } |
228 | } else if (dtype.is_uint()) { |
229 | return IntImm(dtype, 0, span); |
230 | } else if (dtype.is_float()) { |
231 | if (dtype.bits() == 64) { |
232 | return FloatImm(dtype, std::numeric_limits<double>::lowest(), span); |
233 | } else if (dtype.bits() == 32) { |
234 | return FloatImm(dtype, std::numeric_limits<float>::lowest(), span); |
235 | } else if (dtype.bits() == 16) { |
236 | return FloatImm(dtype, -65504.0, span); |
237 | } |
238 | } else if (dtype.is_bfloat16()) { |
239 | return FloatImm(dtype, std::numeric_limits<float>::lowest(), span); |
240 | } |
241 | LOG(FATAL) << "Cannot decide min_value for type" << dtype; |
242 | } |
243 | |
244 | // infinity |
245 | PrimExpr infinity(const DataType& dtype, Span span) { |
246 | using namespace tir; |
247 | ICHECK_EQ(dtype.lanes(), 1); |
248 | if (dtype.is_float()) { |
249 | if (dtype.bits() == 64) { |
250 | return FloatImm(dtype, std::numeric_limits<double>::infinity(), span); |
251 | } else if (dtype.bits() == 32 || dtype.bits() == 16) { |
252 | return FloatImm(dtype, std::numeric_limits<float>::infinity(), span); |
253 | } |
254 | } |
255 | LOG(FATAL) << "Cannot decide infinity for type " << dtype; |
256 | } |
257 | |
258 | namespace tir { |
259 | template <typename ValueType> |
260 | inline bool ConstPowerHelper(ValueType val, int* shift) { |
261 | if (val <= 0) return false; |
262 | shift[0] = 0; |
263 | while (val != 0) { |
264 | if (val & 1) { |
265 | return (val == 1); |
266 | } |
267 | ++shift[0]; |
268 | val = val >> 1; |
269 | } |
270 | return true; |
271 | } |
272 | |
273 | bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { |
274 | if (const auto* op = x.as<tir::IntImmNode>()) { |
275 | return ConstPowerHelper(op->value, shift); |
276 | } else { |
277 | return false; |
278 | } |
279 | } |
280 | } // namespace tir |
281 | |
282 | PrimExpr cast(const DataType& t, PrimExpr value, Span span) { |
283 | using tir::FloatImmNode; |
284 | if (value.dtype() == t) return value; |
285 | // const fold IntImm as they are used in index computations |
286 | if (t.lanes() == 1) { |
287 | if (const IntImmNode* op = value.as<IntImmNode>()) { |
288 | return make_const(t, op->value, op->span); |
289 | } else if (const FloatImmNode* op = value.as<FloatImmNode>()) { |
290 | return make_const(t, op->value, op->span); |
291 | } |
292 | ICHECK(!value.dtype().is_handle()) << "Can't cast a handle to other types." ; |
293 | return tir::Cast(t, value, span); |
294 | } else { |
295 | DataType vtype = t.element_of(); |
296 | if (value.dtype().lanes() == 1) { |
297 | // manually unroll cast |
298 | if (value.dtype() != vtype) { |
299 | if (const IntImmNode* op = value.as<IntImmNode>()) { |
300 | value = make_const(vtype, op->value, op->span); |
301 | } else if (const FloatImmNode* op = value.as<FloatImmNode>()) { |
302 | value = make_const(vtype, op->value, op->span); |
303 | } else { |
304 | value = tir::Cast(vtype, value, span); |
305 | } |
306 | } |
307 | return tir::Broadcast(value, t.lanes(), span); |
308 | } else { |
309 | ICHECK(value.dtype().lanes() == t.lanes()); |
310 | if (const auto* broadcast = value.as<tir::BroadcastNode>()) { |
311 | return tir::Broadcast(cast(vtype, broadcast->value, span), t.lanes(), span); |
312 | } else if (const auto* ramp = value.as<tir::RampNode>()) { |
313 | if (t.is_int() || t.is_uint()) { |
314 | // only cast to index data type can be folded to ramp |
315 | return tir::Ramp(cast(vtype, ramp->base, span), cast(vtype, ramp->stride, span), |
316 | ramp->lanes, span); |
317 | } |
318 | } |
319 | return tir::Cast(t, value, span); |
320 | } |
321 | } |
322 | } |
323 | |
324 | // reinterpret |
325 | PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { |
326 | if (value.dtype() == t) return value; |
327 | return tir::Call(t, tir::builtin::reinterpret(), {value}, span); |
328 | } |
329 | |
330 | // operator+ |
331 | PrimExpr operator+(PrimExpr a, PrimExpr b) { return add(a, b); } |
332 | |
333 | PrimExpr add(PrimExpr a, PrimExpr b, Span span) { |
334 | BinaryOpMatchTypes(a, b, span); |
335 | if (auto ret = arith::TryConstFold<tir::Add>(a, b)) return ret.value(); |
336 | return tir::Add(a, b, span); |
337 | } |
338 | |
339 | // negation |
340 | PrimExpr operator-(PrimExpr a) { return neg(a); } |
341 | |
342 | PrimExpr neg(PrimExpr a, Span span) { |
343 | using tir::FloatImmNode; |
344 | using tir::IntImmNode; |
345 | const IntImmNode* pa = a.as<IntImmNode>(); |
346 | const FloatImmNode* fa = a.as<FloatImmNode>(); |
347 | if (pa) return IntImm(a.dtype(), -pa->value, span); |
348 | if (fa) return FloatImm(a.dtype(), -fa->value, span); |
349 | return make_zero(a.dtype(), span) - a; |
350 | } |
351 | |
352 | PrimExpr operator-(PrimExpr a, PrimExpr b) { return sub(a, b); } |
353 | |
354 | PrimExpr sub(PrimExpr a, PrimExpr b, Span span) { |
355 | BinaryOpMatchTypes(a, b, span); |
356 | if (auto ret = arith::TryConstFold<tir::Sub>(a, b)) return ret.value(); |
357 | return tir::Sub(a, b, span); |
358 | } |
359 | |
360 | PrimExpr operator*(PrimExpr a, PrimExpr b) { return mul(a, b); } |
361 | PrimExpr mul(PrimExpr a, PrimExpr b, Span span) { |
362 | BinaryOpMatchTypes(a, b, span); |
363 | if (auto ret = arith::TryConstFold<tir::Mul>(a, b)) return ret.value(); |
364 | return tir::Mul(a, b, span); |
365 | } |
366 | |
367 | PrimExpr div(PrimExpr a, PrimExpr b, Span span) { |
368 | BinaryOpMatchTypes(a, b, span); |
369 | if (auto ret = arith::TryConstFold<tir::Div>(a, b)) return ret.value(); |
370 | return tir::Div(a, b, span); |
371 | } |
372 | |
373 | PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span) { |
374 | ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; |
375 | ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; |
376 | return div(a, b, span); |
377 | } |
378 | |
379 | PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span) { |
380 | BinaryOpMatchTypes(a, b, span); |
381 | if (auto ret = arith::TryConstFold<tir::Mod>(a, b)) return ret.value(); |
382 | return tir::Mod(a, b, span); |
383 | } |
384 | |
385 | PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); } |
386 | |
387 | PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); } |
388 | |
389 | // TODO(tqchen): switch to floordiv |
390 | PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span) { return floordiv(a, b, span); } |
391 | |
392 | PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span) { return ceildiv(a, b, span); } |
393 | |
394 | PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span) { return floormod(a, b, span); } |
395 | |
396 | PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { |
397 | ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; |
398 | ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; |
399 | BinaryOpMatchTypes(a, b, span); |
400 | if (auto ret = arith::TryConstFold<tir::FloorDiv>(a, b)) return ret.value(); |
401 | return tir::FloorDiv(a, b, span); |
402 | } |
403 | |
404 | PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) { |
405 | ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; |
406 | ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; |
407 | BinaryOpMatchTypes(a, b, span); |
408 | if (auto ret = arith::TryConstFold<tir::FloorDiv>(a + b - 1, b)) return ret.value(); |
409 | return tir::FloorDiv(a + b - 1, b, span); |
410 | } |
411 | |
412 | PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) { |
413 | ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; |
414 | ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; |
415 | BinaryOpMatchTypes(a, b, span); |
416 | if (auto ret = arith::TryConstFold<tir::FloorMod>(a, b)) return ret.value(); |
417 | return tir::FloorMod(a, b, span); |
418 | } |
419 | |
420 | PrimExpr min(PrimExpr a, PrimExpr b, Span span) { |
421 | // inf-aware simplificaiton |
422 | using arith::is_neg_inf; |
423 | using arith::is_pos_inf; |
424 | if (is_pos_inf(a)) return b; |
425 | if (is_neg_inf(a)) return a; |
426 | if (is_pos_inf(b)) return a; |
427 | if (is_neg_inf(b)) return b; |
428 | BinaryOpMatchTypes(a, b, span); |
429 | if (auto ret = arith::TryConstFold<tir::Min>(a, b)) return ret.value(); |
430 | return tir::Min(a, b, span); |
431 | } |
432 | |
433 | PrimExpr max(PrimExpr a, PrimExpr b, Span span) { |
434 | // inf-aware simplificaiton |
435 | using arith::is_neg_inf; |
436 | using arith::is_pos_inf; |
437 | if (is_pos_inf(a)) return a; |
438 | if (is_neg_inf(a)) return b; |
439 | if (is_pos_inf(b)) return b; |
440 | if (is_neg_inf(b)) return a; |
441 | BinaryOpMatchTypes(a, b, span); |
442 | if (auto ret = arith::TryConstFold<tir::Max>(a, b)) return ret.value(); |
443 | return tir::Max(a, b, span); |
444 | } |
445 | |
446 | // if_then_else |
447 | PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { |
448 | ICHECK(cond.dtype() == DataType::Bool(1)) |
449 | << "if_then_else only accept the condition to be boolean type." ; |
450 | BinaryOpMatchTypes(true_value, false_value, span); |
451 | if (const IntImmNode* op = cond.as<IntImmNode>()) { |
452 | if (op->value != 0) { |
453 | return true_value; |
454 | } else { |
455 | return false_value; |
456 | } |
457 | } |
458 | |
459 | return tir::Call(true_value.dtype(), tir::builtin::if_then_else(), |
460 | {cond, true_value, false_value}, span); |
461 | } |
462 | |
463 | // likely |
464 | PrimExpr likely(PrimExpr cond, Span span) { |
465 | if (is_const_int(cond)) return cond; |
466 | return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, span); |
467 | } |
468 | |
469 | // operator> |
470 | PrimExpr operator>(PrimExpr a, PrimExpr b) { return greater(a, b); } |
471 | PrimExpr greater(PrimExpr a, PrimExpr b, Span span) { |
472 | BinaryOpMatchTypes(a, b, span); |
473 | if (auto ret = arith::TryConstFold<tir::GT>(a, b)) return ret.value(); |
474 | return tir::GT(a, b, span); |
475 | } |
476 | |
477 | PrimExpr operator>=(PrimExpr a, PrimExpr b) { return greater_equal(a, b); } |
478 | PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span) { |
479 | BinaryOpMatchTypes(a, b, span); |
480 | if (auto ret = arith::TryConstFold<tir::GE>(a, b)) return ret.value(); |
481 | return tir::GE(a, b, span); |
482 | } |
483 | |
484 | PrimExpr operator<(PrimExpr a, PrimExpr b) { return less(a, b); } |
485 | PrimExpr less(PrimExpr a, PrimExpr b, Span span) { |
486 | BinaryOpMatchTypes(a, b, span); |
487 | if (auto ret = arith::TryConstFold<tir::LT>(a, b)) return ret.value(); |
488 | return tir::LT(a, b, span); |
489 | } |
490 | |
491 | PrimExpr operator<=(PrimExpr a, PrimExpr b) { return less_equal(a, b); } |
492 | PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span) { |
493 | BinaryOpMatchTypes(a, b, span); |
494 | if (auto ret = arith::TryConstFold<tir::LE>(a, b)) return ret.value(); |
495 | return tir::LE(a, b, span); |
496 | } |
497 | |
498 | PrimExpr operator==(PrimExpr a, PrimExpr b) { return equal(a, b); } |
499 | PrimExpr equal(PrimExpr a, PrimExpr b, Span span) { |
500 | BinaryOpMatchTypes(a, b, span); |
501 | if (auto ret = arith::TryConstFold<tir::EQ>(a, b)) return ret.value(); |
502 | return tir::EQ(a, b, span); |
503 | } |
504 | |
505 | PrimExpr operator!=(PrimExpr a, PrimExpr b) { return not_equal(a, b); } |
506 | PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) { |
507 | BinaryOpMatchTypes(a, b, span); |
508 | if (auto ret = arith::TryConstFold<tir::NE>(a, b)) return ret.value(); |
509 | return tir::NE(a, b, span); |
510 | } |
511 | |
512 | namespace { |
513 | void type_check_boolean_args(const PrimExpr& arg, const char* op) { |
514 | ICHECK(arg.dtype().is_bool()) << "Expected boolean argument for " << op << ", but received " |
515 | << arg << " of type " << arg.dtype(); |
516 | } |
517 | void type_check_boolean_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { |
518 | ICHECK(lhs.dtype().is_bool()) << "Expected boolean argument as LHS of " << op << ", but received " |
519 | << lhs << " of type " << lhs.dtype(); |
520 | ICHECK(rhs.dtype().is_bool()) << "Expected boolean argument as RHS of " << op << ", but received " |
521 | << rhs << " of type " << rhs.dtype(); |
522 | } |
523 | |
524 | void type_check_integer_args(const PrimExpr& arg, const char* op) { |
525 | ICHECK(arg.dtype().is_int() || arg.dtype().is_uint()) |
526 | << "Expected integer argument for " << op << ", but received " << arg << " of type " |
527 | << arg.dtype(); |
528 | } |
529 | |
530 | void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { |
531 | ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint()) |
532 | << "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type " |
533 | << lhs.dtype(); |
534 | ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint()) |
535 | << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type " |
536 | << rhs.dtype(); |
537 | } |
538 | } // namespace |
539 | |
540 | PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); } |
541 | PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) { |
542 | type_check_boolean_args(a, b, "&& operator (logical AND)" ); |
543 | if (auto ret = arith::TryConstFold<tir::And>(a, b)) return ret.value(); |
544 | return tir::And(a, b, span); |
545 | } |
546 | |
547 | PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); } |
548 | PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) { |
549 | type_check_boolean_args(a, b, "|| operator (logical OR)" ); |
550 | if (auto ret = arith::TryConstFold<tir::Or>(a, b)) return ret.value(); |
551 | return tir::Or(a, b, span); |
552 | } |
553 | |
554 | PrimExpr operator!(PrimExpr a) { return logical_not(a); } |
555 | PrimExpr logical_not(PrimExpr a, Span span) { |
556 | type_check_boolean_args(a, "! operator (logical NOT)" ); |
557 | if (auto ret = arith::TryConstFold<tir::Not>(a)) return ret.value(); |
558 | return tir::Not(a, span); |
559 | } |
560 | |
561 | // shift right |
562 | PrimExpr operator>>(PrimExpr a, PrimExpr b) { return right_shift(a, b); } |
563 | |
564 | PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) { |
565 | type_check_integer_args(a, b, ">> operator (right shift)" ); |
566 | |
567 | BinaryOpMatchTypes(a, b, span); |
568 | TVM_INDEX_CONST_PROPAGATION({ |
569 | const DataType& rtype = a.dtype(); |
570 | if (pb) |
571 | ICHECK(pb->value >= 0 && pb->value < rtype.bits()) |
572 | << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " |
573 | << rtype; |
574 | if (pa && pb) { |
575 | return IntImm(rtype, (pa->value >> pb->value), span); |
576 | } |
577 | if (pb) { |
578 | if (pb->value == 0) return a; |
579 | } |
580 | }); |
581 | |
582 | return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, span); |
583 | } |
584 | |
585 | // shift left |
586 | PrimExpr operator<<(PrimExpr a, PrimExpr b) { return left_shift(a, b); } |
587 | PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) { |
588 | type_check_integer_args(a, b, "<< operator (left shift)" ); |
589 | BinaryOpMatchTypes(a, b, span); |
590 | TVM_INDEX_CONST_PROPAGATION({ |
591 | const DataType& rtype = a.dtype(); |
592 | if (pb) |
593 | ICHECK(pb->value >= 0 && pb->value < rtype.bits()) |
594 | << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " |
595 | << rtype; |
596 | if (pa && pb) return IntImm(rtype, (pa->value << pb->value), span); |
597 | if (pb) { |
598 | if (pb->value == 0) return a; |
599 | } |
600 | }); |
601 | return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, span); |
602 | } |
603 | |
604 | // bitwise and |
605 | PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); } |
606 | PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) { |
607 | type_check_integer_args(a, b, "& operator (bitwise AND)" ); |
608 | BinaryOpMatchTypes(a, b, span); |
609 | TVM_INDEX_CONST_PROPAGATION({ |
610 | const DataType& rtype = a.dtype(); |
611 | if (pa && pb) return IntImm(rtype, (pa->value & pb->value), span); |
612 | }); |
613 | return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, span); |
614 | } |
615 | |
616 | // bitwise_or |
617 | PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); } |
618 | PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) { |
619 | type_check_integer_args(a, b, "| operator (bitwise OR)" ); |
620 | BinaryOpMatchTypes(a, b, span); |
621 | TVM_INDEX_CONST_PROPAGATION({ |
622 | const DataType& rtype = a.dtype(); |
623 | if (pa && pb) return IntImm(rtype, (pa->value | pb->value), span); |
624 | }); |
625 | return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, span); |
626 | } |
627 | |
628 | // bitwise_xor |
629 | PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); } |
630 | PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) { |
631 | type_check_integer_args(a, b, "^ operator (bitwise XOR)" ); |
632 | BinaryOpMatchTypes(a, b, span); |
633 | TVM_INDEX_CONST_PROPAGATION({ |
634 | const DataType& rtype = a.dtype(); |
635 | if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value), span); |
636 | }); |
637 | return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, span); |
638 | } |
639 | |
640 | // bitwise_not |
641 | PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); } |
642 | |
643 | PrimExpr bitwise_neg(PrimExpr a, Span span) { |
644 | type_check_integer_args(a, "~ operator (bitwise NOT)" ); |
645 | return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span); |
646 | } |
647 | |
648 | TVM_REGISTER_GLOBAL("tir.bitwise_not" ).set_body_typed([](PrimExpr a, Span span) { |
649 | return bitwise_neg(a, span); |
650 | }); |
651 | |
652 | // pow |
653 | PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { |
654 | BinaryOpMatchTypes(x, y, span); |
655 | ICHECK(x.dtype().is_float()) << "power only applies to float" ; |
656 | static auto op = Op::Get("tir.pow" ); |
657 | return tir::Call(x.dtype(), op, {x, y}, span); |
658 | } |
659 | |
660 | TVM_TIR_REGISTER_PURE_BINARY_OP("pow" ).set_attr<TVectorizable>("TVectorizable" , true); |
661 | |
662 | // abs |
663 | PrimExpr abs(PrimExpr x, Span span) { |
664 | if (x.dtype().is_int()) { |
665 | using tir::IntImmNode; |
666 | const IntImmNode* px = x.as<IntImmNode>(); |
667 | if (px) { |
668 | return IntImm(x.dtype(), std::abs(px->value), px->span); |
669 | } |
670 | return tir::Select(x >= make_zero(x.dtype()), x, -x, span); |
671 | } else if (x.dtype().is_float()) { |
672 | using tir::FloatImmNode; |
673 | const FloatImmNode* fx = x.as<FloatImmNode>(); |
674 | if (fx) { |
675 | return FloatImm(x.dtype(), std::fabs(fx->value), fx->span); |
676 | } |
677 | static auto op = Op::Get("tir.fabs" ); |
678 | return tir::Call(x.dtype(), op, {x}, span); |
679 | } else if (x.dtype().is_uint()) { |
680 | return x; |
681 | } else { |
682 | LOG(FATAL) << "Data type " << x.dtype() |
683 | << " not supported for absolute op. Skipping absolute op..." ; |
684 | return x; |
685 | } |
686 | } |
687 | |
688 | TVM_TIR_REGISTER_PURE_UNARY_OP("fabs" ).set_attr<TVectorizable>("TVectorizable" , true); |
689 | |
690 | // isnan |
691 | PrimExpr isnan(PrimExpr x, Span span) { |
692 | DataType t = DataType::Bool(x.dtype().lanes()); |
693 | if (x.dtype().is_int() || x.dtype().is_uint()) { |
694 | return make_const(t, false); |
695 | } else if (x.dtype().is_float()) { |
696 | using tir::FloatImmNode; |
697 | const FloatImmNode* fx = x.as<FloatImmNode>(); |
698 | if (fx) { |
699 | return make_const(t, std::isnan(fx->value), fx->span); |
700 | } |
701 | static auto op = Op::Get("tir.isnan" ); |
702 | if (x.dtype().bits() == 16) { |
703 | return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, span); |
704 | } else { |
705 | return tir::Call(t, op, {x}, span); |
706 | } |
707 | } else { |
708 | LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..." ; |
709 | } |
710 | } |
711 | |
712 | // isinf |
713 | PrimExpr isinf(PrimExpr x, Span span) { |
714 | DataType t = DataType::Bool(x.dtype().lanes()); |
715 | if (x.dtype().is_int() || x.dtype().is_uint()) { |
716 | return make_const(t, false, span); |
717 | } else if (x.dtype().is_float()) { |
718 | PrimExpr infX = infinity(x.dtype(), span); |
719 | return abs(x, span) == infX && !isnan(x, span); |
720 | } else { |
721 | LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness ops. Skipping it..." ; |
722 | } |
723 | } |
724 | |
725 | // isfinite |
726 | PrimExpr isfinite(PrimExpr x, Span span) { return !isinf(x, span) && !isnan(x, span); } |
727 | |
728 | PrimExpr sum(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) { |
729 | Var x("x" , source.dtype(), span), y("y" , source.dtype(), span); |
730 | PrimExpr result = tir::Add(x, y, span); |
731 | PrimExpr identity_element = make_zero(source.dtype(), span); |
732 | tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); |
733 | return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); |
734 | } |
735 | |
736 | PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) { |
737 | type_check_boolean_args(source, "tvm::all" ); |
738 | Var x("x" , source.dtype(), span), y("y" , source.dtype()); |
739 | PrimExpr result = tir::And(x, y, span); |
740 | PrimExpr identity_element = make_const(source.dtype(), true, span); |
741 | tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); |
742 | return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); |
743 | } |
744 | |
745 | PrimExpr any(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) { |
746 | type_check_boolean_args(source, "tvm::any" ); |
747 | Var x("x" , source.dtype(), span), y("y" , source.dtype(), span); |
748 | PrimExpr result = tir::Or(x, y, span); |
749 | PrimExpr identity_element = make_const(source.dtype(), false, span); |
750 | tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); |
751 | return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); |
752 | } |
753 | |
754 | PrimExpr max(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) { |
755 | Var x("x" , source.dtype(), span), y("y" , source.dtype(), span); |
756 | PrimExpr result = tir::Max(x, y, span); |
757 | PrimExpr identity_element = min_value(source.dtype(), span); |
758 | tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); |
759 | return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); |
760 | } |
761 | |
762 | PrimExpr min(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) { |
763 | Var x("x" , source.dtype(), span), y("y" , source.dtype(), span); |
764 | PrimExpr result = tir::Min(x, y, span); |
765 | PrimExpr identity_element = max_value(source.dtype(), span); |
766 | tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); |
767 | return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); |
768 | } |
769 | |
770 | PrimExpr prod(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) { |
771 | Var x("x" , source.dtype(), span), y("y" , source.dtype(), span); |
772 | PrimExpr result = tir::Mul(x, y, span); |
773 | PrimExpr identity_element = make_const(source.dtype(), 1, span); |
774 | tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); |
775 | return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); |
776 | } |
777 | |
778 | // fmod |
779 | PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { |
780 | BinaryOpMatchTypes(x, y, span); |
781 | ICHECK(x.dtype().is_float()) << "fmod only applies to float" ; |
782 | static auto op = Op::Get("tir.fmod" ); |
783 | return tir::Call(x.dtype(), op, {x, y}, span); |
784 | } |
785 | |
786 | TVM_TIR_REGISTER_PURE_UNARY_OP("fmod" ); |
787 | |
788 | // floor |
789 | PrimExpr floor(PrimExpr x, Span span) { |
790 | if (x.dtype().is_int() || x.dtype().is_uint()) { |
791 | return x; |
792 | } |
793 | using tir::FloatImmNode; |
794 | const FloatImmNode* fx = x.as<FloatImmNode>(); |
795 | if (fx) return FloatImm(x.dtype(), std::floor(fx->value), fx->span); |
796 | static auto op = Op::Get("tir.floor" ); |
797 | return tir::Call(x.dtype(), op, {x}, span); |
798 | } |
799 | |
800 | TVM_TIR_REGISTER_PURE_UNARY_OP("floor" ).set_attr<TVectorizable>("TVectorizable" , true); |
801 | |
802 | // ceil |
803 | PrimExpr ceil(PrimExpr x, Span span) { |
804 | if (x.dtype().is_int() || x.dtype().is_uint()) { |
805 | return x; |
806 | } |
807 | using tir::FloatImmNode; |
808 | const FloatImmNode* fx = x.as<FloatImmNode>(); |
809 | if (fx) return FloatImm(x.dtype(), std::ceil(fx->value), fx->span); |
810 | static auto op = Op::Get("tir.ceil" ); |
811 | return tir::Call(x.dtype(), op, {x}, span); |
812 | } |
813 | |
814 | TVM_TIR_REGISTER_PURE_UNARY_OP("ceil" ).set_attr<TVectorizable>("TVectorizable" , true); |
815 | |
816 | // round |
817 | PrimExpr round(PrimExpr x, Span span) { |
818 | if (x.dtype().is_int() || x.dtype().is_uint()) { |
819 | return x; |
820 | } |
821 | using tir::FloatImmNode; |
822 | const FloatImmNode* fx = x.as<FloatImmNode>(); |
823 | if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); |
824 | static auto op = Op::Get("tir.round" ); |
825 | return tir::Call(x.dtype(), op, {x}, span); |
826 | } |
827 | |
828 | TVM_TIR_REGISTER_PURE_UNARY_OP("round" ).set_attr<TVectorizable>("TVectorizable" , true); |
829 | |
830 | // nearbyint |
831 | PrimExpr nearbyint(PrimExpr x, Span span) { |
832 | if (x.dtype().is_int() || x.dtype().is_uint()) { |
833 | return x; |
834 | } |
835 | using tir::FloatImmNode; |
836 | const FloatImmNode* fx = x.as<FloatImmNode>(); |
837 | if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); |
838 | static auto op = Op::Get("tir.nearbyint" ); |
839 | return tir::Call(x.dtype(), op, {x}, span); |
840 | } |
841 | |
842 | TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint" ); |
843 | |
844 | // trunc |
845 | PrimExpr trunc(PrimExpr x, Span span) { |
846 | if (x.dtype().is_int() || x.dtype().is_uint()) { |
847 | return x; |
848 | } |
849 | using tir::FloatImmNode; |
850 | const FloatImmNode* fx = x.as<FloatImmNode>(); |
851 | if (fx) { |
852 | return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)), |
853 | fx->span); |
854 | } |
855 | static auto op = Op::Get("tir.trunc" ); |
856 | return tir::Call(x.dtype(), op, {x}, span); |
857 | } |
858 | |
859 | TVM_TIR_REGISTER_PURE_UNARY_OP("trunc" ).set_attr<TVectorizable>("TVectorizable" , true); |
860 | |
861 | // unary op registration. |
862 | TVM_TIR_REGISTER_PURE_UNARY_OP("exp" ).set_attr<TVectorizable>("TVectorizable" , true); |
863 | |
864 | TVM_TIR_REGISTER_PURE_UNARY_OP("exp2" ).set_attr<TVectorizable>("TVectorizable" , true); |
865 | |
866 | TVM_TIR_REGISTER_PURE_UNARY_OP("exp10" ).set_attr<TVectorizable>("TVectorizable" , true); |
867 | |
868 | TVM_TIR_REGISTER_PURE_UNARY_OP("erf" ); |
869 | |
870 | TVM_TIR_REGISTER_PURE_UNARY_OP("tanh" ).set_attr<TVectorizable>("TVectorizable" , true); |
871 | |
872 | TVM_TIR_REGISTER_PURE_UNARY_OP("sigmoid" ).set_attr<TVectorizable>("TVectorizable" , true); |
873 | |
874 | TVM_TIR_REGISTER_PURE_UNARY_OP("sqrt" ).set_attr<TVectorizable>("TVectorizable" , true); |
875 | |
876 | TVM_TIR_REGISTER_PURE_UNARY_OP("rsqrt" ); |
877 | |
878 | TVM_TIR_REGISTER_PURE_UNARY_OP("log" ).set_attr<TVectorizable>("TVectorizable" , true); |
879 | |
880 | TVM_TIR_REGISTER_PURE_UNARY_OP("log2" ).set_attr<TVectorizable>("TVectorizable" , true); |
881 | |
882 | TVM_TIR_REGISTER_PURE_UNARY_OP("log1p" ); |
883 | |
884 | TVM_TIR_REGISTER_PURE_UNARY_OP("log10" ).set_attr<TVectorizable>("TVectorizable" , true); |
885 | |
886 | TVM_TIR_REGISTER_PURE_UNARY_OP("tan" ).set_attr<TVectorizable>("TVectorizable" , true); |
887 | |
888 | TVM_TIR_REGISTER_PURE_UNARY_OP("cos" ).set_attr<TVectorizable>("TVectorizable" , true); |
889 | |
890 | TVM_TIR_REGISTER_PURE_UNARY_OP("cosh" ).set_attr<TVectorizable>("TVectorizable" , true); |
891 | |
892 | TVM_TIR_REGISTER_PURE_UNARY_OP("sin" ).set_attr<TVectorizable>("TVectorizable" , true); |
893 | |
894 | TVM_TIR_REGISTER_PURE_UNARY_OP("sinh" ).set_attr<TVectorizable>("TVectorizable" , true); |
895 | |
896 | TVM_TIR_REGISTER_PURE_UNARY_OP("asin" ); |
897 | |
898 | TVM_TIR_REGISTER_PURE_UNARY_OP("acos" ); |
899 | |
900 | TVM_TIR_REGISTER_PURE_UNARY_OP("atan" ); |
901 | |
902 | TVM_TIR_REGISTER_PURE_UNARY_OP("acosh" ); |
903 | |
904 | TVM_TIR_REGISTER_PURE_UNARY_OP("asinh" ); |
905 | |
906 | TVM_TIR_REGISTER_PURE_UNARY_OP("atanh" ); |
907 | |
908 | TVM_TIR_REGISTER_PURE_UNARY_OP("clz" ); |
909 | |
910 | // binary intrinsics |
911 | TVM_TIR_REGISTER_PURE_BINARY_OP("atan2" ); |
912 | |
913 | TVM_TIR_REGISTER_PURE_BINARY_OP("nextafter" ); |
914 | |
915 | TVM_TIR_REGISTER_PURE_BINARY_OP("hypot" ); |
916 | |
917 | TVM_TIR_REGISTER_PURE_BINARY_OP("copysign" ); |
918 | |
919 | TVM_TIR_REGISTER_PURE_BINARY_OP("ldexp" ); |
920 | |
921 | TVM_TIR_REGISTER_OP("TVMBackendAllocWorkspace" ) |
922 | .set_num_inputs(5) |
923 | .set_attr<TGlobalSymbol>("TGlobalSymbol" , "TVMBackendAllocWorkspace" ) |
924 | .set_attr<TCallEffectKind>("TCallEffectKind" , Integer(CallEffectKind::kOpaque)); |
925 | |
926 | TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace" ) |
927 | .set_num_inputs(3) |
928 | .set_attr<TGlobalSymbol>("TGlobalSymbol" , "TVMBackendFreeWorkspace" ) |
929 | .set_attr<TCallEffectKind>("TCallEffectKind" , Integer(CallEffectKind::kOpaque)); |
930 | |
931 | // expose basic functions to node namespace |
932 | TVM_REGISTER_GLOBAL("node._const" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
933 | if (args[0].type_code() == kDLInt) { |
934 | *ret = tir::make_const(args[1], args[0].operator int64_t(), args[2]); |
935 | } else if (args[0].type_code() == kDLFloat) { |
936 | *ret = tir::make_const(args[1], args[0].operator double(), args[2]); |
937 | } else { |
938 | LOG(FATAL) << "only accept int or float" ; // FIXME |
939 | } |
940 | }); |
941 | |
942 | TVM_REGISTER_GLOBAL("node.LargeUIntImm" ).set_body_typed(LargeUIntImm); |
943 | |
944 | TVM_REGISTER_GLOBAL("tir.min_value" ).set_body_typed(min_value); |
945 | |
946 | TVM_REGISTER_GLOBAL("tir.max_value" ).set_body_typed(max_value); |
947 | |
948 | TVM_REGISTER_GLOBAL("tir.infinity" ).set_body_typed(infinity); |
949 | |
950 | TVM_REGISTER_GLOBAL("tir.abs" ).set_body_typed(tvm::abs); |
951 | |
952 | TVM_REGISTER_GLOBAL("tir.likely" ).set_body_typed(tvm::likely); |
953 | |
954 | TVM_REGISTER_GLOBAL("tir.isnan" ).set_body_typed(tvm::isnan); |
955 | |
956 | TVM_REGISTER_GLOBAL("tir.isfinite" ).set_body_typed(tvm::isfinite); |
957 | |
958 | TVM_REGISTER_GLOBAL("tir.isinf" ).set_body_typed(tvm::isinf); |
959 | |
960 | TVM_REGISTER_GLOBAL("tir.floor" ).set_body_typed(tvm::floor); |
961 | |
962 | TVM_REGISTER_GLOBAL("tir.ceil" ).set_body_typed(tvm::ceil); |
963 | |
964 | TVM_REGISTER_GLOBAL("tir.round" ).set_body_typed(tvm::round); |
965 | |
966 | TVM_REGISTER_GLOBAL("tir.nearbyint" ).set_body_typed(tvm::nearbyint); |
967 | |
968 | TVM_REGISTER_GLOBAL("tir.trunc" ).set_body_typed(tvm::trunc); |
969 | |
970 | TVM_REGISTER_GLOBAL("tir._cast" ).set_body_typed(tvm::cast); |
971 | |
972 | // operator overloading, smarter than make |
973 | #define REGISTER_MAKE_BINARY_OP(Node, Func) \ |
974 | TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ |
975 | return (Func(a, b, span)); \ |
976 | }) |
977 | |
978 | #define REGISTER_MAKE_BIT_OP(Node, Func) \ |
979 | TVM_REGISTER_GLOBAL("tir." #Node).set_body([](TVMArgs args, TVMRetValue* ret) { \ |
980 | bool lhs_is_int = args[0].type_code() == kDLInt; \ |
981 | bool rhs_is_int = args[1].type_code() == kDLInt; \ |
982 | if (lhs_is_int) { \ |
983 | *ret = (Func(args[0].operator int(), args[1].operator PrimExpr(), args[2])); \ |
984 | } else if (rhs_is_int) { \ |
985 | *ret = (Func(args[0].operator PrimExpr(), args[1].operator int(), args[2])); \ |
986 | } else { \ |
987 | *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr(), args[2])); \ |
988 | } \ |
989 | }) |
990 | |
991 | REGISTER_MAKE_BINARY_OP(_OpAdd, add); |
992 | REGISTER_MAKE_BINARY_OP(_OpSub, sub); |
993 | REGISTER_MAKE_BINARY_OP(_OpMul, mul); |
994 | REGISTER_MAKE_BINARY_OP(_OpDiv, div); |
995 | REGISTER_MAKE_BINARY_OP(_OpMod, truncmod); |
996 | REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv); |
997 | REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod); |
998 | REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); |
999 | REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); |
1000 | REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); |
1001 | REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod); |
1002 | REGISTER_MAKE_BINARY_OP(_OpCeilDiv, ceildiv); |
1003 | REGISTER_MAKE_BINARY_OP(_OpPow, pow); |
1004 | REGISTER_MAKE_BINARY_OP(_OpMin, min); |
1005 | REGISTER_MAKE_BINARY_OP(_OpMax, max); |
1006 | REGISTER_MAKE_BINARY_OP(_OpEQ, equal); |
1007 | REGISTER_MAKE_BINARY_OP(_OpNE, not_equal); |
1008 | REGISTER_MAKE_BINARY_OP(_OpLT, less); // NOLINT(*) |
1009 | REGISTER_MAKE_BINARY_OP(_OpLE, less_equal); // NOLINT(*) |
1010 | REGISTER_MAKE_BINARY_OP(_OpGT, greater); // NOLINT(*) |
1011 | REGISTER_MAKE_BINARY_OP(_OpGE, greater_equal); |
1012 | REGISTER_MAKE_BINARY_OP(_OpAnd, logical_and); |
1013 | REGISTER_MAKE_BINARY_OP(_OpOr, logical_or); |
1014 | REGISTER_MAKE_BIT_OP(bitwise_and, bitwise_and); |
1015 | REGISTER_MAKE_BIT_OP(bitwise_or, bitwise_or); |
1016 | REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor); |
1017 | REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*) |
1018 | REGISTER_MAKE_BIT_OP(right_shift, right_shift); |
1019 | |
1020 | TVM_REGISTER_GLOBAL("tir._OpIfThenElse" ) |
1021 | .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { |
1022 | return if_then_else(cond, true_value, false_value, span); |
1023 | }); |
1024 | |
1025 | TVM_REGISTER_GLOBAL("tir.const_true" ).set_body_typed([](DataType t, Span span) { |
1026 | return const_true(t.lanes(), span); |
1027 | }); |
1028 | |
1029 | } // namespace tvm |
1030 | |