1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file control_flow_graph.cc |
22 | * \brief Utility to deduce bound of expression |
23 | */ |
24 | |
25 | #include "control_flow_graph.h" |
26 | |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/tir/analysis.h> |
29 | #include <tvm/tir/builtin.h> |
30 | #include <tvm/tir/expr.h> |
31 | #include <tvm/tir/op.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | |
34 | #include <algorithm> |
35 | #include <numeric> |
36 | #include <optional> |
37 | #include <queue> |
38 | #include <set> |
39 | #include <sstream> |
40 | #include <unordered_set> |
41 | |
42 | #include "../../arith/conjunctive_normal_form.h" |
43 | #include "../../arith/constraint_extract.h" |
44 | #include "../../arith/ir_mutator_with_analyzer.h" |
45 | #include "../../arith/ir_visitor_with_analyzer.h" |
46 | #include "../../arith/narrow_predicate_expression.h" |
47 | #include "../../arith/unwrap_vector_expr.h" |
48 | |
49 | namespace tvm { |
50 | namespace tir { |
51 | |
52 | using namespace arith; |
53 | |
54 | namespace { |
55 | bool HasBufferLoad(PrimExpr expr) { |
56 | struct Visitor : public ExprVisitor { |
57 | void VisitExpr_(const BufferLoadNode* node) override { found_buffer_load = true; } |
58 | bool found_buffer_load{false}; |
59 | }; |
60 | |
61 | Visitor visitor; |
62 | visitor(expr); |
63 | return visitor.found_buffer_load; |
64 | } |
65 | |
66 | Optional<PrimExpr> SubstituteParamValues(const Array<Var>& param_vars, |
67 | const Array<PrimExpr>& param_values, |
68 | const PrimExpr& expr) { |
69 | ICHECK_EQ(param_vars.size(), param_values.size()) |
70 | << "Expression was defined as having " << param_vars.size() << " parameters, but received " |
71 | << param_values.size() << " arguments." ; |
72 | |
73 | Map<tir::Var, PrimExpr> var_map; |
74 | for (size_t i = 0; i < param_values.size(); i++) { |
75 | var_map.Set(param_vars[i], param_values[i]); |
76 | } |
77 | |
78 | return Substitute(expr, var_map); |
79 | } |
80 | } // namespace |
81 | |
82 | PrimExpr BufferTouch::BeforeLoopIteration() const { |
83 | PrimExpr loop_predicate = Bool(true); |
84 | for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { |
85 | const Var& loop_var = it->first; |
86 | const PrimExpr& loop_expr = it->second; |
87 | loop_predicate = (loop_var <= loop_expr) || ((loop_var == loop_expr) && loop_predicate); |
88 | } |
89 | return loop_predicate; |
90 | } |
91 | |
92 | PrimExpr BufferTouch::AtLoopIteration() const { |
93 | PrimExpr loop_predicate = Bool(true); |
94 | for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { |
95 | const Var& loop_var = it->first; |
96 | const PrimExpr& loop_expr = it->second; |
97 | loop_predicate = (loop_var == loop_expr) && loop_predicate; |
98 | } |
99 | return loop_predicate; |
100 | } |
101 | |
102 | PrimExpr BufferTouch::AfterLoopIteration() const { |
103 | PrimExpr loop_predicate = Bool(true); |
104 | for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { |
105 | const Var& loop_var = it->first; |
106 | const PrimExpr& loop_expr = it->second; |
107 | loop_predicate = (loop_var >= loop_expr) || ((loop_var == loop_expr) && loop_predicate); |
108 | } |
109 | return loop_predicate; |
110 | } |
111 | |
112 | bool BufferTouch::IsSubsetOf(const BufferTouch& other, Analyzer* analyzer) const { |
113 | if (this->buffer.same_as(other.buffer)) { |
114 | With<ConstraintContext> constraint(analyzer, predicate); |
115 | |
116 | return analyzer->CanProve(other.predicate); |
117 | } else { |
118 | return false; |
119 | } |
120 | } |
121 | |
122 | bool BufferTouch::IsDistinctFrom(const BufferTouch& other, Analyzer* analyzer) const { |
123 | if (this->buffer.same_as(other.buffer)) { |
124 | With<ConstraintContext> constraint(analyzer, predicate); |
125 | |
126 | return analyzer->CanProve(!other.predicate); |
127 | } else { |
128 | return true; |
129 | } |
130 | } |
131 | |
132 | std::ostream& operator<<(std::ostream& os, const BufferTouch& tp) { |
133 | auto touch_type = [&]() { |
134 | if (tp.touch_type == BufferTouch::AccessType::Read) { |
135 | return "read" ; |
136 | } else if (tp.touch_type == BufferTouch::AccessType::Write) { |
137 | return "write" ; |
138 | } else if (tp.touch_type == BufferTouch::AccessType::Assume) { |
139 | return "assume" ; |
140 | } else { |
141 | return "???" ; |
142 | } |
143 | }(); |
144 | |
145 | os << "BufferTouch(" << tp.buffer->name << ", " << touch_type << ", " << tp.predicate |
146 | << ", value = " << tp.value << ")" ; |
147 | return os; |
148 | } |
149 | |
150 | class BufferConstraintApply : public IRMutatorWithAnalyzer { |
151 | public: |
152 | using Parent = IRMutatorWithAnalyzer; |
153 | |
154 | BufferConstraintApply(const Map<Buffer, Array<Var>>& axis_var_lookup, |
155 | const std::vector<BufferTouch>& knowns, Analyzer* analyzer) |
156 | : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {} |
157 | |
158 | using Parent::VisitExpr_; |
159 | |
160 | PrimExpr VisitExpr_(const BufferLoadNode* op) override { |
161 | for (const auto& known : knowns_) { |
162 | if (!op->buffer.same_as(known.buffer)) { |
163 | continue; |
164 | } |
165 | |
166 | Optional<Var> lane_var = NullOpt; |
167 | IntImm num_lanes; |
168 | |
169 | Array<PrimExpr> indices = op->indices.Map([&](const auto& index) { |
170 | if (index.dtype().lanes() == 1) { |
171 | return index; |
172 | } else { |
173 | ICHECK(!lane_var) << "Multiple indices found with non-scalar values" ; |
174 | lane_var = Var("lane" , index.dtype().element_of()); |
175 | num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); |
176 | return UnwrapVectorExpr(index, lane_var.value()); |
177 | } |
178 | }); |
179 | |
180 | auto axis_vars = axis_var_lookup_.at(op->buffer); |
181 | PrimExpr predicate = SubstituteParamValues(axis_vars, indices, known.predicate).value(); |
182 | |
183 | std::optional<With<ConstraintContext>> context; |
184 | if (lane_var.defined()) { |
185 | Var lanes = lane_var.value(); |
186 | PrimExpr known = (IntImm(lanes.dtype(), 0) <= lanes) && (lanes < num_lanes); |
187 | context.emplace(analyzer_, known); |
188 | } |
189 | |
190 | if (analyzer_->CanProve(predicate)) { |
191 | return SubstituteParamValues(axis_vars, op->indices, known.value).value(); |
192 | } |
193 | } |
194 | |
195 | return GetRef<PrimExpr>(op); |
196 | } |
197 | |
198 | private: |
199 | const Map<Buffer, Array<Var>>& axis_var_lookup_; |
200 | const std::vector<BufferTouch>& knowns_; |
201 | }; |
202 | |
203 | /*! \brief Extract the control-flow graph |
204 | * |
205 | * Walk through a statement, populating the control-flow graph. |
206 | */ |
207 | class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { |
208 | public: |
209 | static void Build(ControlFlowGraph* out, const Stmt& stmt) { |
210 | ControlFlowGraphBuilder (out); |
211 | extractor.AppendControlBlock(); |
212 | extractor(stmt); |
213 | } |
214 | |
215 | private: |
216 | ControlFlowGraphBuilder(ControlFlowGraph* out) : out_(out) {} |
217 | |
218 | using Parent = IRVisitorWithAnalyzer; |
219 | using Parent::VisitExpr_; |
220 | using Parent::VisitStmt_; |
221 | |
222 | void VisitStmt(const Stmt& stmt) override { |
223 | // Update the lookup table to determine which control-flow block |
224 | // contains the start of the specified statement. This is used |
225 | // later to determine which set of known values should be used to |
226 | // simplify a statement. |
227 | out_->control_flow_lookup_[stmt.get()] = CurrentControlBlock(); |
228 | Stmt prev_stmt = current_stmt_; |
229 | current_stmt_ = stmt; |
230 | Parent::VisitStmt(stmt); |
231 | current_stmt_ = prev_stmt; |
232 | } |
233 | |
234 | void VisitStmt_(const EvaluateNode* op) override { |
235 | if (auto* call = op->value.as<CallNode>()) { |
236 | if (call->op.same_as(builtin::assume())) { |
237 | Assume(call->args[0], true); |
238 | return; |
239 | } |
240 | } |
241 | |
242 | Parent::VisitStmt_(op); |
243 | } |
244 | |
245 | void Assume(PrimExpr assumption, bool from_assume_statement) { |
246 | for (const auto& expr : ExtractConstraints(assumption, false)) { |
247 | AssumeConstraintComponent(expr, from_assume_statement); |
248 | } |
249 | } |
250 | |
251 | void AssumeConstraintComponent(PrimExpr assumption, bool from_assume_statement) { |
252 | PrimExpr additional_predicate = Bool(true); |
253 | |
254 | std::vector<PrimExpr> buffer_exprs; |
255 | for (const auto& expr : ExtractComponents(assumption)) { |
256 | auto side_effect = tir::SideEffect(expr); |
257 | if (side_effect <= tir::CallEffectKind::kPure) { |
258 | // Pulling out portions of the assumption that do not depend |
259 | // on a buffer value allows the following two forms to be |
260 | // treated identically. |
261 | // |
262 | // Option 1: if i < 3: T.assume(buf[i] == value) |
263 | // Option 2: T.assume(i>=3 or buf[i] == value) |
264 | additional_predicate = additional_predicate && logical_not(expr); |
265 | } else if (side_effect == tir::CallEffectKind::kReadState) { |
266 | buffer_exprs.push_back(expr); |
267 | } else { |
268 | LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr |
269 | << " with side-effect \'" << side_effect << "\'" ; |
270 | } |
271 | } |
272 | |
273 | if (buffer_exprs.empty()) { |
274 | out_->non_buffer_assumptions_.push_back(!CurrentScopePredicate() || assumption); |
275 | return; |
276 | } |
277 | |
278 | CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression" ; |
279 | |
280 | auto* as_equal_node = buffer_exprs[0].as<tir::EQNode>(); |
281 | CHECK(as_equal_node || !from_assume_statement) |
282 | << "T.assume buffer constraint must be of the form 'buffer[indices] == " |
283 | "value', but received " |
284 | << assumption; |
285 | if (!as_equal_node) { |
286 | // This assumption is an inequality on a data-dependent |
287 | // conditional. Not an error for this to occur, but also not |
288 | // something that is currently supported. |
289 | return; |
290 | } |
291 | |
292 | tir::BufferLoad load; |
293 | PrimExpr value; |
294 | if (auto* as_load = as_equal_node->a.as<tir::BufferLoadNode>()) { |
295 | load = GetRef<tir::BufferLoad>(as_load); |
296 | value = as_equal_node->b; |
297 | } else if (auto* as_load = as_equal_node->b.as<tir::BufferLoadNode>()) { |
298 | load = GetRef<tir::BufferLoad>(as_load); |
299 | value = as_equal_node->a; |
300 | } else if (!from_assume_statement) { |
301 | return; |
302 | } else { |
303 | LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'" ; |
304 | } |
305 | |
306 | auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; |
307 | CHECK(!has_side_effect || !from_assume_statement) |
308 | << "Buffer value in constraint must be pure expression, but was " << value; |
309 | if (has_side_effect) { |
310 | return; |
311 | } |
312 | |
313 | { |
314 | InternalConstraintContext context(this, additional_predicate); |
315 | VisitAccess(load, BufferTouch::AccessType::Assume, value); |
316 | } |
317 | // Appending a control block ensures that all control blocks have |
318 | // at most one statement that changes the known buffer contents. |
319 | auto prev_block = CurrentControlBlock(); |
320 | auto new_block = AppendControlBlock(); |
321 | MarkControlFlow(prev_block, new_block); |
322 | } |
323 | |
324 | void VisitExpr_(const LetNode* op) override { |
325 | std::optional<BindLetVar> binding; |
326 | if (UsesLoopVar(op->value)) { |
327 | binding.emplace(this, op->var, op->value); |
328 | } |
329 | Parent::VisitExpr_(op); |
330 | } |
331 | |
332 | void VisitStmt_(const LetStmtNode* op) override { |
333 | std::optional<BindLetVar> binding; |
334 | if (UsesLoopVar(op->value)) { |
335 | binding.emplace(this, op->var, op->value); |
336 | } |
337 | Parent::VisitStmt_(op); |
338 | } |
339 | |
340 | void VisitExpr_(const BufferLoadNode* op) override { |
341 | Parent::VisitExpr_(op); |
342 | BufferLoad load = GetRef<BufferLoad>(op); |
343 | VisitAccess(load, BufferTouch::AccessType::Read, load); |
344 | } |
345 | |
346 | void VisitStmt_(const BufferStoreNode* op) override { |
347 | Parent::VisitStmt_(op); |
348 | VisitAccess(GetRef<BufferStore>(op), BufferTouch::AccessType::Write, op->value); |
349 | // Appending a control block ensures that all control blocks have |
350 | // at most one statement that changes the buffer contents. |
351 | auto prev_block = CurrentControlBlock(); |
352 | auto new_block = AppendControlBlock(); |
353 | MarkControlFlow(prev_block, new_block); |
354 | } |
355 | |
356 | void VisitStmt_(const ForNode* op) override { |
357 | out_->iterator_ranges_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
358 | |
359 | auto before_loop = CurrentControlBlock(); |
360 | size_t loop_start = -1; |
361 | |
362 | { |
363 | BindActiveLoopVar binding(this, op->loop_var, op->min, op->extent); |
364 | loop_start = AppendControlBlock(); |
365 | Parent::VisitStmt_(op); |
366 | } |
367 | |
368 | auto loop_end = CurrentControlBlock(); |
369 | auto after_loop = AppendControlBlock(); |
370 | PrimExpr max_iterator_value = analyzer_.Simplify(op->min + op->extent - 1); |
371 | { |
372 | auto [forward, backward] = MarkControlFlow(before_loop, loop_start); |
373 | backward.post_condition = (op->loop_var == op->min); |
374 | forward.var_remap = {{op->loop_var, op->min}}; |
375 | } |
376 | { |
377 | auto [forward, backward] = MarkControlFlow(loop_end, after_loop); |
378 | backward.var_remap = {{op->loop_var, max_iterator_value}}; |
379 | forward.post_condition = (op->loop_var == max_iterator_value); |
380 | } |
381 | { |
382 | auto [forward, backward] = MarkControlFlow(loop_end, loop_start); |
383 | backward.var_remap = {{op->loop_var, op->loop_var - 1}}; |
384 | forward.var_remap = {{op->loop_var, op->loop_var + 1}}; |
385 | backward.post_condition = (op->loop_var > op->min); |
386 | forward.post_condition = (op->loop_var < max_iterator_value); |
387 | } |
388 | } |
389 | |
390 | void VisitStmt_(const IfThenElseNode* op) override { |
391 | this->VisitExpr(op->condition); |
392 | |
393 | PrimExpr real_condition = ExtractRealCondition(op->condition); |
394 | |
395 | auto before_branching = CurrentControlBlock(); |
396 | |
397 | auto branch_start = AppendControlBlock(); |
398 | MarkControlFlow(before_branching, branch_start); |
399 | |
400 | { |
401 | InternalConstraintContext context(this, real_condition); |
402 | auto then_start = AppendControlBlock(); |
403 | if (context.assume.defined()) { |
404 | Assume(context.assume.value(), false); |
405 | } |
406 | auto [forward, backward] = MarkControlFlow(branch_start, then_start); |
407 | backward.post_condition = real_condition; |
408 | forward.post_condition = real_condition; |
409 | this->VisitStmt(op->then_case); |
410 | } |
411 | auto then_end = CurrentControlBlock(); |
412 | |
413 | auto negation = analyzer_.rewrite_simplify(!real_condition); |
414 | { |
415 | InternalConstraintContext context(this, negation); |
416 | auto else_start = AppendControlBlock(); |
417 | if (context.assume.defined()) { |
418 | Assume(context.assume.value(), false); |
419 | } |
420 | auto [forward, backward] = MarkControlFlow(branch_start, else_start); |
421 | backward.post_condition = negation; |
422 | forward.post_condition = negation; |
423 | |
424 | if (op->else_case.defined()) { |
425 | this->VisitStmt(op->else_case.value()); |
426 | } |
427 | } |
428 | |
429 | auto else_end = CurrentControlBlock(); |
430 | auto after_branching = AppendControlBlock(); |
431 | |
432 | if (HasBufferLoad(real_condition)) { |
433 | // The buffer value may have changed during the body of the |
434 | // condition, so we can't provide it as a post-condition. |
435 | MarkControlFlow(then_end, after_branching); |
436 | MarkControlFlow(else_end, after_branching); |
437 | } else { |
438 | { |
439 | auto [forward, backward] = MarkControlFlow(then_end, after_branching); |
440 | backward.post_condition = real_condition; |
441 | forward.post_condition = real_condition; |
442 | } |
443 | { |
444 | auto [forward, backward] = MarkControlFlow(else_end, after_branching); |
445 | backward.post_condition = negation; |
446 | forward.post_condition = negation; |
447 | } |
448 | } |
449 | } |
450 | |
451 | /*! \brief Internal utility, returns true if the expression depends |
452 | * on a loop iterator |
453 | */ |
454 | bool UsesLoopVar(const PrimExpr& expr) { |
455 | return UsesVar(expr, [&](const VarNode* expr_var) { |
456 | return loop_dependent_vars_.find(expr_var) != loop_dependent_vars_.end(); |
457 | }); |
458 | } |
459 | |
460 | /*! \brief Record the interaction with the buffer. |
461 | * |
462 | * \param node The TIR node that accesses the buffer. Should be |
463 | * either a BufferLoad or BufferStore node. |
464 | * |
465 | * \param touch_type The type of buffer access being performed. A |
466 | * BufferStore should always use AccessType::Write. A BufferLoad |
467 | * may use either AccessType::Read or AccessType::Assume, depending |
468 | * on whether the BufferLoad occurs within `builtin::assume`. |
469 | * |
470 | * \param known_value_expr The value in the buffer following the access. |
471 | */ |
472 | template <typename BufferAccess> |
473 | void VisitAccess(const BufferAccess& node, BufferTouch::AccessType touch_type, |
474 | PrimExpr known_value_expr) { |
475 | auto& current_block = out_->control_flow_.back(); |
476 | BufferTouch buffer_touch = current_block.MakeBufferTouch(out_, node->buffer, node->indices, |
477 | touch_type, known_value_expr); |
478 | current_block.touch_points.push_back(buffer_touch); |
479 | } |
480 | |
481 | /*! \brief Return a predicate for having reached the current |
482 | * control-flow block |
483 | * |
484 | * For example, while inside an IfThenElse, will return the |
485 | * IfThenElse's condition. |
486 | */ |
487 | PrimExpr CurrentScopePredicate() const { |
488 | PrimExpr predicate = Bool(true); |
489 | for (const auto& condition : conditions_) { |
490 | predicate = predicate && condition; |
491 | } |
492 | return predicate; |
493 | } |
494 | |
495 | /* \brief Add a new control block, returning its index */ |
496 | size_t AppendControlBlock() { |
497 | size_t index = out_->control_flow_.size(); |
498 | auto& block = out_->control_flow_.emplace_back(); |
499 | block.active_loop_iterators = active_loop_iterators_; |
500 | block.let_bindings_using_loop = let_bindings_using_loop_; |
501 | block.scope_predicate = CurrentScopePredicate(); |
502 | return index; |
503 | } |
504 | |
505 | /* \brief The index of the current control block */ |
506 | size_t CurrentControlBlock() { return out_->control_flow_.size() - 1; } |
507 | |
508 | /* \brief Mark a possible control from one block to another |
509 | * |
510 | * \param from_block The block from which control leaves |
511 | * |
512 | * \param to_block The block to which control enters |
513 | * |
514 | * \param var_remap Variable replacements that should be made in |
515 | * known expression while traversing this edge. For example, |
516 | * replacing `i` with `i-1` when entering the next loop iteration, |
517 | * or replacing `i` with `n-1` when concluding a loop. |
518 | */ |
519 | std::pair<ControlFlowGraph::ControlFlowEdge&, ControlFlowGraph::ControlFlowEdge&> MarkControlFlow( |
520 | size_t from_block, size_t to_block) { |
521 | ICHECK_LE(from_block, out_->control_flow_.size()); |
522 | ICHECK_LE(to_block, out_->control_flow_.size()); |
523 | |
524 | auto& forward = out_->control_flow_[from_block].successors.emplace_back( |
525 | ControlFlowGraph::ControlFlowEdge{to_block, {}, NullOpt}); |
526 | auto& backward = out_->control_flow_[to_block].predecessors.emplace_back( |
527 | ControlFlowGraph::ControlFlowEdge{from_block, {}, NullOpt}); |
528 | return {forward, backward}; |
529 | } |
530 | |
531 | // Internal utility, context manager for entering/leaving a scoped constraint |
532 | struct InternalConstraintContext { |
533 | InternalConstraintContext(ControlFlowGraphBuilder* self, PrimExpr constraint) |
534 | : self(self), analyzer_context(&self->analyzer_, constraint) { |
535 | old_num_constraints = self->conditions_.size(); |
536 | |
537 | auto side_effect = tir::SideEffect(constraint); |
538 | if (side_effect <= tir::CallEffectKind::kPure) { |
539 | self->conditions_.push_back(constraint); |
540 | } else if (side_effect <= tir::CallEffectKind::kReadState) { |
541 | assume = constraint; |
542 | } |
543 | |
544 | new_num_constraints = self->conditions_.size(); |
545 | } |
546 | ~InternalConstraintContext() { |
547 | ICHECK_EQ(self->conditions_.size(), new_num_constraints) |
548 | << "Internal error: Each condition should only be popped once." ; |
549 | self->conditions_.erase(self->conditions_.begin() + old_num_constraints, |
550 | self->conditions_.end()); |
551 | } |
552 | |
553 | ControlFlowGraphBuilder* self{nullptr}; |
554 | With<ConstraintContext> analyzer_context; |
555 | size_t old_num_constraints{0}; |
556 | size_t new_num_constraints{0}; |
557 | Optional<PrimExpr> assume{NullOpt}; |
558 | |
559 | // Disable default-generated copy/move assignment and constructors |
560 | InternalConstraintContext(const InternalConstraintContext&) = delete; |
561 | InternalConstraintContext& operator=(const InternalConstraintContext&) = delete; |
562 | InternalConstraintContext(InternalConstraintContext&&) = delete; |
563 | InternalConstraintContext& operator=(InternalConstraintContext&&) = delete; |
564 | }; |
565 | |
566 | // Internal utility, context manager for tracking a loop |
567 | struct BindActiveLoopVar { |
568 | BindActiveLoopVar(ControlFlowGraphBuilder* self, Var var, PrimExpr loop_min, |
569 | PrimExpr loop_extent) |
570 | : self(self), var(var) { |
571 | PrimExpr loop_max = loop_min + (loop_extent - 1); |
572 | auto loop_range = Range::FromMinExtent(loop_min, loop_extent); |
573 | self->active_loop_iterators_.push_back({var, loop_min, loop_max, loop_range}); |
574 | self->loop_dependent_vars_.insert(var.get()); |
575 | } |
576 | ~BindActiveLoopVar() { self->active_loop_iterators_.pop_back(); } |
577 | |
578 | ControlFlowGraphBuilder* self; |
579 | Var var; |
580 | |
581 | // Disable default-generated copy/move assignment and constructors |
582 | BindActiveLoopVar(const BindActiveLoopVar&) = delete; |
583 | BindActiveLoopVar& operator=(const BindActiveLoopVar&) = delete; |
584 | BindActiveLoopVar(BindActiveLoopVar&&) = delete; |
585 | BindActiveLoopVar& operator=(BindActiveLoopVar&&) = delete; |
586 | }; |
587 | |
588 | // Internal utility, context manager for tracking a variable binding |
589 | struct BindLetVar { |
590 | BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) : self(self), var(var) { |
591 | self->let_bindings_using_loop_.Set(var, value); |
592 | self->loop_dependent_vars_.insert(var.get()); |
593 | } |
594 | ~BindLetVar() { |
595 | self->loop_dependent_vars_.erase(var.get()); |
596 | self->let_bindings_using_loop_.erase(var); |
597 | } |
598 | ControlFlowGraphBuilder* self; |
599 | Var var; |
600 | |
601 | // Disable default-generated copy/move assignment and constructors |
602 | BindLetVar(const BindLetVar&) = delete; |
603 | BindLetVar& operator=(const BindLetVar&) = delete; |
604 | BindLetVar(BindLetVar&&) = delete; |
605 | BindLetVar& operator=(BindLetVar&&) = delete; |
606 | }; |
607 | |
608 | struct LoopEntry { |
609 | Var loop_var; |
610 | PrimExpr loop_min; |
611 | PrimExpr loop_max; |
612 | Range loop_range; |
613 | }; |
614 | |
615 | // Track in order to know which Vars to write in terms of the buffer |
616 | // indices and substitute out of the predicate. |
617 | std::vector<ControlFlowGraph::ControlFlowBlock::LoopEntry> active_loop_iterators_; |
618 | |
619 | // Track all loop iterators, along with values derived from loop iterators. |
620 | std::unordered_set<const VarNode*> loop_dependent_vars_; |
621 | |
622 | // Any let binding that depends, directly or indirectly, on a loop |
623 | // binding. When making a predicate in terms of the buffer indices, |
624 | // these need to be substituted out. |
625 | // std::unordered_map<const VarNode*, PrimExpr> let_bindings_using_loop_; |
626 | Map<Var, PrimExpr> let_bindings_using_loop_; |
627 | |
628 | // Track in order to know what conditions limit the buffer access |
629 | std::vector<PrimExpr> conditions_; |
630 | |
631 | // Track in order to know what statement initiated the buffer access |
632 | Stmt current_stmt_; |
633 | |
634 | // Output data structure |
635 | ControlFlowGraph* out_; |
636 | }; |
637 | |
638 | std::pair<BufferTouch, Map<Var, Range>> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( |
639 | const tir::Buffer& buf, Array<Var> index_variables, Array<PrimExpr> indices, |
640 | BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { |
641 | const auto& current_block = *this; |
642 | |
643 | Analyzer local_analyzer; |
644 | |
645 | Optional<Var> lane_var = NullOpt; |
646 | IntImm num_lanes; |
647 | |
648 | Array<PrimExpr> index_expressions = indices.Map([&](const auto& index) { |
649 | if (index.dtype().lanes() == 1) { |
650 | return index; |
651 | } else { |
652 | ICHECK(!lane_var) << "Multiple indices found with non-scalar values" ; |
653 | lane_var = Var("lane" , index.dtype().element_of()); |
654 | num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); |
655 | return UnwrapVectorExpr(index, lane_var.value()); |
656 | } |
657 | }); |
658 | |
659 | Array<Var> loop_vars; |
660 | |
661 | Map<Var, Range> loop_ranges; |
662 | for (const auto& loop_entry : current_block.active_loop_iterators) { |
663 | loop_vars.push_back(loop_entry.loop_var); |
664 | loop_ranges.Set(loop_entry.loop_var, loop_entry.loop_range); |
665 | } |
666 | |
667 | // If the indices contain multiple lanes, treat the lane variable |
668 | // as an additional loop iterator to be solved for and substituted |
669 | // out. |
670 | if (lane_var) { |
671 | loop_vars.push_back(lane_var.value()); |
672 | loop_ranges.Set(lane_var.value(), Range::FromMinExtent(0, num_lanes)); |
673 | } |
674 | |
675 | IntConstraintsTransform transform = [&]() { |
676 | ICHECK_EQ(index_variables.size(), index_expressions.size()); |
677 | |
678 | Array<PrimExpr> relations; |
679 | |
680 | for (size_t i = 0; i < index_expressions.size(); i++) { |
681 | PrimExpr expr = index_expressions[i]; |
682 | Var var = index_variables[i]; |
683 | |
684 | expr = Substitute(expr, current_block.let_bindings_using_loop); |
685 | relations.push_back(var == expr); |
686 | } |
687 | |
688 | IntConstraints system(loop_vars, loop_ranges, relations); |
689 | return arith::SolveLinearEquations(system); |
690 | }(); |
691 | |
692 | Map<Var, PrimExpr> loop_var_to_axis_var = transform->src_to_dst; |
693 | Map<Var, Range> free_params = transform->dst->ranges; |
694 | PrimExpr transform_predicate = |
695 | std::accumulate(transform->dst->relations.begin(), transform->dst->relations.end(), |
696 | PrimExpr(Bool(true)), [](PrimExpr a, PrimExpr b) { return a && b; }); |
697 | |
698 | transform_predicate = SimplifyAsAndOfOrs(transform_predicate, &local_analyzer); |
699 | |
700 | auto find_removable_params = [&]() -> Map<Var, PrimExpr> { |
701 | Map<Var, PrimExpr> removable_params; |
702 | |
703 | // The arith::SolveLinearEquations is more general than the |
704 | // utilities in iter_affine_map.h, but can introduce free |
705 | // parameters that could later be determined with the known |
706 | // constraints. This step removes all such free parameters. |
707 | for (const auto& expr : ExtractConstraints(transform_predicate)) { |
708 | if (auto* as_equal = expr.as<EQNode>()) { |
709 | auto check_expr = [&](const PrimExpr& a, const PrimExpr& b) { |
710 | auto* var_ptr = a.as<VarNode>(); |
711 | if (!var_ptr) { |
712 | return; |
713 | } |
714 | |
715 | Var var = GetRef<Var>(var_ptr); |
716 | if (free_params.count(var) == 0) { |
717 | return; |
718 | } |
719 | |
720 | bool uses_free_param = |
721 | UsesVar(b, [&](const VarNode* v) { return free_params.count(GetRef<Var>(v)) > 0; }); |
722 | if (uses_free_param) { |
723 | return; |
724 | } |
725 | removable_params.Set(var, b); |
726 | }; |
727 | check_expr(as_equal->a, as_equal->b); |
728 | check_expr(as_equal->b, as_equal->a); |
729 | } |
730 | } |
731 | |
732 | // In addition, the arith::SolveLinearEquation can introduce |
733 | // free parameters with an extent of one. Filtering them out here |
734 | // avoids needing to track them through later simplifications. |
735 | for (const auto [var, range] : free_params) { |
736 | if (is_one(range->extent)) { |
737 | removable_params.Set(var, range->min); |
738 | } |
739 | } |
740 | |
741 | return removable_params; |
742 | }; |
743 | for (auto removable_params = find_removable_params(); removable_params.size() > 0; |
744 | removable_params = find_removable_params()) { |
745 | auto update = [&](const PrimExpr& expr) { |
746 | return local_analyzer.Simplify(Substitute(expr, removable_params)); |
747 | }; |
748 | |
749 | Map<Var, PrimExpr> new_map; |
750 | for (const auto [loop_var, expr] : loop_var_to_axis_var) { |
751 | static_cast<void>(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
752 | new_map.Set(loop_var, update(expr)); |
753 | } |
754 | loop_var_to_axis_var = new_map; |
755 | |
756 | transform_predicate = update(transform_predicate); |
757 | |
758 | for (const auto [var, expr] : removable_params) { |
759 | static_cast<void>(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
760 | free_params.erase(var); |
761 | } |
762 | } |
763 | |
764 | // Normalization function, applied to both the predicate and the |
765 | // known value. Converts from an expression in terms of loop |
766 | // iterators to an expression in terms of buffer indices. |
767 | auto normalize_expr = [&](PrimExpr expr) -> PrimExpr { |
768 | expr = Substitute(expr, current_block.let_bindings_using_loop); |
769 | |
770 | if (lane_var) { |
771 | expr = UnwrapVectorExpr(expr, lane_var.value()); |
772 | } |
773 | expr = Substitute(expr, loop_var_to_axis_var); |
774 | |
775 | return expr; |
776 | }; |
777 | |
778 | // Collect the current loop variables, along with an expression for |
779 | // the loop variables in terms of the buffer axis variables. This |
780 | // is used during forward/backward propagation to generate predicate |
781 | // tracking whether a loop iteration has been reached. |
782 | std::vector<std::pair<Var, PrimExpr>> loop_var_expressions; |
783 | for (const auto& entry : current_block.active_loop_iterators) { |
784 | auto expr_it = loop_var_to_axis_var.find(entry.loop_var); |
785 | ICHECK(expr_it != loop_var_to_axis_var.end()); |
786 | loop_var_expressions.push_back({entry.loop_var, (*expr_it).second}); |
787 | } |
788 | |
789 | // The full predicate is composed of the values required to reach |
790 | // the scope of the BufferStore or builtin::assume(), any bounds |
791 | // implied by solving for the axis variables, and any additional |
792 | // statements resulting from unpacking the expression contained in |
793 | // builtin::assume(). |
794 | PrimExpr scope_predicate = normalize_expr(current_block.scope_predicate); |
795 | transform_predicate = normalize_expr(transform_predicate); |
796 | |
797 | known_value_expr = local_analyzer.Simplify(normalize_expr(known_value_expr)); |
798 | |
799 | // Deliberately use an analyzer without scope-based information, |
800 | // to avoid simplifying `scope_predicate` to True. |
801 | PrimExpr predicate_expr = local_analyzer.Simplify(transform_predicate && scope_predicate); |
802 | |
803 | BufferTouch buffer_touch = {buf, predicate_expr, known_value_expr, loop_var_expressions, |
804 | touch_type}; |
805 | |
806 | return {buffer_touch, free_params}; |
807 | } |
808 | |
809 | BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph, |
810 | const tir::Buffer& buf, |
811 | const Array<PrimExpr>& indices, |
812 | BufferTouch::AccessType touch_type, |
813 | PrimExpr known_value_expr) const { |
814 | ICHECK(graph); |
815 | auto [buffer_touch, free_params] = MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices), |
816 | indices, touch_type, known_value_expr); |
817 | for (const auto& pair : free_params) { |
818 | graph->free_predicate_parameters_.Set(pair.first, pair.second); |
819 | } |
820 | return buffer_touch; |
821 | } |
822 | |
823 | ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) |
824 | : max_revisits_(max_revisits) { |
825 | ControlFlowGraphBuilder::Build(this, stmt); |
826 | ForwardPropagateKnownValues(); |
827 | BackwardPropagateUnusedValues(); |
828 | } |
829 | |
830 | void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) { |
831 | size_t context_index = [&]() { |
832 | auto it = control_flow_lookup_.find(store.get()); |
833 | ICHECK(it != control_flow_lookup_.end()) |
834 | << "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor" ; |
835 | return it->second; |
836 | }(); |
837 | |
838 | auto& touch_points = control_flow_[context_index].touch_points; |
839 | |
840 | touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(), |
841 | [](const BufferTouch& touch) { |
842 | return touch.touch_type == BufferTouch::AccessType::Write; |
843 | }), |
844 | touch_points.end()); |
845 | ForwardPropagateKnownValues(context_index); |
846 | BackwardPropagateUnusedValues(context_index); |
847 | } |
848 | |
849 | std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) { |
850 | os << edge.index; |
851 | if (edge.var_remap.size()) { |
852 | os << " with remap " << edge.var_remap; |
853 | } |
854 | if (edge.post_condition) { |
855 | os << " with postcondition " << edge.post_condition; |
856 | } |
857 | |
858 | return os; |
859 | } |
860 | |
861 | std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowBlock& block) { |
862 | os << "Predecessors: [" ; |
863 | for (size_t i = 0; i < block.predecessors.size(); i++) { |
864 | if (i) { |
865 | os << ", " ; |
866 | } |
867 | os << block.predecessors[i]; |
868 | } |
869 | os << "]\n" ; |
870 | |
871 | os << "Active loop iterators: [" ; |
872 | for (size_t i = 0; i < block.active_loop_iterators.size(); i++) { |
873 | if (i) { |
874 | os << ", " ; |
875 | } |
876 | os << block.active_loop_iterators[i].loop_var; |
877 | } |
878 | os << "]\n" ; |
879 | |
880 | os << "Before block knowns: " << block.known_at_block_start << "\n" ; |
881 | |
882 | os << "Before block unused: " << block.unused_at_block_start << "\n" ; |
883 | |
884 | for (size_t i = 0; i < block.touch_points.size(); i++) { |
885 | os << "Touch[" << i << "] = " << block.touch_points[i] << "\n" ; |
886 | } |
887 | os << "After block: " << block.known_at_block_end << "\n" ; |
888 | |
889 | os << "After block unused: " << block.unused_at_block_end << "\n" ; |
890 | |
891 | os << "Successors: [" ; |
892 | for (size_t i = 0; i < block.successors.size(); i++) { |
893 | if (i) { |
894 | os << ", " ; |
895 | } |
896 | os << block.successors[i]; |
897 | } |
898 | os << "]" ; |
899 | return os; |
900 | } |
901 | |
902 | std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern) { |
903 | os << "Touch pattern contains " << pattern.control_flow_.size() << " control blocks." |
904 | << (pattern.control_flow_.size() ? "\n" : "" ); |
905 | for (size_t i = 0; i < pattern.control_flow_.size(); i++) { |
906 | os << "\t" |
907 | << "ControlBlock[" << i << "] = " << pattern.control_flow_[i] << "\n" ; |
908 | } |
909 | |
910 | return os; |
911 | } |
912 | |
913 | bool BufferTouch::IsEquivalentTo(const BufferTouch& other, Analyzer* analyzer) const { |
914 | // Constraints must apply to the same buffer to be equivalent |
915 | if (!buffer.same_as(other.buffer) || touch_type != other.touch_type) { |
916 | return false; |
917 | } |
918 | |
919 | ExprDeepEqual deep_equal; |
920 | |
921 | auto implies = [&](const PrimExpr& a, const PrimExpr& b) -> bool { |
922 | With<ConstraintContext> context(analyzer, a); |
923 | return analyzer->CanProve(b); |
924 | }; |
925 | |
926 | // Predicates must be equivalent expressions, or must both be undefined |
927 | bool equivalent_predicates = |
928 | deep_equal(predicate, other.predicate) || |
929 | (implies(predicate, other.predicate) && implies(other.predicate, predicate)); |
930 | if (!equivalent_predicates) { |
931 | return false; |
932 | } |
933 | |
934 | // The known value must be equal |
935 | if (!deep_equal(value, other.value) && !analyzer->CanProveEqual(value, other.value)) { |
936 | return false; |
937 | } |
938 | |
939 | return true; |
940 | } |
941 | |
942 | std::ostream& operator<<(std::ostream& os, const BufferState& state) { |
943 | for (size_t i = 0; i < state.constraints_.size(); i++) { |
944 | os << "constraints[" << i << "] = " << state.constraints_[i] |
945 | << (i + 1 == state.constraints_.size() ? "" : "\n" ); |
946 | } |
947 | return os; |
948 | } |
949 | |
950 | PrimExpr BufferState::SubstituteKnownBufferValues( |
951 | PrimExpr expr, const Map<tir::Buffer, Array<tir::Var>>& axis_var_lookup, |
952 | Analyzer* analyzer) const { |
953 | BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer); |
954 | return mutator(std::move(expr)); |
955 | } |
956 | |
957 | void BufferState::AddCondition(const PrimExpr& condition) { |
958 | for (auto& constraint : constraints_) { |
959 | constraint.predicate = constraint.predicate && condition; |
960 | } |
961 | } |
962 | |
963 | void BufferState::Substitute(const Map<Var, PrimExpr>& var_remap, Analyzer* analyzer) { |
964 | if (var_remap.size()) { |
965 | for (auto& prior : constraints_) { |
966 | PrimExpr updated = tvm::tir::Substitute(prior.predicate, var_remap); |
967 | if (!updated.same_as(prior.predicate)) { |
968 | prior.predicate = SimplifyAsAndOfOrs(updated, analyzer); |
969 | } |
970 | } |
971 | } |
972 | } |
973 | |
974 | void BufferState::Simplify(Analyzer* analyzer) { |
975 | for (auto& constraint : constraints_) { |
976 | constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate, analyzer); |
977 | } |
978 | } |
979 | |
980 | void BufferState::Union(const BufferState& b, Analyzer* analyzer) { |
981 | for (const auto& b_constraint : b.constraints_) { |
982 | bool used = false; |
983 | for (auto& a_constraint : constraints_) { |
984 | if (a_constraint.buffer.same_as(b_constraint.buffer) && |
985 | analyzer->CanProveEqual(a_constraint.value, b_constraint.value)) { |
986 | a_constraint.predicate = |
987 | SimplifyAsAndOfOrs(a_constraint.predicate || b_constraint.predicate, analyzer); |
988 | used = true; |
989 | break; |
990 | } |
991 | } |
992 | if (!used) { |
993 | constraints_.push_back(b_constraint); |
994 | } |
995 | } |
996 | } |
997 | |
998 | void BufferState::Intersection(const BufferState& b, Analyzer* analyzer) { |
999 | // For a constraint to be in the output, it must be present in both |
1000 | // inputs. |
1001 | |
1002 | std::vector<BufferTouch> new_constraints; |
1003 | for (const auto& ai : constraints_) { |
1004 | for (const auto& bi : b.constraints_) { |
1005 | if (ai.buffer.same_as(bi.buffer)) { |
1006 | PrimExpr predicate = SimplifyAsAndOfOrs(ai.predicate && bi.predicate, analyzer); |
1007 | if (!is_zero(predicate)) { |
1008 | With<ConstraintContext> context(analyzer, predicate); |
1009 | PrimExpr known_value_a = ai.value; |
1010 | PrimExpr known_value_b = bi.value; |
1011 | |
1012 | bool is_consistent = analyzer->CanProveEqual(known_value_a, known_value_b); |
1013 | if (is_consistent) { |
1014 | new_constraints.push_back({ai.buffer, predicate, known_value_a}); |
1015 | } |
1016 | } |
1017 | } |
1018 | } |
1019 | } |
1020 | |
1021 | constraints_ = std::move(new_constraints); |
1022 | } |
1023 | |
1024 | class BufferRegionCollector : public ExprVisitor { |
1025 | public: |
1026 | struct Region { |
1027 | PrimExpr region_predicate; |
1028 | std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>> known_values; |
1029 | }; |
1030 | |
1031 | static std::vector<Region> Collect(const Map<Buffer, Array<Var>>& axis_var_lookup, |
1032 | const std::vector<BufferTouch>& knowns, |
1033 | const std::vector<Optional<PrimExpr>>& exprs, |
1034 | Analyzer* analyzer) { |
1035 | BufferRegionCollector collector(axis_var_lookup, knowns, analyzer); |
1036 | for (const auto& expr : exprs) { |
1037 | if (expr) { |
1038 | collector(expr.value()); |
1039 | } |
1040 | } |
1041 | |
1042 | return collector.regions_; |
1043 | } |
1044 | |
1045 | private: |
1046 | using Parent = ExprVisitor; |
1047 | |
1048 | BufferRegionCollector(const Map<Buffer, Array<Var>>& axis_var_lookup, |
1049 | const std::vector<BufferTouch>& knowns, Analyzer* analyzer) |
1050 | : analyzer_(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) { |
1051 | regions_.push_back(Region{Bool(true), {}}); |
1052 | } |
1053 | |
1054 | using Parent::VisitExpr_; |
1055 | |
1056 | void VisitExpr_(const BufferLoadNode* op) override { |
1057 | // Helper struct for the known values of this BufferLoad |
1058 | struct Known { |
1059 | PrimExpr predicate; |
1060 | Optional<PrimExpr> value; |
1061 | }; |
1062 | |
1063 | std::vector<Known> new_regions; |
1064 | |
1065 | PrimExpr unknown_region = Bool(true); |
1066 | |
1067 | for (const BufferTouch& constraint : knowns_) { |
1068 | if (!op->buffer.same_as(constraint.buffer)) { |
1069 | // This is a different buffer, so continue searching. |
1070 | continue; |
1071 | } |
1072 | |
1073 | auto axis_vars = axis_var_lookup_.at(op->buffer); |
1074 | PrimExpr touch_predicate = |
1075 | SubstituteParamValues(axis_vars, op->indices, constraint.predicate).value(); |
1076 | touch_predicate = SimplifyAsAndOfOrs(touch_predicate, analyzer_); |
1077 | |
1078 | if (!is_zero(touch_predicate)) { |
1079 | Optional<PrimExpr> known_value = |
1080 | SubstituteParamValues(axis_vars, op->indices, constraint.value); |
1081 | new_regions.push_back(Known{touch_predicate, known_value}); |
1082 | |
1083 | unknown_region = unknown_region && !touch_predicate; |
1084 | unknown_region = SimplifyAsAndOfOrs(unknown_region, analyzer_); |
1085 | } |
1086 | } |
1087 | |
1088 | if (new_regions.size()) { |
1089 | Analyzer local_analyzer; |
1090 | |
1091 | if (!is_zero(unknown_region)) { |
1092 | new_regions.insert(new_regions.begin(), Known{unknown_region, NullOpt}); |
1093 | } |
1094 | |
1095 | std::vector<Region> updated_regions; |
1096 | for (const auto& prev_region : regions_) { |
1097 | for (const auto& new_region : new_regions) { |
1098 | PrimExpr intersection = |
1099 | SimplifyAsAndOfOrs(prev_region.region_predicate && new_region.predicate, analyzer_); |
1100 | |
1101 | if (!is_zero(intersection)) { |
1102 | Region merged{intersection, prev_region.known_values}; |
1103 | merged.known_values[op] = new_region.value; |
1104 | updated_regions.push_back(std::move(merged)); |
1105 | } |
1106 | } |
1107 | } |
1108 | regions_ = updated_regions; |
1109 | } |
1110 | } |
1111 | |
1112 | Analyzer* analyzer_; |
1113 | std::vector<Region> regions_; |
1114 | const Map<Buffer, Array<Var>>& axis_var_lookup_; |
1115 | const std::vector<BufferTouch>& knowns_; |
1116 | }; |
1117 | |
1118 | class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { |
1119 | public: |
1120 | static PrimExpr Apply( |
1121 | const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values, |
1122 | PrimExpr expr, Analyzer* analyzer) { |
1123 | BufferRegionValueReplacer mutator(known_values, analyzer); |
1124 | PrimExpr result = mutator(expr); |
1125 | // Simplification must occur after the substitution, as known |
1126 | // values may provide enable simplifications. Also, cannot track |
1127 | // whether a BufferLoad was |
1128 | result = analyzer->Simplify(result); |
1129 | return result; |
1130 | } |
1131 | |
1132 | private: |
1133 | using Parent = IRMutatorWithAnalyzer; |
1134 | |
1135 | BufferRegionValueReplacer( |
1136 | const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values, |
1137 | Analyzer* analyzer) |
1138 | : Parent(analyzer), known_values_(known_values) {} |
1139 | |
1140 | using Parent::VisitExpr_; |
1141 | |
1142 | PrimExpr VisitExpr_(const BufferLoadNode* op) override { |
1143 | auto it = known_values_.find(op); |
1144 | if (it != known_values_.end() && it->second) { |
1145 | return it->second.value(); |
1146 | } else { |
1147 | return GetRef<PrimExpr>(op); |
1148 | } |
1149 | } |
1150 | |
1151 | const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values_; |
1152 | }; |
1153 | |
1154 | void BufferState::ApplyTouches(const Map<Buffer, Array<Var>>& axis_var_lookup, |
1155 | const std::vector<BufferTouch>& touch_points, Analyzer* analyzer) { |
1156 | std::vector<BufferTouch> new_knowns; |
1157 | Map<Buffer, PrimExpr> keep_prior_known_at; |
1158 | |
1159 | for (auto& touch : touch_points) { |
1160 | if (touch.touch_type == BufferTouch::AccessType::Read) { |
1161 | continue; |
1162 | } |
1163 | |
1164 | PrimExpr known_value = touch.value; |
1165 | |
1166 | PrimExpr predicate = touch.predicate && touch.AfterLoopIteration(); |
1167 | auto regions = BufferRegionCollector::Collect(axis_var_lookup, constraints_, |
1168 | {predicate, touch.value}, analyzer); |
1169 | |
1170 | for (const auto& region : regions) { |
1171 | PrimExpr updated_predicate = BufferRegionValueReplacer::Apply( |
1172 | region.known_values, region.region_predicate && predicate, analyzer); |
1173 | |
1174 | updated_predicate = SimplifyAsAndOfOrs(updated_predicate, analyzer); |
1175 | PrimExpr updated_value = |
1176 | BufferRegionValueReplacer::Apply(region.known_values, known_value, analyzer); |
1177 | |
1178 | if (!is_zero(updated_predicate)) { |
1179 | if (auto it = keep_prior_known_at.find(touch.buffer); it != keep_prior_known_at.end()) { |
1180 | keep_prior_known_at.Set(touch.buffer, (*it).second && !updated_predicate); |
1181 | } else { |
1182 | keep_prior_known_at.Set(touch.buffer, !updated_predicate); |
1183 | } |
1184 | |
1185 | if (!HasBufferLoad(updated_value)) { |
1186 | BufferTouch new_constraint{touch.buffer, updated_predicate, updated_value}; |
1187 | new_knowns.push_back(new_constraint); |
1188 | } |
1189 | } |
1190 | } |
1191 | } |
1192 | |
1193 | if (keep_prior_known_at.size()) { |
1194 | for (auto& constraint : constraints_) { |
1195 | if (auto it = keep_prior_known_at.find(constraint.buffer); it != keep_prior_known_at.end()) { |
1196 | constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate && (*it).second, analyzer); |
1197 | } |
1198 | } |
1199 | } |
1200 | |
1201 | if (new_knowns.size()) { |
1202 | std::vector<bool> used(new_knowns.size(), false); |
1203 | |
1204 | for (auto& constraint : constraints_) { |
1205 | PrimExpr expand_known_at = Bool(false); |
1206 | |
1207 | PrimExpr prev_value = constraint.value; |
1208 | |
1209 | for (size_t i = 0; i < new_knowns.size(); i++) { |
1210 | if (new_knowns[i].buffer.same_as(constraint.buffer)) { |
1211 | Optional<PrimExpr> overwritten_with = new_knowns[i].value; |
1212 | if (overwritten_with && analyzer->CanProveEqual(prev_value, overwritten_with.value())) { |
1213 | expand_known_at = |
1214 | SimplifyAsAndOfOrs(expand_known_at || new_knowns[i].predicate, analyzer); |
1215 | used[i] = true; |
1216 | } |
1217 | } |
1218 | } |
1219 | |
1220 | if (!is_zero(expand_known_at)) { |
1221 | constraint.predicate = |
1222 | SimplifyAsAndOfOrs(constraint.predicate || expand_known_at, analyzer); |
1223 | } |
1224 | } |
1225 | |
1226 | for (size_t i = 0; i < new_knowns.size(); i++) { |
1227 | if (!used[i]) { |
1228 | constraints_.push_back(new_knowns[i]); |
1229 | } |
1230 | } |
1231 | } |
1232 | |
1233 | constraints_.erase( |
1234 | std::remove_if(constraints_.begin(), constraints_.end(), |
1235 | [&](const auto& constraint) { return is_zero(constraint.predicate); }), |
1236 | constraints_.end()); |
1237 | } |
1238 | |
1239 | void BufferState::BackpropUnusedIndices(const Map<Buffer, Array<Var>>& axis_var_lookup, |
1240 | const std::vector<BufferTouch>& touch_points, |
1241 | Analyzer* analyzer) { |
1242 | std::vector<BufferTouch> new_knowns; |
1243 | Map<Buffer, PrimExpr> keep_prior_known_at; |
1244 | |
1245 | Map<Buffer, PrimExpr> regions_written; |
1246 | Map<Buffer, PrimExpr> regions_read; |
1247 | for (auto it = touch_points.rbegin(); it != touch_points.rend(); it++) { |
1248 | const auto& touch = *it; |
1249 | |
1250 | Map<Buffer, PrimExpr>* to_update{nullptr}; |
1251 | if (touch.touch_type == BufferTouch::AccessType::Write) { |
1252 | to_update = ®ions_written; |
1253 | |
1254 | } else if (touch.touch_type == BufferTouch::AccessType::Read) { |
1255 | to_update = ®ions_read; |
1256 | } else { |
1257 | continue; |
1258 | } |
1259 | |
1260 | PrimExpr prev = to_update->Get(touch.buffer).value_or(Bool(false)); |
1261 | PrimExpr new_predicate = touch.predicate && touch.BeforeLoopIteration(); |
1262 | to_update->Set(touch.buffer, prev || new_predicate); |
1263 | } |
1264 | |
1265 | auto update_map = [&](auto& map) { |
1266 | Map<Buffer, PrimExpr> new_map; |
1267 | for (auto [buffer, predicate] : map) { |
1268 | new_map.Set(buffer, SimplifyAsAndOfOrs(predicate, analyzer)); |
1269 | } |
1270 | map = std::move(new_map); |
1271 | }; |
1272 | update_map(regions_written); |
1273 | update_map(regions_read); |
1274 | |
1275 | // If buffer is already in used, widen the predicate |
1276 | for (auto& prev_unused : constraints_) { |
1277 | if (auto opt_predicate = regions_written.Get(prev_unused.buffer)) { |
1278 | PrimExpr new_predicate = prev_unused.predicate || opt_predicate.value(); |
1279 | prev_unused.predicate = SimplifyAsAndOfOrs(new_predicate, analyzer); |
1280 | regions_written.erase(prev_unused.buffer); |
1281 | } |
1282 | } |
1283 | |
1284 | // Otherwise, add new "touch" to represent the unused values |
1285 | for (auto [buffer, predicate] : regions_written) { |
1286 | constraints_.push_back( |
1287 | BufferTouch{buffer, predicate, tir::Call(buffer->dtype, builtin::undef(), {})}); |
1288 | } |
1289 | |
1290 | // If buffer is read out, narrow the predicate |
1291 | for (auto& prev_unused : constraints_) { |
1292 | if (auto opt_pred = regions_read.Get(prev_unused.buffer)) { |
1293 | PrimExpr predicate = opt_pred.value(); |
1294 | prev_unused.predicate = SimplifyAsAndOfOrs(prev_unused.predicate && !predicate, analyzer); |
1295 | } |
1296 | } |
1297 | |
1298 | // Clean-up and remove any empty constraints |
1299 | constraints_.erase( |
1300 | std::remove_if(constraints_.begin(), constraints_.end(), |
1301 | [](const auto& constraint) { return is_zero(constraint.predicate); }), |
1302 | constraints_.end()); |
1303 | } |
1304 | |
1305 | void BufferState::RemoveFreeParameters(const Map<Var, Range>& free_predicate_parameters, |
1306 | Analyzer* analyzer) { |
1307 | for (auto& known : constraints_) { |
1308 | known.predicate = NarrowPredicateExpression(known.predicate, free_predicate_parameters); |
1309 | known.predicate = SimplifyAsAndOfOrs(known.predicate, analyzer); |
1310 | } |
1311 | } |
1312 | |
1313 | bool BufferState::IsEquivalentTo(const BufferState& other, Analyzer* analyzer) const { |
1314 | if (constraints_.size() != other.constraints_.size()) { |
1315 | return false; |
1316 | } |
1317 | |
1318 | for (size_t i = 0; i < constraints_.size(); i++) { |
1319 | if (!constraints_[i].IsEquivalentTo(other.constraints_[i], analyzer)) { |
1320 | return false; |
1321 | } |
1322 | } |
1323 | |
1324 | return true; |
1325 | } |
1326 | |
1327 | Optional<Array<Var>> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const { |
1328 | if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { |
1329 | return (*it).second; |
1330 | } else { |
1331 | return NullOpt; |
1332 | } |
1333 | } |
1334 | |
1335 | Array<Var> ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array<PrimExpr>& indices) { |
1336 | if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { |
1337 | return (*it).second; |
1338 | } |
1339 | |
1340 | Array<Var> vars; |
1341 | for (size_t i = 0; i < indices.size(); i++) { |
1342 | std::stringstream ss; |
1343 | ss << buf->name << "_axis_" << i; |
1344 | vars.push_back(Var(ss.str(), indices[i].dtype().element_of())); |
1345 | } |
1346 | |
1347 | axis_var_lookup_.Set(buf, vars); |
1348 | return vars; |
1349 | } |
1350 | |
1351 | void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_from) { |
1352 | // Values to visit when searching. Using a std::set to |
1353 | // preferentially visit nodes near the start of the control flow. |
1354 | std::set<size_t> to_visit; |
1355 | |
1356 | if (flow_from.has_value()) { |
1357 | to_visit.insert(flow_from.value()); |
1358 | } else { |
1359 | // Initiatize the locations to search from, propagating values |
1360 | // forward from all locations that have a known value. |
1361 | for (size_t i = 0; i < control_flow_.size(); i++) { |
1362 | bool has_known_value = false; |
1363 | for (const auto& touch : control_flow_[i].touch_points) { |
1364 | if (!HasBufferLoad(touch.value)) { |
1365 | has_known_value = true; |
1366 | break; |
1367 | } |
1368 | } |
1369 | |
1370 | if (has_known_value) { |
1371 | to_visit.insert(i); |
1372 | } |
1373 | } |
1374 | } |
1375 | |
1376 | // Map from a block's index |
1377 | std::unordered_map<size_t, size_t> visit_count_lookup; |
1378 | |
1379 | Analyzer analyzer; |
1380 | analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( |
1381 | arith::RewriteSimplifier::kTransitivelyProveInequalities | |
1382 | arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | |
1383 | arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); |
1384 | |
1385 | analyzer.Bind(iterator_ranges_); |
1386 | analyzer.Bind(free_predicate_parameters_); |
1387 | |
1388 | while (to_visit.size()) { |
1389 | size_t visiting = *to_visit.begin(); |
1390 | to_visit.erase(visiting); |
1391 | |
1392 | size_t num_previous_visits = visit_count_lookup[visiting]++; |
1393 | |
1394 | ControlFlowBlock& block = control_flow_[visiting]; |
1395 | |
1396 | // Step 1: Collect known values provided from each predecessor |
1397 | block.known_at_block_start = [&]() -> BufferState { |
1398 | if (num_previous_visits >= max_revisits_) { |
1399 | return BufferState(); |
1400 | } |
1401 | |
1402 | // Validate internal constraint. This should be true by |
1403 | // construction, as ControlFlowGraphBuilder only builds graphs |
1404 | // that have two or fewer predecessors. |
1405 | ICHECK_LE(block.predecessors.size(), 2) |
1406 | << "InternalError: Each block should have at most two predecessors. " |
1407 | << "Graph constructed in ControlFlowGraphBuilder did not satisfy this constraint." ; |
1408 | |
1409 | std::vector<BufferState> states; |
1410 | for (const auto& pred : block.predecessors) { |
1411 | const auto& pred_block = control_flow_[pred.index]; |
1412 | BufferState state = pred_block.known_at_block_end; |
1413 | state.Substitute(pred.var_remap, &analyzer); |
1414 | states.push_back(state); |
1415 | } |
1416 | |
1417 | if (std::all_of(block.predecessors.begin(), block.predecessors.end(), |
1418 | [&](const auto& pred) { return visit_count_lookup[pred.index] == 0; })) { |
1419 | // Predecessors, if any, are unvisited. |
1420 | return {}; |
1421 | } else if (block.predecessors.size() == 1) { |
1422 | // Block has only a single predecessor |
1423 | return states[0]; |
1424 | } |
1425 | |
1426 | const auto& pred_a = block.predecessors[0]; |
1427 | const auto& pred_b = block.predecessors[1]; |
1428 | |
1429 | auto& priors_a = states[0]; |
1430 | auto& priors_b = states[1]; |
1431 | |
1432 | // During the first visit of a block, predecessor blocks may be |
1433 | // unvisited, even though we preferentially visit earlier blocks |
1434 | // first. (e.g. During the first visit of the start of a For |
1435 | // loop, the end of the For loop has not yet been visited.) If |
1436 | // this is the case, assume the best-case scenario that all |
1437 | // knowns are consistent, and rely on a later visit to |
1438 | // resolve/remove any conflicts. |
1439 | if (visit_count_lookup[pred_a.index] == 0) { |
1440 | return priors_b; |
1441 | } else if (visit_count_lookup[pred_b.index] == 0) { |
1442 | return priors_a; |
1443 | } |
1444 | |
1445 | if (pred_a.post_condition && pred_b.post_condition) { |
1446 | // The predicate can identify which predecessor block applies |
1447 | // (e.g. i==0 for the first loop iteration, i>0 for remaining |
1448 | // loop iterations). Therefore, we can use all buffer |
1449 | // constraints, conditional on having come from the |
1450 | // predecessor that provides it. |
1451 | priors_a.AddCondition(pred_a.post_condition.value()); |
1452 | priors_b.AddCondition(pred_b.post_condition.value()); |
1453 | priors_a.Union(priors_b, &analyzer); |
1454 | return priors_a; |
1455 | } else { |
1456 | // We don't know which predecessor applies. Therefore, the |
1457 | // only buffer constraints that can be used are those that |
1458 | // appear in both predecessors. |
1459 | priors_a.Intersection(priors_b, &analyzer); |
1460 | return priors_a; |
1461 | } |
1462 | }(); |
1463 | |
1464 | // Step 2: Collect knowns provided as a result of executing this block |
1465 | auto post_state = [&]() { |
1466 | if (num_previous_visits >= max_revisits_) { |
1467 | return BufferState(); |
1468 | } |
1469 | auto post_state = block.known_at_block_start; |
1470 | post_state.ApplyTouches(axis_var_lookup_, block.touch_points, &analyzer); |
1471 | post_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer); |
1472 | return post_state; |
1473 | }(); |
1474 | |
1475 | // Step 3: If any changes are made to the post knowns since the |
1476 | // previous time we visited this block, mark the successor block |
1477 | // as needing to be visited. |
1478 | if (num_previous_visits == 0 || |
1479 | !post_state.IsEquivalentTo(block.known_at_block_end, &analyzer)) { |
1480 | block.known_at_block_end = std::move(post_state); |
1481 | for (const auto& successor : block.successors) { |
1482 | to_visit.insert(successor.index); |
1483 | } |
1484 | } |
1485 | } |
1486 | } |
1487 | |
1488 | void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_from) { |
1489 | // Values to visit when searching. Using a std::set to |
1490 | // preferentially visit nodes near the end of the control flow. |
1491 | std::set<size_t> to_visit; |
1492 | |
1493 | if (flow_from.has_value()) { |
1494 | to_visit.insert(flow_from.value()); |
1495 | } else { |
1496 | // Initiatize the locations to search from, propagating values |
1497 | // backward from anywhere that performs a write. |
1498 | for (size_t i = 0; i < control_flow_.size(); i++) { |
1499 | const auto& touch_points = control_flow_[i].touch_points; |
1500 | bool performs_write = std::any_of( |
1501 | touch_points.begin(), touch_points.end(), |
1502 | [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; }); |
1503 | if (performs_write) { |
1504 | to_visit.insert(i); |
1505 | } |
1506 | } |
1507 | } |
1508 | |
1509 | // Map from a block's index |
1510 | std::unordered_map<size_t, size_t> visit_count_lookup; |
1511 | |
1512 | Analyzer analyzer; |
1513 | analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( |
1514 | arith::RewriteSimplifier::kTransitivelyProveInequalities | |
1515 | arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | |
1516 | arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); |
1517 | |
1518 | analyzer.Bind(iterator_ranges_); |
1519 | analyzer.Bind(free_predicate_parameters_); |
1520 | |
1521 | while (to_visit.size()) { |
1522 | size_t visiting = *to_visit.rbegin(); |
1523 | to_visit.erase(visiting); |
1524 | |
1525 | size_t num_previous_visits = visit_count_lookup[visiting]++; |
1526 | |
1527 | ControlFlowBlock& block = control_flow_[visiting]; |
1528 | |
1529 | // Step 1: Collect known unused indices provided by each successor |
1530 | block.unused_at_block_end = [&]() -> BufferState { |
1531 | if (num_previous_visits >= max_revisits_) { |
1532 | return BufferState(); |
1533 | } |
1534 | ICHECK_LE(block.successors.size(), 2) |
1535 | << "Each block should have at most two successors, but block " << visiting |
1536 | << " breaks this requirement" ; |
1537 | |
1538 | std::vector<BufferState> states; |
1539 | for (const auto& successor : block.successors) { |
1540 | const auto& successor_block = control_flow_[successor.index]; |
1541 | BufferState state = successor_block.unused_at_block_start; |
1542 | state.Substitute(successor.var_remap, &analyzer); |
1543 | states.push_back(state); |
1544 | } |
1545 | |
1546 | if (std::all_of(block.successors.begin(), block.successors.end(), [&](const auto& successor) { |
1547 | return visit_count_lookup[successor.index] == 0; |
1548 | })) { |
1549 | // Successors, if any, are unvisited. |
1550 | return {}; |
1551 | } else if (block.successors.size() == 1) { |
1552 | // Block has only a single successor |
1553 | return states[0]; |
1554 | } |
1555 | |
1556 | const auto& successor_a = block.successors[0]; |
1557 | const auto& successor_b = block.successors[1]; |
1558 | |
1559 | auto& post_a = states[0]; |
1560 | auto& post_b = states[1]; |
1561 | |
1562 | // During the first visit of a block, successor blocks may be |
1563 | // unvisited, even though we preferentially visit later blocks |
1564 | // first. (e.g. During the first visit of the end of a For |
1565 | // loop, the start of the For loop has not yet been visited.) |
1566 | // If this is the case, assume the best-case scenario that all |
1567 | // knowns are consistent, and rely on a later visit to |
1568 | // resolve/remove any conflicts. |
1569 | if (visit_count_lookup[successor_a.index] == 0) { |
1570 | return post_b; |
1571 | } else if (visit_count_lookup[successor_b.index] == 0) { |
1572 | return post_a; |
1573 | } |
1574 | |
1575 | if (successor_a.post_condition && successor_b.post_condition) { |
1576 | // The predicate can identify which successor block applies |
1577 | // (e.g. i==n-1 for the last loop iteration, i<n-1 for earlier |
1578 | // loop iterations). Therefore, we can use all buffer |
1579 | // constraints, conditional on having come from the |
1580 | // successor that provides it. |
1581 | post_a.AddCondition(successor_a.post_condition.value()); |
1582 | post_b.AddCondition(successor_b.post_condition.value()); |
1583 | post_a.Union(post_b, &analyzer); |
1584 | return post_a; |
1585 | } else { |
1586 | // We don't know which successor applies. Therefore, the |
1587 | // only buffer constraints that can be used are those that |
1588 | // appear in both successors. |
1589 | post_a.Intersection(post_b, &analyzer); |
1590 | return post_a; |
1591 | } |
1592 | }(); |
1593 | |
1594 | // Step 2: Collect knowns provided as a result of executing this block |
1595 | auto unused_at_block_start = [&]() { |
1596 | if (num_previous_visits >= max_revisits_) { |
1597 | return BufferState(); |
1598 | } |
1599 | auto prior_state = block.unused_at_block_end; |
1600 | prior_state.BackpropUnusedIndices(axis_var_lookup_, block.touch_points, &analyzer); |
1601 | prior_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer); |
1602 | return prior_state; |
1603 | }(); |
1604 | |
1605 | // Step 3: If any changes are made to the post knowns since the |
1606 | // previous time we visited this block, mark the successor block |
1607 | // as needing to be visited. |
1608 | if (num_previous_visits == 0 || |
1609 | !unused_at_block_start.IsEquivalentTo(block.unused_at_block_start, &analyzer)) { |
1610 | block.unused_at_block_start = std::move(unused_at_block_start); |
1611 | for (const auto& pred : block.predecessors) { |
1612 | to_visit.insert(pred.index); |
1613 | } |
1614 | } |
1615 | } |
1616 | } |
1617 | |
1618 | bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, |
1619 | const Stmt& context) const { |
1620 | Optional<Array<Var>> index_variables = GetIndexVariables(store->buffer); |
1621 | if (!index_variables) { |
1622 | return false; |
1623 | } |
1624 | |
1625 | auto it = control_flow_lookup_.find(context.get()); |
1626 | ICHECK(it != control_flow_lookup_.end()) << "Context did not occur within analyzed statement:\n" |
1627 | << context; |
1628 | const auto& context_block = control_flow_[it->second]; |
1629 | |
1630 | auto [store_touch, free_params] = context_block.MakeBufferTouch( |
1631 | store->buffer, index_variables.value(), store->indices, BufferTouch::AccessType::Write, |
1632 | BufferLoad(store->buffer, store->indices)); |
1633 | |
1634 | Analyzer local_analyzer; |
1635 | local_analyzer.Bind(free_predicate_parameters_); |
1636 | local_analyzer.Bind(iterator_ranges_); |
1637 | local_analyzer.Bind(free_params); |
1638 | local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( |
1639 | arith::RewriteSimplifier::kTransitivelyProveInequalities | |
1640 | arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | |
1641 | arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); |
1642 | |
1643 | PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration(); |
1644 | |
1645 | predicate = SimplifyAsAndOfOrs(predicate, &local_analyzer); |
1646 | |
1647 | for (const auto& unused : context_block.unused_at_block_end.constraints_) { |
1648 | if (store_touch.buffer.same_as(unused.buffer)) { |
1649 | PrimExpr difference = SimplifyAsAndOfOrs(predicate && !unused.predicate, &local_analyzer); |
1650 | if (is_zero(difference)) { |
1651 | return true; |
1652 | } |
1653 | } |
1654 | } |
1655 | return false; |
1656 | } |
1657 | |
1658 | PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& context, |
1659 | Analyzer* analyzer) const { |
1660 | size_t context_index = [&]() { |
1661 | auto it = control_flow_lookup_.find(context.get()); |
1662 | ICHECK(it != control_flow_lookup_.end()) |
1663 | << "Context did not occur in the Stmt provided to BufferTouchPattern's constructor" ; |
1664 | return it->second; |
1665 | }(); |
1666 | |
1667 | const auto& control_flow_block = control_flow_[context_index]; |
1668 | |
1669 | PrimExpr constraint = Bool(true); |
1670 | for (const auto& known : non_buffer_assumptions_) { |
1671 | constraint = constraint && known; |
1672 | } |
1673 | With<ConstraintContext> constraint_context(analyzer, constraint); |
1674 | With<ConstraintContext> control_flow_scope(analyzer, control_flow_block.scope_predicate); |
1675 | |
1676 | expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues( |
1677 | std::move(expr), axis_var_lookup_, analyzer); |
1678 | |
1679 | expr = analyzer->Simplify(std::move(expr)); |
1680 | return expr; |
1681 | } |
1682 | |
1683 | } // namespace tir |
1684 | } // namespace tvm |
1685 | |