1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file renormalize_split_pattern.cc
22 * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/op.h>
27#include <tvm/tir/stmt.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/transform.h>
30
31#include "../../arith/ir_mutator_with_analyzer.h"
32#include "../../arith/pattern_match.h"
33
34namespace tvm {
35namespace tir {
36
37using namespace arith;
38
39// macro for doing simple rewrite
40#define TRY_REWRITE(SrcExpr, ResExpr) \
41 if ((SrcExpr).Match(ret)) { \
42 return (ResExpr).Eval(); \
43 }
44
45// macro rewrite + recursive_rewrite only if CondExpr is true after match.
46#define TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
47 if ((SrcExpr).Match(ret) && (CondExpr)) { \
48 return RecursiveRewrite((ResExpr).Eval()); \
49 }
50
51class SplitPatternReNormalizer : public IRMutatorWithAnalyzer {
52 public:
53 explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {}
54
55 using IRMutatorWithAnalyzer::VisitExpr_;
56
57 PrimExpr VisitExpr_(const FloorDivNode* op) final {
58 PrimExpr a = VisitExpr(op->a);
59 PrimExpr b = VisitExpr(op->b);
60 PrimExpr ret = floordiv(a, b);
61 // Pattern var to match any expression
62 PVar<PrimExpr> x, y, z;
63 // Pattern var match IntImm
64 PVar<IntImm> c1, c2, c3;
65 // Pattern var for lanes in broadcast and ramp
66 PVar<int> lanes;
67
68 // floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1)
69 TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2),
70 floormod(floordiv(x, c2), floordiv(c3, c2)),
71 c3.Eval()->value % c2.Eval()->value == 0);
72 TRY_RECURSIVE_REWRITE_IF(
73 floordiv(floormod(x, broadcast(c3, lanes)), broadcast(c2, lanes)),
74 floormod(floordiv(x, broadcast(c2, lanes)), broadcast(floordiv(c3, c2), lanes)),
75 c3.Eval()->value % c2.Eval()->value == 0);
76
77 // floordiv(x*c1*c3 + y, c2*c3) = floordiv(x*c1 + floordiv(y, c3), c2)
78 if ((floordiv(x * c1 + y, c2)).Match(ret)) {
79 int64_t c1_val = c1.Eval()->value;
80 int64_t c2_val = c2.Eval()->value;
81 if (c1_val > 0 && c2_val > 0) {
82 int64_t c3 = ZeroAwareGCD(c1_val, c2_val);
83 if (c3 > 1) {
84 IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3);
85 IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3);
86 return RecursiveRewrite(floordiv(x.Eval() * c1_div + floordiv(y.Eval(), c3), c2_div));
87 }
88 }
89 }
90 if ((floordiv(x * broadcast(c1, lanes) + y, broadcast(c2, lanes))).Match(ret)) {
91 int64_t c1_val = c1.Eval()->value;
92 int64_t c2_val = c2.Eval()->value;
93 if (c1_val > 0 && c2_val > 0) {
94 int64_t c3 = ZeroAwareGCD(c1_val, c2_val);
95 if (c3 > 1) {
96 IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3);
97 IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3);
98 return RecursiveRewrite(floordiv(
99 x.Eval() * Broadcast(c1_div, lanes.Eval()) +
100 floordiv(y.Eval(), Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())),
101 Broadcast(c2_div, lanes.Eval())));
102 }
103 }
104 }
105
106 // floordiv(x*c1*c3 + y + z, c2*c3) = floordiv(x*c1 + floordiv(y + z, c3), c2)
107 if ((floordiv(x * c1 + y + z, c2)).Match(ret)) {
108 int64_t c1_val = c1.Eval()->value;
109 int64_t c2_val = c2.Eval()->value;
110 if (c1_val > 0 && c2_val > 0) {
111 int64_t c3 = ZeroAwareGCD(c1_val, c2_val);
112 if (c3 > 1) {
113 IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3);
114 IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3);
115 return RecursiveRewrite(
116 floordiv(x.Eval() * c1_div + floordiv(y.Eval() + z.Eval(), c3), c2_div));
117 }
118 }
119 }
120 if ((floordiv(x * broadcast(c1, lanes) + y + z, broadcast(c2, lanes))).Match(ret)) {
121 int64_t c1_val = c1.Eval()->value;
122 int64_t c2_val = c2.Eval()->value;
123 if (c1_val > 0 && c2_val > 0) {
124 int64_t c3 = ZeroAwareGCD(c1_val, c2_val);
125 if (c3 > 1) {
126 IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3);
127 IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3);
128 return RecursiveRewrite(
129 floordiv(x.Eval() * Broadcast(c1_div, lanes.Eval()) +
130 floordiv(y.Eval() + z.Eval(),
131 Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())),
132 Broadcast(c2_div, lanes.Eval())));
133 }
134 }
135 }
136
137 return ret;
138 }
139
140 PrimExpr VisitExpr_(const LENode* op) { return this->VisitExpr(Not(op->b < op->a)); }
141
142 PrimExpr VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); }
143
144 PrimExpr VisitExpr_(const GENode* op) { return this->VisitExpr(Not(op->a < op->b)); }
145
146 PrimExpr VisitExpr_(const LTNode* op) {
147 PrimExpr a = VisitExpr(op->a);
148 PrimExpr b = VisitExpr(op->b);
149 PrimExpr ret = tir::LT(a, b);
150 // Pattern var to match any expression
151 PVar<PrimExpr> x;
152 // Pattern var match IntImm
153 PVar<IntImm> c1, c2;
154 // x < c2 <=> x/c2 < 1 <=> floor(x / c2) < 1
155 TRY_RECURSIVE_REWRITE_IF(x<c2, floordiv(x, c2) < 1, c2.Eval()->value> 0); // NOLINT
156 return ret;
157 }
158
159 PrimExpr VisitExpr_(const NotNode* op) {
160 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
161 // Pattern var to match any expression
162 PVar<PrimExpr> x, y;
163 TRY_REWRITE(!(!x), x);
164 TRY_REWRITE(!(x <= y), y < x);
165 TRY_REWRITE(!(x >= y), x < y);
166 TRY_REWRITE(!(x < y), y <= x);
167 TRY_REWRITE(!(x > y), x <= y);
168 return ret;
169 }
170
171 Stmt VisitStmt_(const ForNode* op) final {
172 analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
173 With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
174 With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
175 return IRMutatorWithAnalyzer::VisitStmt_(op);
176 }
177
178 // Recursive rewrite x
179 // we limit maximum depth of recursive rewrite allowed to
180 // avoid infinite loop
181 PrimExpr RecursiveRewrite(const PrimExpr& x) {
182 if (recur_depth_ >= kMaxRecurDepth) return x;
183 ++recur_depth_;
184 PrimExpr res = this->VisitExpr(x);
185 --recur_depth_;
186 return res;
187 }
188
189 private:
190 // counter to record recursive rewrite depth.
191 int recur_depth_{0};
192 // maximum number of recursion allowed during a single pass.
193 static const constexpr int kMaxRecurDepth = 5;
194};
195
196namespace transform {
197
198Pass RenormalizeSplitPattern() {
199 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
200 auto* n = f.CopyOnWrite();
201 arith::Analyzer analyzer;
202 n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body));
203 return f;
204 };
205 return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {});
206}
207
208TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern")
209 .set_body_typed(RenormalizeSplitPattern);
210
211} // namespace transform
212
213} // namespace tir
214} // namespace tvm
215