1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file message_passing.cc
22 * \brief The message passing domain.
23 */
24#include "message_passing.h"
25
26#include <tvm/arith/analyzer.h>
27#include <tvm/tir/expr.h>
28
29namespace tvm {
30namespace te {
31
32using namespace tir;
33
34void Update(std::unordered_map<IterVar, Range>* p_state, const IterVar& iv, Range r,
35 arith::Analyzer* analyzer) {
36 auto it = p_state->find(iv);
37 if (it == p_state->end()) {
38 (*p_state)[iv] = r;
39 analyzer->Bind(iv->var, r);
40 } else {
41 bool match =
42 is_zero(it->second->min) && analyzer->CanProve(r->extent - it->second->extent == 0);
43 ICHECK(match) << iv << " domain already inferred,"
44 << " cannot prove their extents are the same " << it->second->extent << " vs "
45 << r->extent;
46 }
47}
48
49/*!
50 * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to
51 * a thread.
52 *
53 * \param stage The stage to operate on.
54 * \param p_state The propagation result of each IterVar.
55 */
56void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) {
57 auto bound_to_thread = [&stage](const IterVar& iv) {
58 bool bound = false;
59 auto it = stage->iter_var_attrs.find(iv);
60 if (it != stage->iter_var_attrs.end()) {
61 bound = (*it).second->bind_thread.defined();
62 }
63 return bound;
64 };
65
66 auto& state = *p_state;
67 // Fill p_state with leaf itervars
68 for (const IterVar& iv : stage->leaf_iter_vars) {
69 state[iv] = bound_to_thread(iv);
70 }
71 // Traverse the graph bottom-up to propagate thread binding information
72 for (size_t i = stage->relations.size(); i != 0; --i) {
73 IterVarRelation rel = stage->relations[i - 1];
74 if (const SplitNode* s = rel.as<SplitNode>()) {
75 state[s->parent] = state[s->inner] || state[s->outer];
76 } else if (const FuseNode* s = rel.as<FuseNode>()) {
77 state[s->inner] = state[s->fused];
78 state[s->outer] = state[s->fused];
79 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
80 state[s->parent] = state[s->rebased];
81 } else if (rel.as<SingletonNode>()) {
82 } else if (const TransformNode* s = rel.as<TransformNode>()) {
83 // Currently, this marks all original iter vars as deriving from
84 // a thread bind if any of the transformed variables are bound,
85 // even if the inverse expression for that iter var doesn't
86 // depend on the bound variable.
87
88 // TODO(Lunderberg): For each of original variable, check
89 // whether any variable in the inverse expression for it has a
90 // thread binding.
91 bool is_thread_binding = false;
92 for (const auto& iter_var : s->transformed_variables) {
93 is_thread_binding = is_thread_binding || state[iter_var];
94 }
95 for (const auto& iter_var : s->original_variables) {
96 state[iter_var] = is_thread_binding;
97 }
98 } else {
99 LOG(FATAL) << "unknown relation type";
100 }
101 }
102}
103
104void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_state,
105 arith::Analyzer* actx, bool allow_missing) {
106 auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
107 if (actx->CanProve(indexmod(a, b) == 0)) {
108 return actx->Simplify(indexdiv(a, b));
109 }
110 return actx->Simplify(indexdiv(a + (b - 1), b));
111 };
112
113 auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
114 if (actx->CanProve(a < b)) {
115 return actx->Simplify(a);
116 }
117 return actx->Simplify(b);
118 };
119
120 std::unordered_map<IterVar, bool> dominating_thread;
121 PassUpThreadBinding(stage, &dominating_thread);
122
123 auto& state = *p_state;
124 // forwar iteration on relations
125 for (IterVarRelation rel : stage->relations) {
126 if (const SplitNode* r = rel.as<SplitNode>()) {
127 if (!state.count(r->parent)) {
128 ICHECK(allow_missing);
129 continue;
130 }
131 ICHECK(!state.count(r->inner));
132 const Range& range_parent = state.at(r->parent);
133 // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the
134 // following conditions are met:
135 // 1. No leaf IterVar derived from iv binds to any thread. People may use split
136 // to force an IterVar extent to match the number of allocated threads to fuse stages
137 // that require different number of threads. We don't want to change these extents.
138 // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,
139 // rather than by an early compiler phase, such as rfactor(). We don't want to tighten an
140 // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later.
141 // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one
142 // zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent
143 // IterVar. We don't touch it.
144 auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) {
145 return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent)
146 ? factor_or_nparts
147 : minimum_or_later(range_parent->extent, factor_or_nparts);
148 };
149 if (r->factor.defined()) {
150 Update(p_state, r->inner,
151 Range::FromMinExtent(0, cast(range_parent->extent.dtype(),
152 resolve_min_extent_for_split(r->inner, r->factor))),
153 actx);
154 Update(p_state, r->outer,
155 Range::FromMinExtent(0, ceil_div(range_parent->extent, r->factor)), actx);
156 } else {
157 Update(p_state, r->outer,
158 Range::FromMinExtent(0, cast(range_parent->extent.dtype(),
159 resolve_min_extent_for_split(r->outer, r->nparts))),
160 actx);
161 Update(p_state, r->inner,
162 Range::FromMinExtent(0, ceil_div(range_parent->extent, r->nparts)), actx);
163 }
164 } else if (const FuseNode* r = rel.as<FuseNode>()) {
165 if (!state.count(r->outer) || !state.count(r->inner)) {
166 ICHECK(allow_missing);
167 continue;
168 }
169 const Range& range_outer = state.at(r->outer);
170 const Range& range_inner = state.at(r->inner);
171 state[r->fused] = Range::FromMinExtent(0, range_outer->extent * range_inner->extent);
172 } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
173 if (!state.count(r->parent)) {
174 ICHECK(allow_missing);
175 continue;
176 }
177 Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx);
178 } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
179 Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx);
180 } else if (const TransformNode* s = rel.as<TransformNode>()) {
181 bool missing_originals = false;
182 for (const auto& iter_var : s->original_variables) {
183 if (!state.count(iter_var)) {
184 ICHECK(allow_missing);
185 missing_originals = true;
186 }
187 }
188 if (missing_originals) {
189 continue;
190 }
191
192 Array<Range> original_ranges;
193 for (const auto& iter_var : s->original_variables) {
194 original_ranges.push_back(state[iter_var]);
195 }
196 Array<Range> updated_ranges = s->forward_transformation->MapRanges(original_ranges);
197
198 ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size());
199 for (size_t i = 0; i < updated_ranges.size(); i++) {
200 Update(p_state, s->transformed_variables[i], updated_ranges[i], actx);
201 }
202
203 } else {
204 LOG(FATAL) << "unknown relation type";
205 }
206 }
207 // update the extents of binded threads.
208 for (auto kv : stage->iter_var_attrs) {
209 if (kv.second->bind_thread.defined()) {
210 ICHECK(state.count(kv.first));
211 Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
212 }
213 }
214}
215
216void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
217 std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) {
218 auto& state = *p_state;
219 for (size_t i = stage->relations.size(); i != 0; --i) {
220 IterVarRelation rel = stage->relations[i - 1];
221 if (const SplitNode* s = rel.as<SplitNode>()) {
222 if (!state.count(s->outer) || !state.count(s->inner)) {
223 ICHECK(allow_missing);
224 continue;
225 }
226 PrimExpr outer = state.at(s->outer);
227 PrimExpr inner = state.at(s->inner);
228 PrimExpr factor = dom_map.at(s->inner)->extent;
229 PrimExpr parent_min = dom_map.at(s->parent)->min;
230 state[s->parent] = inner + outer * factor;
231 // add min if they exist
232 if (!is_zero(parent_min)) {
233 state[s->parent] = state[s->parent] + parent_min;
234 }
235 } else if (const FuseNode* s = rel.as<FuseNode>()) {
236 if (!state.count(s->fused)) {
237 ICHECK(allow_missing);
238 continue;
239 }
240 PrimExpr value = state.at(s->fused);
241 PrimExpr factor = dom_map.at(s->inner)->extent;
242 PrimExpr outer_min = dom_map.at(s->outer)->min;
243 PrimExpr inner_min = dom_map.at(s->inner)->min;
244 state[s->outer] = indexdiv(value, factor);
245 state[s->inner] = indexmod(value, factor);
246 // add min if they exist
247 if (!is_zero(outer_min)) {
248 state[s->outer] = state[s->outer] + outer_min;
249 }
250 if (!is_zero(inner_min)) {
251 state[s->inner] = state[s->inner] + inner_min;
252 }
253 // s->fused, s->outer and s->inner may be of different dtype,
254 // so we cast the `state` back to its original dtype
255 state[s->outer] = cast(s->outer->var.dtype(), state[s->outer]);
256 state[s->inner] = cast(s->inner->var.dtype(), state[s->inner]);
257 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
258 if (!state.count(s->rebased)) {
259 ICHECK(allow_missing);
260 continue;
261 }
262 PrimExpr value = state.at(s->rebased);
263 PrimExpr parent_min = dom_map.at(s->parent)->min;
264 // add min if they exist
265 if (!is_zero(parent_min)) {
266 state[s->parent] = value + parent_min;
267 } else {
268 state[s->parent] = value;
269 }
270 } else if (rel.as<SingletonNode>()) {
271 } else if (const TransformNode* s = rel.as<TransformNode>()) {
272 bool missing_transformed = false;
273 for (const auto& iter_var : s->transformed_variables) {
274 if (!state.count(iter_var)) {
275 ICHECK(allow_missing);
276 missing_transformed = true;
277 }
278 }
279 if (missing_transformed) {
280 continue;
281 }
282
283 Array<PrimExpr> transformed_indices;
284 for (const auto& iter_var : s->transformed_variables) {
285 transformed_indices.push_back(state[iter_var]);
286 }
287 Array<PrimExpr> original_indices = s->inverse_transformation->MapIndices(transformed_indices);
288
289 ICHECK_EQ(original_indices.size(), s->original_variables.size());
290 for (size_t i = 0; i < original_indices.size(); i++) {
291 state[s->original_variables[i]] = original_indices[i];
292 }
293
294 } else {
295 LOG(FATAL) << "unknown relation type";
296 }
297 }
298}
299
300void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
301 std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) {
302 auto& state = *p_state;
303 for (IterVarRelation rel : stage->relations) {
304 if (const SplitNode* s = rel.as<SplitNode>()) {
305 if (!state.count(s->parent)) {
306 ICHECK(allow_missing);
307 continue;
308 }
309 Range r = dom_map.at(s->inner);
310 ICHECK(is_zero(r->min));
311 PrimExpr parent = state.at(s->parent);
312 PrimExpr factor = r->extent;
313 state[s->outer] = indexdiv(parent, factor);
314 state[s->inner] = indexmod(parent, factor);
315 } else if (const FuseNode* s = rel.as<FuseNode>()) {
316 if (!state.count(s->inner) && !state.count(s->outer)) {
317 ICHECK(allow_missing);
318 continue;
319 }
320 PrimExpr factor = dom_map.at(s->inner)->extent;
321 PrimExpr outer_min = dom_map.at(s->outer)->min;
322 PrimExpr inner_min = dom_map.at(s->inner)->min;
323 PrimExpr inner = state.at(s->inner);
324 PrimExpr outer = state.at(s->outer);
325 ICHECK(is_zero(outer_min));
326 ICHECK(is_zero(inner_min));
327 state[s->fused] = outer * factor + inner;
328 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
329 if (!state.count(s->rebased)) {
330 ICHECK(allow_missing);
331 continue;
332 }
333 PrimExpr value = state.at(s->parent);
334 PrimExpr parent_min = dom_map.at(s->parent)->min;
335 ICHECK(is_zero(parent_min));
336 state[s->rebased] = value;
337 } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
338 state[s->iter] = make_zero(s->iter->var.dtype());
339 } else if (const TransformNode* s = rel.as<TransformNode>()) {
340 bool missing_originals = false;
341 for (const auto& iter_var : s->original_variables) {
342 if (!state.count(iter_var)) {
343 ICHECK(allow_missing);
344 missing_originals = true;
345 }
346 }
347 if (missing_originals) {
348 continue;
349 }
350
351 Array<PrimExpr> original_indices;
352 for (const auto& iter_var : s->original_variables) {
353 original_indices.push_back(state[iter_var]);
354 }
355 Array<PrimExpr> transformed_indices = s->forward_transformation->MapIndices(original_indices);
356
357 ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size());
358 for (size_t i = 0; i < transformed_indices.size(); i++) {
359 state[s->transformed_variables[i]] = transformed_indices[i];
360 }
361 } else {
362 LOG(FATAL) << "unknown relation type";
363 }
364 }
365}
366
367// Domain message passing.
368void PassUpDomain(const SplitNode* s, const std::unordered_map<IterVar, Range>& dom_map,
369 const IntSet& outer, const IntSet& inner, IntSet* parent) {
370 if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) &&
371 outer.MatchRange(dom_map.at(s->outer)) && inner.MatchRange(dom_map.at(s->inner))) {
372 *parent = IntSet::FromRange(dom_map.at(s->parent));
373 return;
374 }
375 PrimExpr factor = dom_map.at(s->inner)->extent;
376 PrimExpr parent_min = dom_map.at(s->parent)->min;
377 ICHECK(outer.defined());
378 ICHECK(inner.defined());
379 ICHECK(factor.defined());
380 *parent = arith::EvalSet(s->outer->var * factor + s->inner->var + parent_min,
381 {{s->outer, outer}, {s->inner, inner}});
382}
383
384void PassUpDomain(const FuseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
385 const IntSet& fused, IntSet* outer, IntSet* inner) {
386 ICHECK(dom_map.count(s->outer));
387 ICHECK(dom_map.count(s->inner));
388 ICHECK(dom_map.count(s->fused));
389 arith::Analyzer ana;
390
391 if (fused.MatchRange(dom_map.at(s->fused))) {
392 *outer = IntSet::FromRange(dom_map.at(s->outer));
393 *inner = IntSet::FromRange(dom_map.at(s->inner));
394 return;
395 }
396 PrimExpr outer_min = dom_map.at(s->outer)->min;
397 PrimExpr inner_min = dom_map.at(s->inner)->min;
398
399 if (fused.IsSinglePoint()) {
400 PrimExpr value = fused.PointValue();
401 PrimExpr factor = dom_map.at(s->inner)->extent;
402 PrimExpr v_outer = indexdiv(value, factor);
403 PrimExpr v_inner = indexmod(value, factor);
404 if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
405 if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
406 *outer = IntSet::SinglePoint(v_outer);
407 *inner = IntSet::SinglePoint(v_inner);
408 } else {
409 PrimExpr fused_extent = (fused.max() - fused.min() + 1);
410 PrimExpr inner_extent = dom_map.at(s->inner)->extent;
411 *outer = IntSet::Interval(outer_min + indexdiv(fused.min(), inner_extent),
412 outer_min + indexdiv(fused.max(), inner_extent));
413 if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) &&
414 is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) {
415 // fused never spans multiple rows, make a tight bounding box
416 // there may be other cases when bounding box could be tightened
417 *inner = IntSet::Interval(inner_min + indexmod(fused.min(), inner_extent),
418 inner_min + indexmod(fused.max(), inner_extent));
419 } else { // fused may span multiple rows, use full row widths
420 if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) ||
421 !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) {
422 LOG(WARNING)
423 << "fused and original axes are not aligned, this may cause redundant computations";
424 }
425 *inner = IntSet::FromRange(dom_map.at(s->inner));
426 }
427 return;
428 }
429}
430
431void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
432 const IntSet& rebased, IntSet* parent) {
433 ICHECK(dom_map.count(s->parent));
434 if (rebased.MatchRange(dom_map.at(s->rebased))) {
435 *parent = IntSet::FromRange(dom_map.at(s->parent));
436 return;
437 }
438 PrimExpr parent_min = dom_map.at(s->parent)->min;
439 *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}});
440}
441
442Array<IntSet> PassUpDomain(const TransformNode* s,
443 const std::unordered_map<IterVar, Range>& dom_map,
444 const Map<IterVar, IntSet>& transformed_domains) {
445 Array<IntSet> output;
446
447 Array<PrimExpr> transformed_indices;
448 for (const auto& iter_var : s->transformed_variables) {
449 transformed_indices.push_back(iter_var->var);
450 }
451
452 Array<PrimExpr> transformed_exprs = s->inverse_transformation->MapIndices(transformed_indices);
453
454 ICHECK_EQ(transformed_exprs.size(), s->original_variables.size());
455 for (size_t i = 0; i < transformed_exprs.size(); i++) {
456 output.push_back(arith::EvalSet(transformed_exprs[i], transformed_domains));
457 }
458
459 return output;
460}
461
462void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
463 std::unordered_map<IterVar, IntSet>* p_state) {
464 auto& state = *p_state;
465 for (size_t i = stage->relations.size(); i != 0; --i) {
466 IterVarRelation rel = stage->relations[i - 1];
467 if (const SplitNode* r = rel.as<SplitNode>()) {
468 IntSet parent;
469 PassUpDomain(r, dom_map, state.at(r->outer), state.at(r->inner), &parent);
470 state[r->parent] = parent;
471 } else if (const FuseNode* r = rel.as<FuseNode>()) {
472 IntSet outer, inner;
473 PassUpDomain(r, dom_map, state.at(r->fused), &outer, &inner);
474 state[r->outer] = outer;
475 state[r->inner] = inner;
476 } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
477 IntSet parent;
478 PassUpDomain(r, dom_map, state.at(r->rebased), &parent);
479 state[r->parent] = parent;
480 } else if (rel.as<SingletonNode>()) {
481 } else if (const TransformNode* r = rel.as<TransformNode>()) {
482 Map<IterVar, IntSet> transformed_domains;
483 for (const auto& var : r->transformed_variables) {
484 transformed_domains.Set(var, state.at(var));
485 }
486 auto original_ranges = PassUpDomain(r, dom_map, transformed_domains);
487 ICHECK_EQ(original_ranges.size(), r->original_variables.size());
488 for (size_t i = 0; i < original_ranges.size(); i++) {
489 state[r->original_variables[i]] = original_ranges[i];
490 }
491 } else {
492 LOG(FATAL) << "unknown relation type";
493 }
494 }
495}
496
497// Pass up bit mask with or relation.
498void PassUpBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
499 bool allow_missing) {
500 auto& state = *p_state;
501 for (size_t i = stage->relations.size(); i != 0; --i) {
502 IterVarRelation rel = stage->relations[i - 1];
503 if (const SplitNode* s = rel.as<SplitNode>()) {
504 if (!state.count(s->inner) && !state.count(s->outer)) {
505 ICHECK(allow_missing);
506 continue;
507 }
508 int res = 0;
509 if (state.count(s->parent)) res |= state[s->parent];
510 if (state.count(s->inner)) res |= state[s->inner];
511 if (state.count(s->outer)) res |= state[s->outer];
512 state[s->parent] = res;
513 } else if (const FuseNode* s = rel.as<FuseNode>()) {
514 if (!state.count(s->fused)) {
515 ICHECK(allow_missing);
516 continue;
517 }
518 if (!state.count(s->outer)) {
519 state[s->outer] = state[s->fused];
520 } else {
521 state[s->outer] |= state[s->fused];
522 }
523 if (!state.count(s->inner)) {
524 state[s->inner] = state[s->fused];
525 } else {
526 state[s->inner] |= state[s->fused];
527 }
528 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
529 if (!state.count(s->rebased)) {
530 ICHECK(allow_missing);
531 continue;
532 }
533 if (!state.count(s->parent)) {
534 state[s->parent] = state[s->rebased];
535 } else {
536 state[s->parent] |= state[s->rebased];
537 }
538 } else if (const TransformNode* s = rel.as<TransformNode>()) {
539 for (const auto& original_var : s->original_variables) {
540 for (const auto& transformed_var : s->transformed_variables) {
541 if (!state.count(transformed_var)) {
542 ICHECK(allow_missing);
543 continue;
544 }
545 state[original_var] |= state[transformed_var];
546 }
547 }
548
549 } else if (rel.as<SingletonNode>()) {
550 } else {
551 LOG(FATAL) << "unknown relation type";
552 }
553 }
554}
555
556void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
557 bool allow_missing) {
558 auto& state = *p_state;
559 for (IterVarRelation rel : stage->relations) {
560 if (const SplitNode* s = rel.as<SplitNode>()) {
561 if (!state.count(s->parent)) {
562 ICHECK(allow_missing);
563 continue;
564 }
565 if (!state.count(s->outer)) {
566 state[s->outer] = state.at(s->parent);
567 } else {
568 state[s->outer] |= state.at(s->parent);
569 }
570 if (!state.count(s->inner)) {
571 state[s->inner] = state.at(s->parent);
572 } else {
573 state[s->inner] |= state.at(s->parent);
574 }
575 } else if (const FuseNode* s = rel.as<FuseNode>()) {
576 if (!state.count(s->outer) && !state.count(s->inner)) {
577 ICHECK(allow_missing);
578 continue;
579 }
580 int res = 0;
581 if (state.count(s->outer)) res |= state.at(s->outer);
582 if (state.count(s->inner)) res |= state.at(s->inner);
583 if (state.count(s->fused)) res |= state.at(s->fused);
584 state[s->fused] = res;
585 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
586 if (!state.count(s->parent)) {
587 ICHECK(allow_missing);
588 continue;
589 }
590 if (!state.count(s->rebased)) {
591 state[s->rebased] = state.at(s->parent);
592 } else {
593 state[s->rebased] |= state.at(s->parent);
594 }
595 } else if (const TransformNode* s = rel.as<TransformNode>()) {
596 for (const auto& original_var : s->original_variables) {
597 for (const auto& transformed_var : s->transformed_variables) {
598 if (!state.count(original_var)) {
599 ICHECK(allow_missing);
600 continue;
601 }
602 state[transformed_var] |= state[original_var];
603 }
604 }
605 } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
606 state[s->iter] = 0;
607 } else {
608 LOG(FATAL) << "unknown relation type";
609 }
610 }
611}
612
613/*!
614 * \brief message passing to find if boundary checking on IterVar is needed.
615 * \param s The stage to be used.
616 * \param p_state The message passing state
617 * IterVar->flag
618 */
619void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
620 std::unordered_map<IterVar, bool>* p_state, arith::Analyzer* analyzer) {
621 auto& state = *p_state;
622 for (size_t i = s->relations.size(); i != 0; --i) {
623 IterVarRelation rel = s->relations[i - 1];
624 if (const SplitNode* s = rel.as<SplitNode>()) {
625 bool outer = state.at(s->outer);
626 bool inner = state.at(s->inner);
627
628 if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
629 PrimExpr factor = dom_map.at(s->inner)->extent;
630 PrimExpr step = dom_map.at(s->outer)->extent;
631 if (outer || inner) {
632 state[s->parent] = true;
633 } else {
634 if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
635 state[s->parent] = false;
636 } else {
637 state[s->parent] = true;
638 }
639 }
640 } else {
641 state[s->parent] = true;
642 }
643 } else if (const FuseNode* s = rel.as<FuseNode>()) {
644 bool fused = state.at(s->fused);
645 state[s->outer] = fused;
646 state[s->inner] = fused;
647 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
648 state[s->parent] = state.at(s->rebased);
649 } else if (rel.as<SingletonNode>()) {
650 // nop
651 } else if (const TransformNode* s = rel.as<TransformNode>()) {
652 // Currently, this marks all original iter vars as requiring
653 // bounds checks if any of the transformed variables require
654 // bounds checks, even if the inverse expression for that iter
655 // var doesn't depend on the bound variable.
656
657 // TODO(Lunderberg): For each of original variable, check
658 // whether any variable in the inverse expression for it
659 // requires bounds checking.
660 bool needs_bounds_check = false;
661 for (const auto& iter_var : s->transformed_variables) {
662 needs_bounds_check = needs_bounds_check || state[iter_var];
663 }
664 for (const auto& iter_var : s->original_variables) {
665 state[iter_var] = needs_bounds_check;
666 }
667 } else {
668 LOG(FATAL) << "unknown relation type";
669 }
670 }
671}
672
673bool IsRangeSame(const Range input_1, const Range input_2) {
674 arith::Analyzer analyzer;
675 if (input_1.same_as(input_2)) return true;
676
677 return (analyzer.CanProve(input_1->min == input_2->min) &&
678 analyzer.CanProve(input_1->extent == input_2->extent));
679}
680
681std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map,
682 const std::unordered_map<IterVar, PrimExpr>& value_map,
683 bool skip_ivar_domain,
684 const std::unordered_set<IterVar>& skip_iter) {
685 arith::Analyzer analyzer;
686
687 std::unordered_map<IterVar, bool> bound_state;
688 for (IterVar iv : stage->leaf_iter_vars) {
689 bound_state[iv] = false;
690 }
691 PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
692
693 std::vector<PrimExpr> preds;
694 Map<Var, IntSet> iset_dmap;
695
696 // setup domain map for set analysis
697 for (const auto& kv : dom_map) {
698 iset_dmap.Set(kv.first->var, IntSet::FromRange(kv.second));
699 }
700
701 for (auto entry : dom_map) {
702 analyzer.Bind(entry.first->var, entry.second);
703 }
704
705 for (const IterVar& iv : stage->all_iter_vars) {
706 if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
707 if (bound_state.at(iv)) {
708 Range dom = dom_map.at(iv);
709 PrimExpr value = value_map.at(iv) - dom->min;
710 PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
711 if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
712 preds.emplace_back(value < dom->extent);
713 }
714 }
715 }
716 for (const IterVar& iv : stage->op->root_iter_vars()) {
717 if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
718 Range dom = dom_map.at(iv);
719 ICHECK(iv->dom.defined());
720 if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
721 PrimExpr value = value_map.at(iv) - iv->dom->min;
722 IntSet s = analyzer.int_set(value, iset_dmap);
723 PrimExpr vmin = s.min();
724 PrimExpr vmax = s.max();
725 // The range of `value` resides in [vmin, vmax]
726 if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
727 preds.emplace_back(value >= 0);
728 }
729 if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
730 preds.emplace_back(value < iv->dom->extent);
731 }
732 }
733 }
734 return preds;
735}
736} // namespace te
737} // namespace tvm
738