1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/op.h |
22 | * \brief Common operators defined for Expr. |
23 | * |
24 | * \note Most of the operator defined here perform simple constant folding |
25 | * when the type is int32 or int64 for simplifying the index expressions. |
26 | */ |
27 | // Acknowledgement: Most operator APIs originate from Halide. |
28 | #ifndef TVM_TIR_OP_H_ |
29 | #define TVM_TIR_OP_H_ |
30 | |
31 | #include <tvm/ir/expr.h> |
32 | #include <tvm/ir/op.h> |
33 | #include <tvm/ir/type.h> |
34 | #include <tvm/tir/expr.h> |
35 | #include <tvm/tir/stmt.h> |
36 | |
37 | #include <algorithm> |
38 | #include <limits> |
39 | #include <type_traits> |
40 | |
41 | namespace tvm { |
42 | |
43 | #define TVM_TIR_REGISTER_OP(OpName) \ |
44 | TVM_REGISTER_OP("tir." OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", OpName) |
45 | |
46 | // Most common operators can be overloaded by argument type(PrimExpr). |
47 | // So we put them under the root namespace. |
48 | // |
49 | // We put more developer oriented APIs -- make_const and is_const under tir |
50 | // as they are more specific to the tir namespace. |
51 | |
52 | /*! |
53 | * \brief Get the type of the expression under the unified type system. |
54 | * |
55 | * This function could return a more refined type than |
56 | * the runtime type provided by expr->dtype |
57 | * |
58 | * \param expr The input parameter. |
59 | * \return The result type. |
60 | * |
61 | * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. |
62 | */ |
63 | TVM_DLL Type GetType(const PrimExpr& expr); |
64 | |
65 | /*! |
66 | * \brief Get the type corresponding to DataType |
67 | * \param dtype The data type |
68 | * \return The result type |
69 | * |
70 | * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. |
71 | */ |
72 | TVM_DLL Type GetTypeFromRuntimeDataType(const DataType& dtype); |
73 | |
74 | /*! |
75 | * \brief Get the implied DataType for storing values with type during runtime. |
76 | * |
77 | * \param type The input type. |
78 | * \return The result runtime::DataType. |
79 | * |
80 | * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. |
81 | */ |
82 | TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); |
83 | |
84 | /*! |
85 | * \brief Return the value. |
86 | * |
87 | * \param value The returned value. |
88 | * \param span The location of this operation in the source. |
89 | * \return The return expression. |
90 | */ |
91 | TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); |
92 | |
93 | /*! |
94 | * Query the maximum possible value of dtype. |
95 | * \param dtype The data type. |
96 | * \param span The location of this operation in the source. |
97 | * \return the maximum possible value in this format. |
98 | */ |
99 | TVM_DLL PrimExpr max_value(const DataType& dtype, Span span = Span()); |
100 | |
101 | /*! |
102 | * Query the minimum possible value of dtype. |
103 | * \param dtype The data type. |
104 | * \param span The location of this operation in the source. |
105 | * \return the minimum possible value in this format. |
106 | */ |
107 | TVM_DLL PrimExpr min_value(const DataType& dtype, Span span = Span()); |
108 | |
109 | /*! |
110 | * Get the value of infinity. |
111 | * \param dtype The data type. |
112 | * \param span The location of this operation in the source. |
113 | * \return the infinity value in this format. |
114 | */ |
115 | TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span()); |
116 | |
117 | /*! |
118 | * \brief cast value to type. |
119 | * |
120 | * \param t the target type. |
121 | * \param value The value |
122 | * \param span The location of this operation in the source. |
123 | * \return The result expression. |
124 | * \note This function may return value if the type is the same. |
125 | */ |
126 | TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span()); |
127 | /*! |
128 | * \brief perform reinterpret cast value to type. |
129 | * |
130 | * \param t the target type. |
131 | * \param value The value |
132 | * \param span The location of this operation in the source. |
133 | * \return The result expression. |
134 | * \note This function may return value if the type is the same. |
135 | */ |
136 | TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span()); |
137 | /*! |
138 | * \brief add operator |
139 | * |
140 | * \param a left operand |
141 | * \param b right operand |
142 | * \param span The location of this operation in the source. |
143 | * \return The result expression. |
144 | * \note this function does eager constant folding for |
145 | * index types(int32, int64) when possible. |
146 | */ |
147 | TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span()); |
148 | /*! |
149 | * \brief subtraction operator |
150 | * |
151 | * \param a left operand |
152 | * \param b right operand |
153 | * \param span The location of this operation in the source. |
154 | * \return The result expression. |
155 | * \note this function does eager constant folding for |
156 | * index types(int32, int64) when possible. |
157 | */ |
158 | TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span()); |
159 | /*! |
160 | * \brief negation. |
161 | * |
162 | * \param a input. |
163 | * \param span The location of this operation in the source. |
164 | * \return The result expression. |
165 | * \note this function does eager constant folding for |
166 | * index types(int32, int64) when possible. |
167 | */ |
168 | TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span()); |
169 | /*! |
170 | * \brief multiplication operator |
171 | * |
172 | * \param a left operand |
173 | * \param b right operand |
174 | * \param span The location of this operation in the source. |
175 | * \return The result expression. |
176 | * \note this function does eager constant folding for |
177 | * index types(int32, int64) when possible. |
178 | */ |
179 | TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span()); |
180 | /*! |
181 | * \brief left shift operator |
182 | * |
183 | * \param a left operand |
184 | * \param b right operand |
185 | * \param span The location of this operation in the source. |
186 | * \return The result expression. |
187 | * \note this function does eager constant folding for |
188 | * index types(int32, int64) when possible. |
189 | */ |
190 | TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span()); |
191 | /*! |
192 | * \brief right shift operator |
193 | * |
194 | * \param a left operand |
195 | * \param b right operand |
196 | * \param span The location of this operation in the source. |
197 | * \return The result expression. |
198 | * \note this function does eager constant folding for |
199 | * index types(int32, int64) when possible. |
200 | */ |
201 | TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span()); |
202 | /*! |
203 | * \brief greater |
204 | * |
205 | * \param a left operand |
206 | * \param b right operand |
207 | * \param span The location of this operation in the source. |
208 | * \return The result expression. |
209 | * \note this function does eager constant folding for |
210 | * index types(int32, int64) when possible. |
211 | */ |
212 | TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span()); |
213 | /*! |
214 | * \brief greater_equal |
215 | * |
216 | * \param a left operand |
217 | * \param b right operand |
218 | * \param span The location of this operation in the source. |
219 | * \return The result expression. |
220 | * \note this function does eager constant folding for |
221 | * index types(int32, int64) when possible. |
222 | */ |
223 | TVM_DLL PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span = Span()); |
224 | /*! |
225 | * \brief less |
226 | * |
227 | * \param a left operand |
228 | * \param b right operand |
229 | * \param span The location of this operation in the source. |
230 | * \return The result expression. |
231 | * \note this function does eager constant folding for |
232 | * index types(int32, int64) when possible. |
233 | */ |
234 | TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span()); |
235 | /*! |
236 | * \brief less_equal |
237 | * |
238 | * \param a left operand |
239 | * \param b right operand |
240 | * \param span The location of this operation in the source. |
241 | * \return The result expression. |
242 | * \note this function does eager constant folding for |
243 | * index types(int32, int64) when possible. |
244 | */ |
245 | TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span()); |
246 | /*! |
247 | * \brief equal |
248 | * |
249 | * \param a left operand |
250 | * \param b right operand |
251 | * \param span The location of this operation in the source. |
252 | * \return The result expression. |
253 | * \note this function does eager constant folding for |
254 | * index types(int32, int64) when possible. |
255 | */ |
256 | TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span()); |
257 | /*! |
258 | * \brief not_equal |
259 | * |
260 | * \param a left operand |
261 | * \param b right operand |
262 | * \param span The location of this operation in the source. |
263 | * \return The result expression. |
264 | * \note this function does eager constant folding for |
265 | * index types(int32, int64) when possible. |
266 | */ |
267 | TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span()); |
268 | /*! |
269 | * \brief and |
270 | * |
271 | * \param a left operand |
272 | * \param b right operand |
273 | * \param span The location of this operation in the source. |
274 | * \return The result expression. |
275 | * \note This operator does eager constant folding. |
276 | */ |
277 | TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span()); |
278 | /*! |
279 | * \brief or |
280 | * |
281 | * \param a left operand |
282 | * \param b right operand |
283 | * \param span The location of this operation in the source. |
284 | * \return The result expression. |
285 | * \note This operator does eager constant folding. |
286 | */ |
287 | TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span()); |
288 | /*! |
289 | * \brief not |
290 | * |
291 | * \param a left operand |
292 | * \param span The location of this operation in the source. |
293 | * \return The result expression. |
294 | * \note This operator does eager constant folding. |
295 | */ |
296 | TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span()); |
297 | /*! |
298 | * \brief compute division in C semantics. |
299 | * |
300 | * a / b as in C/C++. |
301 | * |
302 | * When operands are integers, it directly corresponds to truncdiv. |
303 | * |
304 | * \param a left operand |
305 | * \param b right operand |
306 | * \param span The location of this operation in the source. |
307 | * \return The result expression. |
308 | * \note this function does eager constant folding for |
309 | * index types(int32, int64) when possible. |
310 | */ |
311 | TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b, Span span = Span()); |
312 | /*! |
313 | * \brief compute trunc(a / b) |
314 | * |
315 | * This is the default integer division behavior in C. |
316 | * |
317 | * \param a left operand |
318 | * \param b right operand |
319 | * \param span The location of this operation in the source. |
320 | * \return The result expression. |
321 | * \note this function does eager constant folding for |
322 | * index types(int32, int64) when possible. |
323 | */ |
324 | TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span = Span()); |
325 | /*! |
326 | * \brief compute the remainder of truncdiv |
327 | * |
328 | * This is the default integer division behavior in C. |
329 | * |
330 | * \param a left operand |
331 | * \param b right operand |
332 | * \param span The location of this operation in the source. |
333 | * \return The result expression. |
334 | * \note this function does eager constant folding for |
335 | * index types(int32, int64) when possible. |
336 | */ |
337 | TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span()); |
338 | /*! |
339 | * \brief compute floor(a / b) where a and b are non-negative. |
340 | * |
341 | * Use this function for index split calculation. |
342 | * |
343 | * This function might take advantage of the fact |
344 | * that a and b are non-negative. |
345 | * |
346 | * \param a left operand |
347 | * \param b right operand |
348 | * \param span The location of this operation in the source. |
349 | * \return The result expression. |
350 | * \note this function does eager constant folding for |
351 | * index types(int32, int64) when possible. |
352 | */ |
353 | TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span()); |
354 | /*! |
355 | * \brief compute ceil(a / b) where a and b are non-negative. |
356 | * |
357 | * Use this function for shape split calculation. |
358 | * |
359 | * This function might take advantage of the fact |
360 | * that a and b are non-negative. |
361 | * |
362 | * \param a left operand |
363 | * \param b right operand |
364 | * \param span The location of this operation in the source. |
365 | * \return The result expression. |
366 | * \note this function does eager constant folding for |
367 | * shape types(int32, int64) when possible. |
368 | */ |
369 | TVM_DLL PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span = Span()); |
370 | /*! |
371 | * \brief compute the remainder floor(a / b) where a and b are non-negative. |
372 | * |
373 | * Use this function for index split calculation. |
374 | * This function might take advantage of the fact |
375 | * that a and b are non-negative. |
376 | * |
377 | * \param a left operand |
378 | * \param b right operand |
379 | * \param span The location of this operation in the source. |
380 | * \return The result expression. |
381 | * \note this function does eager constant folding for |
382 | * index types(int32, int64) when possible. |
383 | */ |
384 | TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span()); |
385 | /*! |
386 | * \brief compute floor(a / b) |
387 | * |
388 | * \param a left operand |
389 | * \param b right operand |
390 | * \param span The location of this operation in the source. |
391 | * \return The result expression. |
392 | * \note this function does eager constant folding for |
393 | * index types(int32, int64) when possible. |
394 | */ |
395 | TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span()); |
396 | /*! |
397 | * \brief compute ceil(a / b) |
398 | * |
399 | * \param a left operand |
400 | * \param b right operand |
401 | * \param span The location of this operation in the source. |
402 | * \return The result expression. |
403 | * \note this function does eager constant folding for |
404 | * index types(int32, int64) when possible. |
405 | */ |
406 | TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span()); |
407 | /*! |
408 | * \brief compute the remainder of floordiv |
409 | * |
410 | * \param a left operand |
411 | * \param b right operand |
412 | * \param span The location of this operation in the source. |
413 | * \return The result expression. |
414 | * \note this function does eager constant folding for |
415 | * index types(int32, int64) when possible. |
416 | */ |
417 | TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b, Span span = Span()); |
418 | /*! |
419 | * \brief take maximum of two values |
420 | * |
421 | * \param a left operand |
422 | * \param b right operand |
423 | * \param span The location of this operation in the source. |
424 | * \return The result expression. |
425 | * \note this function does eager constant folding for |
426 | * index types(int32, int64) when possible. |
427 | */ |
428 | TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b, Span span = Span()); |
429 | /*! |
430 | * \brief take minimum of two values |
431 | * |
432 | * \param a left operand |
433 | * \param b right operand |
434 | * \param span The location of this operation in the source. |
435 | * \return The result expression. |
436 | * \note this function does eager constant folding for |
437 | * index types(int32, int64) when possible. |
438 | */ |
439 | TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span()); |
440 | /*! |
441 | * \brief take bitwise and of two values |
442 | * |
443 | * \param a left operand |
444 | * \param b right operand |
445 | * \param span The location of this operation in the source. |
446 | * \return The result expression. |
447 | * \note this function does eager constant folding for |
448 | * index types(int32, int64) when possible. |
449 | */ |
450 | TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span()); |
451 | /*! |
452 | * \brief take bitwise or of two values |
453 | * |
454 | * \param a left operand |
455 | * \param b right operand |
456 | * \param span The location of this operation in the source. |
457 | * \return The result expression. |
458 | * \note this function does eager constant folding for |
459 | * index types(int32, int64) when possible. |
460 | */ |
461 | TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span()); |
462 | /*! |
463 | * \brief take bitwise xor of two values |
464 | * |
465 | * \param a left operand |
466 | * \param b right operand |
467 | * \param span The location of this operation in the source. |
468 | * \return The result expression. |
469 | * \note this function does eager constant folding for |
470 | * index types(int32, int64) when possible. |
471 | */ |
472 | TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span()); |
473 | /*! |
474 | * \brief take bitwise negation of two values |
475 | * |
476 | * \param a the input expression. |
477 | * \param span The location of this operation in the source. |
478 | * \return The result expression. |
479 | * \note this function does eager constant folding for |
480 | * index types(int32, int64) when possible. |
481 | */ |
482 | TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span()); |
483 | /*! |
484 | * \brief Conditional expression. |
485 | * |
486 | * \param cond The condition |
487 | * \param true_value The value when results are true. |
488 | * \param false_value The value when results are false. |
489 | * \param span The location of this operation in the source. |
490 | * \return The result expression. |
491 | * \note this function does eager constant folding for |
492 | * index types(int32, int64) when possible. |
493 | */ |
494 | TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, |
495 | Span span = Span()); |
496 | /*! |
497 | * \brief Mark condition as likely. |
498 | * \param cond The condition |
499 | * \param span The location of this operation in the source. |
500 | * \return The marked expression. |
501 | */ |
502 | TVM_DLL PrimExpr likely(PrimExpr cond, Span span = Span()); |
503 | /*! |
504 | * \brief Calculate power(x, y) |
505 | * \param x The left operand. |
506 | * \param y The right operand. |
507 | * \param span The location of this operation in the source. |
508 | */ |
509 | TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y, Span span = Span()); |
510 | /*! |
511 | * \brief Calculate absolute value of x. |
512 | * \param x The input data |
513 | * \param span The location of this operation in the source. |
514 | * |
515 | * \return The aboslute value of input data x |
516 | */ |
517 | TVM_DLL PrimExpr abs(PrimExpr x, Span span = Span()); |
518 | /*! |
519 | * \brief Check if x is NaN. |
520 | * \param x The input data |
521 | * \param span The location of this operation in the source. |
522 | * \return The result expression. |
523 | */ |
524 | TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span()); |
525 | |
526 | /*! |
527 | * \brief Check if x is finite. |
528 | * \param x The input data |
529 | * \param span The location of this operation in the source. |
530 | * \return The result expression. |
531 | */ |
532 | TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span()); |
533 | |
534 | /*! |
535 | * \brief Check if x is infinite. |
536 | * \param x The input data |
537 | * \param span The location of this operation in the source. |
538 | * \return The result expression. |
539 | */ |
540 | TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span()); |
541 | |
542 | /*! |
543 | * \brief sum of source expression over axis |
544 | * \param source The source expression. |
545 | * \param axis List of iteration variables that will be used for reduction. |
546 | * \param init The value with which to initialize the output. |
547 | * \param span The location of this operation in the source. |
548 | * \return The result. |
549 | */ |
550 | TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
551 | Span span = Span()); |
552 | |
553 | /*! |
554 | * \brief logical And of source expression over axis |
555 | * \param source The source expression. |
556 | * \param axis List of iteration variables that will be used for reduction. |
557 | * \param init The value with which to initialize the output. |
558 | * \param span The location of this operation in the source. |
559 | */ |
560 | TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
561 | Span span = Span()); |
562 | |
563 | /*! |
564 | * \brief logical Or of source expression over axis |
565 | * \param source The source expression. |
566 | * \param axis List of iteration variables that will be used for reduction. |
567 | * \param init The value with which to initialize the output. |
568 | * \param span The location of this operation in the source. |
569 | * \return The result. |
570 | */ |
571 | TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
572 | Span span = Span()); |
573 | |
574 | /*! |
575 | * \brief max of source expression over axis |
576 | * \param source The source expression. |
577 | * \param axis List of iteration variables that will be used for reduction. |
578 | * \param init The value with which to initialize the output. |
579 | * \param span The location of this operation in the source. |
580 | * \return The result. |
581 | */ |
582 | TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
583 | Span span = Span()); |
584 | |
585 | /*! |
586 | * \brief max of source expression over axis |
587 | * \param source The source expression. |
588 | * \param axis List of iteration variables that will be used for reduction. |
589 | * \param init The value with which to initialize the output. |
590 | * \param span The location of this operation in the source. |
591 | * \return The result. |
592 | */ |
593 | TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
594 | Span span = Span()); |
595 | |
596 | /*! |
597 | * \brief product of source expression over axis |
598 | * \param source The source expression. |
599 | * \param axis List of iteration variables that will be used for reduction. |
600 | * \param init The value with which to initialize the output. |
601 | * \param span The location of this operation in the source. |
602 | * \return The result. |
603 | */ |
604 | TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
605 | Span span = Span()); |
606 | |
607 | /*! |
608 | * \brief Calculate floor(x) |
609 | * \param x The input expression. |
610 | * \param span The location of this operation in the source. |
611 | * \return The result expression. |
612 | */ |
613 | TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span()); |
614 | |
615 | /*! |
616 | * \brief Calculate ceil(x) |
617 | * \param x The input expression. |
618 | * \param span The location of this operation in the source. |
619 | * \return The result expression. |
620 | */ |
621 | TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span()); |
622 | |
623 | /*! |
624 | * \brief Calculate round(x) |
625 | * \param x The input expression. |
626 | * \param span The location of this operation in the source. |
627 | * \return The result expression. |
628 | */ |
629 | TVM_DLL PrimExpr round(PrimExpr x, Span span = Span()); |
630 | |
631 | /*! |
632 | * \brief Calculates std::nearbyint(x) |
633 | * \param x The input expression. |
634 | * \param span The location of this operation in the source. |
635 | * \return The result expression. |
636 | * This is a faster alternate to round. |
637 | */ |
638 | TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span()); |
639 | |
640 | /*! |
641 | * \brief Calculate trunc(x) |
642 | * \param x The input expression. |
643 | * \param span The location of this operation in the source. |
644 | * \return The result expression. |
645 | */ |
646 | TVM_DLL PrimExpr trunc(PrimExpr x, Span span = Span()); |
647 | |
648 | /*! |
649 | * \brief Construct a large uint constant by its low 32 bits and high 32bits. |
650 | * \param dtype The final data type. |
651 | * \param low The lower 32 bits. |
652 | * \param high The higher 32 bits. |
653 | * \param span The location of this operation in the source. |
654 | * \return The constructed expression. |
655 | */ |
656 | TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span = Span()); |
657 | |
658 | /*! |
659 | * \brief Execute a multiplication between two Q-numbers x and y |
660 | * followed by a right shift s. The mathematical expression is: |
661 | * |
662 | * out = round(x*y*2^-s) |
663 | * |
664 | * Please note that the two Q-numbers x and y are supposed to have |
665 | * the same number of fractional bits q. |
666 | * |
667 | * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) |
668 | * |
669 | * The rounding rule is to the nearest value, rounding half up |
670 | * (i.e., round(x.1) = x and round (x.5) = x+1) |
671 | * \param x first Q-number |
672 | * \param y second Q-number |
673 | * \param q number of fractional bits in x and y. Needs to be > 0 |
674 | * \param s integer right shift |
675 | * \param span The location of this operation in the source. |
676 | * \return The constructed expression. |
677 | */ |
678 | TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, |
679 | Span span = Span()); |
680 | |
681 | // Intrinsic operators |
682 | #define TVM_DECLARE_INTRIN_UNARY(OpName) \ |
683 | inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ |
684 | static const Op& op = Op::Get("tir." #OpName); \ |
685 | if (x.dtype().is_bfloat16()) { \ |
686 | DataType bf16_dtype = x.dtype(); \ |
687 | DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ |
688 | PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ |
689 | PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \ |
690 | return tir::Cast(bf16_dtype, {result_fp32}, span); \ |
691 | } else { \ |
692 | return tir::Call(x.dtype(), op, {x}, span); \ |
693 | } \ |
694 | } |
695 | |
696 | TVM_DECLARE_INTRIN_UNARY(exp); |
697 | TVM_DECLARE_INTRIN_UNARY(exp2); |
698 | TVM_DECLARE_INTRIN_UNARY(exp10); |
699 | TVM_DECLARE_INTRIN_UNARY(erf); |
700 | TVM_DECLARE_INTRIN_UNARY(tanh); |
701 | TVM_DECLARE_INTRIN_UNARY(sigmoid); |
702 | TVM_DECLARE_INTRIN_UNARY(sqrt); |
703 | TVM_DECLARE_INTRIN_UNARY(rsqrt); |
704 | TVM_DECLARE_INTRIN_UNARY(log); |
705 | TVM_DECLARE_INTRIN_UNARY(log2); |
706 | TVM_DECLARE_INTRIN_UNARY(log10); |
707 | TVM_DECLARE_INTRIN_UNARY(log1p); |
708 | TVM_DECLARE_INTRIN_UNARY(popcount); |
709 | TVM_DECLARE_INTRIN_UNARY(tan); |
710 | TVM_DECLARE_INTRIN_UNARY(cos); |
711 | TVM_DECLARE_INTRIN_UNARY(cosh); |
712 | TVM_DECLARE_INTRIN_UNARY(sin); |
713 | TVM_DECLARE_INTRIN_UNARY(sinh); |
714 | TVM_DECLARE_INTRIN_UNARY(asin); |
715 | TVM_DECLARE_INTRIN_UNARY(acos); |
716 | TVM_DECLARE_INTRIN_UNARY(atan); |
717 | TVM_DECLARE_INTRIN_UNARY(acosh); |
718 | TVM_DECLARE_INTRIN_UNARY(asinh); |
719 | TVM_DECLARE_INTRIN_UNARY(atanh); |
720 | TVM_DECLARE_INTRIN_UNARY(clz); |
721 | |
722 | #define TVM_DECLARE_INTRIN_BINARY(OpName) \ |
723 | inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ |
724 | static const Op& op = Op::Get("tir." #OpName); \ |
725 | return tir::Call(x.dtype(), op, {x, y}, span); \ |
726 | } |
727 | |
728 | TVM_DECLARE_INTRIN_BINARY(atan2); |
729 | TVM_DECLARE_INTRIN_BINARY(nextafter); |
730 | TVM_DECLARE_INTRIN_BINARY(copysign); |
731 | TVM_DECLARE_INTRIN_BINARY(hypot); |
732 | TVM_DECLARE_INTRIN_BINARY(ldexp); |
733 | |
734 | namespace tir { |
735 | |
736 | /*! |
737 | * \brief Check if type is a pointer to a runtime element type. |
738 | * \param type The type to be checked. |
739 | * \param element_type The corresponding element type. |
740 | * \return The check results |
741 | */ |
742 | inline bool IsPointerType(const Type& type, const DataType& element_type) { |
743 | if (!type.defined()) return false; |
744 | if (const auto* ptr_type = type.as<PointerTypeNode>()) { |
745 | if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) { |
746 | return prim_type->dtype == element_type; |
747 | } |
748 | } |
749 | return false; |
750 | } |
751 | |
752 | /*! |
753 | * \brief Make a const value with certain data type. |
754 | * \param t The target type. |
755 | * \param value The input value |
756 | * \return the result expression. |
757 | * \tparam ValueType The constant value type |
758 | * \param span The location of this operation in the source. |
759 | */ |
760 | template <typename ValueType, |
761 | typename = typename std::enable_if<std::is_pod<ValueType>::value>::type> |
762 | inline PrimExpr make_const(DataType t, ValueType value, Span span = Span()); |
763 | /*! |
764 | * \brief Make a const zero expr. |
765 | * \param t The target type. |
766 | * \param span The location of this operation in the source. |
767 | * \return the result expression. |
768 | */ |
769 | inline PrimExpr make_zero(DataType t, Span span = Span()); |
770 | /*! |
771 | * \brief Make a constant true expression. |
772 | * \param lanes The number of lanes in the bool |
773 | * \param span The location of this operation in the source. |
774 | * \return The result expression. |
775 | */ |
776 | inline PrimExpr const_true(int lanes = 1, Span span = Span()) { |
777 | return make_const(DataType::UInt(1, lanes), 1); |
778 | } |
779 | /*! |
780 | * \brief Make a constant false expression. |
781 | * \param lanes The number of lanes in the bool |
782 | * \param span The location of this operation in the source. |
783 | * \return The result expression. |
784 | */ |
785 | inline PrimExpr const_false(int lanes = 1, Span span = Span()) { |
786 | return make_const(DataType::UInt(1, lanes), 0); |
787 | } |
788 | /*! |
789 | * \brief Get x as constant int expression. |
790 | * \param x The expression |
791 | * \return the address to the int expression, |
792 | * return nullptr, if x is not IntImm. |
793 | */ |
794 | inline const int64_t* as_const_int(const PrimExpr& x) { |
795 | if (!x.defined()) return nullptr; |
796 | if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) { |
797 | return &(op->value); |
798 | } |
799 | |
800 | return nullptr; |
801 | } |
802 | |
803 | /*! |
804 | * \brief Check whether x is a constant integer expression. |
805 | * \param x The input argument |
806 | * \param value the value to be compared against. |
807 | * \return whether x is constant expression. |
808 | */ |
809 | inline bool is_const_int(const PrimExpr& x, int64_t value); |
810 | |
811 | /*! |
812 | * \brief Check whether stmt is nop. |
813 | * \param stmt The input statement |
814 | * \return whether stmt is nop |
815 | */ |
816 | inline bool is_no_op(const tir::Stmt& stmt); |
817 | |
818 | /*! |
819 | * \brief Check whether x is a constant integer 1 |
820 | * \param x The input argument. |
821 | * \note This only return true for integer types. |
822 | * \return whether x is constant 1 |
823 | */ |
824 | inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); } |
825 | |
826 | /*! |
827 | * \brief Check whether x is a constant integer 0 |
828 | * \param x The input argument |
829 | * \return whether x is constant 0 |
830 | * \note This only return true for integer types. |
831 | */ |
832 | inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); } |
833 | |
834 | /*! |
835 | * \brief Check whether x is an integer constant. |
836 | * \note This only return true for integer types. |
837 | * \return whether x is constant |
838 | */ |
839 | inline bool is_const_int(const PrimExpr& x); |
840 | |
841 | /*! |
842 | * \brief Check whether x is an integer/float constant. |
843 | * \note This only return true for integer types. |
844 | * \return whether x is constant |
845 | */ |
846 | inline bool is_const_number(const PrimExpr& x); |
847 | |
848 | /*! |
849 | * \brief Left fold. |
850 | * \param freduce The reduction function. |
851 | * \param init_value The initial value. |
852 | * \param values The values to be folded. |
853 | * \param span The location of the fold in the source. |
854 | * \return The result. |
855 | * \tparam FReduce The type of the reduction. |
856 | */ |
857 | template <typename FReduce> |
858 | inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values, |
859 | Span span = Span()); |
860 | |
861 | /*! |
862 | * \brief Check whether x is a constant power of two |
863 | * If x is power of two, write the power to the shift. |
864 | * |
865 | * \param x The input expression. |
866 | * \param shift The output shift if x is power of two. |
867 | * \return whether x is constant power of two |
868 | */ |
869 | TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift); |
870 | |
871 | // Implementation details after this |
872 | inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); } |
873 | |
874 | inline bool is_const_number(const PrimExpr& x) { |
875 | if (x.as<tir::IntImmNode>()) { |
876 | return true; |
877 | } else if (x.as<tir::FloatImmNode>()) { |
878 | return true; |
879 | } else if (const auto* op = x.as<tir::BroadcastNode>()) { |
880 | return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>()); |
881 | } |
882 | return false; |
883 | } |
884 | |
885 | inline bool is_positive_const(const PrimExpr& a) { |
886 | const int64_t* as_int = as_const_int(a); |
887 | return as_int && (*as_int > 0); |
888 | } |
889 | |
890 | inline bool is_negative_const(const PrimExpr& a) { |
891 | const int64_t* as_int = as_const_int(a); |
892 | return as_int && (*as_int < 0); |
893 | } |
894 | |
895 | inline bool is_const_int(const PrimExpr& x, int64_t value) { |
896 | const int64_t* as_int = as_const_int(x); |
897 | return as_int && (*as_int == value); |
898 | } |
899 | |
900 | inline bool is_no_op(const tir::Stmt& stmt) { |
901 | if (!stmt.defined()) return true; |
902 | if (const auto* op = stmt.as<tir::EvaluateNode>()) { |
903 | return is_const_int(op->value); |
904 | } |
905 | if (const auto* op = stmt.as<tir::SeqStmtNode>()) { |
906 | return op->seq.size() == 0; |
907 | } |
908 | return false; |
909 | } |
910 | |
911 | template <typename ValueType> |
912 | inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) { |
913 | if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span); |
914 | if (t.is_uint()) { |
915 | // Use IntImm if it is a small integer |
916 | uint64_t uval = static_cast<uint64_t>(value); |
917 | if (value < static_cast<ValueType>(0)) { |
918 | LOG(FATAL) << "cannot make uint from negative value " << value; |
919 | } else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) { |
920 | return IntImm(t, static_cast<int64_t>(value), span); |
921 | } else { |
922 | uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U; |
923 | uint64_t low = uval & mask; |
924 | uint64_t high = uval >> 32U; |
925 | return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span); |
926 | } |
927 | } |
928 | if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value), span); |
929 | // For now, we store const scalar values of custom datatypes within doubles; later, during the |
930 | // datatypes lowering pass, we will lower the value to its true representation in the format |
931 | // specified by the datatype. |
932 | // TODO(gus) when do we need to start worrying about doubles not being precise enough? |
933 | if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) { |
934 | return FloatImm(t, static_cast<double>(value), span); |
935 | } |
936 | LOG(FATAL) << "cannot make const for type " << t; |
937 | } |
938 | |
939 | template <> |
940 | inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) { |
941 | return MakeConstScalar(t, static_cast<int>(value), span); |
942 | } |
943 | |
944 | template <typename ValueType, typename> |
945 | inline PrimExpr make_const(DataType t, ValueType value, Span span) { |
946 | if (t.lanes() == 1) { |
947 | return MakeConstScalar(t, value, span); |
948 | } else { |
949 | return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); |
950 | } |
951 | } |
952 | |
953 | inline PrimExpr make_zero(DataType t, Span span) { |
954 | if (t.is_handle()) { |
955 | return reinterpret(t, make_const(DataType::UInt(64), 0, span)); |
956 | } |
957 | return make_const(t, 0, span); |
958 | } |
959 | |
960 | template <typename FReduce> |
961 | inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values, |
962 | Span span) { |
963 | for (PrimExpr val : values) { |
964 | init_value = freduce(init_value, val, span); |
965 | } |
966 | return init_value; |
967 | } |
968 | |
969 | } // namespace tir |
970 | |
971 | // additional const expression overloading |
972 | #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ |
973 | inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \ |
974 | a = OpFunc(a, b); \ |
975 | return a; \ |
976 | } |
977 | |
978 | #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ |
979 | inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ |
980 | inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ |
981 | inline PrimExpr Name(int a, const PrimExpr& b) { \ |
982 | return Name(tir::make_const(b.dtype(), a), b); \ |
983 | } \ |
984 | inline PrimExpr Name(const PrimExpr& a, int b) { \ |
985 | return Name(a, tir::make_const(a.dtype(), b)); \ |
986 | } \ |
987 | inline PrimExpr Name(const PrimExpr& a, double b) { \ |
988 | return Name(a, tir::make_const(DataType::Float(64), b)); \ |
989 | } |
990 | |
991 | #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \ |
992 | inline PrimExpr Name(const PrimExpr& a, float b, Span span = Span()) { \ |
993 | return Name(a, PrimExpr(b), span); \ |
994 | } \ |
995 | inline PrimExpr Name(float a, const PrimExpr& b, Span span = Span()) { \ |
996 | return Name(PrimExpr(a), b, span); \ |
997 | } \ |
998 | inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ |
999 | return Name(tir::make_const(b.dtype(), a), b, span); \ |
1000 | } \ |
1001 | inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ |
1002 | return Name(a, tir::make_const(a.dtype(), b), span); \ |
1003 | } \ |
1004 | inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \ |
1005 | return Name(a, tir::make_const(DataType::Float(64), b), span); \ |
1006 | } |
1007 | |
1008 | #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ |
1009 | inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \ |
1010 | inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); } |
1011 | |
1012 | #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \ |
1013 | inline PrimExpr Name(const PrimExpr& a, bool b, Span span = Span()) { \ |
1014 | return Name(a, PrimExpr(b), span); \ |
1015 | } \ |
1016 | inline PrimExpr Name(bool a, const PrimExpr& b, Span span = Span()) { \ |
1017 | return Name(PrimExpr(a), b, span); \ |
1018 | } |
1019 | |
1020 | #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ |
1021 | inline PrimExpr Name(const PrimExpr& a, int b) { \ |
1022 | return Name(a, tir::make_const(a.dtype(), b)); \ |
1023 | } \ |
1024 | inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); } |
1025 | |
1026 | #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \ |
1027 | inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ |
1028 | return Name(a, tir::make_const(a.dtype(), b), span); \ |
1029 | } \ |
1030 | inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ |
1031 | return Name(tir::make_const(b.dtype(), a), b, span); \ |
1032 | } |
1033 | |
1034 | TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); |
1035 | TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); |
1036 | TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*); |
1037 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+); |
1038 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-); |
1039 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); |
1040 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*) |
1041 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); |
1042 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) |
1043 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); |
1044 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(max); |
1045 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(min); |
1046 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(div); |
1047 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(add); |
1048 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(sub); |
1049 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(mul); |
1050 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(greater); |
1051 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(greater_equal); |
1052 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(less); |
1053 | TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(less_equal); |
1054 | // integer related ops |
1055 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexdiv); |
1056 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexmod); |
1057 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncdiv); |
1058 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncmod); |
1059 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floordiv); |
1060 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floormod); |
1061 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(right_shift); // NOLINT(*) |
1062 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(left_shift); // NOLINT(*) |
1063 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_and); |
1064 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_or); |
1065 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_xor); |
1066 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) |
1067 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) |
1068 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); |
1069 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); |
1070 | TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); |
1071 | // logical ops |
1072 | TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); |
1073 | TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); |
1074 | TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(logical_and); |
1075 | TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(logical_or); |
1076 | |
1077 | /*! |
1078 | * \brief Helper function to raise a compiler error about division ambiguity. |
1079 | * \note The call to this function will always results in a compiler error. |
1080 | * \tparam TA Any class type. |
1081 | */ |
1082 | template <typename TA> |
1083 | inline void DivAmbiguityError(const TA& a) { |
1084 | constexpr bool div_ambiguity = !std::is_class<TA>::value; |
1085 | static_assert(div_ambiguity, |
1086 | "TVM supports multiple types of integer divisions, " |
1087 | "please call div, indexdiv/indexmod, " |
1088 | "floordiv/floormod or truncdiv/truncmod directly " |
1089 | "to avoid ambiguity in the code. " |
1090 | "Checkout these functions in tir/op.h." ); |
1091 | } |
1092 | |
1093 | // The following code are not intended to be used in the codebase. |
1094 | // Instead, they generate clear compiler errors that ask developers |
1095 | // to use the specific division function. |
1096 | // The second template argument is necessary to make sure the |
1097 | // code compiles lazily by the compiler during invocation. |
1098 | template <typename TB> |
1099 | inline PrimExpr operator/(const PrimExpr& a, const TB& b) { |
1100 | DivAmbiguityError(a); |
1101 | return a; |
1102 | } |
1103 | |
1104 | template <typename TB> |
1105 | inline PrimExpr operator/=(const PrimExpr& a, const TB& b) { |
1106 | DivAmbiguityError(a); |
1107 | return a; |
1108 | } |
1109 | |
1110 | template <typename TB> |
1111 | inline PrimExpr operator%(const PrimExpr& a, const TB& b) { |
1112 | DivAmbiguityError(a); |
1113 | return a; |
1114 | } |
1115 | } // namespace tvm |
1116 | #endif // TVM_TIR_OP_H_ |
1117 | |