1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file control_flow_graph.h
22 * \brief Utility for extracting and interacting with buffer touch points
23 */
24
25#include <tvm/arith/analyzer.h>
26#include <tvm/arith/int_solver.h>
27#include <tvm/runtime/container/array.h>
28#include <tvm/tir/buffer.h>
29#include <tvm/tir/stmt.h>
30#include <tvm/tir/var.h>
31
32#include <optional>
33#include <unordered_map>
34#include <utility>
35#include <vector>
36
37#ifndef TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_
38#define TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_
39
40namespace tvm {
41namespace tir {
42
43/*! \brief Represents an interaction with a buffer */
44struct BufferTouch {
45 enum class AccessType {
46 /*! \brief Buffer access occurs in BufferLoad */
47 Read,
48
49 /*! \brief Buffer access occurs in BufferStore */
50 Write,
51
52 /*! \brief Buffer access occurs in tir::builtin::assume() */
53 Assume,
54 };
55
56 BufferTouch(Buffer buffer, PrimExpr predicate, PrimExpr value)
57 : buffer(buffer),
58 predicate(predicate),
59 value(value),
60 loop_var_expressions({}),
61 touch_type(AccessType::Assume) {}
62
63 BufferTouch(Buffer buffer, PrimExpr predicate, PrimExpr value,
64 std::vector<std::pair<Var, PrimExpr>> loop_var_expressions, AccessType touch_type)
65 : buffer(buffer),
66 predicate(predicate),
67 value(value),
68 loop_var_expressions(loop_var_expressions),
69 touch_type(touch_type) {}
70
71 /*! \brief The buffer being touched */
72 Buffer buffer;
73
74 /*! \brief A predicate that is true when this touch applies
75 *
76 * May be in terms of axis variables to indicate touches that impact
77 * only a portion of a buffer.
78 */
79 PrimExpr predicate;
80
81 /*! \brief The value in this buffer after the touch
82 *
83 * May be in terms of axis variables to indicate a known
84 * non-constant value. May be in terms of a BufferLoad to indicate
85 * an unknown value.
86 */
87 PrimExpr value;
88
89 /*! \brief Active loops during the buffer touch
90 *
91 * The vector contains one entry for each loop that contains the
92 * buffer touch. The `Var` item in each entry is the loop variable
93 * itself. The `PrimExpr` item is an expression for the loop
94 * variable in terms of the buffer axis variables in
95 * `ControlFlowGraph::axis_var_lookup_`.
96 *
97 * Used to construct boolean expressions indicating whether the loop
98 * iteration that performs this touch has been reached.
99 */
100 std::vector<std::pair<Var, PrimExpr>> loop_var_expressions;
101
102 /*! \brief How the buffer was interacted with
103 *
104 * When used as a constraint (e.g. in BufferState), should use
105 * Assume.
106 */
107 AccessType touch_type{AccessType::Assume};
108
109 /*! \brief Generate a boolean expression that is true for indices
110 * accessed by this touch during this iteration or a previous
111 * loop iteration.
112 *
113 * Used during forward propagation, to track known values that were
114 * written in the current loop iteration, or in a preceding loop
115 * iteration.
116 */
117 PrimExpr BeforeLoopIteration() const;
118
119 /*! \brief Generate a boolean expression that is true for indices
120 * accessed by this touch during this loop iteration.
121 *
122 * Used during speculative no-op insertion checks, to specify which
123 * indices must be later overwritten for a store to have no impact
124 * on final results.
125 */
126 PrimExpr AtLoopIteration() const;
127
128 /*! \brief Generate a boolean expression that is true for indices
129 * accessed by this touch during this loop iteration or a
130 * subsequent loop iteration.
131 *
132 * Used during backward propagation, to track indices that that are
133 * overwritten in the current loop iteration or in a later loop
134 * iteration.
135 */
136 PrimExpr AfterLoopIteration() const;
137
138 /* \brief Checks if this touch affects a subset of indices of another
139 *
140 * Returns true if the indices accessed by this touch are a subset
141 * of predicate is true can be proven to be a subset of the other
142 * subset. Returns false if it cannot be proven to be a subset of
143 * ther other subset.
144 */
145 bool IsSubsetOf(const BufferTouch& other, arith::Analyzer* analyzer) const;
146
147 /* \brief Checks if this touch affects distinct indices from another
148 *
149 * Returns true if it can be proven that the two predicates cannot
150 * be simultaneously true. Returns false if it cannot be proven
151 * that the two predicates are distinct.
152 */
153 bool IsDistinctFrom(const BufferTouch& other, arith::Analyzer* analyzer) const;
154
155 /* \brief Checks if this touch affects distinct indices from another
156 *
157 * Returns true if it can be proven that the two predicates cannot
158 * be simultaneously true. Returns false if it cannot be proven
159 * that the two predicates are distinct.
160 */
161 bool IsEquivalentTo(const BufferTouch& other, arith::Analyzer* analyzer) const;
162
163 friend std::ostream& operator<<(std::ostream& os, const BufferTouch& expr);
164};
165
166/*! \brief Represents the known state of buffers at a specific point */
167class BufferState {
168 public:
169 /*! Default constructor
170 *
171 * Initialize the buffer state with no known information.
172 */
173 BufferState() {}
174
175 /*! \brief Replace BufferLoad instances with known values
176 *
177 * \param expr The expression to be updated.
178 *
179 * \param axis_var_lookup A map from buffer to the variables
180 * representing positions along the buffer's axes.
181 *
182 * \param analyzer The analyzer to use when validating a
183 * constraint's predicate.
184 *
185 * \returns The modified expression. If no substitutions are made,
186 * the original expression is returned.
187 */
188 PrimExpr SubstituteKnownBufferValues(PrimExpr expr,
189 const Map<Buffer, Array<Var>>& axis_var_lookup,
190 arith::Analyzer* analyzer) const;
191
192 /*! \brief Apply a condition to all known constraints
193 *
194 * For example, when propagating pre-loop constraints into the body
195 * of a loop, add a condition that the loop iterator is zero.
196 *
197 * \param condition The condition to apply
198 */
199 void AddCondition(const PrimExpr& condition);
200
201 /*! \brief Perform a variable substitution for all constraints
202 *
203 * For example, when propagating constraints from the end of a loop
204 * to the beginning, replace `i` with `i-1`.
205 *
206 * \param var_remap The variable remapping to apply.
207 */
208 void Substitute(const Map<Var, PrimExpr>& var_remap, arith::Analyzer* analyzer);
209
210 /*! \brief Simplify the predicate of all constraints
211 *
212 * \param analyzer The analyzer with which to simplify
213 */
214 void Simplify(arith::Analyzer* analyzer);
215
216 /*! \brief Update the known buffer values based on buffer touches
217 *
218 * For any Write or Assume touches, update the known values. For
219 * any Read touches, ignore. Used to determine known values at the
220 * end of a control flow block, given the known values at the start.
221 *
222 * \param axis_var_lookup A map from buffer to the variables
223 * representing positions along the buffer's axes.
224 *
225 * \param touch_points The buffer touch points to apply
226 *
227 * \param analyzer The analyzer to use for simplifications
228 */
229 void ApplyTouches(const Map<Buffer, Array<Var>>& axis_var_lookup,
230 const std::vector<BufferTouch>& touch_points, arith::Analyzer* analyzer);
231
232 /*! \brief Update unused buffer locations based on buffer touches
233 *
234 * For any Write, mark the written-to indices as unused. (That is,
235 * immediately prior to assigning `buf[i] = expr`, the value stored
236 * at `buf[i]` is irrelevant.) For any Read, mark the read-from
237 * indices as used. This method is used to determine unused buffer
238 * indices at the start of a control flow block, given the unused
239 * buffer indices values at the end.
240 *
241 * \param axis_var_lookup A map from buffer to the variables
242 * representing positions along the buffer's axes.
243 *
244 * \param touch_points The buffer touch points to apply
245 *
246 * \param analyzer The analyzer to use for simplifications
247 */
248 void BackpropUnusedIndices(const Map<Buffer, Array<Var>>& axis_var_lookup,
249 const std::vector<BufferTouch>& touch_points,
250 arith::Analyzer* analyzer);
251
252 /*! \brief Remove free parameters from the constraints
253 *
254 * \param free_predicate_parameters
255 *
256 * \param analyzer The analyzer with which to simplify after removal
257 */
258 void RemoveFreeParameters(const Map<Var, Range>& free_predicate_parameters,
259 arith::Analyzer* analyzer);
260
261 /*! \brief Check if two buffer states are equivalent
262 *
263 * \param other
264 *
265 * \param analyzer The analyzer used to check equality of PrimExpr
266 *
267 * \return True if the two states are provably equivalent, false otherwise.
268 */
269 bool IsEquivalentTo(const BufferState& other, arith::Analyzer* analyzer) const;
270
271 /* \brief Add known values provided by another state
272 *
273 * \param other The state with which to merge constraints
274 *
275 * \param analyzer The analyzer with which to simplify the result
276 */
277 void Union(const BufferState& other, arith::Analyzer* analyzer);
278
279 /* \brief Remove all known values not consistent with another state
280 *
281 * \param other The state with which to merge constraints
282 *
283 * \param analyzer The analyzer with which to simplify the result
284 */
285 void Intersection(const BufferState& other, arith::Analyzer* analyzer);
286
287 friend std::ostream& operator<<(std::ostream& os, const BufferState&);
288
289 private:
290 friend class ControlFlowGraph;
291 /*! \brief The known constraints */
292 std::vector<BufferTouch> constraints_;
293};
294
295/*!
296 * \brief Represents the flow of control through a `tir::Stmt`
297 *
298 * This class contains an internal representation of the possible
299 * control flow that may occur during execution of a `tir::Stmt`. It
300 * consists of a collection of ControlFlowBlock objects, each of which
301 * represents a subset of operations performed during execution, along
302 * with edges that represent allowed transitions between
303 * `ControlFlowBlock`.
304 *
305 * In addition, the following restrictions are used.
306 *
307 * 1. Each block may have at most two predecessors, and at most two
308 * successors.
309 *
310 * 2. Within each block, values stored in a buffer do not change.
311 * That is, encountering a `BufferStore` node requires creating a
312 * new block.
313 *
314 * For example, consider the following PrimFunc
315 *
316 * \code{.py}
317 * @T.prim_func
318 * def func(T.Buffer(16, "float32")):
319 * for i in T.serial(16):
320 * if i < 8:
321 * B[i] = i
322 * else:
323 * B[i] = i-8
324 * \endcode
325 *
326 * The control flow graph would have eight control blocks.
327 *
328 * 1. function_entry, from the start of the function through the
329 * evaluation of the loop's extent.
330 *
331 * Predecessors: n/a
332 * Successors: loop_start
333 *
334 * 2. loop_start, after entering the body of the loop, through the
335 * evaluation of the conditional `i < 8`
336 *
337 * Predecessors: function_entry, after_conditional
338 * Successors: then_clause_start, else_clause_start
339 *
340 * 3. then_clause_start, after entering the then_clause of `i < 8`,
341 * through evaluation of the value `i`.
342 *
343 * Predecessors: loop_start
344 * Successors: then_clause_end
345 *
346 * 4. then_clause_end, after storing to `B[i]` prior to exiting the
347 * then_clause.
348 *
349 * Predecessors: then_clause_start
350 * Successors: after_conditional
351 *
352 * 5. else_clause_start, after entering the else_clause of `i < 8`,
353 * through evaluation of the value `i-8`.
354 *
355 * Predecessors: loop_start
356 * Successors: else_clause_end
357 *
358 * 6. else_clause_end, after storing to `B[i]` prior to exiting the
359 * else_clause.
360 *
361 * Predecessors: else_clause_start
362 * Successors: after_conditional
363 *
364 * 7. after_conditional, after the end of the if/then/else, before the
365 * end of the loop body
366 *
367 * Predecessors: then_clause_end, else_clause_end
368 * Successors: loop_start, after_loop
369 *
370 * 8. after_loop, after the loop
371 *
372 * Predecessors: after_conditional
373 * Successors: n/a
374 *
375 *
376 * By identifying `BufferStore` nodes whose value does not depend on
377 * values stored in input buffers (e.g. initializing `buf[i] = 0.0`),
378 * or whose values are provided using `builtin::assume()`
379 * (e.g. `T.assume(buf[i] == 0.0)`), the value stored in a buffer at
380 * those indices may be known for a given control block. These known
381 * values can then be propagated forward to successor blocks, to be
382 * used in context-dependent simplifications.
383 *
384 * In addition to the allowed transitions between control-flow
385 * blocks, each block also tracks the buffer touch points; which
386 * indices are read from a buffer, which values are written to which
387 * indices of a buffer, and assumptions are provided using
388 * `builtin::assume()`; that occur during the control-flow block.
389 *
390 * Note: The current implementation only tracks the values of
391 * buffers that are constrained to a specific value, and does not
392 * track inequalities that may partially constrain buffer values.
393 * That is, entering a scoped context with a data-dependent equality
394 * condition (e.g. `if buf[i] == value`) is tracked, but entering a
395 * scoped context with a data-dependent inequality condition
396 * (e.g. `if buf[i] > value`) is not tracked.
397 */
398class ControlFlowGraph {
399 public:
400 /* \brief Extract the touch pattern from a TIR statement
401 */
402 explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5);
403
404 /* \brief Check if a write is overwritten without impacting final results
405 *
406 * \param store The store to be examined
407 *
408 * \param context The context in which the buffer store occurs, used
409 * to identify the control-flow block in which the store occurs. In
410 * most cases, this will be the same object as the `store` itself.
411 *
412 * \param analyzer The analyzer to be used for simplifications
413 *
414 * \return True if the specified store can be proven to be
415 * overwritten without contributing to any later statements.
416 * Returns false otherwise.
417 */
418 bool IsOverwrittenWithoutEffect(const BufferStore& store, const Stmt& context) const;
419
420 /* \brief Simplify the expression, assuming it occurs within the given context
421 *
422 * \param expr The expression to be simplified. Does not need to
423 * have occurred within the statement used to construct this
424 * BufferTouchPattern.
425 *
426 * \param context The statement where this expression occurred, or
427 * is to be inserted. Must occur within the statement used to
428 * construct this BufferTouchPattern.
429 *
430 * \param analyzer The analyzer to be used for simplifications
431 *
432 * \returns The simplified statement
433 */
434 PrimExpr SimplifyInContext(PrimExpr expr, const Stmt& context, arith::Analyzer* analyzer) const;
435
436 /*! \brief Remove the specified BufferStore from the control-flow
437 * graph
438 *
439 * Removing the specified store, which may reflow known values.
440 * This is necessary when simplifying sequential stores of the same
441 * value. Otherwise, the first could be removed as a no-op because
442 * it is overwritten by the second, and the second could be removed
443 * as a no-op because it is the same value as the first.
444 *
445 * \param store The store to remove
446 */
447 void RemoveStore(const tir::BufferStore& store);
448
449 friend std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern);
450
451 private:
452 /*! \brief Return index variables representing locations within a
453 * buffer.
454 *
455 * For a given buffer, will always return the same set of variables.
456 *
457 * \param buf The buffer being accessed
458 *
459 * \param indices The indices at which the buffer is being accessed.
460 * These are used to set the dtype of the buffer axis variables.
461 *
462 * \returns Variables representing a position along the buffer's axis.
463 */
464 Array<Var> GetIndexVariables(const Buffer& buf, const Array<PrimExpr>& indices);
465
466 /*! \brief Return index variables representing locations within a
467 * buffer, if they have been generated before.
468 *
469 * For a given buffer, will always return the same set of variables.
470 *
471 * \param buf The buffer being accessed
472 *
473 * \returns Variables representing a position along the buffer's axis.
474 */
475 Optional<Array<Var>> GetIndexVariables(const Buffer& buf) const;
476
477 /*! \brief Propagate known values from known BufferStore/assume
478 * subsequent control flow blocks
479 *
480 * \param flow_from If specified, re-flow only from that block.
481 */
482 void ForwardPropagateKnownValues(std::optional<size_t> flow_from = std::nullopt);
483
484 /*! \brief Propagate overwritten/unused indices to preceding control
485 * flow blocks
486 *
487 * \param flow_from If specified, re-flow only from that block.
488 */
489 void BackwardPropagateUnusedValues(std::optional<size_t> flow_from = std::nullopt);
490
491 struct ControlFlowEdge {
492 /* \brief The source block of the control flow edge
493 *
494 * Lookup index into `control_flow_`
495 */
496 size_t index;
497
498 /*! \brief Variable remaps
499 *
500 * e.g. Replacing loop iterator `i` with `i-1` when following an
501 * edge from the end of a loop to the beginning of the loop.
502 */
503 Map<Var, PrimExpr> var_remap;
504
505 /*! \brief Condition that must to true after following this edge
506 *
507 * This is applied after variable remapping. For example, `i >
508 * loop_min` when following the an edge from the end of a loop to
509 * the beginning of the loop.
510 */
511 Optional<PrimExpr> post_condition;
512 };
513 friend std::ostream& operator<<(std::ostream& os, const ControlFlowEdge& edge);
514
515 struct ControlFlowBlock {
516 struct LoopEntry {
517 Var loop_var;
518 PrimExpr loop_min;
519 PrimExpr loop_max;
520 Range loop_range;
521 };
522
523 /*! \brief Loop iterators that are active during this block */
524 std::vector<LoopEntry> active_loop_iterators;
525
526 /*! \brief Loop-dependent Let bindings that may appear within the block */
527 Map<Var, PrimExpr> let_bindings_using_loop;
528
529 /*! \brief Predicate that must be true to have reached this block */
530 PrimExpr scope_predicate{Bool(true)};
531
532 /*! \brief All known values prior to executing the block */
533 BufferState known_at_block_start;
534
535 /*! \brief All known values after executing the block */
536 BufferState known_at_block_end;
537
538 /*! \brief Indices whose value at the start of the block is known to be unused */
539 BufferState unused_at_block_start;
540
541 /*! \brief Indices whose value at the end of the block is known to be unused */
542 BufferState unused_at_block_end;
543
544 /* \brief Buffer touches that occur within the block
545 *
546 * All buffer touches within a block can be treated as occurring
547 * simultaneously.
548 */
549 std::vector<BufferTouch> touch_points;
550
551 /* \brief The blocks that occur after this block
552 *
553 * Lookup index into `control_flow_`
554 */
555 std::vector<ControlFlowEdge> successors;
556
557 /* \brief The blocks that occur before this block */
558 std::vector<ControlFlowEdge> predecessors;
559
560 /* \brief Construct a BufferTouch instance within this
561 * ControlFlowBlock
562 *
563 * \param graph The mutable ControlFlowGraph that owns the buffer
564 * touch. Any free parameters used in the BufferTouch's predicate
565 * will be tracked by the ControlFlowGraph.
566 *
567 * \param buf The Buffer being accessed
568 *
569 * \param indices The indices at which the buffer is accessed, in
570 * terms of the loop variables.
571 *
572 * \param touch_type The type of touch being generated
573 *
574 * \param known_expr_value The value being written to the buffer
575 *
576 * \returns The newly generated BufferTouch
577 */
578 BufferTouch MakeBufferTouch(ControlFlowGraph* graph, const Buffer& buf,
579 const Array<PrimExpr>& indices, BufferTouch::AccessType touch_type,
580 PrimExpr known_value_expr) const;
581
582 /* \brief Construct a BufferTouch instance as if it occurred in
583 * this ControlFlowBlock
584 *
585 * Used when speculative checking if a BufferStore could be
586 * inserted.
587 *
588 * \param buf The Buffer being accessed
589 *
590 * \param index_variables The variables representing location
591 * within a buffer, with one variable for each axis of the buffer.
592 *
593 * \param indices The indices at which the buffer is accessed, in
594 * terms of the loop variables.
595 *
596 * \param touch_type The type of touch being generated
597 *
598 * \param known_expr_value The value being written to the buffer
599 *
600 * \returns The newly generated BufferTouch, and a map specifying
601 * all free parameters that may occur in the BufferTouch's
602 * predicate.
603 */
604 std::pair<BufferTouch, Map<Var, Range>> MakeBufferTouch(const Buffer& buf,
605 Array<Var> index_variables,
606 Array<PrimExpr> indices,
607 BufferTouch::AccessType touch_type,
608 PrimExpr known_value_expr) const;
609 };
610 friend std::ostream& operator<<(std::ostream& os, const ControlFlowBlock& pattern);
611
612 /* \brief The control flow that occurs within the analyzed statement */
613 std::vector<ControlFlowBlock> control_flow_;
614
615 /* \brief A lookup into control_flow_
616 *
617 * A map to look up the control flow block that contains the
618 * statement.
619 */
620 std::unordered_map<const StmtNode*, size_t> control_flow_lookup_;
621
622 /*! \brief A map from free parameters to their range
623 *
624 * A BufferStore/BufferLoad has indices in terms of loop iterators,
625 * while the internal BufferTouch must have predicate in terms of
626 * the buffer's axes. While converting to the internal BufferTouch,
627 * reduction axes show up as free parameters. Tracking the range of
628 * the free parameters allows them to be removed later, by requiring
629 * a predicate to be true for all values of the free parameters.
630 */
631 Map<Var, Range> free_predicate_parameters_;
632
633 /*! \brief Ranges of iterators found in the analyzed statement */
634 Map<Var, Range> iterator_ranges_;
635
636 /* \brief A map from buffer to the variables representing positions
637 * along the buffer's axes.
638 *
639 * This is stored here, rather than as part of the BufferState or
640 * BufferTouch, to ensure that all access of a buffer use the same
641 * variables to represent the buffer's axes, reducing the amount of
642 * variable substitution required.
643 */
644 Map<Buffer, Array<Var>> axis_var_lookup_;
645
646 /* \brief Assumptions that do not depend on buffer values
647 *
648 * These may be collected as part of the handling of `builtin::assume()`, and do not depend on any
649 * buffer. Since TIR only allows mutable values as part of buffers, these assumptions may be used
650 * anywhere the
651 */
652 std::vector<PrimExpr> non_buffer_assumptions_;
653
654 friend class ControlFlowGraphBuilder;
655
656 /*! \brief The maximum number of revisits while flowing constraints */
657 size_t max_revisits_;
658};
659
660} // namespace tir
661} // namespace tvm
662#endif // TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_
663