1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file reduce_branching_through_overcompute.cc
22 *
23 * \brief Attempt to remove conditional statements by introducing
24 * extra computations that do not impact the final results.
25 */
26
27#include <tvm/tir/op.h>
28#include <tvm/tir/transform.h>
29
30#include <optional>
31
32#include "../../arith/ir_mutator_with_analyzer.h"
33#include "../analysis/control_flow_graph.h"
34#include "remove_no_op.h"
35#include "simplify.h"
36
37namespace tvm {
38namespace tir {
39
40struct ReduceBranchingThroughOvercomputeConfigNode
41 : public tvm::AttrsNode<ReduceBranchingThroughOvercomputeConfigNode> {
42 bool use_dataflow_analysis;
43
44 TVM_DECLARE_ATTRS(ReduceBranchingThroughOvercomputeConfigNode,
45 "tir.transform.ReduceBranchingThroughOvercomputeConfig") {
46 TVM_ATTR_FIELD(use_dataflow_analysis)
47 .describe(
48 "If true, known buffer values are propagated and used "
49 "to statically prove that overcompute is valid.")
50 .set_default(false);
51 }
52};
53
54class ReduceBranchingThroughOvercomputeConfig : public Attrs {
55 public:
56 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReduceBranchingThroughOvercomputeConfig, Attrs,
57 ReduceBranchingThroughOvercomputeConfigNode);
58};
59
60TVM_REGISTER_NODE_TYPE(ReduceBranchingThroughOvercomputeConfigNode);
61TVM_REGISTER_PASS_CONFIG_OPTION("tir.ReduceBranchingThroughOvercompute",
62 ReduceBranchingThroughOvercomputeConfig);
63
64struct ElseBranchFiller : StmtExprMutator {
65 Stmt VisitStmt_(const IfThenElseNode* op) override {
66 IfThenElse ret = Downcast<IfThenElse>(StmtExprMutator::VisitStmt_(op));
67 if (ret->else_case.defined()) {
68 return std::move(ret);
69 } else {
70 auto new_else_clause = Evaluate(0);
71 new_else_clauses.insert(new_else_clause);
72 return IfThenElse(ret->condition, ret->then_case, new_else_clause);
73 }
74 }
75
76 std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual> new_else_clauses;
77};
78
79class ElseBranchStripper : public StmtExprMutator {
80 public:
81 ElseBranchStripper(
82 const std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual>& new_else_clauses)
83 : new_else_clauses_(new_else_clauses) {}
84
85 private:
86 Stmt VisitStmt_(const IfThenElseNode* op) override {
87 IfThenElse ret = Downcast<IfThenElse>(StmtExprMutator::VisitStmt_(op));
88 auto as_eval = ret->else_case.as<EvaluateNode>();
89 if (as_eval && new_else_clauses_.count(GetRef<Evaluate>(as_eval))) {
90 return IfThenElse(ret->condition, ret->then_case);
91 } else {
92 return std::move(ret);
93 }
94 }
95
96 const std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual>& new_else_clauses_;
97};
98
99class BranchReducer : public arith::IRMutatorWithAnalyzer {
100 public:
101 static Stmt Apply(Stmt stmt, const std::optional<ControlFlowGraph>& touch_pattern) {
102 arith::Analyzer analyzer;
103 BranchReducer visitor(&analyzer, touch_pattern);
104 return visitor(std::move(stmt));
105 }
106
107 private:
108 using Parent = IRMutatorWithAnalyzer;
109 using Parent::VisitStmt;
110 using Parent::VisitStmt_;
111
112 BranchReducer(arith::Analyzer* analyzer, const std::optional<ControlFlowGraph>& touch_pattern)
113 : Parent(analyzer), touch_pattern_(touch_pattern) {}
114
115 Stmt VisitStmt_(const IfThenElseNode* op) final {
116 IfThenElse cond = Downcast<IfThenElse>(Parent::VisitStmt_(op));
117
118 auto is_special_case = [&](PrimExpr condition, Stmt general_case, Stmt special_case) -> bool {
119 condition = analyzer_->rewrite_simplify(condition);
120 With<arith::ConstraintContext> constraint(analyzer_, condition);
121 Stmt stmt = RemoveNoOp(general_case, analyzer_, touch_pattern_, special_case.get());
122 return StructuralEqual()(stmt, special_case);
123 };
124
125 ICHECK(cond->else_case.defined() || !touch_pattern_.has_value())
126 << "Temp assert, should be true whenever touch pattern is available";
127 Stmt else_case = cond->else_case.value_or(Evaluate(0));
128
129 if (is_special_case(cond->condition, else_case, cond->then_case)) {
130 return else_case;
131 } else if (is_special_case(!cond->condition, cond->then_case, else_case)) {
132 return cond->then_case;
133 } else {
134 return std::move(cond);
135 }
136 }
137
138 private:
139 const std::optional<ControlFlowGraph>& touch_pattern_;
140};
141
142namespace transform {
143
144Pass ReduceBranchingThroughOvercompute() {
145 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
146 arith::Analyzer analyzer;
147
148 ReduceBranchingThroughOvercomputeConfig config =
149 ctx->GetConfig<ReduceBranchingThroughOvercomputeConfig>(
150 "tir.ReduceBranchingThroughOvercompute")
151 .value_or(AttrsWithDefaultValues<ReduceBranchingThroughOvercomputeConfig>());
152
153 auto* n = f.CopyOnWrite();
154
155 std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
156 ElseBranchFiller else_branch_filler;
157 if (config->use_dataflow_analysis) {
158 n->body = else_branch_filler(std::move(n->body));
159 touch_pattern.emplace(n->body);
160 }
161
162 n->body = BranchReducer::Apply(std::move(n->body), touch_pattern);
163
164 if (config->use_dataflow_analysis) {
165 n->body = ElseBranchStripper(else_branch_filler.new_else_clauses)(std::move(n->body));
166 }
167 return f;
168 };
169 return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute", {});
170}
171
172TVM_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute")
173 .set_body_typed(ReduceBranchingThroughOvercompute);
174
175} // namespace transform
176
177} // namespace tir
178} // namespace tvm
179