1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/op.h
22 * \brief Common operators defined for Expr.
23 *
24 * \note Most of the operator defined here perform simple constant folding
25 * when the type is int32 or int64 for simplifying the index expressions.
26 */
27// Acknowledgement: Most operator APIs originate from Halide.
28#ifndef TVM_TIR_OP_H_
29#define TVM_TIR_OP_H_
30
31#include <tvm/ir/expr.h>
32#include <tvm/ir/op.h>
33#include <tvm/ir/type.h>
34#include <tvm/tir/expr.h>
35#include <tvm/tir/stmt.h>
36
37#include <algorithm>
38#include <limits>
39#include <type_traits>
40
41namespace tvm {
42
43#define TVM_TIR_REGISTER_OP(OpName) \
44 TVM_REGISTER_OP("tir." OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", OpName)
45
46// Most common operators can be overloaded by argument type(PrimExpr).
47// So we put them under the root namespace.
48//
49// We put more developer oriented APIs -- make_const and is_const under tir
50// as they are more specific to the tir namespace.
51
52/*!
53 * \brief Get the type of the expression under the unified type system.
54 *
55 * This function could return a more refined type than
56 * the runtime type provided by expr->dtype
57 *
58 * \param expr The input parameter.
59 * \return The result type.
60 *
61 * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
62 */
63TVM_DLL Type GetType(const PrimExpr& expr);
64
65/*!
66 * \brief Get the type corresponding to DataType
67 * \param dtype The data type
68 * \return The result type
69 *
70 * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
71 */
72TVM_DLL Type GetTypeFromRuntimeDataType(const DataType& dtype);
73
74/*!
75 * \brief Get the implied DataType for storing values with type during runtime.
76 *
77 * \param type The input type.
78 * \return The result runtime::DataType.
79 *
80 * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
81 */
82TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
83
84/*!
85 * \brief Return the value.
86 *
87 * \param value The returned value.
88 * \param span The location of this operation in the source.
89 * \return The return expression.
90 */
91TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
92
93/*!
94 * Query the maximum possible value of dtype.
95 * \param dtype The data type.
96 * \param span The location of this operation in the source.
97 * \return the maximum possible value in this format.
98 */
99TVM_DLL PrimExpr max_value(const DataType& dtype, Span span = Span());
100
101/*!
102 * Query the minimum possible value of dtype.
103 * \param dtype The data type.
104 * \param span The location of this operation in the source.
105 * \return the minimum possible value in this format.
106 */
107TVM_DLL PrimExpr min_value(const DataType& dtype, Span span = Span());
108
109/*!
110 * Get the value of infinity.
111 * \param dtype The data type.
112 * \param span The location of this operation in the source.
113 * \return the infinity value in this format.
114 */
115TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span());
116
117/*!
118 * \brief cast value to type.
119 *
120 * \param t the target type.
121 * \param value The value
122 * \param span The location of this operation in the source.
123 * \return The result expression.
124 * \note This function may return value if the type is the same.
125 */
126TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span());
127/*!
128 * \brief perform reinterpret cast value to type.
129 *
130 * \param t the target type.
131 * \param value The value
132 * \param span The location of this operation in the source.
133 * \return The result expression.
134 * \note This function may return value if the type is the same.
135 */
136TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span());
137/*!
138 * \brief add operator
139 *
140 * \param a left operand
141 * \param b right operand
142 * \param span The location of this operation in the source.
143 * \return The result expression.
144 * \note this function does eager constant folding for
145 * index types(int32, int64) when possible.
146 */
147TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span());
148/*!
149 * \brief subtraction operator
150 *
151 * \param a left operand
152 * \param b right operand
153 * \param span The location of this operation in the source.
154 * \return The result expression.
155 * \note this function does eager constant folding for
156 * index types(int32, int64) when possible.
157 */
158TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span());
159/*!
160 * \brief negation.
161 *
162 * \param a input.
163 * \param span The location of this operation in the source.
164 * \return The result expression.
165 * \note this function does eager constant folding for
166 * index types(int32, int64) when possible.
167 */
168TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span());
169/*!
170 * \brief multiplication operator
171 *
172 * \param a left operand
173 * \param b right operand
174 * \param span The location of this operation in the source.
175 * \return The result expression.
176 * \note this function does eager constant folding for
177 * index types(int32, int64) when possible.
178 */
179TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span());
180/*!
181 * \brief left shift operator
182 *
183 * \param a left operand
184 * \param b right operand
185 * \param span The location of this operation in the source.
186 * \return The result expression.
187 * \note this function does eager constant folding for
188 * index types(int32, int64) when possible.
189 */
190TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span());
191/*!
192 * \brief right shift operator
193 *
194 * \param a left operand
195 * \param b right operand
196 * \param span The location of this operation in the source.
197 * \return The result expression.
198 * \note this function does eager constant folding for
199 * index types(int32, int64) when possible.
200 */
201TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span());
202/*!
203 * \brief greater
204 *
205 * \param a left operand
206 * \param b right operand
207 * \param span The location of this operation in the source.
208 * \return The result expression.
209 * \note this function does eager constant folding for
210 * index types(int32, int64) when possible.
211 */
212TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span());
213/*!
214 * \brief greater_equal
215 *
216 * \param a left operand
217 * \param b right operand
218 * \param span The location of this operation in the source.
219 * \return The result expression.
220 * \note this function does eager constant folding for
221 * index types(int32, int64) when possible.
222 */
223TVM_DLL PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span = Span());
224/*!
225 * \brief less
226 *
227 * \param a left operand
228 * \param b right operand
229 * \param span The location of this operation in the source.
230 * \return The result expression.
231 * \note this function does eager constant folding for
232 * index types(int32, int64) when possible.
233 */
234TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span());
235/*!
236 * \brief less_equal
237 *
238 * \param a left operand
239 * \param b right operand
240 * \param span The location of this operation in the source.
241 * \return The result expression.
242 * \note this function does eager constant folding for
243 * index types(int32, int64) when possible.
244 */
245TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span());
246/*!
247 * \brief equal
248 *
249 * \param a left operand
250 * \param b right operand
251 * \param span The location of this operation in the source.
252 * \return The result expression.
253 * \note this function does eager constant folding for
254 * index types(int32, int64) when possible.
255 */
256TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span());
257/*!
258 * \brief not_equal
259 *
260 * \param a left operand
261 * \param b right operand
262 * \param span The location of this operation in the source.
263 * \return The result expression.
264 * \note this function does eager constant folding for
265 * index types(int32, int64) when possible.
266 */
267TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span());
268/*!
269 * \brief and
270 *
271 * \param a left operand
272 * \param b right operand
273 * \param span The location of this operation in the source.
274 * \return The result expression.
275 * \note This operator does eager constant folding.
276 */
277TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span());
278/*!
279 * \brief or
280 *
281 * \param a left operand
282 * \param b right operand
283 * \param span The location of this operation in the source.
284 * \return The result expression.
285 * \note This operator does eager constant folding.
286 */
287TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span());
288/*!
289 * \brief not
290 *
291 * \param a left operand
292 * \param span The location of this operation in the source.
293 * \return The result expression.
294 * \note This operator does eager constant folding.
295 */
296TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span());
297/*!
298 * \brief compute division in C semantics.
299 *
300 * a / b as in C/C++.
301 *
302 * When operands are integers, it directly corresponds to truncdiv.
303 *
304 * \param a left operand
305 * \param b right operand
306 * \param span The location of this operation in the source.
307 * \return The result expression.
308 * \note this function does eager constant folding for
309 * index types(int32, int64) when possible.
310 */
311TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b, Span span = Span());
312/*!
313 * \brief compute trunc(a / b)
314 *
315 * This is the default integer division behavior in C.
316 *
317 * \param a left operand
318 * \param b right operand
319 * \param span The location of this operation in the source.
320 * \return The result expression.
321 * \note this function does eager constant folding for
322 * index types(int32, int64) when possible.
323 */
324TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span = Span());
325/*!
326 * \brief compute the remainder of truncdiv
327 *
328 * This is the default integer division behavior in C.
329 *
330 * \param a left operand
331 * \param b right operand
332 * \param span The location of this operation in the source.
333 * \return The result expression.
334 * \note this function does eager constant folding for
335 * index types(int32, int64) when possible.
336 */
337TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span());
338/*!
339 * \brief compute floor(a / b) where a and b are non-negative.
340 *
341 * Use this function for index split calculation.
342 *
343 * This function might take advantage of the fact
344 * that a and b are non-negative.
345 *
346 * \param a left operand
347 * \param b right operand
348 * \param span The location of this operation in the source.
349 * \return The result expression.
350 * \note this function does eager constant folding for
351 * index types(int32, int64) when possible.
352 */
353TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span());
354/*!
355 * \brief compute ceil(a / b) where a and b are non-negative.
356 *
357 * Use this function for shape split calculation.
358 *
359 * This function might take advantage of the fact
360 * that a and b are non-negative.
361 *
362 * \param a left operand
363 * \param b right operand
364 * \param span The location of this operation in the source.
365 * \return The result expression.
366 * \note this function does eager constant folding for
367 * shape types(int32, int64) when possible.
368 */
369TVM_DLL PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span = Span());
370/*!
371 * \brief compute the remainder floor(a / b) where a and b are non-negative.
372 *
373 * Use this function for index split calculation.
374 * This function might take advantage of the fact
375 * that a and b are non-negative.
376 *
377 * \param a left operand
378 * \param b right operand
379 * \param span The location of this operation in the source.
380 * \return The result expression.
381 * \note this function does eager constant folding for
382 * index types(int32, int64) when possible.
383 */
384TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span());
385/*!
386 * \brief compute floor(a / b)
387 *
388 * \param a left operand
389 * \param b right operand
390 * \param span The location of this operation in the source.
391 * \return The result expression.
392 * \note this function does eager constant folding for
393 * index types(int32, int64) when possible.
394 */
395TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span());
396/*!
397 * \brief compute ceil(a / b)
398 *
399 * \param a left operand
400 * \param b right operand
401 * \param span The location of this operation in the source.
402 * \return The result expression.
403 * \note this function does eager constant folding for
404 * index types(int32, int64) when possible.
405 */
406TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span());
407/*!
408 * \brief compute the remainder of floordiv
409 *
410 * \param a left operand
411 * \param b right operand
412 * \param span The location of this operation in the source.
413 * \return The result expression.
414 * \note this function does eager constant folding for
415 * index types(int32, int64) when possible.
416 */
417TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b, Span span = Span());
418/*!
419 * \brief take maximum of two values
420 *
421 * \param a left operand
422 * \param b right operand
423 * \param span The location of this operation in the source.
424 * \return The result expression.
425 * \note this function does eager constant folding for
426 * index types(int32, int64) when possible.
427 */
428TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b, Span span = Span());
429/*!
430 * \brief take minimum of two values
431 *
432 * \param a left operand
433 * \param b right operand
434 * \param span The location of this operation in the source.
435 * \return The result expression.
436 * \note this function does eager constant folding for
437 * index types(int32, int64) when possible.
438 */
439TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span());
440/*!
441 * \brief take bitwise and of two values
442 *
443 * \param a left operand
444 * \param b right operand
445 * \param span The location of this operation in the source.
446 * \return The result expression.
447 * \note this function does eager constant folding for
448 * index types(int32, int64) when possible.
449 */
450TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span());
451/*!
452 * \brief take bitwise or of two values
453 *
454 * \param a left operand
455 * \param b right operand
456 * \param span The location of this operation in the source.
457 * \return The result expression.
458 * \note this function does eager constant folding for
459 * index types(int32, int64) when possible.
460 */
461TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span());
462/*!
463 * \brief take bitwise xor of two values
464 *
465 * \param a left operand
466 * \param b right operand
467 * \param span The location of this operation in the source.
468 * \return The result expression.
469 * \note this function does eager constant folding for
470 * index types(int32, int64) when possible.
471 */
472TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span());
473/*!
474 * \brief take bitwise negation of two values
475 *
476 * \param a the input expression.
477 * \param span The location of this operation in the source.
478 * \return The result expression.
479 * \note this function does eager constant folding for
480 * index types(int32, int64) when possible.
481 */
482TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span());
483/*!
484 * \brief Conditional expression.
485 *
486 * \param cond The condition
487 * \param true_value The value when results are true.
488 * \param false_value The value when results are false.
489 * \param span The location of this operation in the source.
490 * \return The result expression.
491 * \note this function does eager constant folding for
492 * index types(int32, int64) when possible.
493 */
494TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value,
495 Span span = Span());
496/*!
497 * \brief Mark condition as likely.
498 * \param cond The condition
499 * \param span The location of this operation in the source.
500 * \return The marked expression.
501 */
502TVM_DLL PrimExpr likely(PrimExpr cond, Span span = Span());
503/*!
504 * \brief Calculate power(x, y)
505 * \param x The left operand.
506 * \param y The right operand.
507 * \param span The location of this operation in the source.
508 */
509TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y, Span span = Span());
510/*!
511 * \brief Calculate absolute value of x.
512 * \param x The input data
513 * \param span The location of this operation in the source.
514 *
515 * \return The aboslute value of input data x
516 */
517TVM_DLL PrimExpr abs(PrimExpr x, Span span = Span());
518/*!
519 * \brief Check if x is NaN.
520 * \param x The input data
521 * \param span The location of this operation in the source.
522 * \return The result expression.
523 */
524TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
525
526/*!
527 * \brief Check if x is finite.
528 * \param x The input data
529 * \param span The location of this operation in the source.
530 * \return The result expression.
531 */
532TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
533
534/*!
535 * \brief Check if x is infinite.
536 * \param x The input data
537 * \param span The location of this operation in the source.
538 * \return The result expression.
539 */
540TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span());
541
542/*!
543 * \brief sum of source expression over axis
544 * \param source The source expression.
545 * \param axis List of iteration variables that will be used for reduction.
546 * \param init The value with which to initialize the output.
547 * \param span The location of this operation in the source.
548 * \return The result.
549 */
550TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
551 Span span = Span());
552
553/*!
554 * \brief logical And of source expression over axis
555 * \param source The source expression.
556 * \param axis List of iteration variables that will be used for reduction.
557 * \param init The value with which to initialize the output.
558 * \param span The location of this operation in the source.
559 */
560TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
561 Span span = Span());
562
563/*!
564 * \brief logical Or of source expression over axis
565 * \param source The source expression.
566 * \param axis List of iteration variables that will be used for reduction.
567 * \param init The value with which to initialize the output.
568 * \param span The location of this operation in the source.
569 * \return The result.
570 */
571TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
572 Span span = Span());
573
574/*!
575 * \brief max of source expression over axis
576 * \param source The source expression.
577 * \param axis List of iteration variables that will be used for reduction.
578 * \param init The value with which to initialize the output.
579 * \param span The location of this operation in the source.
580 * \return The result.
581 */
582TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
583 Span span = Span());
584
585/*!
586 * \brief max of source expression over axis
587 * \param source The source expression.
588 * \param axis List of iteration variables that will be used for reduction.
589 * \param init The value with which to initialize the output.
590 * \param span The location of this operation in the source.
591 * \return The result.
592 */
593TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
594 Span span = Span());
595
596/*!
597 * \brief product of source expression over axis
598 * \param source The source expression.
599 * \param axis List of iteration variables that will be used for reduction.
600 * \param init The value with which to initialize the output.
601 * \param span The location of this operation in the source.
602 * \return The result.
603 */
604TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
605 Span span = Span());
606
607/*!
608 * \brief Calculate floor(x)
609 * \param x The input expression.
610 * \param span The location of this operation in the source.
611 * \return The result expression.
612 */
613TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span());
614
615/*!
616 * \brief Calculate ceil(x)
617 * \param x The input expression.
618 * \param span The location of this operation in the source.
619 * \return The result expression.
620 */
621TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
622
623/*!
624 * \brief Calculate round(x)
625 * \param x The input expression.
626 * \param span The location of this operation in the source.
627 * \return The result expression.
628 */
629TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
630
631/*!
632 * \brief Calculates std::nearbyint(x)
633 * \param x The input expression.
634 * \param span The location of this operation in the source.
635 * \return The result expression.
636 * This is a faster alternate to round.
637 */
638TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());
639
640/*!
641 * \brief Calculate trunc(x)
642 * \param x The input expression.
643 * \param span The location of this operation in the source.
644 * \return The result expression.
645 */
646TVM_DLL PrimExpr trunc(PrimExpr x, Span span = Span());
647
648/*!
649 * \brief Construct a large uint constant by its low 32 bits and high 32bits.
650 * \param dtype The final data type.
651 * \param low The lower 32 bits.
652 * \param high The higher 32 bits.
653 * \param span The location of this operation in the source.
654 * \return The constructed expression.
655 */
656TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span = Span());
657
658/*!
659 * \brief Execute a multiplication between two Q-numbers x and y
660 * followed by a right shift s. The mathematical expression is:
661 *
662 * out = round(x*y*2^-s)
663 *
664 * Please note that the two Q-numbers x and y are supposed to have
665 * the same number of fractional bits q.
666 *
667 * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
668 *
669 * The rounding rule is to the nearest value, rounding half up
670 * (i.e., round(x.1) = x and round (x.5) = x+1)
671 * \param x first Q-number
672 * \param y second Q-number
673 * \param q number of fractional bits in x and y. Needs to be > 0
674 * \param s integer right shift
675 * \param span The location of this operation in the source.
676 * \return The constructed expression.
677 */
678TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
679 Span span = Span());
680
681// Intrinsic operators
682#define TVM_DECLARE_INTRIN_UNARY(OpName) \
683 inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
684 static const Op& op = Op::Get("tir." #OpName); \
685 if (x.dtype().is_bfloat16()) { \
686 DataType bf16_dtype = x.dtype(); \
687 DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
688 PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \
689 PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \
690 return tir::Cast(bf16_dtype, {result_fp32}, span); \
691 } else { \
692 return tir::Call(x.dtype(), op, {x}, span); \
693 } \
694 }
695
696TVM_DECLARE_INTRIN_UNARY(exp);
697TVM_DECLARE_INTRIN_UNARY(exp2);
698TVM_DECLARE_INTRIN_UNARY(exp10);
699TVM_DECLARE_INTRIN_UNARY(erf);
700TVM_DECLARE_INTRIN_UNARY(tanh);
701TVM_DECLARE_INTRIN_UNARY(sigmoid);
702TVM_DECLARE_INTRIN_UNARY(sqrt);
703TVM_DECLARE_INTRIN_UNARY(rsqrt);
704TVM_DECLARE_INTRIN_UNARY(log);
705TVM_DECLARE_INTRIN_UNARY(log2);
706TVM_DECLARE_INTRIN_UNARY(log10);
707TVM_DECLARE_INTRIN_UNARY(log1p);
708TVM_DECLARE_INTRIN_UNARY(popcount);
709TVM_DECLARE_INTRIN_UNARY(tan);
710TVM_DECLARE_INTRIN_UNARY(cos);
711TVM_DECLARE_INTRIN_UNARY(cosh);
712TVM_DECLARE_INTRIN_UNARY(sin);
713TVM_DECLARE_INTRIN_UNARY(sinh);
714TVM_DECLARE_INTRIN_UNARY(asin);
715TVM_DECLARE_INTRIN_UNARY(acos);
716TVM_DECLARE_INTRIN_UNARY(atan);
717TVM_DECLARE_INTRIN_UNARY(acosh);
718TVM_DECLARE_INTRIN_UNARY(asinh);
719TVM_DECLARE_INTRIN_UNARY(atanh);
720TVM_DECLARE_INTRIN_UNARY(clz);
721
722#define TVM_DECLARE_INTRIN_BINARY(OpName) \
723 inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
724 static const Op& op = Op::Get("tir." #OpName); \
725 return tir::Call(x.dtype(), op, {x, y}, span); \
726 }
727
728TVM_DECLARE_INTRIN_BINARY(atan2);
729TVM_DECLARE_INTRIN_BINARY(nextafter);
730TVM_DECLARE_INTRIN_BINARY(copysign);
731TVM_DECLARE_INTRIN_BINARY(hypot);
732TVM_DECLARE_INTRIN_BINARY(ldexp);
733
734namespace tir {
735
736/*!
737 * \brief Check if type is a pointer to a runtime element type.
738 * \param type The type to be checked.
739 * \param element_type The corresponding element type.
740 * \return The check results
741 */
742inline bool IsPointerType(const Type& type, const DataType& element_type) {
743 if (!type.defined()) return false;
744 if (const auto* ptr_type = type.as<PointerTypeNode>()) {
745 if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
746 return prim_type->dtype == element_type;
747 }
748 }
749 return false;
750}
751
752/*!
753 * \brief Make a const value with certain data type.
754 * \param t The target type.
755 * \param value The input value
756 * \return the result expression.
757 * \tparam ValueType The constant value type
758 * \param span The location of this operation in the source.
759 */
760template <typename ValueType,
761 typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
762inline PrimExpr make_const(DataType t, ValueType value, Span span = Span());
763/*!
764 * \brief Make a const zero expr.
765 * \param t The target type.
766 * \param span The location of this operation in the source.
767 * \return the result expression.
768 */
769inline PrimExpr make_zero(DataType t, Span span = Span());
770/*!
771 * \brief Make a constant true expression.
772 * \param lanes The number of lanes in the bool
773 * \param span The location of this operation in the source.
774 * \return The result expression.
775 */
776inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
777 return make_const(DataType::UInt(1, lanes), 1);
778}
779/*!
780 * \brief Make a constant false expression.
781 * \param lanes The number of lanes in the bool
782 * \param span The location of this operation in the source.
783 * \return The result expression.
784 */
785inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
786 return make_const(DataType::UInt(1, lanes), 0);
787}
788/*!
789 * \brief Get x as constant int expression.
790 * \param x The expression
791 * \return the address to the int expression,
792 * return nullptr, if x is not IntImm.
793 */
794inline const int64_t* as_const_int(const PrimExpr& x) {
795 if (!x.defined()) return nullptr;
796 if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) {
797 return &(op->value);
798 }
799
800 return nullptr;
801}
802
803/*!
804 * \brief Check whether x is a constant integer expression.
805 * \param x The input argument
806 * \param value the value to be compared against.
807 * \return whether x is constant expression.
808 */
809inline bool is_const_int(const PrimExpr& x, int64_t value);
810
811/*!
812 * \brief Check whether stmt is nop.
813 * \param stmt The input statement
814 * \return whether stmt is nop
815 */
816inline bool is_no_op(const tir::Stmt& stmt);
817
818/*!
819 * \brief Check whether x is a constant integer 1
820 * \param x The input argument.
821 * \note This only return true for integer types.
822 * \return whether x is constant 1
823 */
824inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
825
826/*!
827 * \brief Check whether x is a constant integer 0
828 * \param x The input argument
829 * \return whether x is constant 0
830 * \note This only return true for integer types.
831 */
832inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
833
834/*!
835 * \brief Check whether x is an integer constant.
836 * \note This only return true for integer types.
837 * \return whether x is constant
838 */
839inline bool is_const_int(const PrimExpr& x);
840
841/*!
842 * \brief Check whether x is an integer/float constant.
843 * \note This only return true for integer types.
844 * \return whether x is constant
845 */
846inline bool is_const_number(const PrimExpr& x);
847
848/*!
849 * \brief Left fold.
850 * \param freduce The reduction function.
851 * \param init_value The initial value.
852 * \param values The values to be folded.
853 * \param span The location of the fold in the source.
854 * \return The result.
855 * \tparam FReduce The type of the reduction.
856 */
857template <typename FReduce>
858inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values,
859 Span span = Span());
860
861/*!
862 * \brief Check whether x is a constant power of two
863 * If x is power of two, write the power to the shift.
864 *
865 * \param x The input expression.
866 * \param shift The output shift if x is power of two.
867 * \return whether x is constant power of two
868 */
869TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
870
871// Implementation details after this
872inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); }
873
874inline bool is_const_number(const PrimExpr& x) {
875 if (x.as<tir::IntImmNode>()) {
876 return true;
877 } else if (x.as<tir::FloatImmNode>()) {
878 return true;
879 } else if (const auto* op = x.as<tir::BroadcastNode>()) {
880 return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>());
881 }
882 return false;
883}
884
885inline bool is_positive_const(const PrimExpr& a) {
886 const int64_t* as_int = as_const_int(a);
887 return as_int && (*as_int > 0);
888}
889
890inline bool is_negative_const(const PrimExpr& a) {
891 const int64_t* as_int = as_const_int(a);
892 return as_int && (*as_int < 0);
893}
894
895inline bool is_const_int(const PrimExpr& x, int64_t value) {
896 const int64_t* as_int = as_const_int(x);
897 return as_int && (*as_int == value);
898}
899
900inline bool is_no_op(const tir::Stmt& stmt) {
901 if (!stmt.defined()) return true;
902 if (const auto* op = stmt.as<tir::EvaluateNode>()) {
903 return is_const_int(op->value);
904 }
905 if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
906 return op->seq.size() == 0;
907 }
908 return false;
909}
910
911template <typename ValueType>
912inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
913 if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span);
914 if (t.is_uint()) {
915 // Use IntImm if it is a small integer
916 uint64_t uval = static_cast<uint64_t>(value);
917 if (value < static_cast<ValueType>(0)) {
918 LOG(FATAL) << "cannot make uint from negative value " << value;
919 } else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
920 return IntImm(t, static_cast<int64_t>(value), span);
921 } else {
922 uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
923 uint64_t low = uval & mask;
924 uint64_t high = uval >> 32U;
925 return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
926 }
927 }
928 if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value), span);
929 // For now, we store const scalar values of custom datatypes within doubles; later, during the
930 // datatypes lowering pass, we will lower the value to its true representation in the format
931 // specified by the datatype.
932 // TODO(gus) when do we need to start worrying about doubles not being precise enough?
933 if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
934 return FloatImm(t, static_cast<double>(value), span);
935 }
936 LOG(FATAL) << "cannot make const for type " << t;
937}
938
939template <>
940inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
941 return MakeConstScalar(t, static_cast<int>(value), span);
942}
943
944template <typename ValueType, typename>
945inline PrimExpr make_const(DataType t, ValueType value, Span span) {
946 if (t.lanes() == 1) {
947 return MakeConstScalar(t, value, span);
948 } else {
949 return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
950 }
951}
952
953inline PrimExpr make_zero(DataType t, Span span) {
954 if (t.is_handle()) {
955 return reinterpret(t, make_const(DataType::UInt(64), 0, span));
956 }
957 return make_const(t, 0, span);
958}
959
960template <typename FReduce>
961inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values,
962 Span span) {
963 for (PrimExpr val : values) {
964 init_value = freduce(init_value, val, span);
965 }
966 return init_value;
967}
968
969} // namespace tir
970
971// additional const expression overloading
972#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
973 inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \
974 a = OpFunc(a, b); \
975 return a; \
976 }
977
978#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
979 inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
980 inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
981 inline PrimExpr Name(int a, const PrimExpr& b) { \
982 return Name(tir::make_const(b.dtype(), a), b); \
983 } \
984 inline PrimExpr Name(const PrimExpr& a, int b) { \
985 return Name(a, tir::make_const(a.dtype(), b)); \
986 } \
987 inline PrimExpr Name(const PrimExpr& a, double b) { \
988 return Name(a, tir::make_const(DataType::Float(64), b)); \
989 }
990
991#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
992 inline PrimExpr Name(const PrimExpr& a, float b, Span span = Span()) { \
993 return Name(a, PrimExpr(b), span); \
994 } \
995 inline PrimExpr Name(float a, const PrimExpr& b, Span span = Span()) { \
996 return Name(PrimExpr(a), b, span); \
997 } \
998 inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
999 return Name(tir::make_const(b.dtype(), a), b, span); \
1000 } \
1001 inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1002 return Name(a, tir::make_const(a.dtype(), b), span); \
1003 } \
1004 inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
1005 return Name(a, tir::make_const(DataType::Float(64), b), span); \
1006 }
1007
1008#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
1009 inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \
1010 inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); }
1011
1012#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1013 inline PrimExpr Name(const PrimExpr& a, bool b, Span span = Span()) { \
1014 return Name(a, PrimExpr(b), span); \
1015 } \
1016 inline PrimExpr Name(bool a, const PrimExpr& b, Span span = Span()) { \
1017 return Name(PrimExpr(a), b, span); \
1018 }
1019
1020#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1021 inline PrimExpr Name(const PrimExpr& a, int b) { \
1022 return Name(a, tir::make_const(a.dtype(), b)); \
1023 } \
1024 inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); }
1025
1026#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1027 inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1028 return Name(a, tir::make_const(a.dtype(), b), span); \
1029 } \
1030 inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1031 return Name(tir::make_const(b.dtype(), a), b, span); \
1032 }
1033
1034TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
1035TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
1036TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
1037TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
1038TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
1039TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
1040TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
1041TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
1042TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
1043TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
1044TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(max);
1045TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(min);
1046TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(div);
1047TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(add);
1048TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(sub);
1049TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(mul);
1050TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(greater);
1051TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(greater_equal);
1052TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(less);
1053TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(less_equal);
1054// integer related ops
1055TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexdiv);
1056TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexmod);
1057TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncdiv);
1058TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncmod);
1059TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floordiv);
1060TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floormod);
1061TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(right_shift); // NOLINT(*)
1062TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(left_shift); // NOLINT(*)
1063TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_and);
1064TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_or);
1065TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_xor);
1066TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
1067TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
1068TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
1069TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|);
1070TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
1071// logical ops
1072TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
1073TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
1074TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(logical_and);
1075TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(logical_or);
1076
1077/*!
1078 * \brief Helper function to raise a compiler error about division ambiguity.
1079 * \note The call to this function will always results in a compiler error.
1080 * \tparam TA Any class type.
1081 */
1082template <typename TA>
1083inline void DivAmbiguityError(const TA& a) {
1084 constexpr bool div_ambiguity = !std::is_class<TA>::value;
1085 static_assert(div_ambiguity,
1086 "TVM supports multiple types of integer divisions, "
1087 "please call div, indexdiv/indexmod, "
1088 "floordiv/floormod or truncdiv/truncmod directly "
1089 "to avoid ambiguity in the code. "
1090 "Checkout these functions in tir/op.h.");
1091}
1092
1093// The following code are not intended to be used in the codebase.
1094// Instead, they generate clear compiler errors that ask developers
1095// to use the specific division function.
1096// The second template argument is necessary to make sure the
1097// code compiles lazily by the compiler during invocation.
1098template <typename TB>
1099inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
1100 DivAmbiguityError(a);
1101 return a;
1102}
1103
1104template <typename TB>
1105inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
1106 DivAmbiguityError(a);
1107 return a;
1108}
1109
1110template <typename TB>
1111inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
1112 DivAmbiguityError(a);
1113 return a;
1114}
1115} // namespace tvm
1116#endif // TVM_TIR_OP_H_
1117