1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./ir_comparator.h"
20
21namespace tvm {
22
23namespace tir {
24
25/******** Tensorize Comparator ********/
26
27class TensorIntrinMismatchError : public ScheduleError {
28 public:
29 explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt rhs_stmt,
30 std::vector<std::string> error_messages)
31 : lhs_mod_(std::move(lhs_mod)),
32 lhs_stmt_(std::move(lhs_stmt)),
33 rhs_stmt_(std::move(rhs_stmt)),
34 error_messages_(std::move(error_messages)) {
35 ICHECK(lhs_stmt_->IsInstance<ForNode>() || lhs_stmt_->IsInstance<BlockNode>());
36 }
37
38 String FastErrorString() const final {
39 return "ScheduleError: The stmt doesn't match the tensor intrin.";
40 }
41
42 String DetailRenderTemplate() const final {
43 std::ostringstream os;
44 os << "The stmt {0} doesn't match the tensor intrin\nThe pattern attempting to be matched:\n"
45 << lhs_stmt_ << "\nDoes not match the tensorize description:\n"
46 << rhs_stmt_;
47 for (const auto& msg : error_messages_) {
48 os << msg << std::endl;
49 }
50 return os.str();
51 }
52
53 IRModule mod() const final { return lhs_mod_; }
54
55 Array<ObjectRef> LocationsOfInterest() const final { return {lhs_stmt_}; }
56
57 private:
58 IRModule lhs_mod_;
59 Stmt lhs_stmt_;
60 Stmt rhs_stmt_;
61 std::vector<std::string> error_messages_;
62};
63
64/* Override the dispatcher to make sure RHS is always valid */
65bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) {
66 bool equal = n.same_as(other) ||
67 ((n->type_index() == other->type_index()) && StmtComparator::VisitStmt(n, other));
68 if (!equal && assert_mode_ && (n->IsInstance<ForNode>() || n->IsInstance<BlockNode>())) {
69 throw TensorIntrinMismatchError(lhs_mod_, n, other, std::move(error_messages_));
70 }
71 return equal;
72}
73
74bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
75 bool equal = n.same_as(other) ||
76 ((n->type_index() == other->type_index()) &&
77 n.dtype().code() == other.dtype().code() && ExprComparator::VisitExpr(n, other));
78 if (!equal && assert_mode_) {
79 std::ostringstream os;
80 os << "Expression mismatch: " << n << " vs " << other;
81 EmitError(os.str());
82 }
83 return equal;
84}
85
86bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
87 const auto* rhs = other.as<ForNode>();
88 if (!DefEqual(op->loop_var, rhs->loop_var)) {
89 if (assert_mode_) {
90 std::ostringstream os;
91 os << "ForNode loop vars do not match: op->loop_var=" << op->loop_var
92 << " vs rhs->loop_var=" << rhs->loop_var;
93 EmitError(os.str());
94 }
95 return false;
96 }
97 if (!VisitExpr(op->min, rhs->min)) {
98 if (assert_mode_) {
99 std::ostringstream os;
100 os << "ForNode min values do not match: op->min=" << op->min << " vs rhs->min=" << rhs->min;
101 EmitError(os.str());
102 }
103 return false;
104 }
105 if (!VisitExpr(op->extent, rhs->extent)) {
106 if (assert_mode_) {
107 std::ostringstream os;
108 os << "ForNode extent values do not match: op->extent=" << op->extent
109 << " vs rhs->extent=" << rhs->extent;
110 EmitError(os.str());
111 }
112 return false;
113 }
114 if (op->thread_binding.defined() != rhs->thread_binding.defined()) {
115 if (assert_mode_) {
116 std::ostringstream os;
117 os << "ForNode thread_bindings do not match: op->thread_binding.defined()="
118 << op->thread_binding.defined()
119 << " vs rhs->thread_binding.defined()=" << rhs->thread_binding.defined();
120 EmitError(os.str());
121 }
122 return false;
123 }
124 if (op->thread_binding.defined() &&
125 !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) {
126 return false;
127 }
128 if (op->kind != rhs->kind) {
129 if (assert_mode_) {
130 std::ostringstream os;
131 os << "ForNode kinds do not match: op->kind=" << op->kind << " vs rhs->kind=" << rhs->kind;
132 EmitError(os.str());
133 }
134 return false;
135 }
136 if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
137 if (assert_mode_) {
138 std::ostringstream os;
139 os << "ForNode annotation maps do not match: op->annotations=" << op->annotations
140 << " vs rhs->annotations=" << rhs->annotations;
141 EmitError(os.str());
142 }
143 return false;
144 }
145 return VisitStmt(op->body, rhs->body);
146}
147
148bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) {
149 const auto* rhs = other.as<SeqStmtNode>();
150 return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt);
151}
152
153bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) {
154 const auto* rhs = other.as<BufferStoreNode>();
155 return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
156}
157
158bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) {
159 const auto* rhs = other.as<BlockRealizeNode>();
160 if (!is_scope_block) {
161 if (!CompareArray(op->iter_values, rhs->iter_values, &TensorizeComparator::VisitExpr)) {
162 if (assert_mode_) {
163 std::ostringstream os;
164 os << "BlockRealizeNode iter_values do not match: op->iter_values=" << op->iter_values
165 << " vs rhs->iter_values=" << rhs->iter_values;
166 EmitError(os.str());
167 }
168 return false;
169 }
170 }
171 return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block);
172}
173
174bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) {
175 const auto* rhs = other.as<BlockNode>();
176 // Check block equality.
177 // All iter vars and buffer regions including the order should match.
178 // When checking iter vars, DefEqual is used to remap variables.
179 if (!is_scope_block) {
180 if (!CompareArray(op->iter_vars, rhs->iter_vars, &TensorizeComparator::CompareIterVar)) {
181 if (assert_mode_) {
182 std::ostringstream os;
183 os << "BlockNode iter_vars do not match: op->alloc_buffers=" << op->iter_vars
184 << " vs rhs->alloc_buffers=" << rhs->iter_vars;
185 EmitError(os.str());
186 }
187 return false;
188 }
189 if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) {
190 if (assert_mode_) {
191 std::ostringstream os;
192 os << "BlockNode alloc_buffers do not match: op->alloc_buffers=" << op->alloc_buffers
193 << " vs rhs->alloc_buffers=" << rhs->alloc_buffers;
194 EmitError(os.str());
195 }
196 return false;
197 }
198 }
199 if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) {
200 if (assert_mode_) {
201 std::ostringstream os;
202 os << "BlockNode write buffers do not match: op->writes=" << op->writes
203 << " vs rhs->writes=" << rhs->writes;
204 EmitError(os.str());
205 }
206 return false;
207 }
208 if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) {
209 if (assert_mode_) {
210 std::ostringstream os;
211 os << "BlockNode read buffers regions do not match: op->reads=" << op->reads
212 << " vs rhs->reads=" << rhs->reads;
213 EmitError(os.str());
214 }
215 return false;
216 }
217 is_scope_block = false;
218 return VisitStmt(op->body, rhs->body);
219}
220
221// Exprs
222#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \
223 bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \
224 const auto* rhs = other.as<OpName>(); \
225 return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \
226 }
227
228TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode);
229TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode);
230TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode);
231TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode);
232TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode);
233TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode);
234TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode);
235TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode);
236TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode);
237TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode);
238TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode);
239TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode);
240TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode);
241TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode);
242TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode);
243TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode);
244TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode);
245
246bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) {
247 const auto* rhs = other.as<IntImmNode>();
248 if (op->value != rhs->value) {
249 if (assert_mode_) {
250 std::ostringstream os;
251 os << "IntImmNode values do not match: op->value=" << op->value
252 << " vs rhs->value=" << rhs->value;
253 EmitError(os.str());
254 }
255 return false;
256 }
257 return true;
258}
259
260bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) {
261 const auto* rhs = other.as<FloatImmNode>();
262 if (op->value != rhs->value) {
263 if (assert_mode_) {
264 std::ostringstream os;
265 os << "FloatImmNode values do not match: op->value=" << op->value
266 << " vs rhs->value=" << rhs->value;
267 EmitError(os.str());
268 }
269 return false;
270 }
271 return true;
272}
273
274bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) {
275 const auto* rhs = other.as<CastNode>();
276 return VisitExpr(op->value, rhs->value);
277}
278
279bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) {
280 const auto* rhs = other.as<VarNode>();
281 auto lhs = GetRef<Var>(op);
282 if (lhs.same_as(other)) return true;
283 if (op->dtype.code() != rhs->dtype.code()) {
284 if (assert_mode_) {
285 std::ostringstream os;
286 os << "VarNode data type codes do not match: op->dtype.code()=" << op->dtype.code()
287 << " vs rhs->dtype.code()=" << rhs->dtype.code();
288 EmitError(os.str());
289 }
290 return false;
291 }
292 auto it = equal_map_.find(lhs);
293 return it != equal_map_.end() && it->second.same_as(other);
294}
295
296bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) {
297 const auto* rhs = other.as<BufferLoadNode>();
298 return CompareBufferAccess(op, rhs);
299}
300
301bool TensorizeComparator::VisitExpr_(const SelectNode* op, const PrimExpr& other) {
302 const auto* rhs = other.as<SelectNode>();
303 return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value, rhs->true_value) &&
304 VisitExpr(op->false_value, rhs->false_value);
305}
306
307bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) {
308 if (lhs.same_as(rhs)) return true;
309 auto it = equal_map_.find(lhs);
310 // If there is already a mapping
311 if (it != equal_map_.end()) return it->second.same_as(rhs);
312 // Otherwise remap lhs to rhs
313 equal_map_[lhs] = rhs;
314 // Cast if necessary. This allows the workload and the tensor intrin to have different dtypes in
315 // the indices.
316 analyzer_.Bind(lhs, cast(lhs.dtype(), rhs));
317 return true;
318}
319
320bool TensorizeComparator::CompareAnnotation(const std::pair<String, ObjectRef>& lhs,
321 const std::pair<String, ObjectRef>& rhs) {
322 if (lhs.first != rhs.first) {
323 if (assert_mode_) {
324 std::ostringstream os;
325 os << "CompareAnnotation key mismatch: lhs.first=" << lhs.first
326 << " vs rhs.first=" << rhs.first;
327 EmitError(os.str());
328 }
329 return false;
330 }
331 return VisitExpr(Downcast<PrimExpr>(lhs.second), Downcast<PrimExpr>(rhs.second));
332}
333
334bool TensorizeComparator::CompareAnnotationMap(const Map<String, ObjectRef>& lhs,
335 const Map<String, ObjectRef>& rhs) {
336 if (lhs.same_as(rhs)) return true;
337 if (lhs.size() != rhs.size()) {
338 if (assert_mode_) {
339 std::ostringstream os;
340 os << "CompareAnnotationMap size mismatch: lhs.size()=" << lhs.size()
341 << " vs rhs.size()=" << rhs.size();
342 EmitError(os.str());
343 }
344 return false;
345 }
346
347 auto sort_map =
348 [](const Map<String, ObjectRef>& map) -> std::vector<std::pair<String, ObjectRef>> {
349 std::vector<std::pair<String, ObjectRef>> ret(map.begin(), map.end());
350 sort(ret.begin(), ret.end());
351 return ret;
352 };
353
354 std::vector<std::pair<String, ObjectRef>> lhs_array = sort_map(lhs);
355 std::vector<std::pair<String, ObjectRef>> rhs_array = sort_map(rhs);
356
357 for (size_t i = 0; i < lhs.size(); ++i) {
358 if (!CompareAnnotation(lhs_array[i], rhs_array[i])) {
359 if (assert_mode_) {
360 std::ostringstream os;
361 os << "CompareAnnotationMap annotations mismatch within AnnotationMap.";
362 EmitError(os.str());
363 }
364 return false;
365 }
366 }
367 return true;
368}
369
370bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) {
371 if (lhs.same_as(rhs)) return true;
372 auto it = rhs_buffer_map_.find(rhs);
373 bool equal;
374 if (it != rhs_buffer_map_.end()) {
375 equal = (*it).second.same_as(lhs);
376 } else {
377 // Remap both buffer itself and buffer data, skip buffer shape
378 equal =
379 DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope();
380 if (equal) {
381 rhs_buffer_map_[rhs] = lhs;
382 } else {
383 if (assert_mode_) {
384 std::ostringstream os;
385 os << "CompareBuffer buffer mismatch. data: " << lhs->data << " vs " << rhs->data
386 << ", dtypes: " << lhs->dtype << " vs " << rhs->dtype << ", scope(): " << lhs.scope()
387 << " vs " << rhs.scope();
388 EmitError(os.str());
389 }
390 }
391 }
392 return equal;
393}
394
395bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) {
396 if (!CompareBuffer(lhs->buffer, rhs->buffer)) {
397 if (assert_mode_) {
398 std::ostringstream os;
399 os << "CompareBufferRegion returning false due to buffer mismatch: lhs->buffer="
400 << lhs->buffer << " vs rhs->buffer=" << rhs->buffer;
401 EmitError(os.str());
402 }
403 return false;
404 }
405 int offset = static_cast<int>(lhs->region.size()) - static_cast<int>(rhs->region.size());
406 // Number of indices in RHS (desc of the tensor intrinsic) must be smaller than it in LHS
407 if (offset < 0) {
408 if (assert_mode_) {
409 std::ostringstream os;
410 os << "CompareBufferRegion returning false because buffer region sizes do not match: "
411 "lhs->region.size()="
412 << lhs->region.size() << " vs rhs->region.size()=" << rhs->region.size();
413 EmitError(os.str());
414 }
415 return false;
416 }
417
418 auto it = buffer_indices_.find(lhs->buffer);
419 if (it == buffer_indices_.end()) {
420 // Update base indices for the buffer, this can only happen if it is visiting the scope block.
421 ICHECK(is_scope_block);
422 std::vector<PrimExpr> indices_base;
423 indices_base.reserve(lhs->region.size());
424 for (int i = 0; i < offset; i++) {
425 // High-dim region must be element-wise
426 if (!is_one(lhs->region[i]->extent)) {
427 if (assert_mode_) {
428 std::ostringstream os;
429 os << "CompareBufferRegion returning false because buffer extent high-dim region must be "
430 "element-wise. lhs->region[i]->extent="
431 << lhs->region[i]->extent;
432 EmitError(os.str());
433 }
434 return false;
435 }
436 indices_base.emplace_back(lhs->region[i]->min);
437 }
438 for (size_t i = 0; i < rhs->region.size(); i++) {
439 // save base index
440 indices_base.emplace_back(lhs->region[i + offset]->min);
441 // check extent match
442 if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) {
443 if (assert_mode_) {
444 std::ostringstream os;
445 os << "CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]="
446 << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i];
447 EmitError(os.str());
448 }
449 return false;
450 }
451 }
452 buffer_indices_.emplace(lhs->buffer, std::move(indices_base));
453 } else {
454 // Check the base indices are consistent.
455 const std::vector<PrimExpr>& indices_base = it->second;
456 for (int i = 0; i < offset; i++) {
457 // High-dim region must be element-wise
458 if (!is_one(lhs->region[i]->extent)) {
459 if (assert_mode_) {
460 std::ostringstream os;
461 os << "CompareBufferRegion returning false because buffer extent high-dim region must be "
462 "element-wise. lhs->region[i]->extent="
463 << lhs->region[i]->extent;
464 EmitError(os.str());
465 }
466 return false;
467 }
468 if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) {
469 if (assert_mode_) {
470 std::ostringstream os;
471 os << "Buffer base index consistency check failed due to unequal index base: "
472 "indices_base[i]="
473 << indices_base[i] << " vs lhs->region[i]->min=" << lhs->region[i]->min;
474 EmitError(os.str());
475 }
476 return false;
477 }
478 }
479 for (size_t i = 0; i < rhs->region.size(); i++) {
480 // check extent match
481 if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) {
482 if (assert_mode_) {
483 std::ostringstream os;
484 os << "CompareBufferRegion buffer region extent mismatch. lhs->region[i + offset]="
485 << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i];
486 EmitError(os.str());
487 }
488 return false;
489 }
490 PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]);
491 if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) {
492 if (assert_mode_) {
493 std::ostringstream os;
494 os << "CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]="
495 << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i];
496 EmitError(os.str());
497 }
498 return false;
499 }
500 }
501 }
502 return true;
503}
504
505// Comparator for BufferStoreNode and BufferLoadNode
506template <typename T>
507bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
508 if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
509 int offset = static_cast<int>(lhs->indices.size()) - static_cast<int>(rhs->indices.size());
510 if (offset < 0) {
511 if (assert_mode_) {
512 std::ostringstream os;
513 os << "CompareBufferAccess returning false because buffer indices sizes do not match: "
514 "lhs->indices.size()="
515 << lhs->indices.size() << " vs rhs->indices.size()=" << rhs->indices.size();
516 EmitError(os.str());
517 }
518 return false;
519 }
520 auto it = buffer_indices_.find(lhs->buffer);
521 ICHECK(it != buffer_indices_.end());
522 const std::vector<PrimExpr>& indices_base = (*it).second;
523 ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset);
524 for (size_t i = 0; i < rhs->indices.size(); i++) {
525 PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i + offset];
526 if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) {
527 if (assert_mode_) {
528 std::ostringstream os;
529 os << "CompareBufferAccess buffer indices mismatch. lhs->indices[i + offset]="
530 << lhs->indices[i + offset] << " vs rhs->indices[i]=" << rhs->indices[i];
531 EmitError(os.str());
532 }
533 return false;
534 }
535 }
536 return true;
537}
538
539template <typename T, typename Self, typename F>
540bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp) {
541 if (lhs.same_as(rhs)) return true;
542 if (lhs.size() != rhs.size()) {
543 if (assert_mode_) {
544 std::ostringstream os;
545 os << "CompareArray array size mismatch. lhs.size()=" << lhs.size()
546 << " vs rhs.size()=" << rhs.size();
547 EmitError(os.str());
548 }
549 return false;
550 }
551 for (size_t i = 0; i < lhs.size(); ++i) {
552 if (!(static_cast<Self*>(this)->*cmp)(lhs[i], rhs[i])) return false;
553 }
554 return true;
555}
556
557bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) {
558 return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent);
559}
560
561bool TensorizeComparator::CompareIterVar(const IterVar& lhs, const IterVar& rhs) {
562 return DefEqual(lhs->var, rhs->var) && lhs->iter_type == rhs->iter_type;
563}
564
565void TensorizeComparator::EmitError(const std::string& error_message) {
566 error_messages_.push_back(error_message);
567}
568
569/******** AutoTensorize Extractor ********/
570
571bool AutoTensorizeComparator::VisitExprDefault_(const Object* op, const PrimExpr& other) {
572 return false;
573}
574
575bool AutoTensorizeComparator::VisitStmtDefault_(const Object* op, const Stmt& other) {
576 return false;
577}
578
579bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) {
580 const auto* rhs = other.as<BlockNode>();
581 // Check block equality.
582 // All iter vars and buffer regions including the order should match.
583 // When checking iter vars, DefEqual is used to remap variables.
584 if (!is_scope_block) {
585 if (!CompareArray(op->iter_vars, rhs->iter_vars, &AutoTensorizeComparator::CompareIterVar)) {
586 return false;
587 }
588 if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
589 return false;
590 }
591 if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers,
592 &AutoTensorizeComparator::CompareBuffer)) {
593 return false;
594 }
595 for (const IterVar& block_iter : op->iter_vars) {
596 inner_iter_dom_map_.Set(block_iter->var, arith::IntSet::FromRange(block_iter->dom));
597 }
598 } else {
599 auto collect_iter = [&](const BlockNode* op, std::vector<IterVar>& iters) -> bool {
600 for (const auto& iter : op->iter_vars) {
601 analyzer_.Bind(iter->var, iter->dom);
602 if (iter->iter_type == IterVarType::kDataPar ||
603 iter->iter_type == IterVarType::kCommReduce) {
604 iters.push_back(iter);
605 } else {
606 return false;
607 }
608 }
609 return true;
610 };
611 if (!collect_iter(op, lhs_iters_)) {
612 return false;
613 }
614 if (!collect_iter(rhs, rhs_iters_)) {
615 return false;
616 }
617 }
618 is_scope_block = false;
619 return VisitStmt(op->body, rhs->body);
620}
621
622bool AutoTensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) {
623 if (lhs.same_as(rhs)) return true;
624 auto it = rhs_buffer_map_.find(rhs);
625 bool equal;
626 if (it != rhs_buffer_map_.end()) {
627 equal = (*it).second.same_as(lhs);
628 } else {
629 // Remap both buffer itself and buffer data, skip buffer shape and scope
630 equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype;
631 if (equal) {
632 rhs_buffer_map_[rhs] = lhs;
633 lhs_buffer_map_[lhs] = rhs;
634 }
635 }
636 return equal;
637}
638
639bool AutoTensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) {
640 const auto* rhs = other.as<BufferStoreNode>();
641 return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
642}
643
644bool AutoTensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) {
645 const auto* rhs = other.as<BufferLoadNode>();
646 return CompareBufferAccess(op, rhs);
647}
648
649template <typename T>
650bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
651 if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
652 auto it_lhs = lhs_buffer_indices_map_.find(lhs->buffer);
653 if (it_lhs == lhs_buffer_indices_map_.end()) {
654 if (rhs_buffer_indices_map_.find(rhs->buffer) != rhs_buffer_indices_map_.end()) {
655 return false;
656 }
657 std::vector<PrimExpr> lhs_indices;
658 for (const PrimExpr& index : lhs->indices) {
659 lhs_indices.push_back(SimplifyNonTrivialExpr(index, &analyzer_));
660 }
661
662 auto is_scalar_access = [](const Array<PrimExpr>& indices, PrimExpr index) {
663 // Check if the indexing is of the form C[0]
664 if (indices.size() > 1) return false;
665 auto int_imm = index.template as<IntImmNode>();
666 if (int_imm && int_imm->value == 0) return true;
667 return false;
668 };
669
670 for (const auto& index : rhs->indices) {
671 if (!index.template as<VarNode>() && !is_scalar_access(rhs->indices, index)) return false;
672 }
673 lhs_buffer_indices_map_[lhs->buffer] = lhs_indices;
674 rhs_buffer_indices_map_[rhs->buffer] = rhs->indices;
675 } else {
676 auto it_rhs = rhs_buffer_indices_map_.find(rhs->buffer);
677 if (it_rhs == rhs_buffer_indices_map_.end()) {
678 return false;
679 }
680 auto indices_check = [&](const Array<PrimExpr>& indices,
681 const Array<PrimExpr>& old_indices) -> bool {
682 if (indices.size() != old_indices.size()) {
683 return false;
684 }
685 for (size_t i = 0; i < indices.size(); ++i) {
686 if (!analyzer_.CanProveEqual(indices[i], old_indices[i])) {
687 return false;
688 }
689 }
690 return true;
691 };
692 if (!indices_check(lhs->indices, it_lhs->second)) return false;
693 if (!indices_check(rhs->indices, it_rhs->second)) return false;
694 }
695 return true;
696}
697
698} // namespace tir
699} // namespace tvm
700