1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file message_passing.cc |
22 | * \brief The message passing domain. |
23 | */ |
24 | #include "message_passing.h" |
25 | |
26 | #include <tvm/arith/analyzer.h> |
27 | #include <tvm/tir/expr.h> |
28 | |
29 | namespace tvm { |
30 | namespace te { |
31 | |
32 | using namespace tir; |
33 | |
34 | void Update(std::unordered_map<IterVar, Range>* p_state, const IterVar& iv, Range r, |
35 | arith::Analyzer* analyzer) { |
36 | auto it = p_state->find(iv); |
37 | if (it == p_state->end()) { |
38 | (*p_state)[iv] = r; |
39 | analyzer->Bind(iv->var, r); |
40 | } else { |
41 | bool match = |
42 | is_zero(it->second->min) && analyzer->CanProve(r->extent - it->second->extent == 0); |
43 | ICHECK(match) << iv << " domain already inferred," |
44 | << " cannot prove their extents are the same " << it->second->extent << " vs " |
45 | << r->extent; |
46 | } |
47 | } |
48 | |
49 | /*! |
50 | * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to |
51 | * a thread. |
52 | * |
53 | * \param stage The stage to operate on. |
54 | * \param p_state The propagation result of each IterVar. |
55 | */ |
56 | void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) { |
57 | auto bound_to_thread = [&stage](const IterVar& iv) { |
58 | bool bound = false; |
59 | auto it = stage->iter_var_attrs.find(iv); |
60 | if (it != stage->iter_var_attrs.end()) { |
61 | bound = (*it).second->bind_thread.defined(); |
62 | } |
63 | return bound; |
64 | }; |
65 | |
66 | auto& state = *p_state; |
67 | // Fill p_state with leaf itervars |
68 | for (const IterVar& iv : stage->leaf_iter_vars) { |
69 | state[iv] = bound_to_thread(iv); |
70 | } |
71 | // Traverse the graph bottom-up to propagate thread binding information |
72 | for (size_t i = stage->relations.size(); i != 0; --i) { |
73 | IterVarRelation rel = stage->relations[i - 1]; |
74 | if (const SplitNode* s = rel.as<SplitNode>()) { |
75 | state[s->parent] = state[s->inner] || state[s->outer]; |
76 | } else if (const FuseNode* s = rel.as<FuseNode>()) { |
77 | state[s->inner] = state[s->fused]; |
78 | state[s->outer] = state[s->fused]; |
79 | } else if (const RebaseNode* s = rel.as<RebaseNode>()) { |
80 | state[s->parent] = state[s->rebased]; |
81 | } else if (rel.as<SingletonNode>()) { |
82 | } else if (const TransformNode* s = rel.as<TransformNode>()) { |
83 | // Currently, this marks all original iter vars as deriving from |
84 | // a thread bind if any of the transformed variables are bound, |
85 | // even if the inverse expression for that iter var doesn't |
86 | // depend on the bound variable. |
87 | |
88 | // TODO(Lunderberg): For each of original variable, check |
89 | // whether any variable in the inverse expression for it has a |
90 | // thread binding. |
91 | bool is_thread_binding = false; |
92 | for (const auto& iter_var : s->transformed_variables) { |
93 | is_thread_binding = is_thread_binding || state[iter_var]; |
94 | } |
95 | for (const auto& iter_var : s->original_variables) { |
96 | state[iter_var] = is_thread_binding; |
97 | } |
98 | } else { |
99 | LOG(FATAL) << "unknown relation type" ; |
100 | } |
101 | } |
102 | } |
103 | |
104 | void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_state, |
105 | arith::Analyzer* actx, bool allow_missing) { |
106 | auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) { |
107 | if (actx->CanProve(indexmod(a, b) == 0)) { |
108 | return actx->Simplify(indexdiv(a, b)); |
109 | } |
110 | return actx->Simplify(indexdiv(a + (b - 1), b)); |
111 | }; |
112 | |
113 | auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { |
114 | if (actx->CanProve(a < b)) { |
115 | return actx->Simplify(a); |
116 | } |
117 | return actx->Simplify(b); |
118 | }; |
119 | |
120 | std::unordered_map<IterVar, bool> dominating_thread; |
121 | PassUpThreadBinding(stage, &dominating_thread); |
122 | |
123 | auto& state = *p_state; |
124 | // forwar iteration on relations |
125 | for (IterVarRelation rel : stage->relations) { |
126 | if (const SplitNode* r = rel.as<SplitNode>()) { |
127 | if (!state.count(r->parent)) { |
128 | ICHECK(allow_missing); |
129 | continue; |
130 | } |
131 | ICHECK(!state.count(r->inner)); |
132 | const Range& range_parent = state.at(r->parent); |
133 | // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the |
134 | // following conditions are met: |
135 | // 1. No leaf IterVar derived from iv binds to any thread. People may use split |
136 | // to force an IterVar extent to match the number of allocated threads to fuse stages |
137 | // that require different number of threads. We don't want to change these extents. |
138 | // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound, |
139 | // rather than by an early compiler phase, such as rfactor(). We don't want to tighten an |
140 | // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later. |
141 | // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one |
142 | // zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent |
143 | // IterVar. We don't touch it. |
144 | auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) { |
145 | return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent) |
146 | ? factor_or_nparts |
147 | : minimum_or_later(range_parent->extent, factor_or_nparts); |
148 | }; |
149 | if (r->factor.defined()) { |
150 | Update(p_state, r->inner, |
151 | Range::FromMinExtent(0, cast(range_parent->extent.dtype(), |
152 | resolve_min_extent_for_split(r->inner, r->factor))), |
153 | actx); |
154 | Update(p_state, r->outer, |
155 | Range::FromMinExtent(0, ceil_div(range_parent->extent, r->factor)), actx); |
156 | } else { |
157 | Update(p_state, r->outer, |
158 | Range::FromMinExtent(0, cast(range_parent->extent.dtype(), |
159 | resolve_min_extent_for_split(r->outer, r->nparts))), |
160 | actx); |
161 | Update(p_state, r->inner, |
162 | Range::FromMinExtent(0, ceil_div(range_parent->extent, r->nparts)), actx); |
163 | } |
164 | } else if (const FuseNode* r = rel.as<FuseNode>()) { |
165 | if (!state.count(r->outer) || !state.count(r->inner)) { |
166 | ICHECK(allow_missing); |
167 | continue; |
168 | } |
169 | const Range& range_outer = state.at(r->outer); |
170 | const Range& range_inner = state.at(r->inner); |
171 | state[r->fused] = Range::FromMinExtent(0, range_outer->extent * range_inner->extent); |
172 | } else if (const RebaseNode* r = rel.as<RebaseNode>()) { |
173 | if (!state.count(r->parent)) { |
174 | ICHECK(allow_missing); |
175 | continue; |
176 | } |
177 | Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx); |
178 | } else if (const SingletonNode* s = rel.as<SingletonNode>()) { |
179 | Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx); |
180 | } else if (const TransformNode* s = rel.as<TransformNode>()) { |
181 | bool missing_originals = false; |
182 | for (const auto& iter_var : s->original_variables) { |
183 | if (!state.count(iter_var)) { |
184 | ICHECK(allow_missing); |
185 | missing_originals = true; |
186 | } |
187 | } |
188 | if (missing_originals) { |
189 | continue; |
190 | } |
191 | |
192 | Array<Range> original_ranges; |
193 | for (const auto& iter_var : s->original_variables) { |
194 | original_ranges.push_back(state[iter_var]); |
195 | } |
196 | Array<Range> updated_ranges = s->forward_transformation->MapRanges(original_ranges); |
197 | |
198 | ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size()); |
199 | for (size_t i = 0; i < updated_ranges.size(); i++) { |
200 | Update(p_state, s->transformed_variables[i], updated_ranges[i], actx); |
201 | } |
202 | |
203 | } else { |
204 | LOG(FATAL) << "unknown relation type" ; |
205 | } |
206 | } |
207 | // update the extents of binded threads. |
208 | for (auto kv : stage->iter_var_attrs) { |
209 | if (kv.second->bind_thread.defined()) { |
210 | ICHECK(state.count(kv.first)); |
211 | Update(p_state, kv.second->bind_thread, state.at(kv.first), actx); |
212 | } |
213 | } |
214 | } |
215 | |
216 | void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map, |
217 | std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) { |
218 | auto& state = *p_state; |
219 | for (size_t i = stage->relations.size(); i != 0; --i) { |
220 | IterVarRelation rel = stage->relations[i - 1]; |
221 | if (const SplitNode* s = rel.as<SplitNode>()) { |
222 | if (!state.count(s->outer) || !state.count(s->inner)) { |
223 | ICHECK(allow_missing); |
224 | continue; |
225 | } |
226 | PrimExpr outer = state.at(s->outer); |
227 | PrimExpr inner = state.at(s->inner); |
228 | PrimExpr factor = dom_map.at(s->inner)->extent; |
229 | PrimExpr parent_min = dom_map.at(s->parent)->min; |
230 | state[s->parent] = inner + outer * factor; |
231 | // add min if they exist |
232 | if (!is_zero(parent_min)) { |
233 | state[s->parent] = state[s->parent] + parent_min; |
234 | } |
235 | } else if (const FuseNode* s = rel.as<FuseNode>()) { |
236 | if (!state.count(s->fused)) { |
237 | ICHECK(allow_missing); |
238 | continue; |
239 | } |
240 | PrimExpr value = state.at(s->fused); |
241 | PrimExpr factor = dom_map.at(s->inner)->extent; |
242 | PrimExpr outer_min = dom_map.at(s->outer)->min; |
243 | PrimExpr inner_min = dom_map.at(s->inner)->min; |
244 | state[s->outer] = indexdiv(value, factor); |
245 | state[s->inner] = indexmod(value, factor); |
246 | // add min if they exist |
247 | if (!is_zero(outer_min)) { |
248 | state[s->outer] = state[s->outer] + outer_min; |
249 | } |
250 | if (!is_zero(inner_min)) { |
251 | state[s->inner] = state[s->inner] + inner_min; |
252 | } |
253 | // s->fused, s->outer and s->inner may be of different dtype, |
254 | // so we cast the `state` back to its original dtype |
255 | state[s->outer] = cast(s->outer->var.dtype(), state[s->outer]); |
256 | state[s->inner] = cast(s->inner->var.dtype(), state[s->inner]); |
257 | } else if (const RebaseNode* s = rel.as<RebaseNode>()) { |
258 | if (!state.count(s->rebased)) { |
259 | ICHECK(allow_missing); |
260 | continue; |
261 | } |
262 | PrimExpr value = state.at(s->rebased); |
263 | PrimExpr parent_min = dom_map.at(s->parent)->min; |
264 | // add min if they exist |
265 | if (!is_zero(parent_min)) { |
266 | state[s->parent] = value + parent_min; |
267 | } else { |
268 | state[s->parent] = value; |
269 | } |
270 | } else if (rel.as<SingletonNode>()) { |
271 | } else if (const TransformNode* s = rel.as<TransformNode>()) { |
272 | bool missing_transformed = false; |
273 | for (const auto& iter_var : s->transformed_variables) { |
274 | if (!state.count(iter_var)) { |
275 | ICHECK(allow_missing); |
276 | missing_transformed = true; |
277 | } |
278 | } |
279 | if (missing_transformed) { |
280 | continue; |
281 | } |
282 | |
283 | Array<PrimExpr> transformed_indices; |
284 | for (const auto& iter_var : s->transformed_variables) { |
285 | transformed_indices.push_back(state[iter_var]); |
286 | } |
287 | Array<PrimExpr> original_indices = s->inverse_transformation->MapIndices(transformed_indices); |
288 | |
289 | ICHECK_EQ(original_indices.size(), s->original_variables.size()); |
290 | for (size_t i = 0; i < original_indices.size(); i++) { |
291 | state[s->original_variables[i]] = original_indices[i]; |
292 | } |
293 | |
294 | } else { |
295 | LOG(FATAL) << "unknown relation type" ; |
296 | } |
297 | } |
298 | } |
299 | |
300 | void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map, |
301 | std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) { |
302 | auto& state = *p_state; |
303 | for (IterVarRelation rel : stage->relations) { |
304 | if (const SplitNode* s = rel.as<SplitNode>()) { |
305 | if (!state.count(s->parent)) { |
306 | ICHECK(allow_missing); |
307 | continue; |
308 | } |
309 | Range r = dom_map.at(s->inner); |
310 | ICHECK(is_zero(r->min)); |
311 | PrimExpr parent = state.at(s->parent); |
312 | PrimExpr factor = r->extent; |
313 | state[s->outer] = indexdiv(parent, factor); |
314 | state[s->inner] = indexmod(parent, factor); |
315 | } else if (const FuseNode* s = rel.as<FuseNode>()) { |
316 | if (!state.count(s->inner) && !state.count(s->outer)) { |
317 | ICHECK(allow_missing); |
318 | continue; |
319 | } |
320 | PrimExpr factor = dom_map.at(s->inner)->extent; |
321 | PrimExpr outer_min = dom_map.at(s->outer)->min; |
322 | PrimExpr inner_min = dom_map.at(s->inner)->min; |
323 | PrimExpr inner = state.at(s->inner); |
324 | PrimExpr outer = state.at(s->outer); |
325 | ICHECK(is_zero(outer_min)); |
326 | ICHECK(is_zero(inner_min)); |
327 | state[s->fused] = outer * factor + inner; |
328 | } else if (const RebaseNode* s = rel.as<RebaseNode>()) { |
329 | if (!state.count(s->rebased)) { |
330 | ICHECK(allow_missing); |
331 | continue; |
332 | } |
333 | PrimExpr value = state.at(s->parent); |
334 | PrimExpr parent_min = dom_map.at(s->parent)->min; |
335 | ICHECK(is_zero(parent_min)); |
336 | state[s->rebased] = value; |
337 | } else if (const SingletonNode* s = rel.as<SingletonNode>()) { |
338 | state[s->iter] = make_zero(s->iter->var.dtype()); |
339 | } else if (const TransformNode* s = rel.as<TransformNode>()) { |
340 | bool missing_originals = false; |
341 | for (const auto& iter_var : s->original_variables) { |
342 | if (!state.count(iter_var)) { |
343 | ICHECK(allow_missing); |
344 | missing_originals = true; |
345 | } |
346 | } |
347 | if (missing_originals) { |
348 | continue; |
349 | } |
350 | |
351 | Array<PrimExpr> original_indices; |
352 | for (const auto& iter_var : s->original_variables) { |
353 | original_indices.push_back(state[iter_var]); |
354 | } |
355 | Array<PrimExpr> transformed_indices = s->forward_transformation->MapIndices(original_indices); |
356 | |
357 | ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size()); |
358 | for (size_t i = 0; i < transformed_indices.size(); i++) { |
359 | state[s->transformed_variables[i]] = transformed_indices[i]; |
360 | } |
361 | } else { |
362 | LOG(FATAL) << "unknown relation type" ; |
363 | } |
364 | } |
365 | } |
366 | |
367 | // Domain message passing. |
368 | void PassUpDomain(const SplitNode* s, const std::unordered_map<IterVar, Range>& dom_map, |
369 | const IntSet& outer, const IntSet& inner, IntSet* parent) { |
370 | if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) && |
371 | outer.MatchRange(dom_map.at(s->outer)) && inner.MatchRange(dom_map.at(s->inner))) { |
372 | *parent = IntSet::FromRange(dom_map.at(s->parent)); |
373 | return; |
374 | } |
375 | PrimExpr factor = dom_map.at(s->inner)->extent; |
376 | PrimExpr parent_min = dom_map.at(s->parent)->min; |
377 | ICHECK(outer.defined()); |
378 | ICHECK(inner.defined()); |
379 | ICHECK(factor.defined()); |
380 | *parent = arith::EvalSet(s->outer->var * factor + s->inner->var + parent_min, |
381 | {{s->outer, outer}, {s->inner, inner}}); |
382 | } |
383 | |
384 | void PassUpDomain(const FuseNode* s, const std::unordered_map<IterVar, Range>& dom_map, |
385 | const IntSet& fused, IntSet* outer, IntSet* inner) { |
386 | ICHECK(dom_map.count(s->outer)); |
387 | ICHECK(dom_map.count(s->inner)); |
388 | ICHECK(dom_map.count(s->fused)); |
389 | arith::Analyzer ana; |
390 | |
391 | if (fused.MatchRange(dom_map.at(s->fused))) { |
392 | *outer = IntSet::FromRange(dom_map.at(s->outer)); |
393 | *inner = IntSet::FromRange(dom_map.at(s->inner)); |
394 | return; |
395 | } |
396 | PrimExpr outer_min = dom_map.at(s->outer)->min; |
397 | PrimExpr inner_min = dom_map.at(s->inner)->min; |
398 | |
399 | if (fused.IsSinglePoint()) { |
400 | PrimExpr value = fused.PointValue(); |
401 | PrimExpr factor = dom_map.at(s->inner)->extent; |
402 | PrimExpr v_outer = indexdiv(value, factor); |
403 | PrimExpr v_inner = indexmod(value, factor); |
404 | if (!is_zero(outer_min)) v_outer = v_outer + outer_min; |
405 | if (!is_zero(inner_min)) v_inner = v_inner + inner_min; |
406 | *outer = IntSet::SinglePoint(v_outer); |
407 | *inner = IntSet::SinglePoint(v_inner); |
408 | } else { |
409 | PrimExpr fused_extent = (fused.max() - fused.min() + 1); |
410 | PrimExpr inner_extent = dom_map.at(s->inner)->extent; |
411 | *outer = IntSet::Interval(outer_min + indexdiv(fused.min(), inner_extent), |
412 | outer_min + indexdiv(fused.max(), inner_extent)); |
413 | if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) && |
414 | is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) { |
415 | // fused never spans multiple rows, make a tight bounding box |
416 | // there may be other cases when bounding box could be tightened |
417 | *inner = IntSet::Interval(inner_min + indexmod(fused.min(), inner_extent), |
418 | inner_min + indexmod(fused.max(), inner_extent)); |
419 | } else { // fused may span multiple rows, use full row widths |
420 | if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) || |
421 | !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) { |
422 | LOG(WARNING) |
423 | << "fused and original axes are not aligned, this may cause redundant computations" ; |
424 | } |
425 | *inner = IntSet::FromRange(dom_map.at(s->inner)); |
426 | } |
427 | return; |
428 | } |
429 | } |
430 | |
431 | void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>& dom_map, |
432 | const IntSet& rebased, IntSet* parent) { |
433 | ICHECK(dom_map.count(s->parent)); |
434 | if (rebased.MatchRange(dom_map.at(s->rebased))) { |
435 | *parent = IntSet::FromRange(dom_map.at(s->parent)); |
436 | return; |
437 | } |
438 | PrimExpr parent_min = dom_map.at(s->parent)->min; |
439 | *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}}); |
440 | } |
441 | |
442 | Array<IntSet> PassUpDomain(const TransformNode* s, |
443 | const std::unordered_map<IterVar, Range>& dom_map, |
444 | const Map<IterVar, IntSet>& transformed_domains) { |
445 | Array<IntSet> output; |
446 | |
447 | Array<PrimExpr> transformed_indices; |
448 | for (const auto& iter_var : s->transformed_variables) { |
449 | transformed_indices.push_back(iter_var->var); |
450 | } |
451 | |
452 | Array<PrimExpr> transformed_exprs = s->inverse_transformation->MapIndices(transformed_indices); |
453 | |
454 | ICHECK_EQ(transformed_exprs.size(), s->original_variables.size()); |
455 | for (size_t i = 0; i < transformed_exprs.size(); i++) { |
456 | output.push_back(arith::EvalSet(transformed_exprs[i], transformed_domains)); |
457 | } |
458 | |
459 | return output; |
460 | } |
461 | |
462 | void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
463 | std::unordered_map<IterVar, IntSet>* p_state) { |
464 | auto& state = *p_state; |
465 | for (size_t i = stage->relations.size(); i != 0; --i) { |
466 | IterVarRelation rel = stage->relations[i - 1]; |
467 | if (const SplitNode* r = rel.as<SplitNode>()) { |
468 | IntSet parent; |
469 | PassUpDomain(r, dom_map, state.at(r->outer), state.at(r->inner), &parent); |
470 | state[r->parent] = parent; |
471 | } else if (const FuseNode* r = rel.as<FuseNode>()) { |
472 | IntSet outer, inner; |
473 | PassUpDomain(r, dom_map, state.at(r->fused), &outer, &inner); |
474 | state[r->outer] = outer; |
475 | state[r->inner] = inner; |
476 | } else if (const RebaseNode* r = rel.as<RebaseNode>()) { |
477 | IntSet parent; |
478 | PassUpDomain(r, dom_map, state.at(r->rebased), &parent); |
479 | state[r->parent] = parent; |
480 | } else if (rel.as<SingletonNode>()) { |
481 | } else if (const TransformNode* r = rel.as<TransformNode>()) { |
482 | Map<IterVar, IntSet> transformed_domains; |
483 | for (const auto& var : r->transformed_variables) { |
484 | transformed_domains.Set(var, state.at(var)); |
485 | } |
486 | auto original_ranges = PassUpDomain(r, dom_map, transformed_domains); |
487 | ICHECK_EQ(original_ranges.size(), r->original_variables.size()); |
488 | for (size_t i = 0; i < original_ranges.size(); i++) { |
489 | state[r->original_variables[i]] = original_ranges[i]; |
490 | } |
491 | } else { |
492 | LOG(FATAL) << "unknown relation type" ; |
493 | } |
494 | } |
495 | } |
496 | |
497 | // Pass up bit mask with or relation. |
498 | void PassUpBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state, |
499 | bool allow_missing) { |
500 | auto& state = *p_state; |
501 | for (size_t i = stage->relations.size(); i != 0; --i) { |
502 | IterVarRelation rel = stage->relations[i - 1]; |
503 | if (const SplitNode* s = rel.as<SplitNode>()) { |
504 | if (!state.count(s->inner) && !state.count(s->outer)) { |
505 | ICHECK(allow_missing); |
506 | continue; |
507 | } |
508 | int res = 0; |
509 | if (state.count(s->parent)) res |= state[s->parent]; |
510 | if (state.count(s->inner)) res |= state[s->inner]; |
511 | if (state.count(s->outer)) res |= state[s->outer]; |
512 | state[s->parent] = res; |
513 | } else if (const FuseNode* s = rel.as<FuseNode>()) { |
514 | if (!state.count(s->fused)) { |
515 | ICHECK(allow_missing); |
516 | continue; |
517 | } |
518 | if (!state.count(s->outer)) { |
519 | state[s->outer] = state[s->fused]; |
520 | } else { |
521 | state[s->outer] |= state[s->fused]; |
522 | } |
523 | if (!state.count(s->inner)) { |
524 | state[s->inner] = state[s->fused]; |
525 | } else { |
526 | state[s->inner] |= state[s->fused]; |
527 | } |
528 | } else if (const RebaseNode* s = rel.as<RebaseNode>()) { |
529 | if (!state.count(s->rebased)) { |
530 | ICHECK(allow_missing); |
531 | continue; |
532 | } |
533 | if (!state.count(s->parent)) { |
534 | state[s->parent] = state[s->rebased]; |
535 | } else { |
536 | state[s->parent] |= state[s->rebased]; |
537 | } |
538 | } else if (const TransformNode* s = rel.as<TransformNode>()) { |
539 | for (const auto& original_var : s->original_variables) { |
540 | for (const auto& transformed_var : s->transformed_variables) { |
541 | if (!state.count(transformed_var)) { |
542 | ICHECK(allow_missing); |
543 | continue; |
544 | } |
545 | state[original_var] |= state[transformed_var]; |
546 | } |
547 | } |
548 | |
549 | } else if (rel.as<SingletonNode>()) { |
550 | } else { |
551 | LOG(FATAL) << "unknown relation type" ; |
552 | } |
553 | } |
554 | } |
555 | |
556 | void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state, |
557 | bool allow_missing) { |
558 | auto& state = *p_state; |
559 | for (IterVarRelation rel : stage->relations) { |
560 | if (const SplitNode* s = rel.as<SplitNode>()) { |
561 | if (!state.count(s->parent)) { |
562 | ICHECK(allow_missing); |
563 | continue; |
564 | } |
565 | if (!state.count(s->outer)) { |
566 | state[s->outer] = state.at(s->parent); |
567 | } else { |
568 | state[s->outer] |= state.at(s->parent); |
569 | } |
570 | if (!state.count(s->inner)) { |
571 | state[s->inner] = state.at(s->parent); |
572 | } else { |
573 | state[s->inner] |= state.at(s->parent); |
574 | } |
575 | } else if (const FuseNode* s = rel.as<FuseNode>()) { |
576 | if (!state.count(s->outer) && !state.count(s->inner)) { |
577 | ICHECK(allow_missing); |
578 | continue; |
579 | } |
580 | int res = 0; |
581 | if (state.count(s->outer)) res |= state.at(s->outer); |
582 | if (state.count(s->inner)) res |= state.at(s->inner); |
583 | if (state.count(s->fused)) res |= state.at(s->fused); |
584 | state[s->fused] = res; |
585 | } else if (const RebaseNode* s = rel.as<RebaseNode>()) { |
586 | if (!state.count(s->parent)) { |
587 | ICHECK(allow_missing); |
588 | continue; |
589 | } |
590 | if (!state.count(s->rebased)) { |
591 | state[s->rebased] = state.at(s->parent); |
592 | } else { |
593 | state[s->rebased] |= state.at(s->parent); |
594 | } |
595 | } else if (const TransformNode* s = rel.as<TransformNode>()) { |
596 | for (const auto& original_var : s->original_variables) { |
597 | for (const auto& transformed_var : s->transformed_variables) { |
598 | if (!state.count(original_var)) { |
599 | ICHECK(allow_missing); |
600 | continue; |
601 | } |
602 | state[transformed_var] |= state[original_var]; |
603 | } |
604 | } |
605 | } else if (const SingletonNode* s = rel.as<SingletonNode>()) { |
606 | state[s->iter] = 0; |
607 | } else { |
608 | LOG(FATAL) << "unknown relation type" ; |
609 | } |
610 | } |
611 | } |
612 | |
613 | /*! |
614 | * \brief message passing to find if boundary checking on IterVar is needed. |
615 | * \param s The stage to be used. |
616 | * \param p_state The message passing state |
617 | * IterVar->flag |
618 | */ |
619 | void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map, |
620 | std::unordered_map<IterVar, bool>* p_state, arith::Analyzer* analyzer) { |
621 | auto& state = *p_state; |
622 | for (size_t i = s->relations.size(); i != 0; --i) { |
623 | IterVarRelation rel = s->relations[i - 1]; |
624 | if (const SplitNode* s = rel.as<SplitNode>()) { |
625 | bool outer = state.at(s->outer); |
626 | bool inner = state.at(s->inner); |
627 | |
628 | if (dom_map.count(s->inner) && dom_map.count(s->outer)) { |
629 | PrimExpr factor = dom_map.at(s->inner)->extent; |
630 | PrimExpr step = dom_map.at(s->outer)->extent; |
631 | if (outer || inner) { |
632 | state[s->parent] = true; |
633 | } else { |
634 | if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) { |
635 | state[s->parent] = false; |
636 | } else { |
637 | state[s->parent] = true; |
638 | } |
639 | } |
640 | } else { |
641 | state[s->parent] = true; |
642 | } |
643 | } else if (const FuseNode* s = rel.as<FuseNode>()) { |
644 | bool fused = state.at(s->fused); |
645 | state[s->outer] = fused; |
646 | state[s->inner] = fused; |
647 | } else if (const RebaseNode* s = rel.as<RebaseNode>()) { |
648 | state[s->parent] = state.at(s->rebased); |
649 | } else if (rel.as<SingletonNode>()) { |
650 | // nop |
651 | } else if (const TransformNode* s = rel.as<TransformNode>()) { |
652 | // Currently, this marks all original iter vars as requiring |
653 | // bounds checks if any of the transformed variables require |
654 | // bounds checks, even if the inverse expression for that iter |
655 | // var doesn't depend on the bound variable. |
656 | |
657 | // TODO(Lunderberg): For each of original variable, check |
658 | // whether any variable in the inverse expression for it |
659 | // requires bounds checking. |
660 | bool needs_bounds_check = false; |
661 | for (const auto& iter_var : s->transformed_variables) { |
662 | needs_bounds_check = needs_bounds_check || state[iter_var]; |
663 | } |
664 | for (const auto& iter_var : s->original_variables) { |
665 | state[iter_var] = needs_bounds_check; |
666 | } |
667 | } else { |
668 | LOG(FATAL) << "unknown relation type" ; |
669 | } |
670 | } |
671 | } |
672 | |
673 | bool IsRangeSame(const Range input_1, const Range input_2) { |
674 | arith::Analyzer analyzer; |
675 | if (input_1.same_as(input_2)) return true; |
676 | |
677 | return (analyzer.CanProve(input_1->min == input_2->min) && |
678 | analyzer.CanProve(input_1->extent == input_2->extent)); |
679 | } |
680 | |
681 | std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map, |
682 | const std::unordered_map<IterVar, PrimExpr>& value_map, |
683 | bool skip_ivar_domain, |
684 | const std::unordered_set<IterVar>& skip_iter) { |
685 | arith::Analyzer analyzer; |
686 | |
687 | std::unordered_map<IterVar, bool> bound_state; |
688 | for (IterVar iv : stage->leaf_iter_vars) { |
689 | bound_state[iv] = false; |
690 | } |
691 | PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); |
692 | |
693 | std::vector<PrimExpr> preds; |
694 | Map<Var, IntSet> iset_dmap; |
695 | |
696 | // setup domain map for set analysis |
697 | for (const auto& kv : dom_map) { |
698 | iset_dmap.Set(kv.first->var, IntSet::FromRange(kv.second)); |
699 | } |
700 | |
701 | for (auto entry : dom_map) { |
702 | analyzer.Bind(entry.first->var, entry.second); |
703 | } |
704 | |
705 | for (const IterVar& iv : stage->all_iter_vars) { |
706 | if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue; |
707 | if (bound_state.at(iv)) { |
708 | Range dom = dom_map.at(iv); |
709 | PrimExpr value = value_map.at(iv) - dom->min; |
710 | PrimExpr vmax = analyzer.int_set(value, iset_dmap).max(); |
711 | if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) { |
712 | preds.emplace_back(value < dom->extent); |
713 | } |
714 | } |
715 | } |
716 | for (const IterVar& iv : stage->op->root_iter_vars()) { |
717 | if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue; |
718 | Range dom = dom_map.at(iv); |
719 | ICHECK(iv->dom.defined()); |
720 | if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) { |
721 | PrimExpr value = value_map.at(iv) - iv->dom->min; |
722 | IntSet s = analyzer.int_set(value, iset_dmap); |
723 | PrimExpr vmin = s.min(); |
724 | PrimExpr vmax = s.max(); |
725 | // The range of `value` resides in [vmin, vmax] |
726 | if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) { |
727 | preds.emplace_back(value >= 0); |
728 | } |
729 | if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) { |
730 | preds.emplace_back(value < iv->dom->extent); |
731 | } |
732 | } |
733 | } |
734 | return preds; |
735 | } |
736 | } // namespace te |
737 | } // namespace tvm |
738 | |