1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file renormalize_split_pattern.cc |
22 | * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/op.h> |
27 | #include <tvm/tir/stmt.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | #include <tvm/tir/transform.h> |
30 | |
31 | #include "../../arith/ir_mutator_with_analyzer.h" |
32 | #include "../../arith/pattern_match.h" |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | using namespace arith; |
38 | |
39 | // macro for doing simple rewrite |
40 | #define TRY_REWRITE(SrcExpr, ResExpr) \ |
41 | if ((SrcExpr).Match(ret)) { \ |
42 | return (ResExpr).Eval(); \ |
43 | } |
44 | |
45 | // macro rewrite + recursive_rewrite only if CondExpr is true after match. |
46 | #define TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ |
47 | if ((SrcExpr).Match(ret) && (CondExpr)) { \ |
48 | return RecursiveRewrite((ResExpr).Eval()); \ |
49 | } |
50 | |
51 | class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { |
52 | public: |
53 | explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} |
54 | |
55 | using IRMutatorWithAnalyzer::VisitExpr_; |
56 | |
57 | PrimExpr VisitExpr_(const FloorDivNode* op) final { |
58 | PrimExpr a = VisitExpr(op->a); |
59 | PrimExpr b = VisitExpr(op->b); |
60 | PrimExpr ret = floordiv(a, b); |
61 | // Pattern var to match any expression |
62 | PVar<PrimExpr> x, y, z; |
63 | // Pattern var match IntImm |
64 | PVar<IntImm> c1, c2, c3; |
65 | // Pattern var for lanes in broadcast and ramp |
66 | PVar<int> lanes; |
67 | |
68 | // floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1) |
69 | TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2), |
70 | floormod(floordiv(x, c2), floordiv(c3, c2)), |
71 | c3.Eval()->value % c2.Eval()->value == 0); |
72 | TRY_RECURSIVE_REWRITE_IF( |
73 | floordiv(floormod(x, broadcast(c3, lanes)), broadcast(c2, lanes)), |
74 | floormod(floordiv(x, broadcast(c2, lanes)), broadcast(floordiv(c3, c2), lanes)), |
75 | c3.Eval()->value % c2.Eval()->value == 0); |
76 | |
77 | // floordiv(x*c1*c3 + y, c2*c3) = floordiv(x*c1 + floordiv(y, c3), c2) |
78 | if ((floordiv(x * c1 + y, c2)).Match(ret)) { |
79 | int64_t c1_val = c1.Eval()->value; |
80 | int64_t c2_val = c2.Eval()->value; |
81 | if (c1_val > 0 && c2_val > 0) { |
82 | int64_t c3 = ZeroAwareGCD(c1_val, c2_val); |
83 | if (c3 > 1) { |
84 | IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); |
85 | IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); |
86 | return RecursiveRewrite(floordiv(x.Eval() * c1_div + floordiv(y.Eval(), c3), c2_div)); |
87 | } |
88 | } |
89 | } |
90 | if ((floordiv(x * broadcast(c1, lanes) + y, broadcast(c2, lanes))).Match(ret)) { |
91 | int64_t c1_val = c1.Eval()->value; |
92 | int64_t c2_val = c2.Eval()->value; |
93 | if (c1_val > 0 && c2_val > 0) { |
94 | int64_t c3 = ZeroAwareGCD(c1_val, c2_val); |
95 | if (c3 > 1) { |
96 | IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); |
97 | IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); |
98 | return RecursiveRewrite(floordiv( |
99 | x.Eval() * Broadcast(c1_div, lanes.Eval()) + |
100 | floordiv(y.Eval(), Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), |
101 | Broadcast(c2_div, lanes.Eval()))); |
102 | } |
103 | } |
104 | } |
105 | |
106 | // floordiv(x*c1*c3 + y + z, c2*c3) = floordiv(x*c1 + floordiv(y + z, c3), c2) |
107 | if ((floordiv(x * c1 + y + z, c2)).Match(ret)) { |
108 | int64_t c1_val = c1.Eval()->value; |
109 | int64_t c2_val = c2.Eval()->value; |
110 | if (c1_val > 0 && c2_val > 0) { |
111 | int64_t c3 = ZeroAwareGCD(c1_val, c2_val); |
112 | if (c3 > 1) { |
113 | IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); |
114 | IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); |
115 | return RecursiveRewrite( |
116 | floordiv(x.Eval() * c1_div + floordiv(y.Eval() + z.Eval(), c3), c2_div)); |
117 | } |
118 | } |
119 | } |
120 | if ((floordiv(x * broadcast(c1, lanes) + y + z, broadcast(c2, lanes))).Match(ret)) { |
121 | int64_t c1_val = c1.Eval()->value; |
122 | int64_t c2_val = c2.Eval()->value; |
123 | if (c1_val > 0 && c2_val > 0) { |
124 | int64_t c3 = ZeroAwareGCD(c1_val, c2_val); |
125 | if (c3 > 1) { |
126 | IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); |
127 | IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); |
128 | return RecursiveRewrite( |
129 | floordiv(x.Eval() * Broadcast(c1_div, lanes.Eval()) + |
130 | floordiv(y.Eval() + z.Eval(), |
131 | Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), |
132 | Broadcast(c2_div, lanes.Eval()))); |
133 | } |
134 | } |
135 | } |
136 | |
137 | return ret; |
138 | } |
139 | |
140 | PrimExpr VisitExpr_(const LENode* op) { return this->VisitExpr(Not(op->b < op->a)); } |
141 | |
142 | PrimExpr VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } |
143 | |
144 | PrimExpr VisitExpr_(const GENode* op) { return this->VisitExpr(Not(op->a < op->b)); } |
145 | |
146 | PrimExpr VisitExpr_(const LTNode* op) { |
147 | PrimExpr a = VisitExpr(op->a); |
148 | PrimExpr b = VisitExpr(op->b); |
149 | PrimExpr ret = tir::LT(a, b); |
150 | // Pattern var to match any expression |
151 | PVar<PrimExpr> x; |
152 | // Pattern var match IntImm |
153 | PVar<IntImm> c1, c2; |
154 | // x < c2 <=> x/c2 < 1 <=> floor(x / c2) < 1 |
155 | TRY_RECURSIVE_REWRITE_IF(x<c2, floordiv(x, c2) < 1, c2.Eval()->value> 0); // NOLINT |
156 | return ret; |
157 | } |
158 | |
159 | PrimExpr VisitExpr_(const NotNode* op) { |
160 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
161 | // Pattern var to match any expression |
162 | PVar<PrimExpr> x, y; |
163 | TRY_REWRITE(!(!x), x); |
164 | TRY_REWRITE(!(x <= y), y < x); |
165 | TRY_REWRITE(!(x >= y), x < y); |
166 | TRY_REWRITE(!(x < y), y <= x); |
167 | TRY_REWRITE(!(x > y), x <= y); |
168 | return ret; |
169 | } |
170 | |
171 | Stmt VisitStmt_(const ForNode* op) final { |
172 | analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
173 | With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min); |
174 | With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent); |
175 | return IRMutatorWithAnalyzer::VisitStmt_(op); |
176 | } |
177 | |
178 | // Recursive rewrite x |
179 | // we limit maximum depth of recursive rewrite allowed to |
180 | // avoid infinite loop |
181 | PrimExpr RecursiveRewrite(const PrimExpr& x) { |
182 | if (recur_depth_ >= kMaxRecurDepth) return x; |
183 | ++recur_depth_; |
184 | PrimExpr res = this->VisitExpr(x); |
185 | --recur_depth_; |
186 | return res; |
187 | } |
188 | |
189 | private: |
190 | // counter to record recursive rewrite depth. |
191 | int recur_depth_{0}; |
192 | // maximum number of recursion allowed during a single pass. |
193 | static const constexpr int kMaxRecurDepth = 5; |
194 | }; |
195 | |
196 | namespace transform { |
197 | |
198 | Pass RenormalizeSplitPattern() { |
199 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
200 | auto* n = f.CopyOnWrite(); |
201 | arith::Analyzer analyzer; |
202 | n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body)); |
203 | return f; |
204 | }; |
205 | return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern" , {}); |
206 | } |
207 | |
208 | TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern" ) |
209 | .set_body_typed(RenormalizeSplitPattern); |
210 | |
211 | } // namespace transform |
212 | |
213 | } // namespace tir |
214 | } // namespace tvm |
215 | |