1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file reduce_branching_through_overcompute.cc |
22 | * |
23 | * \brief Attempt to remove conditional statements by introducing |
24 | * extra computations that do not impact the final results. |
25 | */ |
26 | |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <optional> |
31 | |
32 | #include "../../arith/ir_mutator_with_analyzer.h" |
33 | #include "../analysis/control_flow_graph.h" |
34 | #include "remove_no_op.h" |
35 | #include "simplify.h" |
36 | |
37 | namespace tvm { |
38 | namespace tir { |
39 | |
40 | struct ReduceBranchingThroughOvercomputeConfigNode |
41 | : public tvm::AttrsNode<ReduceBranchingThroughOvercomputeConfigNode> { |
42 | bool use_dataflow_analysis; |
43 | |
44 | TVM_DECLARE_ATTRS(ReduceBranchingThroughOvercomputeConfigNode, |
45 | "tir.transform.ReduceBranchingThroughOvercomputeConfig" ) { |
46 | TVM_ATTR_FIELD(use_dataflow_analysis) |
47 | .describe( |
48 | "If true, known buffer values are propagated and used " |
49 | "to statically prove that overcompute is valid." ) |
50 | .set_default(false); |
51 | } |
52 | }; |
53 | |
54 | class ReduceBranchingThroughOvercomputeConfig : public Attrs { |
55 | public: |
56 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReduceBranchingThroughOvercomputeConfig, Attrs, |
57 | ReduceBranchingThroughOvercomputeConfigNode); |
58 | }; |
59 | |
60 | TVM_REGISTER_NODE_TYPE(ReduceBranchingThroughOvercomputeConfigNode); |
61 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.ReduceBranchingThroughOvercompute" , |
62 | ReduceBranchingThroughOvercomputeConfig); |
63 | |
64 | struct ElseBranchFiller : StmtExprMutator { |
65 | Stmt VisitStmt_(const IfThenElseNode* op) override { |
66 | IfThenElse ret = Downcast<IfThenElse>(StmtExprMutator::VisitStmt_(op)); |
67 | if (ret->else_case.defined()) { |
68 | return std::move(ret); |
69 | } else { |
70 | auto new_else_clause = Evaluate(0); |
71 | new_else_clauses.insert(new_else_clause); |
72 | return IfThenElse(ret->condition, ret->then_case, new_else_clause); |
73 | } |
74 | } |
75 | |
76 | std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual> new_else_clauses; |
77 | }; |
78 | |
79 | class ElseBranchStripper : public StmtExprMutator { |
80 | public: |
81 | ElseBranchStripper( |
82 | const std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual>& new_else_clauses) |
83 | : new_else_clauses_(new_else_clauses) {} |
84 | |
85 | private: |
86 | Stmt VisitStmt_(const IfThenElseNode* op) override { |
87 | IfThenElse ret = Downcast<IfThenElse>(StmtExprMutator::VisitStmt_(op)); |
88 | auto as_eval = ret->else_case.as<EvaluateNode>(); |
89 | if (as_eval && new_else_clauses_.count(GetRef<Evaluate>(as_eval))) { |
90 | return IfThenElse(ret->condition, ret->then_case); |
91 | } else { |
92 | return std::move(ret); |
93 | } |
94 | } |
95 | |
96 | const std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual>& new_else_clauses_; |
97 | }; |
98 | |
99 | class BranchReducer : public arith::IRMutatorWithAnalyzer { |
100 | public: |
101 | static Stmt Apply(Stmt stmt, const std::optional<ControlFlowGraph>& touch_pattern) { |
102 | arith::Analyzer analyzer; |
103 | BranchReducer visitor(&analyzer, touch_pattern); |
104 | return visitor(std::move(stmt)); |
105 | } |
106 | |
107 | private: |
108 | using Parent = IRMutatorWithAnalyzer; |
109 | using Parent::VisitStmt; |
110 | using Parent::VisitStmt_; |
111 | |
112 | BranchReducer(arith::Analyzer* analyzer, const std::optional<ControlFlowGraph>& touch_pattern) |
113 | : Parent(analyzer), touch_pattern_(touch_pattern) {} |
114 | |
115 | Stmt VisitStmt_(const IfThenElseNode* op) final { |
116 | IfThenElse cond = Downcast<IfThenElse>(Parent::VisitStmt_(op)); |
117 | |
118 | auto is_special_case = [&](PrimExpr condition, Stmt general_case, Stmt special_case) -> bool { |
119 | condition = analyzer_->rewrite_simplify(condition); |
120 | With<arith::ConstraintContext> constraint(analyzer_, condition); |
121 | Stmt stmt = RemoveNoOp(general_case, analyzer_, touch_pattern_, special_case.get()); |
122 | return StructuralEqual()(stmt, special_case); |
123 | }; |
124 | |
125 | ICHECK(cond->else_case.defined() || !touch_pattern_.has_value()) |
126 | << "Temp assert, should be true whenever touch pattern is available" ; |
127 | Stmt else_case = cond->else_case.value_or(Evaluate(0)); |
128 | |
129 | if (is_special_case(cond->condition, else_case, cond->then_case)) { |
130 | return else_case; |
131 | } else if (is_special_case(!cond->condition, cond->then_case, else_case)) { |
132 | return cond->then_case; |
133 | } else { |
134 | return std::move(cond); |
135 | } |
136 | } |
137 | |
138 | private: |
139 | const std::optional<ControlFlowGraph>& touch_pattern_; |
140 | }; |
141 | |
142 | namespace transform { |
143 | |
144 | Pass ReduceBranchingThroughOvercompute() { |
145 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
146 | arith::Analyzer analyzer; |
147 | |
148 | ReduceBranchingThroughOvercomputeConfig config = |
149 | ctx->GetConfig<ReduceBranchingThroughOvercomputeConfig>( |
150 | "tir.ReduceBranchingThroughOvercompute" ) |
151 | .value_or(AttrsWithDefaultValues<ReduceBranchingThroughOvercomputeConfig>()); |
152 | |
153 | auto* n = f.CopyOnWrite(); |
154 | |
155 | std::optional<ControlFlowGraph> touch_pattern = std::nullopt; |
156 | ElseBranchFiller else_branch_filler; |
157 | if (config->use_dataflow_analysis) { |
158 | n->body = else_branch_filler(std::move(n->body)); |
159 | touch_pattern.emplace(n->body); |
160 | } |
161 | |
162 | n->body = BranchReducer::Apply(std::move(n->body), touch_pattern); |
163 | |
164 | if (config->use_dataflow_analysis) { |
165 | n->body = ElseBranchStripper(else_branch_filler.new_else_clauses)(std::move(n->body)); |
166 | } |
167 | return f; |
168 | }; |
169 | return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute" , {}); |
170 | } |
171 | |
172 | TVM_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute" ) |
173 | .set_body_typed(ReduceBranchingThroughOvercompute); |
174 | |
175 | } // namespace transform |
176 | |
177 | } // namespace tir |
178 | } // namespace tvm |
179 | |