1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file control_flow_graph.cc
22 * \brief Utility to deduce bound of expression
23 */
24
25#include "control_flow_graph.h"
26
27#include <tvm/runtime/registry.h>
28#include <tvm/tir/analysis.h>
29#include <tvm/tir/builtin.h>
30#include <tvm/tir/expr.h>
31#include <tvm/tir/op.h>
32#include <tvm/tir/stmt_functor.h>
33
34#include <algorithm>
35#include <numeric>
36#include <optional>
37#include <queue>
38#include <set>
39#include <sstream>
40#include <unordered_set>
41
42#include "../../arith/conjunctive_normal_form.h"
43#include "../../arith/constraint_extract.h"
44#include "../../arith/ir_mutator_with_analyzer.h"
45#include "../../arith/ir_visitor_with_analyzer.h"
46#include "../../arith/narrow_predicate_expression.h"
47#include "../../arith/unwrap_vector_expr.h"
48
49namespace tvm {
50namespace tir {
51
52using namespace arith;
53
54namespace {
55bool HasBufferLoad(PrimExpr expr) {
56 struct Visitor : public ExprVisitor {
57 void VisitExpr_(const BufferLoadNode* node) override { found_buffer_load = true; }
58 bool found_buffer_load{false};
59 };
60
61 Visitor visitor;
62 visitor(expr);
63 return visitor.found_buffer_load;
64}
65
66Optional<PrimExpr> SubstituteParamValues(const Array<Var>& param_vars,
67 const Array<PrimExpr>& param_values,
68 const PrimExpr& expr) {
69 ICHECK_EQ(param_vars.size(), param_values.size())
70 << "Expression was defined as having " << param_vars.size() << " parameters, but received "
71 << param_values.size() << " arguments.";
72
73 Map<tir::Var, PrimExpr> var_map;
74 for (size_t i = 0; i < param_values.size(); i++) {
75 var_map.Set(param_vars[i], param_values[i]);
76 }
77
78 return Substitute(expr, var_map);
79}
80} // namespace
81
82PrimExpr BufferTouch::BeforeLoopIteration() const {
83 PrimExpr loop_predicate = Bool(true);
84 for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) {
85 const Var& loop_var = it->first;
86 const PrimExpr& loop_expr = it->second;
87 loop_predicate = (loop_var <= loop_expr) || ((loop_var == loop_expr) && loop_predicate);
88 }
89 return loop_predicate;
90}
91
92PrimExpr BufferTouch::AtLoopIteration() const {
93 PrimExpr loop_predicate = Bool(true);
94 for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) {
95 const Var& loop_var = it->first;
96 const PrimExpr& loop_expr = it->second;
97 loop_predicate = (loop_var == loop_expr) && loop_predicate;
98 }
99 return loop_predicate;
100}
101
102PrimExpr BufferTouch::AfterLoopIteration() const {
103 PrimExpr loop_predicate = Bool(true);
104 for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) {
105 const Var& loop_var = it->first;
106 const PrimExpr& loop_expr = it->second;
107 loop_predicate = (loop_var >= loop_expr) || ((loop_var == loop_expr) && loop_predicate);
108 }
109 return loop_predicate;
110}
111
112bool BufferTouch::IsSubsetOf(const BufferTouch& other, Analyzer* analyzer) const {
113 if (this->buffer.same_as(other.buffer)) {
114 With<ConstraintContext> constraint(analyzer, predicate);
115
116 return analyzer->CanProve(other.predicate);
117 } else {
118 return false;
119 }
120}
121
122bool BufferTouch::IsDistinctFrom(const BufferTouch& other, Analyzer* analyzer) const {
123 if (this->buffer.same_as(other.buffer)) {
124 With<ConstraintContext> constraint(analyzer, predicate);
125
126 return analyzer->CanProve(!other.predicate);
127 } else {
128 return true;
129 }
130}
131
132std::ostream& operator<<(std::ostream& os, const BufferTouch& tp) {
133 auto touch_type = [&]() {
134 if (tp.touch_type == BufferTouch::AccessType::Read) {
135 return "read";
136 } else if (tp.touch_type == BufferTouch::AccessType::Write) {
137 return "write";
138 } else if (tp.touch_type == BufferTouch::AccessType::Assume) {
139 return "assume";
140 } else {
141 return "???";
142 }
143 }();
144
145 os << "BufferTouch(" << tp.buffer->name << ", " << touch_type << ", " << tp.predicate
146 << ", value = " << tp.value << ")";
147 return os;
148}
149
150class BufferConstraintApply : public IRMutatorWithAnalyzer {
151 public:
152 using Parent = IRMutatorWithAnalyzer;
153
154 BufferConstraintApply(const Map<Buffer, Array<Var>>& axis_var_lookup,
155 const std::vector<BufferTouch>& knowns, Analyzer* analyzer)
156 : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {}
157
158 using Parent::VisitExpr_;
159
160 PrimExpr VisitExpr_(const BufferLoadNode* op) override {
161 for (const auto& known : knowns_) {
162 if (!op->buffer.same_as(known.buffer)) {
163 continue;
164 }
165
166 Optional<Var> lane_var = NullOpt;
167 IntImm num_lanes;
168
169 Array<PrimExpr> indices = op->indices.Map([&](const auto& index) {
170 if (index.dtype().lanes() == 1) {
171 return index;
172 } else {
173 ICHECK(!lane_var) << "Multiple indices found with non-scalar values";
174 lane_var = Var("lane", index.dtype().element_of());
175 num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes());
176 return UnwrapVectorExpr(index, lane_var.value());
177 }
178 });
179
180 auto axis_vars = axis_var_lookup_.at(op->buffer);
181 PrimExpr predicate = SubstituteParamValues(axis_vars, indices, known.predicate).value();
182
183 std::optional<With<ConstraintContext>> context;
184 if (lane_var.defined()) {
185 Var lanes = lane_var.value();
186 PrimExpr known = (IntImm(lanes.dtype(), 0) <= lanes) && (lanes < num_lanes);
187 context.emplace(analyzer_, known);
188 }
189
190 if (analyzer_->CanProve(predicate)) {
191 return SubstituteParamValues(axis_vars, op->indices, known.value).value();
192 }
193 }
194
195 return GetRef<PrimExpr>(op);
196 }
197
198 private:
199 const Map<Buffer, Array<Var>>& axis_var_lookup_;
200 const std::vector<BufferTouch>& knowns_;
201};
202
203/*! \brief Extract the control-flow graph
204 *
205 * Walk through a statement, populating the control-flow graph.
206 */
207class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer {
208 public:
209 static void Build(ControlFlowGraph* out, const Stmt& stmt) {
210 ControlFlowGraphBuilder extractor(out);
211 extractor.AppendControlBlock();
212 extractor(stmt);
213 }
214
215 private:
216 ControlFlowGraphBuilder(ControlFlowGraph* out) : out_(out) {}
217
218 using Parent = IRVisitorWithAnalyzer;
219 using Parent::VisitExpr_;
220 using Parent::VisitStmt_;
221
222 void VisitStmt(const Stmt& stmt) override {
223 // Update the lookup table to determine which control-flow block
224 // contains the start of the specified statement. This is used
225 // later to determine which set of known values should be used to
226 // simplify a statement.
227 out_->control_flow_lookup_[stmt.get()] = CurrentControlBlock();
228 Stmt prev_stmt = current_stmt_;
229 current_stmt_ = stmt;
230 Parent::VisitStmt(stmt);
231 current_stmt_ = prev_stmt;
232 }
233
234 void VisitStmt_(const EvaluateNode* op) override {
235 if (auto* call = op->value.as<CallNode>()) {
236 if (call->op.same_as(builtin::assume())) {
237 Assume(call->args[0], true);
238 return;
239 }
240 }
241
242 Parent::VisitStmt_(op);
243 }
244
245 void Assume(PrimExpr assumption, bool from_assume_statement) {
246 for (const auto& expr : ExtractConstraints(assumption, false)) {
247 AssumeConstraintComponent(expr, from_assume_statement);
248 }
249 }
250
251 void AssumeConstraintComponent(PrimExpr assumption, bool from_assume_statement) {
252 PrimExpr additional_predicate = Bool(true);
253
254 std::vector<PrimExpr> buffer_exprs;
255 for (const auto& expr : ExtractComponents(assumption)) {
256 auto side_effect = tir::SideEffect(expr);
257 if (side_effect <= tir::CallEffectKind::kPure) {
258 // Pulling out portions of the assumption that do not depend
259 // on a buffer value allows the following two forms to be
260 // treated identically.
261 //
262 // Option 1: if i < 3: T.assume(buf[i] == value)
263 // Option 2: T.assume(i>=3 or buf[i] == value)
264 additional_predicate = additional_predicate && logical_not(expr);
265 } else if (side_effect == tir::CallEffectKind::kReadState) {
266 buffer_exprs.push_back(expr);
267 } else {
268 LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr
269 << " with side-effect \'" << side_effect << "\'";
270 }
271 }
272
273 if (buffer_exprs.empty()) {
274 out_->non_buffer_assumptions_.push_back(!CurrentScopePredicate() || assumption);
275 return;
276 }
277
278 CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression";
279
280 auto* as_equal_node = buffer_exprs[0].as<tir::EQNode>();
281 CHECK(as_equal_node || !from_assume_statement)
282 << "T.assume buffer constraint must be of the form 'buffer[indices] == "
283 "value', but received "
284 << assumption;
285 if (!as_equal_node) {
286 // This assumption is an inequality on a data-dependent
287 // conditional. Not an error for this to occur, but also not
288 // something that is currently supported.
289 return;
290 }
291
292 tir::BufferLoad load;
293 PrimExpr value;
294 if (auto* as_load = as_equal_node->a.as<tir::BufferLoadNode>()) {
295 load = GetRef<tir::BufferLoad>(as_load);
296 value = as_equal_node->b;
297 } else if (auto* as_load = as_equal_node->b.as<tir::BufferLoadNode>()) {
298 load = GetRef<tir::BufferLoad>(as_load);
299 value = as_equal_node->a;
300 } else if (!from_assume_statement) {
301 return;
302 } else {
303 LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'";
304 }
305
306 auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure;
307 CHECK(!has_side_effect || !from_assume_statement)
308 << "Buffer value in constraint must be pure expression, but was " << value;
309 if (has_side_effect) {
310 return;
311 }
312
313 {
314 InternalConstraintContext context(this, additional_predicate);
315 VisitAccess(load, BufferTouch::AccessType::Assume, value);
316 }
317 // Appending a control block ensures that all control blocks have
318 // at most one statement that changes the known buffer contents.
319 auto prev_block = CurrentControlBlock();
320 auto new_block = AppendControlBlock();
321 MarkControlFlow(prev_block, new_block);
322 }
323
324 void VisitExpr_(const LetNode* op) override {
325 std::optional<BindLetVar> binding;
326 if (UsesLoopVar(op->value)) {
327 binding.emplace(this, op->var, op->value);
328 }
329 Parent::VisitExpr_(op);
330 }
331
332 void VisitStmt_(const LetStmtNode* op) override {
333 std::optional<BindLetVar> binding;
334 if (UsesLoopVar(op->value)) {
335 binding.emplace(this, op->var, op->value);
336 }
337 Parent::VisitStmt_(op);
338 }
339
340 void VisitExpr_(const BufferLoadNode* op) override {
341 Parent::VisitExpr_(op);
342 BufferLoad load = GetRef<BufferLoad>(op);
343 VisitAccess(load, BufferTouch::AccessType::Read, load);
344 }
345
346 void VisitStmt_(const BufferStoreNode* op) override {
347 Parent::VisitStmt_(op);
348 VisitAccess(GetRef<BufferStore>(op), BufferTouch::AccessType::Write, op->value);
349 // Appending a control block ensures that all control blocks have
350 // at most one statement that changes the buffer contents.
351 auto prev_block = CurrentControlBlock();
352 auto new_block = AppendControlBlock();
353 MarkControlFlow(prev_block, new_block);
354 }
355
356 void VisitStmt_(const ForNode* op) override {
357 out_->iterator_ranges_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent));
358
359 auto before_loop = CurrentControlBlock();
360 size_t loop_start = -1;
361
362 {
363 BindActiveLoopVar binding(this, op->loop_var, op->min, op->extent);
364 loop_start = AppendControlBlock();
365 Parent::VisitStmt_(op);
366 }
367
368 auto loop_end = CurrentControlBlock();
369 auto after_loop = AppendControlBlock();
370 PrimExpr max_iterator_value = analyzer_.Simplify(op->min + op->extent - 1);
371 {
372 auto [forward, backward] = MarkControlFlow(before_loop, loop_start);
373 backward.post_condition = (op->loop_var == op->min);
374 forward.var_remap = {{op->loop_var, op->min}};
375 }
376 {
377 auto [forward, backward] = MarkControlFlow(loop_end, after_loop);
378 backward.var_remap = {{op->loop_var, max_iterator_value}};
379 forward.post_condition = (op->loop_var == max_iterator_value);
380 }
381 {
382 auto [forward, backward] = MarkControlFlow(loop_end, loop_start);
383 backward.var_remap = {{op->loop_var, op->loop_var - 1}};
384 forward.var_remap = {{op->loop_var, op->loop_var + 1}};
385 backward.post_condition = (op->loop_var > op->min);
386 forward.post_condition = (op->loop_var < max_iterator_value);
387 }
388 }
389
390 void VisitStmt_(const IfThenElseNode* op) override {
391 this->VisitExpr(op->condition);
392
393 PrimExpr real_condition = ExtractRealCondition(op->condition);
394
395 auto before_branching = CurrentControlBlock();
396
397 auto branch_start = AppendControlBlock();
398 MarkControlFlow(before_branching, branch_start);
399
400 {
401 InternalConstraintContext context(this, real_condition);
402 auto then_start = AppendControlBlock();
403 if (context.assume.defined()) {
404 Assume(context.assume.value(), false);
405 }
406 auto [forward, backward] = MarkControlFlow(branch_start, then_start);
407 backward.post_condition = real_condition;
408 forward.post_condition = real_condition;
409 this->VisitStmt(op->then_case);
410 }
411 auto then_end = CurrentControlBlock();
412
413 auto negation = analyzer_.rewrite_simplify(!real_condition);
414 {
415 InternalConstraintContext context(this, negation);
416 auto else_start = AppendControlBlock();
417 if (context.assume.defined()) {
418 Assume(context.assume.value(), false);
419 }
420 auto [forward, backward] = MarkControlFlow(branch_start, else_start);
421 backward.post_condition = negation;
422 forward.post_condition = negation;
423
424 if (op->else_case.defined()) {
425 this->VisitStmt(op->else_case.value());
426 }
427 }
428
429 auto else_end = CurrentControlBlock();
430 auto after_branching = AppendControlBlock();
431
432 if (HasBufferLoad(real_condition)) {
433 // The buffer value may have changed during the body of the
434 // condition, so we can't provide it as a post-condition.
435 MarkControlFlow(then_end, after_branching);
436 MarkControlFlow(else_end, after_branching);
437 } else {
438 {
439 auto [forward, backward] = MarkControlFlow(then_end, after_branching);
440 backward.post_condition = real_condition;
441 forward.post_condition = real_condition;
442 }
443 {
444 auto [forward, backward] = MarkControlFlow(else_end, after_branching);
445 backward.post_condition = negation;
446 forward.post_condition = negation;
447 }
448 }
449 }
450
451 /*! \brief Internal utility, returns true if the expression depends
452 * on a loop iterator
453 */
454 bool UsesLoopVar(const PrimExpr& expr) {
455 return UsesVar(expr, [&](const VarNode* expr_var) {
456 return loop_dependent_vars_.find(expr_var) != loop_dependent_vars_.end();
457 });
458 }
459
460 /*! \brief Record the interaction with the buffer.
461 *
462 * \param node The TIR node that accesses the buffer. Should be
463 * either a BufferLoad or BufferStore node.
464 *
465 * \param touch_type The type of buffer access being performed. A
466 * BufferStore should always use AccessType::Write. A BufferLoad
467 * may use either AccessType::Read or AccessType::Assume, depending
468 * on whether the BufferLoad occurs within `builtin::assume`.
469 *
470 * \param known_value_expr The value in the buffer following the access.
471 */
472 template <typename BufferAccess>
473 void VisitAccess(const BufferAccess& node, BufferTouch::AccessType touch_type,
474 PrimExpr known_value_expr) {
475 auto& current_block = out_->control_flow_.back();
476 BufferTouch buffer_touch = current_block.MakeBufferTouch(out_, node->buffer, node->indices,
477 touch_type, known_value_expr);
478 current_block.touch_points.push_back(buffer_touch);
479 }
480
481 /*! \brief Return a predicate for having reached the current
482 * control-flow block
483 *
484 * For example, while inside an IfThenElse, will return the
485 * IfThenElse's condition.
486 */
487 PrimExpr CurrentScopePredicate() const {
488 PrimExpr predicate = Bool(true);
489 for (const auto& condition : conditions_) {
490 predicate = predicate && condition;
491 }
492 return predicate;
493 }
494
495 /* \brief Add a new control block, returning its index */
496 size_t AppendControlBlock() {
497 size_t index = out_->control_flow_.size();
498 auto& block = out_->control_flow_.emplace_back();
499 block.active_loop_iterators = active_loop_iterators_;
500 block.let_bindings_using_loop = let_bindings_using_loop_;
501 block.scope_predicate = CurrentScopePredicate();
502 return index;
503 }
504
505 /* \brief The index of the current control block */
506 size_t CurrentControlBlock() { return out_->control_flow_.size() - 1; }
507
508 /* \brief Mark a possible control from one block to another
509 *
510 * \param from_block The block from which control leaves
511 *
512 * \param to_block The block to which control enters
513 *
514 * \param var_remap Variable replacements that should be made in
515 * known expression while traversing this edge. For example,
516 * replacing `i` with `i-1` when entering the next loop iteration,
517 * or replacing `i` with `n-1` when concluding a loop.
518 */
519 std::pair<ControlFlowGraph::ControlFlowEdge&, ControlFlowGraph::ControlFlowEdge&> MarkControlFlow(
520 size_t from_block, size_t to_block) {
521 ICHECK_LE(from_block, out_->control_flow_.size());
522 ICHECK_LE(to_block, out_->control_flow_.size());
523
524 auto& forward = out_->control_flow_[from_block].successors.emplace_back(
525 ControlFlowGraph::ControlFlowEdge{to_block, {}, NullOpt});
526 auto& backward = out_->control_flow_[to_block].predecessors.emplace_back(
527 ControlFlowGraph::ControlFlowEdge{from_block, {}, NullOpt});
528 return {forward, backward};
529 }
530
531 // Internal utility, context manager for entering/leaving a scoped constraint
532 struct InternalConstraintContext {
533 InternalConstraintContext(ControlFlowGraphBuilder* self, PrimExpr constraint)
534 : self(self), analyzer_context(&self->analyzer_, constraint) {
535 old_num_constraints = self->conditions_.size();
536
537 auto side_effect = tir::SideEffect(constraint);
538 if (side_effect <= tir::CallEffectKind::kPure) {
539 self->conditions_.push_back(constraint);
540 } else if (side_effect <= tir::CallEffectKind::kReadState) {
541 assume = constraint;
542 }
543
544 new_num_constraints = self->conditions_.size();
545 }
546 ~InternalConstraintContext() {
547 ICHECK_EQ(self->conditions_.size(), new_num_constraints)
548 << "Internal error: Each condition should only be popped once.";
549 self->conditions_.erase(self->conditions_.begin() + old_num_constraints,
550 self->conditions_.end());
551 }
552
553 ControlFlowGraphBuilder* self{nullptr};
554 With<ConstraintContext> analyzer_context;
555 size_t old_num_constraints{0};
556 size_t new_num_constraints{0};
557 Optional<PrimExpr> assume{NullOpt};
558
559 // Disable default-generated copy/move assignment and constructors
560 InternalConstraintContext(const InternalConstraintContext&) = delete;
561 InternalConstraintContext& operator=(const InternalConstraintContext&) = delete;
562 InternalConstraintContext(InternalConstraintContext&&) = delete;
563 InternalConstraintContext& operator=(InternalConstraintContext&&) = delete;
564 };
565
566 // Internal utility, context manager for tracking a loop
567 struct BindActiveLoopVar {
568 BindActiveLoopVar(ControlFlowGraphBuilder* self, Var var, PrimExpr loop_min,
569 PrimExpr loop_extent)
570 : self(self), var(var) {
571 PrimExpr loop_max = loop_min + (loop_extent - 1);
572 auto loop_range = Range::FromMinExtent(loop_min, loop_extent);
573 self->active_loop_iterators_.push_back({var, loop_min, loop_max, loop_range});
574 self->loop_dependent_vars_.insert(var.get());
575 }
576 ~BindActiveLoopVar() { self->active_loop_iterators_.pop_back(); }
577
578 ControlFlowGraphBuilder* self;
579 Var var;
580
581 // Disable default-generated copy/move assignment and constructors
582 BindActiveLoopVar(const BindActiveLoopVar&) = delete;
583 BindActiveLoopVar& operator=(const BindActiveLoopVar&) = delete;
584 BindActiveLoopVar(BindActiveLoopVar&&) = delete;
585 BindActiveLoopVar& operator=(BindActiveLoopVar&&) = delete;
586 };
587
588 // Internal utility, context manager for tracking a variable binding
589 struct BindLetVar {
590 BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) : self(self), var(var) {
591 self->let_bindings_using_loop_.Set(var, value);
592 self->loop_dependent_vars_.insert(var.get());
593 }
594 ~BindLetVar() {
595 self->loop_dependent_vars_.erase(var.get());
596 self->let_bindings_using_loop_.erase(var);
597 }
598 ControlFlowGraphBuilder* self;
599 Var var;
600
601 // Disable default-generated copy/move assignment and constructors
602 BindLetVar(const BindLetVar&) = delete;
603 BindLetVar& operator=(const BindLetVar&) = delete;
604 BindLetVar(BindLetVar&&) = delete;
605 BindLetVar& operator=(BindLetVar&&) = delete;
606 };
607
608 struct LoopEntry {
609 Var loop_var;
610 PrimExpr loop_min;
611 PrimExpr loop_max;
612 Range loop_range;
613 };
614
615 // Track in order to know which Vars to write in terms of the buffer
616 // indices and substitute out of the predicate.
617 std::vector<ControlFlowGraph::ControlFlowBlock::LoopEntry> active_loop_iterators_;
618
619 // Track all loop iterators, along with values derived from loop iterators.
620 std::unordered_set<const VarNode*> loop_dependent_vars_;
621
622 // Any let binding that depends, directly or indirectly, on a loop
623 // binding. When making a predicate in terms of the buffer indices,
624 // these need to be substituted out.
625 // std::unordered_map<const VarNode*, PrimExpr> let_bindings_using_loop_;
626 Map<Var, PrimExpr> let_bindings_using_loop_;
627
628 // Track in order to know what conditions limit the buffer access
629 std::vector<PrimExpr> conditions_;
630
631 // Track in order to know what statement initiated the buffer access
632 Stmt current_stmt_;
633
634 // Output data structure
635 ControlFlowGraph* out_;
636};
637
638std::pair<BufferTouch, Map<Var, Range>> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(
639 const tir::Buffer& buf, Array<Var> index_variables, Array<PrimExpr> indices,
640 BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const {
641 const auto& current_block = *this;
642
643 Analyzer local_analyzer;
644
645 Optional<Var> lane_var = NullOpt;
646 IntImm num_lanes;
647
648 Array<PrimExpr> index_expressions = indices.Map([&](const auto& index) {
649 if (index.dtype().lanes() == 1) {
650 return index;
651 } else {
652 ICHECK(!lane_var) << "Multiple indices found with non-scalar values";
653 lane_var = Var("lane", index.dtype().element_of());
654 num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes());
655 return UnwrapVectorExpr(index, lane_var.value());
656 }
657 });
658
659 Array<Var> loop_vars;
660
661 Map<Var, Range> loop_ranges;
662 for (const auto& loop_entry : current_block.active_loop_iterators) {
663 loop_vars.push_back(loop_entry.loop_var);
664 loop_ranges.Set(loop_entry.loop_var, loop_entry.loop_range);
665 }
666
667 // If the indices contain multiple lanes, treat the lane variable
668 // as an additional loop iterator to be solved for and substituted
669 // out.
670 if (lane_var) {
671 loop_vars.push_back(lane_var.value());
672 loop_ranges.Set(lane_var.value(), Range::FromMinExtent(0, num_lanes));
673 }
674
675 IntConstraintsTransform transform = [&]() {
676 ICHECK_EQ(index_variables.size(), index_expressions.size());
677
678 Array<PrimExpr> relations;
679
680 for (size_t i = 0; i < index_expressions.size(); i++) {
681 PrimExpr expr = index_expressions[i];
682 Var var = index_variables[i];
683
684 expr = Substitute(expr, current_block.let_bindings_using_loop);
685 relations.push_back(var == expr);
686 }
687
688 IntConstraints system(loop_vars, loop_ranges, relations);
689 return arith::SolveLinearEquations(system);
690 }();
691
692 Map<Var, PrimExpr> loop_var_to_axis_var = transform->src_to_dst;
693 Map<Var, Range> free_params = transform->dst->ranges;
694 PrimExpr transform_predicate =
695 std::accumulate(transform->dst->relations.begin(), transform->dst->relations.end(),
696 PrimExpr(Bool(true)), [](PrimExpr a, PrimExpr b) { return a && b; });
697
698 transform_predicate = SimplifyAsAndOfOrs(transform_predicate, &local_analyzer);
699
700 auto find_removable_params = [&]() -> Map<Var, PrimExpr> {
701 Map<Var, PrimExpr> removable_params;
702
703 // The arith::SolveLinearEquations is more general than the
704 // utilities in iter_affine_map.h, but can introduce free
705 // parameters that could later be determined with the known
706 // constraints. This step removes all such free parameters.
707 for (const auto& expr : ExtractConstraints(transform_predicate)) {
708 if (auto* as_equal = expr.as<EQNode>()) {
709 auto check_expr = [&](const PrimExpr& a, const PrimExpr& b) {
710 auto* var_ptr = a.as<VarNode>();
711 if (!var_ptr) {
712 return;
713 }
714
715 Var var = GetRef<Var>(var_ptr);
716 if (free_params.count(var) == 0) {
717 return;
718 }
719
720 bool uses_free_param =
721 UsesVar(b, [&](const VarNode* v) { return free_params.count(GetRef<Var>(v)) > 0; });
722 if (uses_free_param) {
723 return;
724 }
725 removable_params.Set(var, b);
726 };
727 check_expr(as_equal->a, as_equal->b);
728 check_expr(as_equal->b, as_equal->a);
729 }
730 }
731
732 // In addition, the arith::SolveLinearEquation can introduce
733 // free parameters with an extent of one. Filtering them out here
734 // avoids needing to track them through later simplifications.
735 for (const auto [var, range] : free_params) {
736 if (is_one(range->extent)) {
737 removable_params.Set(var, range->min);
738 }
739 }
740
741 return removable_params;
742 };
743 for (auto removable_params = find_removable_params(); removable_params.size() > 0;
744 removable_params = find_removable_params()) {
745 auto update = [&](const PrimExpr& expr) {
746 return local_analyzer.Simplify(Substitute(expr, removable_params));
747 };
748
749 Map<Var, PrimExpr> new_map;
750 for (const auto [loop_var, expr] : loop_var_to_axis_var) {
751 static_cast<void>(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
752 new_map.Set(loop_var, update(expr));
753 }
754 loop_var_to_axis_var = new_map;
755
756 transform_predicate = update(transform_predicate);
757
758 for (const auto [var, expr] : removable_params) {
759 static_cast<void>(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
760 free_params.erase(var);
761 }
762 }
763
764 // Normalization function, applied to both the predicate and the
765 // known value. Converts from an expression in terms of loop
766 // iterators to an expression in terms of buffer indices.
767 auto normalize_expr = [&](PrimExpr expr) -> PrimExpr {
768 expr = Substitute(expr, current_block.let_bindings_using_loop);
769
770 if (lane_var) {
771 expr = UnwrapVectorExpr(expr, lane_var.value());
772 }
773 expr = Substitute(expr, loop_var_to_axis_var);
774
775 return expr;
776 };
777
778 // Collect the current loop variables, along with an expression for
779 // the loop variables in terms of the buffer axis variables. This
780 // is used during forward/backward propagation to generate predicate
781 // tracking whether a loop iteration has been reached.
782 std::vector<std::pair<Var, PrimExpr>> loop_var_expressions;
783 for (const auto& entry : current_block.active_loop_iterators) {
784 auto expr_it = loop_var_to_axis_var.find(entry.loop_var);
785 ICHECK(expr_it != loop_var_to_axis_var.end());
786 loop_var_expressions.push_back({entry.loop_var, (*expr_it).second});
787 }
788
789 // The full predicate is composed of the values required to reach
790 // the scope of the BufferStore or builtin::assume(), any bounds
791 // implied by solving for the axis variables, and any additional
792 // statements resulting from unpacking the expression contained in
793 // builtin::assume().
794 PrimExpr scope_predicate = normalize_expr(current_block.scope_predicate);
795 transform_predicate = normalize_expr(transform_predicate);
796
797 known_value_expr = local_analyzer.Simplify(normalize_expr(known_value_expr));
798
799 // Deliberately use an analyzer without scope-based information,
800 // to avoid simplifying `scope_predicate` to True.
801 PrimExpr predicate_expr = local_analyzer.Simplify(transform_predicate && scope_predicate);
802
803 BufferTouch buffer_touch = {buf, predicate_expr, known_value_expr, loop_var_expressions,
804 touch_type};
805
806 return {buffer_touch, free_params};
807}
808
809BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph,
810 const tir::Buffer& buf,
811 const Array<PrimExpr>& indices,
812 BufferTouch::AccessType touch_type,
813 PrimExpr known_value_expr) const {
814 ICHECK(graph);
815 auto [buffer_touch, free_params] = MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices),
816 indices, touch_type, known_value_expr);
817 for (const auto& pair : free_params) {
818 graph->free_predicate_parameters_.Set(pair.first, pair.second);
819 }
820 return buffer_touch;
821}
822
823ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
824 : max_revisits_(max_revisits) {
825 ControlFlowGraphBuilder::Build(this, stmt);
826 ForwardPropagateKnownValues();
827 BackwardPropagateUnusedValues();
828}
829
830void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) {
831 size_t context_index = [&]() {
832 auto it = control_flow_lookup_.find(store.get());
833 ICHECK(it != control_flow_lookup_.end())
834 << "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor";
835 return it->second;
836 }();
837
838 auto& touch_points = control_flow_[context_index].touch_points;
839
840 touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(),
841 [](const BufferTouch& touch) {
842 return touch.touch_type == BufferTouch::AccessType::Write;
843 }),
844 touch_points.end());
845 ForwardPropagateKnownValues(context_index);
846 BackwardPropagateUnusedValues(context_index);
847}
848
849std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) {
850 os << edge.index;
851 if (edge.var_remap.size()) {
852 os << " with remap " << edge.var_remap;
853 }
854 if (edge.post_condition) {
855 os << " with postcondition " << edge.post_condition;
856 }
857
858 return os;
859}
860
861std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowBlock& block) {
862 os << "Predecessors: [";
863 for (size_t i = 0; i < block.predecessors.size(); i++) {
864 if (i) {
865 os << ", ";
866 }
867 os << block.predecessors[i];
868 }
869 os << "]\n";
870
871 os << "Active loop iterators: [";
872 for (size_t i = 0; i < block.active_loop_iterators.size(); i++) {
873 if (i) {
874 os << ", ";
875 }
876 os << block.active_loop_iterators[i].loop_var;
877 }
878 os << "]\n";
879
880 os << "Before block knowns: " << block.known_at_block_start << "\n";
881
882 os << "Before block unused: " << block.unused_at_block_start << "\n";
883
884 for (size_t i = 0; i < block.touch_points.size(); i++) {
885 os << "Touch[" << i << "] = " << block.touch_points[i] << "\n";
886 }
887 os << "After block: " << block.known_at_block_end << "\n";
888
889 os << "After block unused: " << block.unused_at_block_end << "\n";
890
891 os << "Successors: [";
892 for (size_t i = 0; i < block.successors.size(); i++) {
893 if (i) {
894 os << ", ";
895 }
896 os << block.successors[i];
897 }
898 os << "]";
899 return os;
900}
901
902std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern) {
903 os << "Touch pattern contains " << pattern.control_flow_.size() << " control blocks."
904 << (pattern.control_flow_.size() ? "\n" : "");
905 for (size_t i = 0; i < pattern.control_flow_.size(); i++) {
906 os << "\t"
907 << "ControlBlock[" << i << "] = " << pattern.control_flow_[i] << "\n";
908 }
909
910 return os;
911}
912
913bool BufferTouch::IsEquivalentTo(const BufferTouch& other, Analyzer* analyzer) const {
914 // Constraints must apply to the same buffer to be equivalent
915 if (!buffer.same_as(other.buffer) || touch_type != other.touch_type) {
916 return false;
917 }
918
919 ExprDeepEqual deep_equal;
920
921 auto implies = [&](const PrimExpr& a, const PrimExpr& b) -> bool {
922 With<ConstraintContext> context(analyzer, a);
923 return analyzer->CanProve(b);
924 };
925
926 // Predicates must be equivalent expressions, or must both be undefined
927 bool equivalent_predicates =
928 deep_equal(predicate, other.predicate) ||
929 (implies(predicate, other.predicate) && implies(other.predicate, predicate));
930 if (!equivalent_predicates) {
931 return false;
932 }
933
934 // The known value must be equal
935 if (!deep_equal(value, other.value) && !analyzer->CanProveEqual(value, other.value)) {
936 return false;
937 }
938
939 return true;
940}
941
942std::ostream& operator<<(std::ostream& os, const BufferState& state) {
943 for (size_t i = 0; i < state.constraints_.size(); i++) {
944 os << "constraints[" << i << "] = " << state.constraints_[i]
945 << (i + 1 == state.constraints_.size() ? "" : "\n");
946 }
947 return os;
948}
949
950PrimExpr BufferState::SubstituteKnownBufferValues(
951 PrimExpr expr, const Map<tir::Buffer, Array<tir::Var>>& axis_var_lookup,
952 Analyzer* analyzer) const {
953 BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer);
954 return mutator(std::move(expr));
955}
956
957void BufferState::AddCondition(const PrimExpr& condition) {
958 for (auto& constraint : constraints_) {
959 constraint.predicate = constraint.predicate && condition;
960 }
961}
962
963void BufferState::Substitute(const Map<Var, PrimExpr>& var_remap, Analyzer* analyzer) {
964 if (var_remap.size()) {
965 for (auto& prior : constraints_) {
966 PrimExpr updated = tvm::tir::Substitute(prior.predicate, var_remap);
967 if (!updated.same_as(prior.predicate)) {
968 prior.predicate = SimplifyAsAndOfOrs(updated, analyzer);
969 }
970 }
971 }
972}
973
974void BufferState::Simplify(Analyzer* analyzer) {
975 for (auto& constraint : constraints_) {
976 constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate, analyzer);
977 }
978}
979
980void BufferState::Union(const BufferState& b, Analyzer* analyzer) {
981 for (const auto& b_constraint : b.constraints_) {
982 bool used = false;
983 for (auto& a_constraint : constraints_) {
984 if (a_constraint.buffer.same_as(b_constraint.buffer) &&
985 analyzer->CanProveEqual(a_constraint.value, b_constraint.value)) {
986 a_constraint.predicate =
987 SimplifyAsAndOfOrs(a_constraint.predicate || b_constraint.predicate, analyzer);
988 used = true;
989 break;
990 }
991 }
992 if (!used) {
993 constraints_.push_back(b_constraint);
994 }
995 }
996}
997
998void BufferState::Intersection(const BufferState& b, Analyzer* analyzer) {
999 // For a constraint to be in the output, it must be present in both
1000 // inputs.
1001
1002 std::vector<BufferTouch> new_constraints;
1003 for (const auto& ai : constraints_) {
1004 for (const auto& bi : b.constraints_) {
1005 if (ai.buffer.same_as(bi.buffer)) {
1006 PrimExpr predicate = SimplifyAsAndOfOrs(ai.predicate && bi.predicate, analyzer);
1007 if (!is_zero(predicate)) {
1008 With<ConstraintContext> context(analyzer, predicate);
1009 PrimExpr known_value_a = ai.value;
1010 PrimExpr known_value_b = bi.value;
1011
1012 bool is_consistent = analyzer->CanProveEqual(known_value_a, known_value_b);
1013 if (is_consistent) {
1014 new_constraints.push_back({ai.buffer, predicate, known_value_a});
1015 }
1016 }
1017 }
1018 }
1019 }
1020
1021 constraints_ = std::move(new_constraints);
1022}
1023
1024class BufferRegionCollector : public ExprVisitor {
1025 public:
1026 struct Region {
1027 PrimExpr region_predicate;
1028 std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>> known_values;
1029 };
1030
1031 static std::vector<Region> Collect(const Map<Buffer, Array<Var>>& axis_var_lookup,
1032 const std::vector<BufferTouch>& knowns,
1033 const std::vector<Optional<PrimExpr>>& exprs,
1034 Analyzer* analyzer) {
1035 BufferRegionCollector collector(axis_var_lookup, knowns, analyzer);
1036 for (const auto& expr : exprs) {
1037 if (expr) {
1038 collector(expr.value());
1039 }
1040 }
1041
1042 return collector.regions_;
1043 }
1044
1045 private:
1046 using Parent = ExprVisitor;
1047
1048 BufferRegionCollector(const Map<Buffer, Array<Var>>& axis_var_lookup,
1049 const std::vector<BufferTouch>& knowns, Analyzer* analyzer)
1050 : analyzer_(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {
1051 regions_.push_back(Region{Bool(true), {}});
1052 }
1053
1054 using Parent::VisitExpr_;
1055
1056 void VisitExpr_(const BufferLoadNode* op) override {
1057 // Helper struct for the known values of this BufferLoad
1058 struct Known {
1059 PrimExpr predicate;
1060 Optional<PrimExpr> value;
1061 };
1062
1063 std::vector<Known> new_regions;
1064
1065 PrimExpr unknown_region = Bool(true);
1066
1067 for (const BufferTouch& constraint : knowns_) {
1068 if (!op->buffer.same_as(constraint.buffer)) {
1069 // This is a different buffer, so continue searching.
1070 continue;
1071 }
1072
1073 auto axis_vars = axis_var_lookup_.at(op->buffer);
1074 PrimExpr touch_predicate =
1075 SubstituteParamValues(axis_vars, op->indices, constraint.predicate).value();
1076 touch_predicate = SimplifyAsAndOfOrs(touch_predicate, analyzer_);
1077
1078 if (!is_zero(touch_predicate)) {
1079 Optional<PrimExpr> known_value =
1080 SubstituteParamValues(axis_vars, op->indices, constraint.value);
1081 new_regions.push_back(Known{touch_predicate, known_value});
1082
1083 unknown_region = unknown_region && !touch_predicate;
1084 unknown_region = SimplifyAsAndOfOrs(unknown_region, analyzer_);
1085 }
1086 }
1087
1088 if (new_regions.size()) {
1089 Analyzer local_analyzer;
1090
1091 if (!is_zero(unknown_region)) {
1092 new_regions.insert(new_regions.begin(), Known{unknown_region, NullOpt});
1093 }
1094
1095 std::vector<Region> updated_regions;
1096 for (const auto& prev_region : regions_) {
1097 for (const auto& new_region : new_regions) {
1098 PrimExpr intersection =
1099 SimplifyAsAndOfOrs(prev_region.region_predicate && new_region.predicate, analyzer_);
1100
1101 if (!is_zero(intersection)) {
1102 Region merged{intersection, prev_region.known_values};
1103 merged.known_values[op] = new_region.value;
1104 updated_regions.push_back(std::move(merged));
1105 }
1106 }
1107 }
1108 regions_ = updated_regions;
1109 }
1110 }
1111
1112 Analyzer* analyzer_;
1113 std::vector<Region> regions_;
1114 const Map<Buffer, Array<Var>>& axis_var_lookup_;
1115 const std::vector<BufferTouch>& knowns_;
1116};
1117
1118class BufferRegionValueReplacer : public IRMutatorWithAnalyzer {
1119 public:
1120 static PrimExpr Apply(
1121 const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values,
1122 PrimExpr expr, Analyzer* analyzer) {
1123 BufferRegionValueReplacer mutator(known_values, analyzer);
1124 PrimExpr result = mutator(expr);
1125 // Simplification must occur after the substitution, as known
1126 // values may provide enable simplifications. Also, cannot track
1127 // whether a BufferLoad was
1128 result = analyzer->Simplify(result);
1129 return result;
1130 }
1131
1132 private:
1133 using Parent = IRMutatorWithAnalyzer;
1134
1135 BufferRegionValueReplacer(
1136 const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values,
1137 Analyzer* analyzer)
1138 : Parent(analyzer), known_values_(known_values) {}
1139
1140 using Parent::VisitExpr_;
1141
1142 PrimExpr VisitExpr_(const BufferLoadNode* op) override {
1143 auto it = known_values_.find(op);
1144 if (it != known_values_.end() && it->second) {
1145 return it->second.value();
1146 } else {
1147 return GetRef<PrimExpr>(op);
1148 }
1149 }
1150
1151 const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values_;
1152};
1153
1154void BufferState::ApplyTouches(const Map<Buffer, Array<Var>>& axis_var_lookup,
1155 const std::vector<BufferTouch>& touch_points, Analyzer* analyzer) {
1156 std::vector<BufferTouch> new_knowns;
1157 Map<Buffer, PrimExpr> keep_prior_known_at;
1158
1159 for (auto& touch : touch_points) {
1160 if (touch.touch_type == BufferTouch::AccessType::Read) {
1161 continue;
1162 }
1163
1164 PrimExpr known_value = touch.value;
1165
1166 PrimExpr predicate = touch.predicate && touch.AfterLoopIteration();
1167 auto regions = BufferRegionCollector::Collect(axis_var_lookup, constraints_,
1168 {predicate, touch.value}, analyzer);
1169
1170 for (const auto& region : regions) {
1171 PrimExpr updated_predicate = BufferRegionValueReplacer::Apply(
1172 region.known_values, region.region_predicate && predicate, analyzer);
1173
1174 updated_predicate = SimplifyAsAndOfOrs(updated_predicate, analyzer);
1175 PrimExpr updated_value =
1176 BufferRegionValueReplacer::Apply(region.known_values, known_value, analyzer);
1177
1178 if (!is_zero(updated_predicate)) {
1179 if (auto it = keep_prior_known_at.find(touch.buffer); it != keep_prior_known_at.end()) {
1180 keep_prior_known_at.Set(touch.buffer, (*it).second && !updated_predicate);
1181 } else {
1182 keep_prior_known_at.Set(touch.buffer, !updated_predicate);
1183 }
1184
1185 if (!HasBufferLoad(updated_value)) {
1186 BufferTouch new_constraint{touch.buffer, updated_predicate, updated_value};
1187 new_knowns.push_back(new_constraint);
1188 }
1189 }
1190 }
1191 }
1192
1193 if (keep_prior_known_at.size()) {
1194 for (auto& constraint : constraints_) {
1195 if (auto it = keep_prior_known_at.find(constraint.buffer); it != keep_prior_known_at.end()) {
1196 constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate && (*it).second, analyzer);
1197 }
1198 }
1199 }
1200
1201 if (new_knowns.size()) {
1202 std::vector<bool> used(new_knowns.size(), false);
1203
1204 for (auto& constraint : constraints_) {
1205 PrimExpr expand_known_at = Bool(false);
1206
1207 PrimExpr prev_value = constraint.value;
1208
1209 for (size_t i = 0; i < new_knowns.size(); i++) {
1210 if (new_knowns[i].buffer.same_as(constraint.buffer)) {
1211 Optional<PrimExpr> overwritten_with = new_knowns[i].value;
1212 if (overwritten_with && analyzer->CanProveEqual(prev_value, overwritten_with.value())) {
1213 expand_known_at =
1214 SimplifyAsAndOfOrs(expand_known_at || new_knowns[i].predicate, analyzer);
1215 used[i] = true;
1216 }
1217 }
1218 }
1219
1220 if (!is_zero(expand_known_at)) {
1221 constraint.predicate =
1222 SimplifyAsAndOfOrs(constraint.predicate || expand_known_at, analyzer);
1223 }
1224 }
1225
1226 for (size_t i = 0; i < new_knowns.size(); i++) {
1227 if (!used[i]) {
1228 constraints_.push_back(new_knowns[i]);
1229 }
1230 }
1231 }
1232
1233 constraints_.erase(
1234 std::remove_if(constraints_.begin(), constraints_.end(),
1235 [&](const auto& constraint) { return is_zero(constraint.predicate); }),
1236 constraints_.end());
1237}
1238
1239void BufferState::BackpropUnusedIndices(const Map<Buffer, Array<Var>>& axis_var_lookup,
1240 const std::vector<BufferTouch>& touch_points,
1241 Analyzer* analyzer) {
1242 std::vector<BufferTouch> new_knowns;
1243 Map<Buffer, PrimExpr> keep_prior_known_at;
1244
1245 Map<Buffer, PrimExpr> regions_written;
1246 Map<Buffer, PrimExpr> regions_read;
1247 for (auto it = touch_points.rbegin(); it != touch_points.rend(); it++) {
1248 const auto& touch = *it;
1249
1250 Map<Buffer, PrimExpr>* to_update{nullptr};
1251 if (touch.touch_type == BufferTouch::AccessType::Write) {
1252 to_update = &regions_written;
1253
1254 } else if (touch.touch_type == BufferTouch::AccessType::Read) {
1255 to_update = &regions_read;
1256 } else {
1257 continue;
1258 }
1259
1260 PrimExpr prev = to_update->Get(touch.buffer).value_or(Bool(false));
1261 PrimExpr new_predicate = touch.predicate && touch.BeforeLoopIteration();
1262 to_update->Set(touch.buffer, prev || new_predicate);
1263 }
1264
1265 auto update_map = [&](auto& map) {
1266 Map<Buffer, PrimExpr> new_map;
1267 for (auto [buffer, predicate] : map) {
1268 new_map.Set(buffer, SimplifyAsAndOfOrs(predicate, analyzer));
1269 }
1270 map = std::move(new_map);
1271 };
1272 update_map(regions_written);
1273 update_map(regions_read);
1274
1275 // If buffer is already in used, widen the predicate
1276 for (auto& prev_unused : constraints_) {
1277 if (auto opt_predicate = regions_written.Get(prev_unused.buffer)) {
1278 PrimExpr new_predicate = prev_unused.predicate || opt_predicate.value();
1279 prev_unused.predicate = SimplifyAsAndOfOrs(new_predicate, analyzer);
1280 regions_written.erase(prev_unused.buffer);
1281 }
1282 }
1283
1284 // Otherwise, add new "touch" to represent the unused values
1285 for (auto [buffer, predicate] : regions_written) {
1286 constraints_.push_back(
1287 BufferTouch{buffer, predicate, tir::Call(buffer->dtype, builtin::undef(), {})});
1288 }
1289
1290 // If buffer is read out, narrow the predicate
1291 for (auto& prev_unused : constraints_) {
1292 if (auto opt_pred = regions_read.Get(prev_unused.buffer)) {
1293 PrimExpr predicate = opt_pred.value();
1294 prev_unused.predicate = SimplifyAsAndOfOrs(prev_unused.predicate && !predicate, analyzer);
1295 }
1296 }
1297
1298 // Clean-up and remove any empty constraints
1299 constraints_.erase(
1300 std::remove_if(constraints_.begin(), constraints_.end(),
1301 [](const auto& constraint) { return is_zero(constraint.predicate); }),
1302 constraints_.end());
1303}
1304
1305void BufferState::RemoveFreeParameters(const Map<Var, Range>& free_predicate_parameters,
1306 Analyzer* analyzer) {
1307 for (auto& known : constraints_) {
1308 known.predicate = NarrowPredicateExpression(known.predicate, free_predicate_parameters);
1309 known.predicate = SimplifyAsAndOfOrs(known.predicate, analyzer);
1310 }
1311}
1312
1313bool BufferState::IsEquivalentTo(const BufferState& other, Analyzer* analyzer) const {
1314 if (constraints_.size() != other.constraints_.size()) {
1315 return false;
1316 }
1317
1318 for (size_t i = 0; i < constraints_.size(); i++) {
1319 if (!constraints_[i].IsEquivalentTo(other.constraints_[i], analyzer)) {
1320 return false;
1321 }
1322 }
1323
1324 return true;
1325}
1326
1327Optional<Array<Var>> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const {
1328 if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) {
1329 return (*it).second;
1330 } else {
1331 return NullOpt;
1332 }
1333}
1334
1335Array<Var> ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array<PrimExpr>& indices) {
1336 if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) {
1337 return (*it).second;
1338 }
1339
1340 Array<Var> vars;
1341 for (size_t i = 0; i < indices.size(); i++) {
1342 std::stringstream ss;
1343 ss << buf->name << "_axis_" << i;
1344 vars.push_back(Var(ss.str(), indices[i].dtype().element_of()));
1345 }
1346
1347 axis_var_lookup_.Set(buf, vars);
1348 return vars;
1349}
1350
1351void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_from) {
1352 // Values to visit when searching. Using a std::set to
1353 // preferentially visit nodes near the start of the control flow.
1354 std::set<size_t> to_visit;
1355
1356 if (flow_from.has_value()) {
1357 to_visit.insert(flow_from.value());
1358 } else {
1359 // Initiatize the locations to search from, propagating values
1360 // forward from all locations that have a known value.
1361 for (size_t i = 0; i < control_flow_.size(); i++) {
1362 bool has_known_value = false;
1363 for (const auto& touch : control_flow_[i].touch_points) {
1364 if (!HasBufferLoad(touch.value)) {
1365 has_known_value = true;
1366 break;
1367 }
1368 }
1369
1370 if (has_known_value) {
1371 to_visit.insert(i);
1372 }
1373 }
1374 }
1375
1376 // Map from a block's index
1377 std::unordered_map<size_t, size_t> visit_count_lookup;
1378
1379 Analyzer analyzer;
1380 analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
1381 arith::RewriteSimplifier::kTransitivelyProveInequalities |
1382 arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
1383 arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
1384
1385 analyzer.Bind(iterator_ranges_);
1386 analyzer.Bind(free_predicate_parameters_);
1387
1388 while (to_visit.size()) {
1389 size_t visiting = *to_visit.begin();
1390 to_visit.erase(visiting);
1391
1392 size_t num_previous_visits = visit_count_lookup[visiting]++;
1393
1394 ControlFlowBlock& block = control_flow_[visiting];
1395
1396 // Step 1: Collect known values provided from each predecessor
1397 block.known_at_block_start = [&]() -> BufferState {
1398 if (num_previous_visits >= max_revisits_) {
1399 return BufferState();
1400 }
1401
1402 // Validate internal constraint. This should be true by
1403 // construction, as ControlFlowGraphBuilder only builds graphs
1404 // that have two or fewer predecessors.
1405 ICHECK_LE(block.predecessors.size(), 2)
1406 << "InternalError: Each block should have at most two predecessors. "
1407 << "Graph constructed in ControlFlowGraphBuilder did not satisfy this constraint.";
1408
1409 std::vector<BufferState> states;
1410 for (const auto& pred : block.predecessors) {
1411 const auto& pred_block = control_flow_[pred.index];
1412 BufferState state = pred_block.known_at_block_end;
1413 state.Substitute(pred.var_remap, &analyzer);
1414 states.push_back(state);
1415 }
1416
1417 if (std::all_of(block.predecessors.begin(), block.predecessors.end(),
1418 [&](const auto& pred) { return visit_count_lookup[pred.index] == 0; })) {
1419 // Predecessors, if any, are unvisited.
1420 return {};
1421 } else if (block.predecessors.size() == 1) {
1422 // Block has only a single predecessor
1423 return states[0];
1424 }
1425
1426 const auto& pred_a = block.predecessors[0];
1427 const auto& pred_b = block.predecessors[1];
1428
1429 auto& priors_a = states[0];
1430 auto& priors_b = states[1];
1431
1432 // During the first visit of a block, predecessor blocks may be
1433 // unvisited, even though we preferentially visit earlier blocks
1434 // first. (e.g. During the first visit of the start of a For
1435 // loop, the end of the For loop has not yet been visited.) If
1436 // this is the case, assume the best-case scenario that all
1437 // knowns are consistent, and rely on a later visit to
1438 // resolve/remove any conflicts.
1439 if (visit_count_lookup[pred_a.index] == 0) {
1440 return priors_b;
1441 } else if (visit_count_lookup[pred_b.index] == 0) {
1442 return priors_a;
1443 }
1444
1445 if (pred_a.post_condition && pred_b.post_condition) {
1446 // The predicate can identify which predecessor block applies
1447 // (e.g. i==0 for the first loop iteration, i>0 for remaining
1448 // loop iterations). Therefore, we can use all buffer
1449 // constraints, conditional on having come from the
1450 // predecessor that provides it.
1451 priors_a.AddCondition(pred_a.post_condition.value());
1452 priors_b.AddCondition(pred_b.post_condition.value());
1453 priors_a.Union(priors_b, &analyzer);
1454 return priors_a;
1455 } else {
1456 // We don't know which predecessor applies. Therefore, the
1457 // only buffer constraints that can be used are those that
1458 // appear in both predecessors.
1459 priors_a.Intersection(priors_b, &analyzer);
1460 return priors_a;
1461 }
1462 }();
1463
1464 // Step 2: Collect knowns provided as a result of executing this block
1465 auto post_state = [&]() {
1466 if (num_previous_visits >= max_revisits_) {
1467 return BufferState();
1468 }
1469 auto post_state = block.known_at_block_start;
1470 post_state.ApplyTouches(axis_var_lookup_, block.touch_points, &analyzer);
1471 post_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer);
1472 return post_state;
1473 }();
1474
1475 // Step 3: If any changes are made to the post knowns since the
1476 // previous time we visited this block, mark the successor block
1477 // as needing to be visited.
1478 if (num_previous_visits == 0 ||
1479 !post_state.IsEquivalentTo(block.known_at_block_end, &analyzer)) {
1480 block.known_at_block_end = std::move(post_state);
1481 for (const auto& successor : block.successors) {
1482 to_visit.insert(successor.index);
1483 }
1484 }
1485 }
1486}
1487
1488void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_from) {
1489 // Values to visit when searching. Using a std::set to
1490 // preferentially visit nodes near the end of the control flow.
1491 std::set<size_t> to_visit;
1492
1493 if (flow_from.has_value()) {
1494 to_visit.insert(flow_from.value());
1495 } else {
1496 // Initiatize the locations to search from, propagating values
1497 // backward from anywhere that performs a write.
1498 for (size_t i = 0; i < control_flow_.size(); i++) {
1499 const auto& touch_points = control_flow_[i].touch_points;
1500 bool performs_write = std::any_of(
1501 touch_points.begin(), touch_points.end(),
1502 [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
1503 if (performs_write) {
1504 to_visit.insert(i);
1505 }
1506 }
1507 }
1508
1509 // Map from a block's index
1510 std::unordered_map<size_t, size_t> visit_count_lookup;
1511
1512 Analyzer analyzer;
1513 analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
1514 arith::RewriteSimplifier::kTransitivelyProveInequalities |
1515 arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
1516 arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
1517
1518 analyzer.Bind(iterator_ranges_);
1519 analyzer.Bind(free_predicate_parameters_);
1520
1521 while (to_visit.size()) {
1522 size_t visiting = *to_visit.rbegin();
1523 to_visit.erase(visiting);
1524
1525 size_t num_previous_visits = visit_count_lookup[visiting]++;
1526
1527 ControlFlowBlock& block = control_flow_[visiting];
1528
1529 // Step 1: Collect known unused indices provided by each successor
1530 block.unused_at_block_end = [&]() -> BufferState {
1531 if (num_previous_visits >= max_revisits_) {
1532 return BufferState();
1533 }
1534 ICHECK_LE(block.successors.size(), 2)
1535 << "Each block should have at most two successors, but block " << visiting
1536 << " breaks this requirement";
1537
1538 std::vector<BufferState> states;
1539 for (const auto& successor : block.successors) {
1540 const auto& successor_block = control_flow_[successor.index];
1541 BufferState state = successor_block.unused_at_block_start;
1542 state.Substitute(successor.var_remap, &analyzer);
1543 states.push_back(state);
1544 }
1545
1546 if (std::all_of(block.successors.begin(), block.successors.end(), [&](const auto& successor) {
1547 return visit_count_lookup[successor.index] == 0;
1548 })) {
1549 // Successors, if any, are unvisited.
1550 return {};
1551 } else if (block.successors.size() == 1) {
1552 // Block has only a single successor
1553 return states[0];
1554 }
1555
1556 const auto& successor_a = block.successors[0];
1557 const auto& successor_b = block.successors[1];
1558
1559 auto& post_a = states[0];
1560 auto& post_b = states[1];
1561
1562 // During the first visit of a block, successor blocks may be
1563 // unvisited, even though we preferentially visit later blocks
1564 // first. (e.g. During the first visit of the end of a For
1565 // loop, the start of the For loop has not yet been visited.)
1566 // If this is the case, assume the best-case scenario that all
1567 // knowns are consistent, and rely on a later visit to
1568 // resolve/remove any conflicts.
1569 if (visit_count_lookup[successor_a.index] == 0) {
1570 return post_b;
1571 } else if (visit_count_lookup[successor_b.index] == 0) {
1572 return post_a;
1573 }
1574
1575 if (successor_a.post_condition && successor_b.post_condition) {
1576 // The predicate can identify which successor block applies
1577 // (e.g. i==n-1 for the last loop iteration, i<n-1 for earlier
1578 // loop iterations). Therefore, we can use all buffer
1579 // constraints, conditional on having come from the
1580 // successor that provides it.
1581 post_a.AddCondition(successor_a.post_condition.value());
1582 post_b.AddCondition(successor_b.post_condition.value());
1583 post_a.Union(post_b, &analyzer);
1584 return post_a;
1585 } else {
1586 // We don't know which successor applies. Therefore, the
1587 // only buffer constraints that can be used are those that
1588 // appear in both successors.
1589 post_a.Intersection(post_b, &analyzer);
1590 return post_a;
1591 }
1592 }();
1593
1594 // Step 2: Collect knowns provided as a result of executing this block
1595 auto unused_at_block_start = [&]() {
1596 if (num_previous_visits >= max_revisits_) {
1597 return BufferState();
1598 }
1599 auto prior_state = block.unused_at_block_end;
1600 prior_state.BackpropUnusedIndices(axis_var_lookup_, block.touch_points, &analyzer);
1601 prior_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer);
1602 return prior_state;
1603 }();
1604
1605 // Step 3: If any changes are made to the post knowns since the
1606 // previous time we visited this block, mark the successor block
1607 // as needing to be visited.
1608 if (num_previous_visits == 0 ||
1609 !unused_at_block_start.IsEquivalentTo(block.unused_at_block_start, &analyzer)) {
1610 block.unused_at_block_start = std::move(unused_at_block_start);
1611 for (const auto& pred : block.predecessors) {
1612 to_visit.insert(pred.index);
1613 }
1614 }
1615 }
1616}
1617
1618bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store,
1619 const Stmt& context) const {
1620 Optional<Array<Var>> index_variables = GetIndexVariables(store->buffer);
1621 if (!index_variables) {
1622 return false;
1623 }
1624
1625 auto it = control_flow_lookup_.find(context.get());
1626 ICHECK(it != control_flow_lookup_.end()) << "Context did not occur within analyzed statement:\n"
1627 << context;
1628 const auto& context_block = control_flow_[it->second];
1629
1630 auto [store_touch, free_params] = context_block.MakeBufferTouch(
1631 store->buffer, index_variables.value(), store->indices, BufferTouch::AccessType::Write,
1632 BufferLoad(store->buffer, store->indices));
1633
1634 Analyzer local_analyzer;
1635 local_analyzer.Bind(free_predicate_parameters_);
1636 local_analyzer.Bind(iterator_ranges_);
1637 local_analyzer.Bind(free_params);
1638 local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
1639 arith::RewriteSimplifier::kTransitivelyProveInequalities |
1640 arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
1641 arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
1642
1643 PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration();
1644
1645 predicate = SimplifyAsAndOfOrs(predicate, &local_analyzer);
1646
1647 for (const auto& unused : context_block.unused_at_block_end.constraints_) {
1648 if (store_touch.buffer.same_as(unused.buffer)) {
1649 PrimExpr difference = SimplifyAsAndOfOrs(predicate && !unused.predicate, &local_analyzer);
1650 if (is_zero(difference)) {
1651 return true;
1652 }
1653 }
1654 }
1655 return false;
1656}
1657
1658PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& context,
1659 Analyzer* analyzer) const {
1660 size_t context_index = [&]() {
1661 auto it = control_flow_lookup_.find(context.get());
1662 ICHECK(it != control_flow_lookup_.end())
1663 << "Context did not occur in the Stmt provided to BufferTouchPattern's constructor";
1664 return it->second;
1665 }();
1666
1667 const auto& control_flow_block = control_flow_[context_index];
1668
1669 PrimExpr constraint = Bool(true);
1670 for (const auto& known : non_buffer_assumptions_) {
1671 constraint = constraint && known;
1672 }
1673 With<ConstraintContext> constraint_context(analyzer, constraint);
1674 With<ConstraintContext> control_flow_scope(analyzer, control_flow_block.scope_predicate);
1675
1676 expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues(
1677 std::move(expr), axis_var_lookup_, analyzer);
1678
1679 expr = analyzer->Simplify(std::move(expr));
1680 return expr;
1681}
1682
1683} // namespace tir
1684} // namespace tvm
1685