1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./ir_comparator.h" |
20 | |
21 | namespace tvm { |
22 | |
23 | namespace tir { |
24 | |
25 | /******** Tensorize Comparator ********/ |
26 | |
27 | class TensorIntrinMismatchError : public ScheduleError { |
28 | public: |
29 | explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt rhs_stmt, |
30 | std::vector<std::string> error_messages) |
31 | : lhs_mod_(std::move(lhs_mod)), |
32 | lhs_stmt_(std::move(lhs_stmt)), |
33 | rhs_stmt_(std::move(rhs_stmt)), |
34 | error_messages_(std::move(error_messages)) { |
35 | ICHECK(lhs_stmt_->IsInstance<ForNode>() || lhs_stmt_->IsInstance<BlockNode>()); |
36 | } |
37 | |
38 | String FastErrorString() const final { |
39 | return "ScheduleError: The stmt doesn't match the tensor intrin." ; |
40 | } |
41 | |
42 | String DetailRenderTemplate() const final { |
43 | std::ostringstream os; |
44 | os << "The stmt {0} doesn't match the tensor intrin\nThe pattern attempting to be matched:\n" |
45 | << lhs_stmt_ << "\nDoes not match the tensorize description:\n" |
46 | << rhs_stmt_; |
47 | for (const auto& msg : error_messages_) { |
48 | os << msg << std::endl; |
49 | } |
50 | return os.str(); |
51 | } |
52 | |
53 | IRModule mod() const final { return lhs_mod_; } |
54 | |
55 | Array<ObjectRef> LocationsOfInterest() const final { return {lhs_stmt_}; } |
56 | |
57 | private: |
58 | IRModule lhs_mod_; |
59 | Stmt lhs_stmt_; |
60 | Stmt rhs_stmt_; |
61 | std::vector<std::string> error_messages_; |
62 | }; |
63 | |
64 | /* Override the dispatcher to make sure RHS is always valid */ |
65 | bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { |
66 | bool equal = n.same_as(other) || |
67 | ((n->type_index() == other->type_index()) && StmtComparator::VisitStmt(n, other)); |
68 | if (!equal && assert_mode_ && (n->IsInstance<ForNode>() || n->IsInstance<BlockNode>())) { |
69 | throw TensorIntrinMismatchError(lhs_mod_, n, other, std::move(error_messages_)); |
70 | } |
71 | return equal; |
72 | } |
73 | |
74 | bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { |
75 | bool equal = n.same_as(other) || |
76 | ((n->type_index() == other->type_index()) && |
77 | n.dtype().code() == other.dtype().code() && ExprComparator::VisitExpr(n, other)); |
78 | if (!equal && assert_mode_) { |
79 | std::ostringstream os; |
80 | os << "Expression mismatch: " << n << " vs " << other; |
81 | EmitError(os.str()); |
82 | } |
83 | return equal; |
84 | } |
85 | |
86 | bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { |
87 | const auto* rhs = other.as<ForNode>(); |
88 | if (!DefEqual(op->loop_var, rhs->loop_var)) { |
89 | if (assert_mode_) { |
90 | std::ostringstream os; |
91 | os << "ForNode loop vars do not match: op->loop_var=" << op->loop_var |
92 | << " vs rhs->loop_var=" << rhs->loop_var; |
93 | EmitError(os.str()); |
94 | } |
95 | return false; |
96 | } |
97 | if (!VisitExpr(op->min, rhs->min)) { |
98 | if (assert_mode_) { |
99 | std::ostringstream os; |
100 | os << "ForNode min values do not match: op->min=" << op->min << " vs rhs->min=" << rhs->min; |
101 | EmitError(os.str()); |
102 | } |
103 | return false; |
104 | } |
105 | if (!VisitExpr(op->extent, rhs->extent)) { |
106 | if (assert_mode_) { |
107 | std::ostringstream os; |
108 | os << "ForNode extent values do not match: op->extent=" << op->extent |
109 | << " vs rhs->extent=" << rhs->extent; |
110 | EmitError(os.str()); |
111 | } |
112 | return false; |
113 | } |
114 | if (op->thread_binding.defined() != rhs->thread_binding.defined()) { |
115 | if (assert_mode_) { |
116 | std::ostringstream os; |
117 | os << "ForNode thread_bindings do not match: op->thread_binding.defined()=" |
118 | << op->thread_binding.defined() |
119 | << " vs rhs->thread_binding.defined()=" << rhs->thread_binding.defined(); |
120 | EmitError(os.str()); |
121 | } |
122 | return false; |
123 | } |
124 | if (op->thread_binding.defined() && |
125 | !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) { |
126 | return false; |
127 | } |
128 | if (op->kind != rhs->kind) { |
129 | if (assert_mode_) { |
130 | std::ostringstream os; |
131 | os << "ForNode kinds do not match: op->kind=" << op->kind << " vs rhs->kind=" << rhs->kind; |
132 | EmitError(os.str()); |
133 | } |
134 | return false; |
135 | } |
136 | if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { |
137 | if (assert_mode_) { |
138 | std::ostringstream os; |
139 | os << "ForNode annotation maps do not match: op->annotations=" << op->annotations |
140 | << " vs rhs->annotations=" << rhs->annotations; |
141 | EmitError(os.str()); |
142 | } |
143 | return false; |
144 | } |
145 | return VisitStmt(op->body, rhs->body); |
146 | } |
147 | |
148 | bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) { |
149 | const auto* rhs = other.as<SeqStmtNode>(); |
150 | return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt); |
151 | } |
152 | |
153 | bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { |
154 | const auto* rhs = other.as<BufferStoreNode>(); |
155 | return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); |
156 | } |
157 | |
158 | bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { |
159 | const auto* rhs = other.as<BlockRealizeNode>(); |
160 | if (!is_scope_block) { |
161 | if (!CompareArray(op->iter_values, rhs->iter_values, &TensorizeComparator::VisitExpr)) { |
162 | if (assert_mode_) { |
163 | std::ostringstream os; |
164 | os << "BlockRealizeNode iter_values do not match: op->iter_values=" << op->iter_values |
165 | << " vs rhs->iter_values=" << rhs->iter_values; |
166 | EmitError(os.str()); |
167 | } |
168 | return false; |
169 | } |
170 | } |
171 | return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); |
172 | } |
173 | |
174 | bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { |
175 | const auto* rhs = other.as<BlockNode>(); |
176 | // Check block equality. |
177 | // All iter vars and buffer regions including the order should match. |
178 | // When checking iter vars, DefEqual is used to remap variables. |
179 | if (!is_scope_block) { |
180 | if (!CompareArray(op->iter_vars, rhs->iter_vars, &TensorizeComparator::CompareIterVar)) { |
181 | if (assert_mode_) { |
182 | std::ostringstream os; |
183 | os << "BlockNode iter_vars do not match: op->alloc_buffers=" << op->iter_vars |
184 | << " vs rhs->alloc_buffers=" << rhs->iter_vars; |
185 | EmitError(os.str()); |
186 | } |
187 | return false; |
188 | } |
189 | if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { |
190 | if (assert_mode_) { |
191 | std::ostringstream os; |
192 | os << "BlockNode alloc_buffers do not match: op->alloc_buffers=" << op->alloc_buffers |
193 | << " vs rhs->alloc_buffers=" << rhs->alloc_buffers; |
194 | EmitError(os.str()); |
195 | } |
196 | return false; |
197 | } |
198 | } |
199 | if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { |
200 | if (assert_mode_) { |
201 | std::ostringstream os; |
202 | os << "BlockNode write buffers do not match: op->writes=" << op->writes |
203 | << " vs rhs->writes=" << rhs->writes; |
204 | EmitError(os.str()); |
205 | } |
206 | return false; |
207 | } |
208 | if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { |
209 | if (assert_mode_) { |
210 | std::ostringstream os; |
211 | os << "BlockNode read buffers regions do not match: op->reads=" << op->reads |
212 | << " vs rhs->reads=" << rhs->reads; |
213 | EmitError(os.str()); |
214 | } |
215 | return false; |
216 | } |
217 | is_scope_block = false; |
218 | return VisitStmt(op->body, rhs->body); |
219 | } |
220 | |
221 | // Exprs |
222 | #define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \ |
223 | bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \ |
224 | const auto* rhs = other.as<OpName>(); \ |
225 | return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ |
226 | } |
227 | |
228 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode); |
229 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode); |
230 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode); |
231 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode); |
232 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode); |
233 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode); |
234 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode); |
235 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode); |
236 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode); |
237 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode); |
238 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode); |
239 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode); |
240 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode); |
241 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode); |
242 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode); |
243 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode); |
244 | TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); |
245 | |
246 | bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { |
247 | const auto* rhs = other.as<IntImmNode>(); |
248 | if (op->value != rhs->value) { |
249 | if (assert_mode_) { |
250 | std::ostringstream os; |
251 | os << "IntImmNode values do not match: op->value=" << op->value |
252 | << " vs rhs->value=" << rhs->value; |
253 | EmitError(os.str()); |
254 | } |
255 | return false; |
256 | } |
257 | return true; |
258 | } |
259 | |
260 | bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { |
261 | const auto* rhs = other.as<FloatImmNode>(); |
262 | if (op->value != rhs->value) { |
263 | if (assert_mode_) { |
264 | std::ostringstream os; |
265 | os << "FloatImmNode values do not match: op->value=" << op->value |
266 | << " vs rhs->value=" << rhs->value; |
267 | EmitError(os.str()); |
268 | } |
269 | return false; |
270 | } |
271 | return true; |
272 | } |
273 | |
274 | bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { |
275 | const auto* rhs = other.as<CastNode>(); |
276 | return VisitExpr(op->value, rhs->value); |
277 | } |
278 | |
279 | bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { |
280 | const auto* rhs = other.as<VarNode>(); |
281 | auto lhs = GetRef<Var>(op); |
282 | if (lhs.same_as(other)) return true; |
283 | if (op->dtype.code() != rhs->dtype.code()) { |
284 | if (assert_mode_) { |
285 | std::ostringstream os; |
286 | os << "VarNode data type codes do not match: op->dtype.code()=" << op->dtype.code() |
287 | << " vs rhs->dtype.code()=" << rhs->dtype.code(); |
288 | EmitError(os.str()); |
289 | } |
290 | return false; |
291 | } |
292 | auto it = equal_map_.find(lhs); |
293 | return it != equal_map_.end() && it->second.same_as(other); |
294 | } |
295 | |
296 | bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { |
297 | const auto* rhs = other.as<BufferLoadNode>(); |
298 | return CompareBufferAccess(op, rhs); |
299 | } |
300 | |
301 | bool TensorizeComparator::VisitExpr_(const SelectNode* op, const PrimExpr& other) { |
302 | const auto* rhs = other.as<SelectNode>(); |
303 | return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value, rhs->true_value) && |
304 | VisitExpr(op->false_value, rhs->false_value); |
305 | } |
306 | |
307 | bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { |
308 | if (lhs.same_as(rhs)) return true; |
309 | auto it = equal_map_.find(lhs); |
310 | // If there is already a mapping |
311 | if (it != equal_map_.end()) return it->second.same_as(rhs); |
312 | // Otherwise remap lhs to rhs |
313 | equal_map_[lhs] = rhs; |
314 | // Cast if necessary. This allows the workload and the tensor intrin to have different dtypes in |
315 | // the indices. |
316 | analyzer_.Bind(lhs, cast(lhs.dtype(), rhs)); |
317 | return true; |
318 | } |
319 | |
320 | bool TensorizeComparator::CompareAnnotation(const std::pair<String, ObjectRef>& lhs, |
321 | const std::pair<String, ObjectRef>& rhs) { |
322 | if (lhs.first != rhs.first) { |
323 | if (assert_mode_) { |
324 | std::ostringstream os; |
325 | os << "CompareAnnotation key mismatch: lhs.first=" << lhs.first |
326 | << " vs rhs.first=" << rhs.first; |
327 | EmitError(os.str()); |
328 | } |
329 | return false; |
330 | } |
331 | return VisitExpr(Downcast<PrimExpr>(lhs.second), Downcast<PrimExpr>(rhs.second)); |
332 | } |
333 | |
334 | bool TensorizeComparator::CompareAnnotationMap(const Map<String, ObjectRef>& lhs, |
335 | const Map<String, ObjectRef>& rhs) { |
336 | if (lhs.same_as(rhs)) return true; |
337 | if (lhs.size() != rhs.size()) { |
338 | if (assert_mode_) { |
339 | std::ostringstream os; |
340 | os << "CompareAnnotationMap size mismatch: lhs.size()=" << lhs.size() |
341 | << " vs rhs.size()=" << rhs.size(); |
342 | EmitError(os.str()); |
343 | } |
344 | return false; |
345 | } |
346 | |
347 | auto sort_map = |
348 | [](const Map<String, ObjectRef>& map) -> std::vector<std::pair<String, ObjectRef>> { |
349 | std::vector<std::pair<String, ObjectRef>> ret(map.begin(), map.end()); |
350 | sort(ret.begin(), ret.end()); |
351 | return ret; |
352 | }; |
353 | |
354 | std::vector<std::pair<String, ObjectRef>> lhs_array = sort_map(lhs); |
355 | std::vector<std::pair<String, ObjectRef>> rhs_array = sort_map(rhs); |
356 | |
357 | for (size_t i = 0; i < lhs.size(); ++i) { |
358 | if (!CompareAnnotation(lhs_array[i], rhs_array[i])) { |
359 | if (assert_mode_) { |
360 | std::ostringstream os; |
361 | os << "CompareAnnotationMap annotations mismatch within AnnotationMap." ; |
362 | EmitError(os.str()); |
363 | } |
364 | return false; |
365 | } |
366 | } |
367 | return true; |
368 | } |
369 | |
370 | bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { |
371 | if (lhs.same_as(rhs)) return true; |
372 | auto it = rhs_buffer_map_.find(rhs); |
373 | bool equal; |
374 | if (it != rhs_buffer_map_.end()) { |
375 | equal = (*it).second.same_as(lhs); |
376 | } else { |
377 | // Remap both buffer itself and buffer data, skip buffer shape |
378 | equal = |
379 | DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope(); |
380 | if (equal) { |
381 | rhs_buffer_map_[rhs] = lhs; |
382 | } else { |
383 | if (assert_mode_) { |
384 | std::ostringstream os; |
385 | os << "CompareBuffer buffer mismatch. data: " << lhs->data << " vs " << rhs->data |
386 | << ", dtypes: " << lhs->dtype << " vs " << rhs->dtype << ", scope(): " << lhs.scope() |
387 | << " vs " << rhs.scope(); |
388 | EmitError(os.str()); |
389 | } |
390 | } |
391 | } |
392 | return equal; |
393 | } |
394 | |
395 | bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { |
396 | if (!CompareBuffer(lhs->buffer, rhs->buffer)) { |
397 | if (assert_mode_) { |
398 | std::ostringstream os; |
399 | os << "CompareBufferRegion returning false due to buffer mismatch: lhs->buffer=" |
400 | << lhs->buffer << " vs rhs->buffer=" << rhs->buffer; |
401 | EmitError(os.str()); |
402 | } |
403 | return false; |
404 | } |
405 | int offset = static_cast<int>(lhs->region.size()) - static_cast<int>(rhs->region.size()); |
406 | // Number of indices in RHS (desc of the tensor intrinsic) must be smaller than it in LHS |
407 | if (offset < 0) { |
408 | if (assert_mode_) { |
409 | std::ostringstream os; |
410 | os << "CompareBufferRegion returning false because buffer region sizes do not match: " |
411 | "lhs->region.size()=" |
412 | << lhs->region.size() << " vs rhs->region.size()=" << rhs->region.size(); |
413 | EmitError(os.str()); |
414 | } |
415 | return false; |
416 | } |
417 | |
418 | auto it = buffer_indices_.find(lhs->buffer); |
419 | if (it == buffer_indices_.end()) { |
420 | // Update base indices for the buffer, this can only happen if it is visiting the scope block. |
421 | ICHECK(is_scope_block); |
422 | std::vector<PrimExpr> indices_base; |
423 | indices_base.reserve(lhs->region.size()); |
424 | for (int i = 0; i < offset; i++) { |
425 | // High-dim region must be element-wise |
426 | if (!is_one(lhs->region[i]->extent)) { |
427 | if (assert_mode_) { |
428 | std::ostringstream os; |
429 | os << "CompareBufferRegion returning false because buffer extent high-dim region must be " |
430 | "element-wise. lhs->region[i]->extent=" |
431 | << lhs->region[i]->extent; |
432 | EmitError(os.str()); |
433 | } |
434 | return false; |
435 | } |
436 | indices_base.emplace_back(lhs->region[i]->min); |
437 | } |
438 | for (size_t i = 0; i < rhs->region.size(); i++) { |
439 | // save base index |
440 | indices_base.emplace_back(lhs->region[i + offset]->min); |
441 | // check extent match |
442 | if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { |
443 | if (assert_mode_) { |
444 | std::ostringstream os; |
445 | os << "CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]=" |
446 | << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i]; |
447 | EmitError(os.str()); |
448 | } |
449 | return false; |
450 | } |
451 | } |
452 | buffer_indices_.emplace(lhs->buffer, std::move(indices_base)); |
453 | } else { |
454 | // Check the base indices are consistent. |
455 | const std::vector<PrimExpr>& indices_base = it->second; |
456 | for (int i = 0; i < offset; i++) { |
457 | // High-dim region must be element-wise |
458 | if (!is_one(lhs->region[i]->extent)) { |
459 | if (assert_mode_) { |
460 | std::ostringstream os; |
461 | os << "CompareBufferRegion returning false because buffer extent high-dim region must be " |
462 | "element-wise. lhs->region[i]->extent=" |
463 | << lhs->region[i]->extent; |
464 | EmitError(os.str()); |
465 | } |
466 | return false; |
467 | } |
468 | if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { |
469 | if (assert_mode_) { |
470 | std::ostringstream os; |
471 | os << "Buffer base index consistency check failed due to unequal index base: " |
472 | "indices_base[i]=" |
473 | << indices_base[i] << " vs lhs->region[i]->min=" << lhs->region[i]->min; |
474 | EmitError(os.str()); |
475 | } |
476 | return false; |
477 | } |
478 | } |
479 | for (size_t i = 0; i < rhs->region.size(); i++) { |
480 | // check extent match |
481 | if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { |
482 | if (assert_mode_) { |
483 | std::ostringstream os; |
484 | os << "CompareBufferRegion buffer region extent mismatch. lhs->region[i + offset]=" |
485 | << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i]; |
486 | EmitError(os.str()); |
487 | } |
488 | return false; |
489 | } |
490 | PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]); |
491 | if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { |
492 | if (assert_mode_) { |
493 | std::ostringstream os; |
494 | os << "CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]=" |
495 | << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i]; |
496 | EmitError(os.str()); |
497 | } |
498 | return false; |
499 | } |
500 | } |
501 | } |
502 | return true; |
503 | } |
504 | |
505 | // Comparator for BufferStoreNode and BufferLoadNode |
506 | template <typename T> |
507 | bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { |
508 | if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; |
509 | int offset = static_cast<int>(lhs->indices.size()) - static_cast<int>(rhs->indices.size()); |
510 | if (offset < 0) { |
511 | if (assert_mode_) { |
512 | std::ostringstream os; |
513 | os << "CompareBufferAccess returning false because buffer indices sizes do not match: " |
514 | "lhs->indices.size()=" |
515 | << lhs->indices.size() << " vs rhs->indices.size()=" << rhs->indices.size(); |
516 | EmitError(os.str()); |
517 | } |
518 | return false; |
519 | } |
520 | auto it = buffer_indices_.find(lhs->buffer); |
521 | ICHECK(it != buffer_indices_.end()); |
522 | const std::vector<PrimExpr>& indices_base = (*it).second; |
523 | ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset); |
524 | for (size_t i = 0; i < rhs->indices.size(); i++) { |
525 | PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i + offset]; |
526 | if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) { |
527 | if (assert_mode_) { |
528 | std::ostringstream os; |
529 | os << "CompareBufferAccess buffer indices mismatch. lhs->indices[i + offset]=" |
530 | << lhs->indices[i + offset] << " vs rhs->indices[i]=" << rhs->indices[i]; |
531 | EmitError(os.str()); |
532 | } |
533 | return false; |
534 | } |
535 | } |
536 | return true; |
537 | } |
538 | |
539 | template <typename T, typename Self, typename F> |
540 | bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp) { |
541 | if (lhs.same_as(rhs)) return true; |
542 | if (lhs.size() != rhs.size()) { |
543 | if (assert_mode_) { |
544 | std::ostringstream os; |
545 | os << "CompareArray array size mismatch. lhs.size()=" << lhs.size() |
546 | << " vs rhs.size()=" << rhs.size(); |
547 | EmitError(os.str()); |
548 | } |
549 | return false; |
550 | } |
551 | for (size_t i = 0; i < lhs.size(); ++i) { |
552 | if (!(static_cast<Self*>(this)->*cmp)(lhs[i], rhs[i])) return false; |
553 | } |
554 | return true; |
555 | } |
556 | |
557 | bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { |
558 | return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); |
559 | } |
560 | |
561 | bool TensorizeComparator::CompareIterVar(const IterVar& lhs, const IterVar& rhs) { |
562 | return DefEqual(lhs->var, rhs->var) && lhs->iter_type == rhs->iter_type; |
563 | } |
564 | |
565 | void TensorizeComparator::EmitError(const std::string& error_message) { |
566 | error_messages_.push_back(error_message); |
567 | } |
568 | |
569 | /******** AutoTensorize Extractor ********/ |
570 | |
571 | bool AutoTensorizeComparator::VisitExprDefault_(const Object* op, const PrimExpr& other) { |
572 | return false; |
573 | } |
574 | |
575 | bool AutoTensorizeComparator::VisitStmtDefault_(const Object* op, const Stmt& other) { |
576 | return false; |
577 | } |
578 | |
579 | bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { |
580 | const auto* rhs = other.as<BlockNode>(); |
581 | // Check block equality. |
582 | // All iter vars and buffer regions including the order should match. |
583 | // When checking iter vars, DefEqual is used to remap variables. |
584 | if (!is_scope_block) { |
585 | if (!CompareArray(op->iter_vars, rhs->iter_vars, &AutoTensorizeComparator::CompareIterVar)) { |
586 | return false; |
587 | } |
588 | if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { |
589 | return false; |
590 | } |
591 | if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, |
592 | &AutoTensorizeComparator::CompareBuffer)) { |
593 | return false; |
594 | } |
595 | for (const IterVar& block_iter : op->iter_vars) { |
596 | inner_iter_dom_map_.Set(block_iter->var, arith::IntSet::FromRange(block_iter->dom)); |
597 | } |
598 | } else { |
599 | auto collect_iter = [&](const BlockNode* op, std::vector<IterVar>& iters) -> bool { |
600 | for (const auto& iter : op->iter_vars) { |
601 | analyzer_.Bind(iter->var, iter->dom); |
602 | if (iter->iter_type == IterVarType::kDataPar || |
603 | iter->iter_type == IterVarType::kCommReduce) { |
604 | iters.push_back(iter); |
605 | } else { |
606 | return false; |
607 | } |
608 | } |
609 | return true; |
610 | }; |
611 | if (!collect_iter(op, lhs_iters_)) { |
612 | return false; |
613 | } |
614 | if (!collect_iter(rhs, rhs_iters_)) { |
615 | return false; |
616 | } |
617 | } |
618 | is_scope_block = false; |
619 | return VisitStmt(op->body, rhs->body); |
620 | } |
621 | |
622 | bool AutoTensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { |
623 | if (lhs.same_as(rhs)) return true; |
624 | auto it = rhs_buffer_map_.find(rhs); |
625 | bool equal; |
626 | if (it != rhs_buffer_map_.end()) { |
627 | equal = (*it).second.same_as(lhs); |
628 | } else { |
629 | // Remap both buffer itself and buffer data, skip buffer shape and scope |
630 | equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype; |
631 | if (equal) { |
632 | rhs_buffer_map_[rhs] = lhs; |
633 | lhs_buffer_map_[lhs] = rhs; |
634 | } |
635 | } |
636 | return equal; |
637 | } |
638 | |
639 | bool AutoTensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { |
640 | const auto* rhs = other.as<BufferStoreNode>(); |
641 | return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); |
642 | } |
643 | |
644 | bool AutoTensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { |
645 | const auto* rhs = other.as<BufferLoadNode>(); |
646 | return CompareBufferAccess(op, rhs); |
647 | } |
648 | |
649 | template <typename T> |
650 | bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { |
651 | if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; |
652 | auto it_lhs = lhs_buffer_indices_map_.find(lhs->buffer); |
653 | if (it_lhs == lhs_buffer_indices_map_.end()) { |
654 | if (rhs_buffer_indices_map_.find(rhs->buffer) != rhs_buffer_indices_map_.end()) { |
655 | return false; |
656 | } |
657 | std::vector<PrimExpr> lhs_indices; |
658 | for (const PrimExpr& index : lhs->indices) { |
659 | lhs_indices.push_back(SimplifyNonTrivialExpr(index, &analyzer_)); |
660 | } |
661 | |
662 | auto is_scalar_access = [](const Array<PrimExpr>& indices, PrimExpr index) { |
663 | // Check if the indexing is of the form C[0] |
664 | if (indices.size() > 1) return false; |
665 | auto int_imm = index.template as<IntImmNode>(); |
666 | if (int_imm && int_imm->value == 0) return true; |
667 | return false; |
668 | }; |
669 | |
670 | for (const auto& index : rhs->indices) { |
671 | if (!index.template as<VarNode>() && !is_scalar_access(rhs->indices, index)) return false; |
672 | } |
673 | lhs_buffer_indices_map_[lhs->buffer] = lhs_indices; |
674 | rhs_buffer_indices_map_[rhs->buffer] = rhs->indices; |
675 | } else { |
676 | auto it_rhs = rhs_buffer_indices_map_.find(rhs->buffer); |
677 | if (it_rhs == rhs_buffer_indices_map_.end()) { |
678 | return false; |
679 | } |
680 | auto indices_check = [&](const Array<PrimExpr>& indices, |
681 | const Array<PrimExpr>& old_indices) -> bool { |
682 | if (indices.size() != old_indices.size()) { |
683 | return false; |
684 | } |
685 | for (size_t i = 0; i < indices.size(); ++i) { |
686 | if (!analyzer_.CanProveEqual(indices[i], old_indices[i])) { |
687 | return false; |
688 | } |
689 | } |
690 | return true; |
691 | }; |
692 | if (!indices_check(lhs->indices, it_lhs->second)) return false; |
693 | if (!indices_check(rhs->indices, it_rhs->second)) return false; |
694 | } |
695 | return true; |
696 | } |
697 | |
698 | } // namespace tir |
699 | } // namespace tvm |
700 | |