1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file data_type_rewriter.cc |
22 | * \brief Rewrite the data type of expressions. |
23 | */ |
24 | |
25 | #include <tvm/tir/builtin.h> |
26 | #include <tvm/tir/data_type_rewriter.h> |
27 | #include <tvm/tir/op.h> |
28 | |
29 | #include "./functor_common.h" |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) { |
35 | Stmt s = StmtExprMutator::VisitStmt_(op); |
36 | op = s.as<ForNode>(); |
37 | ICHECK(op != nullptr) << "Expected type to be ForNode, but get " << s->GetTypeKey(); |
38 | PrimExpr e = VisitExpr(op->loop_var); |
39 | Var var = Downcast<Var>(e); |
40 | return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body, |
41 | op->thread_binding, op->annotations); |
42 | } |
43 | |
44 | Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { |
45 | BlockRealize realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op)); |
46 | Array<PrimExpr> new_iter_values; |
47 | bool changed = false; |
48 | for (int i = 0; i < static_cast<int>(op->iter_values.size()); ++i) { |
49 | auto dtype = realize->block->iter_vars[i]->var->dtype; |
50 | if (op->iter_values[i]->dtype != dtype) { |
51 | new_iter_values.push_back(cast(dtype, realize->iter_values[i])); |
52 | changed = true; |
53 | } else { |
54 | new_iter_values.push_back(realize->iter_values[i]); |
55 | } |
56 | } |
57 | if (changed) { |
58 | realize.CopyOnWrite()->iter_values = std::move(new_iter_values); |
59 | } |
60 | return std::move(realize); |
61 | } |
62 | |
63 | Stmt DataTypeLegalizer::VisitStmt_(const BlockNode* op) { |
64 | Block new_block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); |
65 | Array<IterVar> new_iter_vars = MutateArray(new_block->iter_vars, [/*this*/](const IterVar& iter) { |
66 | auto dtype = iter->var.dtype(); |
67 | if (iter->dom->min->dtype != dtype || iter->dom->extent->dtype != dtype) { |
68 | IterVar new_iter = iter; |
69 | new_iter.CopyOnWrite()->dom = |
70 | Range(cast(dtype, iter->dom->min), cast(dtype, iter->dom->extent)); |
71 | return new_iter; |
72 | } else { |
73 | return iter; |
74 | } |
75 | }); |
76 | if (!op->iter_vars.same_as(new_iter_vars)) { |
77 | new_block.CopyOnWrite()->iter_vars = std::move(new_iter_vars); |
78 | } |
79 | return std::move(new_block); |
80 | } |
81 | |
82 | Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) { |
83 | if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { |
84 | Stmt s = StmtExprMutator::VisitStmt_(op); |
85 | op = s.as<AttrStmtNode>(); |
86 | ICHECK(op != nullptr) << "Expected type to be AttrStmtNode" |
87 | << ", but get " << s->GetTypeKey(); |
88 | const IterVarNode* iv = op->node.as<IterVarNode>(); |
89 | ICHECK(iv != nullptr) << "Expected type to be IterVarNode" |
90 | << ", but get " << op->node->GetTypeKey(); |
91 | PrimExpr e = VisitExpr(iv->var); |
92 | Var var = Downcast<Var>(e); |
93 | if (ivmap_.find(iv) == ivmap_.end()) { |
94 | Range dom = iv->dom; |
95 | if (dom.defined()) { |
96 | PrimExpr extend = dom->extent; |
97 | ICHECK(extend.dtype().is_int() && var.dtype().is_int()); |
98 | if (var.dtype().bits() != extend.dtype().bits()) { |
99 | DataType dtype = var.dtype(); |
100 | dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span); |
101 | } |
102 | } |
103 | ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag); |
104 | } |
105 | return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); |
106 | } |
107 | return StmtExprMutator::VisitStmt_(op); |
108 | } |
109 | |
110 | Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) { |
111 | PrimExpr value = this->VisitExpr(op->value); |
112 | auto new_var = op->var.copy_with_dtype(value.dtype()); |
113 | |
114 | if (value.dtype() != op->var->dtype) { |
115 | var_remap_[op->var.get()] = new_var; |
116 | } |
117 | |
118 | Stmt new_body = this->VisitStmt(op->body); |
119 | |
120 | if (value.same_as(op->value) && new_body.same_as(op->body)) { |
121 | return GetRef<Stmt>(op); |
122 | } else if (value.dtype() == op->var->dtype) { |
123 | auto n = CopyOnWrite(op); |
124 | n->value = std::move(value); |
125 | n->body = std::move(new_body); |
126 | return Stmt(n); |
127 | } else { |
128 | return LetStmt(new_var, value, new_body, op->span); |
129 | } |
130 | } |
131 | |
132 | PrimExpr DataTypeLegalizer::VisitExpr_(const VarNode* op) { |
133 | if (auto it = var_remap_.find(op); it != var_remap_.end()) { |
134 | return it->second; |
135 | } |
136 | return GetRef<Var>(op); |
137 | } |
138 | |
139 | PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) { |
140 | PrimExpr condition = this->VisitExpr(op->condition); |
141 | PrimExpr true_value = this->VisitExpr(op->true_value); |
142 | PrimExpr false_value = this->VisitExpr(op->false_value); |
143 | if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && |
144 | false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) { |
145 | return GetRef<PrimExpr>(op); |
146 | } else { |
147 | int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); |
148 | DataType dtype = true_value.dtype().with_bits(bits); |
149 | if (true_value.dtype() != dtype) true_value = cast(dtype, true_value); |
150 | if (false_value.dtype() != dtype) false_value = cast(dtype, false_value); |
151 | return Select(condition, true_value, false_value); |
152 | } |
153 | } |
154 | |
155 | PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) { |
156 | PrimExpr base = VisitExpr(op->base); |
157 | PrimExpr stride = VisitExpr(op->stride); |
158 | if (base.same_as(op->base) && stride.same_as(op->stride) && base.dtype() == stride.dtype()) { |
159 | return GetRef<PrimExpr>(op); |
160 | } else { |
161 | ICHECK(base.dtype().is_int() && stride.dtype().is_int()); |
162 | int bits = std::max(base.dtype().bits(), stride.dtype().bits()); |
163 | DataType dtype = base.dtype().with_bits(bits); |
164 | if (base.dtype() != dtype) base = cast(dtype, base); |
165 | if (stride.dtype() != dtype) stride = cast(dtype, stride); |
166 | return Ramp(base, stride, op->lanes); |
167 | } |
168 | } |
169 | |
170 | PrimExpr DataTypeLegalizer::VisitExpr_(const CastNode* op) { |
171 | return StmtExprMutator::VisitExpr_(op); |
172 | } |
173 | |
174 | #define TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ |
175 | PrimExpr DataTypeLegalizer::VisitExpr_(const OP* op) { \ |
176 | PrimExpr a = this->VisitExpr(op->a); \ |
177 | PrimExpr b = this->VisitExpr(op->b); \ |
178 | if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \ |
179 | return GetRef<PrimExpr>(op); \ |
180 | } else { \ |
181 | return FUNC(a, b); \ |
182 | } \ |
183 | } |
184 | |
185 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); |
186 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); |
187 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); |
188 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); |
189 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); |
190 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); |
191 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); |
192 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); |
193 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); |
194 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); |
195 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); |
196 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); |
197 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) |
198 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) |
199 | TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); |
200 | |
201 | #undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH |
202 | |
203 | PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { |
204 | PrimExpr e = StmtExprMutator::VisitExpr_(op); |
205 | op = e.as<CallNode>(); |
206 | static const Op& builtin_pow_ = Op::Get("tir.pow" ); |
207 | ICHECK(op != nullptr) << "Expected type to be CallNode" |
208 | << ", but get " << e->GetTypeKey(); |
209 | if (op->op.same_as(builtin::shift_right())) { |
210 | return op->args[0] >> op->args[1]; |
211 | } else if (op->op.same_as(builtin::shift_left())) { |
212 | return op->args[0] << op->args[1]; |
213 | } else if (op->op.same_as(builtin::bitwise_and())) { |
214 | return op->args[0] & op->args[1]; |
215 | } else if (op->op.same_as(builtin::bitwise_or())) { |
216 | return op->args[0] | op->args[1]; |
217 | } else if (op->op.same_as(builtin::bitwise_xor())) { |
218 | return op->args[0] ^ op->args[1]; |
219 | } else if (op->op.same_as(builtin_pow_)) { |
220 | return pow(op->args[0], op->args[1]); |
221 | } else if (op->op.same_as(builtin::if_then_else())) { |
222 | return if_then_else(op->args[0], op->args[1], op->args[2]); |
223 | } |
224 | return e; |
225 | } |
226 | |
227 | Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { |
228 | bool is_enabled = is_enabled_; |
229 | is_enabled_ = true; |
230 | auto new_extents = op->extents.Map([this](const PrimExpr& e) { return this->VisitExpr(e); }); |
231 | auto new_cond = VisitExpr(op->condition); |
232 | is_enabled_ = is_enabled; |
233 | auto new_body = this->VisitStmt(op->body); |
234 | if (!new_extents.same_as(op->extents) || !new_cond.same_as(op->condition) || |
235 | !new_body.same_as(op->body)) { |
236 | Allocate new_allocate = GetRef<Allocate>(op); |
237 | auto* n = new_allocate.CopyOnWrite(); |
238 | n->extents = std::move(new_extents); |
239 | n->condition = std::move(new_cond); |
240 | n->body = std::move(new_body); |
241 | return std::move(new_allocate); |
242 | } else { |
243 | return GetRef<Stmt>(op); |
244 | } |
245 | } |
246 | |
247 | Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) { |
248 | Buffer new_buffer = VisitBuffer(op->buffer); |
249 | DeclBuffer decl_buffer = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op)); |
250 | if (!new_buffer.same_as(op->buffer)) { |
251 | decl_buffer.CopyOnWrite()->buffer = new_buffer; |
252 | } |
253 | return std::move(decl_buffer); |
254 | } |
255 | |
256 | Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { |
257 | bool is_condition = is_condition_; |
258 | is_condition_ = true; |
259 | auto new_predicate = VisitExpr(op->predicate); |
260 | is_condition_ = is_condition; |
261 | |
262 | bool is_enabled = is_enabled_; |
263 | is_enabled_ = true; |
264 | auto new_iter_values = |
265 | op->iter_values.Map([this](const PrimExpr& e) { return this->VisitExpr(e); }); |
266 | is_enabled_ = is_enabled; |
267 | Block new_body = Downcast<Block>(this->VisitStmt(op->block)); |
268 | if (!new_predicate.same_as(op->predicate) || !new_iter_values.same_as(op->iter_values) || |
269 | !new_body.same_as(op->block)) { |
270 | BlockRealize new_block_realize = GetRef<BlockRealize>(op); |
271 | auto* n = new_block_realize.CopyOnWrite(); |
272 | n->predicate = std::move(new_predicate); |
273 | n->iter_values = std::move(new_iter_values); |
274 | n->block = std::move(new_body); |
275 | return std::move(new_block_realize); |
276 | } else { |
277 | return GetRef<Stmt>(op); |
278 | } |
279 | } |
280 | |
281 | Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { |
282 | Array<Buffer> new_alloc_buffers = |
283 | op->alloc_buffers.Map([this](const Buffer& buffer) { return this->VisitBuffer(buffer); }); |
284 | Array<MatchBufferRegion> new_match_buffers = |
285 | op->match_buffers.Map([this](const MatchBufferRegion& match_buffer_region) { |
286 | Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer); |
287 | BufferRegion new_buffer_region = this->VisitBufferRegion(match_buffer_region->source); |
288 | if (!new_buffer.same_as(match_buffer_region->buffer) || |
289 | !new_buffer_region.same_as(match_buffer_region->source)) { |
290 | return MatchBufferRegion(new_buffer, new_buffer_region); |
291 | } else { |
292 | return match_buffer_region; |
293 | } |
294 | }); |
295 | Array<BufferRegion> new_reads = op->reads.Map( |
296 | [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); |
297 | Array<BufferRegion> new_writes = op->writes.Map( |
298 | [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); |
299 | Array<IterVar> new_iter_vars = |
300 | op->iter_vars.Map([this](const IterVar& iter_var) { return this->VisitIterVar(iter_var); }); |
301 | Optional<Stmt> new_init = NullOpt; |
302 | if (op->init.defined()) { |
303 | new_init = this->VisitStmt(op->init.value()); |
304 | } |
305 | Map<String, ObjectRef> new_annotations = VisitBlockAnnotations(op->annotations); |
306 | Stmt new_body = this->VisitStmt(op->body); |
307 | |
308 | if (!new_init.same_as(op->init) || !new_body.same_as(op->body) || |
309 | !new_alloc_buffers.same_as(op->alloc_buffers) || |
310 | !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) || |
311 | !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars) || |
312 | !new_annotations.same_as(op->annotations)) { |
313 | Block new_block = GetRef<Block>(op); |
314 | BlockNode* n = new_block.CopyOnWrite(); |
315 | n->alloc_buffers = std::move(new_alloc_buffers); |
316 | n->match_buffers = std::move(new_match_buffers); |
317 | n->reads = std::move(new_reads); |
318 | n->writes = std::move(new_writes); |
319 | n->iter_vars = std::move(new_iter_vars); |
320 | n->init = std::move(new_init); |
321 | n->annotations = std::move(new_annotations); |
322 | n->body = std::move(new_body); |
323 | return std::move(new_block); |
324 | } |
325 | return GetRef<Stmt>(op); |
326 | } |
327 | |
328 | Map<String, ObjectRef> IndexDataTypeRewriter::VisitBlockAnnotations( |
329 | const Map<String, ObjectRef>& annotations) { |
330 | auto new_annotations = annotations; |
331 | |
332 | std::function<ObjectRef(const ObjectRef&)> f_mutate_obj = |
333 | [this, &f_mutate_obj](const ObjectRef& obj) -> ObjectRef { |
334 | if (!obj.defined()) { |
335 | return obj; |
336 | } |
337 | if (obj->IsInstance<BufferNode>()) { |
338 | Buffer buffer = Downcast<Buffer>(obj); |
339 | if (Buffer new_buffer = GetRemappedBuffer(buffer); !new_buffer.same_as(buffer)) { |
340 | return new_buffer; |
341 | } |
342 | } else if (obj->IsInstance<ArrayNode>()) { |
343 | return Downcast<Array<ObjectRef>>(obj).Map(f_mutate_obj); |
344 | } |
345 | return obj; |
346 | }; |
347 | for (const auto& [key, value] : annotations) { |
348 | auto new_value = f_mutate_obj(value); |
349 | if (!new_value.same_as(value)) { |
350 | new_annotations.Set(key, new_value); |
351 | } |
352 | } |
353 | return new_annotations; |
354 | } |
355 | |
356 | Buffer IndexDataTypeRewriter::GetRemappedBuffer(const Buffer& buffer) { |
357 | if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) { |
358 | return (*it).second; |
359 | } |
360 | return buffer; |
361 | } |
362 | |
363 | IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) { |
364 | bool is_enabled = is_enabled_; |
365 | is_enabled_ = true; |
366 | Var new_var = Downcast<Var>(VisitExpr(iter_var->var)); |
367 | PrimExpr min = VisitExpr(iter_var->dom->min); |
368 | PrimExpr extent = VisitExpr(iter_var->dom->extent); |
369 | is_enabled_ = is_enabled; |
370 | if (!new_var.same_as(iter_var->var) || !min.same_as(iter_var->dom->min) || |
371 | !extent.same_as(iter_var->dom->extent)) { |
372 | IterVar new_iter_var = iter_var; |
373 | IterVarNode* n = new_iter_var.CopyOnWrite(); |
374 | n->var = std::move(new_var); |
375 | n->dom = Range(min, extent); |
376 | return new_iter_var; |
377 | } |
378 | return iter_var; |
379 | } |
380 | |
381 | Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) { |
382 | bool is_enabled = is_enabled_; |
383 | |
384 | is_enabled_ = true; |
385 | Array<PrimExpr> new_shape = |
386 | buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); }); |
387 | Array<PrimExpr> new_strides = |
388 | buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); }); |
389 | auto new_elem_offset = VisitExpr(buffer->elem_offset); |
390 | is_enabled_ = is_enabled; |
391 | |
392 | if (!buffer->shape.same_as(new_shape) || !buffer->strides.same_as(new_strides) || |
393 | !buffer->elem_offset.same_as(new_elem_offset)) { |
394 | Buffer new_buffer = buffer; |
395 | BufferNode* new_buffer_node = new_buffer.CopyOnWrite(); |
396 | new_buffer_node->shape = std::move(new_shape); |
397 | new_buffer_node->strides = std::move(new_strides); |
398 | new_buffer_node->elem_offset = std::move(new_elem_offset); |
399 | buffer_remap_.Set(buffer, new_buffer); |
400 | return new_buffer; |
401 | } else { |
402 | return buffer; |
403 | } |
404 | } |
405 | |
406 | BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& buffer_region) { |
407 | Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer); |
408 | |
409 | bool is_enabled = is_enabled_; |
410 | is_enabled_ = true; |
411 | auto new_region = buffer_region->region.Map([&](const Range& range) { |
412 | return Range::FromMinExtent(this->VisitExpr(range->min), this->VisitExpr(range->extent)); |
413 | }); |
414 | is_enabled_ = is_enabled; |
415 | |
416 | if (!remapped_buffer.same_as(buffer_region->buffer) || |
417 | !new_region.same_as(buffer_region->region)) { |
418 | return BufferRegion(remapped_buffer, new_region); |
419 | } else { |
420 | return buffer_region; |
421 | } |
422 | } |
423 | |
424 | Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { |
425 | BufferStore store = GetRef<BufferStore>(op); |
426 | |
427 | Buffer new_buffer = GetRemappedBuffer(op->buffer); |
428 | auto value = this->VisitExpr(op->value); |
429 | if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) { |
430 | value = cast(new_buffer->dtype, value); |
431 | } |
432 | auto indices = VisitIndices(op->indices); |
433 | |
434 | if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) || |
435 | !indices.same_as(op->indices)) { |
436 | auto writer = store.CopyOnWrite(); |
437 | writer->buffer = new_buffer; |
438 | writer->value = value; |
439 | writer->indices = indices; |
440 | } |
441 | |
442 | return std::move(store); |
443 | } |
444 | |
445 | PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) { |
446 | BufferLoad load = GetRef<BufferLoad>(op); |
447 | |
448 | Buffer new_buffer = GetRemappedBuffer(op->buffer); |
449 | auto indices = VisitIndices(op->indices); |
450 | |
451 | if (!new_buffer.same_as(op->buffer) || !indices.same_as(op->indices)) { |
452 | auto writer = load.CopyOnWrite(); |
453 | writer->indices = indices; |
454 | writer->buffer = new_buffer; |
455 | } |
456 | |
457 | return std::move(load); |
458 | } |
459 | |
460 | Array<PrimExpr> IndexDataTypeRewriter::VisitIndices(Array<PrimExpr> indices) { |
461 | bool is_enabled = is_enabled_; |
462 | is_enabled_ = true; |
463 | |
464 | auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; |
465 | indices.MutateByApply(fmutate); |
466 | |
467 | is_enabled_ = is_enabled; |
468 | |
469 | return indices; |
470 | } |
471 | |
472 | Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) { |
473 | bool is_condition = is_condition_; |
474 | is_condition_ = true; |
475 | PrimExpr cond = VisitExpr(op->condition); |
476 | is_condition_ = is_condition; |
477 | |
478 | Stmt then_case = VisitStmt(op->then_case); |
479 | Optional<Stmt> else_case = |
480 | op->else_case.defined() ? Optional<Stmt>{VisitStmt(op->else_case.value())} : NullOpt; |
481 | if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) || |
482 | !else_case.same_as(op->else_case)) { |
483 | IfThenElse new_stmt = GetRef<IfThenElse>(op); |
484 | auto* n = new_stmt.CopyOnWrite(); |
485 | n->condition = std::move(cond); |
486 | n->then_case = std::move(then_case); |
487 | n->else_case = std::move(else_case); |
488 | return std::move(new_stmt); |
489 | } |
490 | return GetRef<Stmt>(op); |
491 | } |
492 | |
493 | Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { |
494 | bool is_enabled = is_enabled_; |
495 | is_enabled_ = true; |
496 | Var new_loop_var = Downcast<Var>(VisitExpr(op->loop_var)); |
497 | PrimExpr min = VisitExpr(op->min); |
498 | PrimExpr extent = VisitExpr(op->extent); |
499 | is_enabled_ = is_enabled; |
500 | |
501 | Stmt new_body = VisitStmt(op->body); |
502 | |
503 | if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || !extent.same_as(op->extent) || |
504 | !new_body.same_as(op->body)) { |
505 | For new_for = GetRef<For>(op); |
506 | auto* n = new_for.CopyOnWrite(); |
507 | n->loop_var = new_loop_var; |
508 | n->min = cast(new_loop_var.dtype(), min); |
509 | n->extent = cast(new_loop_var.dtype(), extent); |
510 | n->body = new_body; |
511 | return std::move(new_for); |
512 | } else { |
513 | return GetRef<Stmt>(op); |
514 | } |
515 | } |
516 | |
517 | #define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ |
518 | PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \ |
519 | bool is_enabled = is_enabled_; \ |
520 | is_enabled_ = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \ |
521 | auto result = Parent::VisitExpr_(op); \ |
522 | is_enabled_ = is_enabled; \ |
523 | return std::move(result); \ |
524 | } |
525 | |
526 | TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); |
527 | TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); |
528 | TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); |
529 | TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) |
530 | TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) |
531 | TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); |
532 | |
533 | PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) { |
534 | // handle if_then_else condition |
535 | if (op->op.same_as(builtin::if_then_else())) { |
536 | bool is_condition = is_condition_; |
537 | is_condition_ = true; |
538 | PrimExpr cond = VisitExpr(op->args[0]); |
539 | is_condition_ = is_condition; |
540 | return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2])); |
541 | } |
542 | return Parent::VisitExpr_(op); |
543 | } |
544 | |
545 | #undef TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH |
546 | |
547 | IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type) |
548 | : target_data_type_(std::move(target_data_type)) {} |
549 | |
550 | PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) { |
551 | Map<Var, Buffer> new_buffer_map = func->buffer_map; |
552 | for (const auto& [var, buffer] : func->buffer_map) { |
553 | new_buffer_map.Set(var, VisitBuffer(buffer)); |
554 | } |
555 | PrimFuncNode* new_func = func.CopyOnWrite(); |
556 | new_func->buffer_map = std::move(new_buffer_map); |
557 | new_func->body = VisitStmt(std::move(new_func->body)); |
558 | return func; |
559 | } |
560 | |
561 | PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) { |
562 | if (is_enabled_) { |
563 | ICHECK_LE(op->value, Downcast<Integer>(max_value(target_data_type_))->value); |
564 | return cast(target_data_type_, GetRef<IntImm>(op)); |
565 | } |
566 | return GetRef<IntImm>(op); |
567 | } |
568 | |
569 | PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { |
570 | if (is_enabled_ && op->dtype != target_data_type_ && !var_remap_.count(op)) { |
571 | var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_); |
572 | } |
573 | return DataTypeLegalizer::VisitExpr_(op); |
574 | } |
575 | |
576 | PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) { |
577 | // Unwrap the cast only when the dtype of this cast is integer dtype. |
578 | // When the dtype of this cast is not integer dtype, it means that this cast |
579 | // has some other purpose, and we should not unwrap the cast. |
580 | if (is_enabled_ && op->dtype.is_int()) { |
581 | PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value); |
582 | return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value); |
583 | } |
584 | return IndexDataTypeRewriter::VisitExpr_(op); |
585 | } |
586 | |
587 | } // namespace tir |
588 | } // namespace tvm |
589 | |