1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file data_type_rewriter.cc
22 * \brief Rewrite the data type of expressions.
23 */
24
25#include <tvm/tir/builtin.h>
26#include <tvm/tir/data_type_rewriter.h>
27#include <tvm/tir/op.h>
28
29#include "./functor_common.h"
30
31namespace tvm {
32namespace tir {
33
34Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) {
35 Stmt s = StmtExprMutator::VisitStmt_(op);
36 op = s.as<ForNode>();
37 ICHECK(op != nullptr) << "Expected type to be ForNode, but get " << s->GetTypeKey();
38 PrimExpr e = VisitExpr(op->loop_var);
39 Var var = Downcast<Var>(e);
40 return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body,
41 op->thread_binding, op->annotations);
42}
43
44Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) {
45 BlockRealize realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
46 Array<PrimExpr> new_iter_values;
47 bool changed = false;
48 for (int i = 0; i < static_cast<int>(op->iter_values.size()); ++i) {
49 auto dtype = realize->block->iter_vars[i]->var->dtype;
50 if (op->iter_values[i]->dtype != dtype) {
51 new_iter_values.push_back(cast(dtype, realize->iter_values[i]));
52 changed = true;
53 } else {
54 new_iter_values.push_back(realize->iter_values[i]);
55 }
56 }
57 if (changed) {
58 realize.CopyOnWrite()->iter_values = std::move(new_iter_values);
59 }
60 return std::move(realize);
61}
62
63Stmt DataTypeLegalizer::VisitStmt_(const BlockNode* op) {
64 Block new_block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
65 Array<IterVar> new_iter_vars = MutateArray(new_block->iter_vars, [/*this*/](const IterVar& iter) {
66 auto dtype = iter->var.dtype();
67 if (iter->dom->min->dtype != dtype || iter->dom->extent->dtype != dtype) {
68 IterVar new_iter = iter;
69 new_iter.CopyOnWrite()->dom =
70 Range(cast(dtype, iter->dom->min), cast(dtype, iter->dom->extent));
71 return new_iter;
72 } else {
73 return iter;
74 }
75 });
76 if (!op->iter_vars.same_as(new_iter_vars)) {
77 new_block.CopyOnWrite()->iter_vars = std::move(new_iter_vars);
78 }
79 return std::move(new_block);
80}
81
82Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
83 if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
84 Stmt s = StmtExprMutator::VisitStmt_(op);
85 op = s.as<AttrStmtNode>();
86 ICHECK(op != nullptr) << "Expected type to be AttrStmtNode"
87 << ", but get " << s->GetTypeKey();
88 const IterVarNode* iv = op->node.as<IterVarNode>();
89 ICHECK(iv != nullptr) << "Expected type to be IterVarNode"
90 << ", but get " << op->node->GetTypeKey();
91 PrimExpr e = VisitExpr(iv->var);
92 Var var = Downcast<Var>(e);
93 if (ivmap_.find(iv) == ivmap_.end()) {
94 Range dom = iv->dom;
95 if (dom.defined()) {
96 PrimExpr extend = dom->extent;
97 ICHECK(extend.dtype().is_int() && var.dtype().is_int());
98 if (var.dtype().bits() != extend.dtype().bits()) {
99 DataType dtype = var.dtype();
100 dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span);
101 }
102 }
103 ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag);
104 }
105 return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
106 }
107 return StmtExprMutator::VisitStmt_(op);
108}
109
110Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) {
111 PrimExpr value = this->VisitExpr(op->value);
112 auto new_var = op->var.copy_with_dtype(value.dtype());
113
114 if (value.dtype() != op->var->dtype) {
115 var_remap_[op->var.get()] = new_var;
116 }
117
118 Stmt new_body = this->VisitStmt(op->body);
119
120 if (value.same_as(op->value) && new_body.same_as(op->body)) {
121 return GetRef<Stmt>(op);
122 } else if (value.dtype() == op->var->dtype) {
123 auto n = CopyOnWrite(op);
124 n->value = std::move(value);
125 n->body = std::move(new_body);
126 return Stmt(n);
127 } else {
128 return LetStmt(new_var, value, new_body, op->span);
129 }
130}
131
132PrimExpr DataTypeLegalizer::VisitExpr_(const VarNode* op) {
133 if (auto it = var_remap_.find(op); it != var_remap_.end()) {
134 return it->second;
135 }
136 return GetRef<Var>(op);
137}
138
139PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) {
140 PrimExpr condition = this->VisitExpr(op->condition);
141 PrimExpr true_value = this->VisitExpr(op->true_value);
142 PrimExpr false_value = this->VisitExpr(op->false_value);
143 if (condition.same_as(op->condition) && true_value.same_as(op->true_value) &&
144 false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) {
145 return GetRef<PrimExpr>(op);
146 } else {
147 int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits());
148 DataType dtype = true_value.dtype().with_bits(bits);
149 if (true_value.dtype() != dtype) true_value = cast(dtype, true_value);
150 if (false_value.dtype() != dtype) false_value = cast(dtype, false_value);
151 return Select(condition, true_value, false_value);
152 }
153}
154
155PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) {
156 PrimExpr base = VisitExpr(op->base);
157 PrimExpr stride = VisitExpr(op->stride);
158 if (base.same_as(op->base) && stride.same_as(op->stride) && base.dtype() == stride.dtype()) {
159 return GetRef<PrimExpr>(op);
160 } else {
161 ICHECK(base.dtype().is_int() && stride.dtype().is_int());
162 int bits = std::max(base.dtype().bits(), stride.dtype().bits());
163 DataType dtype = base.dtype().with_bits(bits);
164 if (base.dtype() != dtype) base = cast(dtype, base);
165 if (stride.dtype() != dtype) stride = cast(dtype, stride);
166 return Ramp(base, stride, op->lanes);
167 }
168}
169
170PrimExpr DataTypeLegalizer::VisitExpr_(const CastNode* op) {
171 return StmtExprMutator::VisitExpr_(op);
172}
173
174#define TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
175 PrimExpr DataTypeLegalizer::VisitExpr_(const OP* op) { \
176 PrimExpr a = this->VisitExpr(op->a); \
177 PrimExpr b = this->VisitExpr(op->b); \
178 if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \
179 return GetRef<PrimExpr>(op); \
180 } else { \
181 return FUNC(a, b); \
182 } \
183 }
184
185TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+);
186TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-);
187TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*);
188TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div);
189TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod);
190TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv);
191TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod);
192TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min);
193TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max);
194TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);
195TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=);
196TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=);
197TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*)
198TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*)
199TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
200
201#undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
202
203PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
204 PrimExpr e = StmtExprMutator::VisitExpr_(op);
205 op = e.as<CallNode>();
206 static const Op& builtin_pow_ = Op::Get("tir.pow");
207 ICHECK(op != nullptr) << "Expected type to be CallNode"
208 << ", but get " << e->GetTypeKey();
209 if (op->op.same_as(builtin::shift_right())) {
210 return op->args[0] >> op->args[1];
211 } else if (op->op.same_as(builtin::shift_left())) {
212 return op->args[0] << op->args[1];
213 } else if (op->op.same_as(builtin::bitwise_and())) {
214 return op->args[0] & op->args[1];
215 } else if (op->op.same_as(builtin::bitwise_or())) {
216 return op->args[0] | op->args[1];
217 } else if (op->op.same_as(builtin::bitwise_xor())) {
218 return op->args[0] ^ op->args[1];
219 } else if (op->op.same_as(builtin_pow_)) {
220 return pow(op->args[0], op->args[1]);
221 } else if (op->op.same_as(builtin::if_then_else())) {
222 return if_then_else(op->args[0], op->args[1], op->args[2]);
223 }
224 return e;
225}
226
227Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) {
228 bool is_enabled = is_enabled_;
229 is_enabled_ = true;
230 auto new_extents = op->extents.Map([this](const PrimExpr& e) { return this->VisitExpr(e); });
231 auto new_cond = VisitExpr(op->condition);
232 is_enabled_ = is_enabled;
233 auto new_body = this->VisitStmt(op->body);
234 if (!new_extents.same_as(op->extents) || !new_cond.same_as(op->condition) ||
235 !new_body.same_as(op->body)) {
236 Allocate new_allocate = GetRef<Allocate>(op);
237 auto* n = new_allocate.CopyOnWrite();
238 n->extents = std::move(new_extents);
239 n->condition = std::move(new_cond);
240 n->body = std::move(new_body);
241 return std::move(new_allocate);
242 } else {
243 return GetRef<Stmt>(op);
244 }
245}
246
247Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) {
248 Buffer new_buffer = VisitBuffer(op->buffer);
249 DeclBuffer decl_buffer = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
250 if (!new_buffer.same_as(op->buffer)) {
251 decl_buffer.CopyOnWrite()->buffer = new_buffer;
252 }
253 return std::move(decl_buffer);
254}
255
256Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) {
257 bool is_condition = is_condition_;
258 is_condition_ = true;
259 auto new_predicate = VisitExpr(op->predicate);
260 is_condition_ = is_condition;
261
262 bool is_enabled = is_enabled_;
263 is_enabled_ = true;
264 auto new_iter_values =
265 op->iter_values.Map([this](const PrimExpr& e) { return this->VisitExpr(e); });
266 is_enabled_ = is_enabled;
267 Block new_body = Downcast<Block>(this->VisitStmt(op->block));
268 if (!new_predicate.same_as(op->predicate) || !new_iter_values.same_as(op->iter_values) ||
269 !new_body.same_as(op->block)) {
270 BlockRealize new_block_realize = GetRef<BlockRealize>(op);
271 auto* n = new_block_realize.CopyOnWrite();
272 n->predicate = std::move(new_predicate);
273 n->iter_values = std::move(new_iter_values);
274 n->block = std::move(new_body);
275 return std::move(new_block_realize);
276 } else {
277 return GetRef<Stmt>(op);
278 }
279}
280
281Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) {
282 Array<Buffer> new_alloc_buffers =
283 op->alloc_buffers.Map([this](const Buffer& buffer) { return this->VisitBuffer(buffer); });
284 Array<MatchBufferRegion> new_match_buffers =
285 op->match_buffers.Map([this](const MatchBufferRegion& match_buffer_region) {
286 Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer);
287 BufferRegion new_buffer_region = this->VisitBufferRegion(match_buffer_region->source);
288 if (!new_buffer.same_as(match_buffer_region->buffer) ||
289 !new_buffer_region.same_as(match_buffer_region->source)) {
290 return MatchBufferRegion(new_buffer, new_buffer_region);
291 } else {
292 return match_buffer_region;
293 }
294 });
295 Array<BufferRegion> new_reads = op->reads.Map(
296 [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); });
297 Array<BufferRegion> new_writes = op->writes.Map(
298 [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); });
299 Array<IterVar> new_iter_vars =
300 op->iter_vars.Map([this](const IterVar& iter_var) { return this->VisitIterVar(iter_var); });
301 Optional<Stmt> new_init = NullOpt;
302 if (op->init.defined()) {
303 new_init = this->VisitStmt(op->init.value());
304 }
305 Map<String, ObjectRef> new_annotations = VisitBlockAnnotations(op->annotations);
306 Stmt new_body = this->VisitStmt(op->body);
307
308 if (!new_init.same_as(op->init) || !new_body.same_as(op->body) ||
309 !new_alloc_buffers.same_as(op->alloc_buffers) ||
310 !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) ||
311 !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars) ||
312 !new_annotations.same_as(op->annotations)) {
313 Block new_block = GetRef<Block>(op);
314 BlockNode* n = new_block.CopyOnWrite();
315 n->alloc_buffers = std::move(new_alloc_buffers);
316 n->match_buffers = std::move(new_match_buffers);
317 n->reads = std::move(new_reads);
318 n->writes = std::move(new_writes);
319 n->iter_vars = std::move(new_iter_vars);
320 n->init = std::move(new_init);
321 n->annotations = std::move(new_annotations);
322 n->body = std::move(new_body);
323 return std::move(new_block);
324 }
325 return GetRef<Stmt>(op);
326}
327
328Map<String, ObjectRef> IndexDataTypeRewriter::VisitBlockAnnotations(
329 const Map<String, ObjectRef>& annotations) {
330 auto new_annotations = annotations;
331
332 std::function<ObjectRef(const ObjectRef&)> f_mutate_obj =
333 [this, &f_mutate_obj](const ObjectRef& obj) -> ObjectRef {
334 if (!obj.defined()) {
335 return obj;
336 }
337 if (obj->IsInstance<BufferNode>()) {
338 Buffer buffer = Downcast<Buffer>(obj);
339 if (Buffer new_buffer = GetRemappedBuffer(buffer); !new_buffer.same_as(buffer)) {
340 return new_buffer;
341 }
342 } else if (obj->IsInstance<ArrayNode>()) {
343 return Downcast<Array<ObjectRef>>(obj).Map(f_mutate_obj);
344 }
345 return obj;
346 };
347 for (const auto& [key, value] : annotations) {
348 auto new_value = f_mutate_obj(value);
349 if (!new_value.same_as(value)) {
350 new_annotations.Set(key, new_value);
351 }
352 }
353 return new_annotations;
354}
355
356Buffer IndexDataTypeRewriter::GetRemappedBuffer(const Buffer& buffer) {
357 if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
358 return (*it).second;
359 }
360 return buffer;
361}
362
363IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) {
364 bool is_enabled = is_enabled_;
365 is_enabled_ = true;
366 Var new_var = Downcast<Var>(VisitExpr(iter_var->var));
367 PrimExpr min = VisitExpr(iter_var->dom->min);
368 PrimExpr extent = VisitExpr(iter_var->dom->extent);
369 is_enabled_ = is_enabled;
370 if (!new_var.same_as(iter_var->var) || !min.same_as(iter_var->dom->min) ||
371 !extent.same_as(iter_var->dom->extent)) {
372 IterVar new_iter_var = iter_var;
373 IterVarNode* n = new_iter_var.CopyOnWrite();
374 n->var = std::move(new_var);
375 n->dom = Range(min, extent);
376 return new_iter_var;
377 }
378 return iter_var;
379}
380
381Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
382 bool is_enabled = is_enabled_;
383
384 is_enabled_ = true;
385 Array<PrimExpr> new_shape =
386 buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
387 Array<PrimExpr> new_strides =
388 buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
389 auto new_elem_offset = VisitExpr(buffer->elem_offset);
390 is_enabled_ = is_enabled;
391
392 if (!buffer->shape.same_as(new_shape) || !buffer->strides.same_as(new_strides) ||
393 !buffer->elem_offset.same_as(new_elem_offset)) {
394 Buffer new_buffer = buffer;
395 BufferNode* new_buffer_node = new_buffer.CopyOnWrite();
396 new_buffer_node->shape = std::move(new_shape);
397 new_buffer_node->strides = std::move(new_strides);
398 new_buffer_node->elem_offset = std::move(new_elem_offset);
399 buffer_remap_.Set(buffer, new_buffer);
400 return new_buffer;
401 } else {
402 return buffer;
403 }
404}
405
406BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& buffer_region) {
407 Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer);
408
409 bool is_enabled = is_enabled_;
410 is_enabled_ = true;
411 auto new_region = buffer_region->region.Map([&](const Range& range) {
412 return Range::FromMinExtent(this->VisitExpr(range->min), this->VisitExpr(range->extent));
413 });
414 is_enabled_ = is_enabled;
415
416 if (!remapped_buffer.same_as(buffer_region->buffer) ||
417 !new_region.same_as(buffer_region->region)) {
418 return BufferRegion(remapped_buffer, new_region);
419 } else {
420 return buffer_region;
421 }
422}
423
424Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
425 BufferStore store = GetRef<BufferStore>(op);
426
427 Buffer new_buffer = GetRemappedBuffer(op->buffer);
428 auto value = this->VisitExpr(op->value);
429 if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
430 value = cast(new_buffer->dtype, value);
431 }
432 auto indices = VisitIndices(op->indices);
433
434 if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) ||
435 !indices.same_as(op->indices)) {
436 auto writer = store.CopyOnWrite();
437 writer->buffer = new_buffer;
438 writer->value = value;
439 writer->indices = indices;
440 }
441
442 return std::move(store);
443}
444
445PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) {
446 BufferLoad load = GetRef<BufferLoad>(op);
447
448 Buffer new_buffer = GetRemappedBuffer(op->buffer);
449 auto indices = VisitIndices(op->indices);
450
451 if (!new_buffer.same_as(op->buffer) || !indices.same_as(op->indices)) {
452 auto writer = load.CopyOnWrite();
453 writer->indices = indices;
454 writer->buffer = new_buffer;
455 }
456
457 return std::move(load);
458}
459
460Array<PrimExpr> IndexDataTypeRewriter::VisitIndices(Array<PrimExpr> indices) {
461 bool is_enabled = is_enabled_;
462 is_enabled_ = true;
463
464 auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
465 indices.MutateByApply(fmutate);
466
467 is_enabled_ = is_enabled;
468
469 return indices;
470}
471
472Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) {
473 bool is_condition = is_condition_;
474 is_condition_ = true;
475 PrimExpr cond = VisitExpr(op->condition);
476 is_condition_ = is_condition;
477
478 Stmt then_case = VisitStmt(op->then_case);
479 Optional<Stmt> else_case =
480 op->else_case.defined() ? Optional<Stmt>{VisitStmt(op->else_case.value())} : NullOpt;
481 if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) ||
482 !else_case.same_as(op->else_case)) {
483 IfThenElse new_stmt = GetRef<IfThenElse>(op);
484 auto* n = new_stmt.CopyOnWrite();
485 n->condition = std::move(cond);
486 n->then_case = std::move(then_case);
487 n->else_case = std::move(else_case);
488 return std::move(new_stmt);
489 }
490 return GetRef<Stmt>(op);
491}
492
493Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) {
494 bool is_enabled = is_enabled_;
495 is_enabled_ = true;
496 Var new_loop_var = Downcast<Var>(VisitExpr(op->loop_var));
497 PrimExpr min = VisitExpr(op->min);
498 PrimExpr extent = VisitExpr(op->extent);
499 is_enabled_ = is_enabled;
500
501 Stmt new_body = VisitStmt(op->body);
502
503 if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || !extent.same_as(op->extent) ||
504 !new_body.same_as(op->body)) {
505 For new_for = GetRef<For>(op);
506 auto* n = new_for.CopyOnWrite();
507 n->loop_var = new_loop_var;
508 n->min = cast(new_loop_var.dtype(), min);
509 n->extent = cast(new_loop_var.dtype(), extent);
510 n->body = new_body;
511 return std::move(new_for);
512 } else {
513 return GetRef<Stmt>(op);
514 }
515}
516
517#define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
518 PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \
519 bool is_enabled = is_enabled_; \
520 is_enabled_ = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \
521 auto result = Parent::VisitExpr_(op); \
522 is_enabled_ = is_enabled; \
523 return std::move(result); \
524 }
525
526TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);
527TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=);
528TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=);
529TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*)
530TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*)
531TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
532
533PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) {
534 // handle if_then_else condition
535 if (op->op.same_as(builtin::if_then_else())) {
536 bool is_condition = is_condition_;
537 is_condition_ = true;
538 PrimExpr cond = VisitExpr(op->args[0]);
539 is_condition_ = is_condition;
540 return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2]));
541 }
542 return Parent::VisitExpr_(op);
543}
544
545#undef TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH
546
547IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type)
548 : target_data_type_(std::move(target_data_type)) {}
549
550PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
551 Map<Var, Buffer> new_buffer_map = func->buffer_map;
552 for (const auto& [var, buffer] : func->buffer_map) {
553 new_buffer_map.Set(var, VisitBuffer(buffer));
554 }
555 PrimFuncNode* new_func = func.CopyOnWrite();
556 new_func->buffer_map = std::move(new_buffer_map);
557 new_func->body = VisitStmt(std::move(new_func->body));
558 return func;
559}
560
561PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) {
562 if (is_enabled_) {
563 ICHECK_LE(op->value, Downcast<Integer>(max_value(target_data_type_))->value);
564 return cast(target_data_type_, GetRef<IntImm>(op));
565 }
566 return GetRef<IntImm>(op);
567}
568
569PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
570 if (is_enabled_ && op->dtype != target_data_type_ && !var_remap_.count(op)) {
571 var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
572 }
573 return DataTypeLegalizer::VisitExpr_(op);
574}
575
576PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
577 // Unwrap the cast only when the dtype of this cast is integer dtype.
578 // When the dtype of this cast is not integer dtype, it means that this cast
579 // has some other purpose, and we should not unwrap the cast.
580 if (is_enabled_ && op->dtype.is_int()) {
581 PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value);
582 return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value);
583 }
584 return IndexDataTypeRewriter::VisitExpr_(op);
585}
586
587} // namespace tir
588} // namespace tvm
589