1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Lower intrinsic calls and ops to device specific ir when possible.
22 * \file lower_intrin.cc
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/target/target.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/transform.h>
29
30#include <limits>
31#include <unordered_set>
32
33#include "../../arith/ir_mutator_with_analyzer.h"
34#include "../../arith/pattern_match.h"
35
36namespace tvm {
37namespace tir {
38
39class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
40 public:
41 using IRMutatorWithAnalyzer::VisitExpr_;
42 using IRMutatorWithAnalyzer::VisitStmt_;
43 using FLowerGeneral = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;
44
45 IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "")
46 : IRMutatorWithAnalyzer(analyzer) {
47 std::vector<std::string> patterns;
48 patterns.push_back(target + ".FLowerIntrinsic");
49 patterns.push_back(target + ".FLegalize");
50 bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
51 if (is_llvm_aarch64) {
52 patterns.push_back(target + ".aarch64.FLowerIntrinsic");
53 patterns.push_back(target + ".aarch64.FLegalize");
54 }
55 patterns.push_back("default.FLowerIntrinsic");
56 patterns.push_back("default.FLegalize");
57
58 for (const std::string& pattern : patterns)
59 if (Op::HasAttrMap(pattern)) {
60 attr_maps_.push_back(Op::GetAttrMap<FLowerGeneral>(pattern));
61 if (fma_ == nullptr) {
62 fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr);
63 }
64 }
65 }
66
67 PrimExpr VisitExpr_(const CallNode* op) final {
68 if (auto* ptr_op = op->op.as<OpNode>()) {
69 for (const auto& f_attr_map : attr_maps_) {
70 FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr);
71 if (f != nullptr) {
72 PrimExpr e = GetRef<PrimExpr>(op);
73 PrimExpr r = f(e);
74 ICHECK(r.defined()) << "intrinsic rule must always return valid Expr";
75 if (!r.same_as(e)) {
76 r = this->VisitExpr(r);
77 if (r.defined()) {
78 return r;
79 }
80 }
81 }
82 }
83 }
84 return IRMutatorWithAnalyzer::VisitExpr_(op);
85 }
86
87 PrimExpr VisitExpr_(const AddNode* op) final {
88 if (const MulNode* mb = op->b.as<MulNode>()) {
89 return MakeFMA(mb->a, mb->b, op->a, op);
90 } else if (const MulNode* ma = op->a.as<MulNode>()) {
91 return MakeFMA(ma->a, ma->b, op->b, op);
92 }
93 return IRMutatorWithAnalyzer::VisitExpr_(op);
94 }
95
96 // We use floordiv for integer analysis,
97 // but will need to lower them to native truncdiv instructions
98 PrimExpr VisitExpr_(const FloorDivNode* op) final {
99 auto e = GetRef<PrimExpr>(op);
100 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
101 op = ret.as<FloorDivNode>();
102 if (op == nullptr) return ret;
103 int shift;
104 const DataType& dtype = op->dtype;
105 ICHECK(dtype.is_int() || dtype.is_uint());
106
107 if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
108 // lower to right shift if possible.
109 return op->a >> make_const(dtype, shift);
110 }
111
112 if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
113 // Common path, positive divisor
114 if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) {
115 return truncdiv(op->a, op->b);
116 }
117
118 // If the numerator's lower bound is known, express the floordiv
119 // in terms of truncdiv using only positive operands.
120 arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
121 if (const_int_bound->min_value != arith::ConstIntBound::kNegInf &&
122 const_int_bound->min_value < 0 &&
123 const_int_bound->min_value > Downcast<IntImm>(tvm::min_value(op->a->dtype))->value) {
124 // The goal is to write floordiv(a,b) in terms of truncdiv, without using
125 // negative operands.
126 //
127 // For any integer c
128 //
129 // floordiv(a,b) == floordiv(a + b*c - b*c, b)
130 // == floordiv(a + b*c, b) - c
131 //
132 // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
133 // truncdiv as follows.
134 //
135 // c == ceildiv(-a_min,b)
136 // == floordiv(-a_min + (b-1), b)
137 // == truncdiv(-a_min + (b-1), b)
138 //
139 // When substituted into `a + b*c`, this results in a positive argument.
140 //
141 // a + b*c
142 // == a + b*ceildiv(-a_min,b)
143 // == a - b*floordiv(a_min,b)
144 // >= a - b*floordiv(a,b)
145 // == floormod(a, b)
146 // >= 0
147 //
148 // Since the argument is positive, this allows floordiv to be written as
149 // followed.
150 //
151 // floordiv(a,b)
152 // == floordiv(a + b*c, b) - c
153 // == truncdiv(a + b*c, b) - c
154 IntImm min(op->a->dtype, const_int_bound->min_value);
155 PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
156 PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
157 return truncdiv(offset_numerator, op->b) - ceildiv;
158 }
159
160 DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
161 PrimExpr rdiv = truncdiv(op->a, op->b);
162 PrimExpr rmod = truncmod(op->a, op->b);
163 // condition on b >= 0.
164 // truncmod(a, b) < 0 will implies ceildiv,
165 // So we need to correct these cases.
166 if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
167 // equivalent to rdiv + (rmod >= 0 ? 0: -1);
168 return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
169 } else {
170 return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1));
171 }
172
173 } else {
174 if (dtype.is_float()) {
175 // floor(a / b)
176 return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>());
177 } else {
178 // uncommon case
179 DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
180 auto rmod = tir::Var("rmod", dtype);
181 auto rdiv = tir::Var("rdiv", dtype);
182 // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
183 // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
184 PrimExpr let_rdiv =
185 tir::Let(rdiv, truncdiv(op->a, op->b),
186 tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
187 rdiv - make_const(dtype, 1)));
188 return Let(rmod, truncmod(op->a, op->b), let_rdiv);
189 }
190 }
191 }
192
193 PrimExpr VisitExpr_(const FloorModNode* op) final {
194 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
195 op = ret.as<FloorModNode>();
196 if (op == nullptr) return ret;
197 // Lower floordiv to native truncdiv.
198 int shift;
199 const DataType& dtype = op->dtype;
200 ICHECK(dtype.is_int() || dtype.is_uint());
201
202 if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
203 // lower to masking if possible.
204 int64_t mask = (static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
205 return op->a & make_const(dtype, mask);
206 }
207
208 if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
209 // Common pass, positive divisor
210 if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
211 return truncmod(op->a, op->b);
212 }
213
214 // If the numerator's lower bound is known, express the floormod
215 // in terms of truncmod using only positive operands.
216 arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
217 if (const_int_bound->min_value != arith::ConstIntBound::kNegInf &&
218 const_int_bound->min_value < 0 &&
219 const_int_bound->min_value > Downcast<IntImm>(tvm::min_value(op->a->dtype))->value) {
220 // The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
221 // without using negative operands.
222 //
223 // For any integer c
224 //
225 // floormod(a, b) == floormod(a + b*c, b)
226 //
227 // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
228 // truncdiv as follows.
229 //
230 // c == ceildiv(-a_min,b)
231 // == floordiv(-a_min + (b-1), b)
232 // == truncdiv(-a_min + (b-1), b)
233 //
234 // When substituted into `a + b*c`, this results in a positive argument.
235 //
236 // a + b*c
237 // == a + b*ceildiv(-a_min,b)
238 // == a - b*floordiv(a_min,b)
239 // >= a - b*floordiv(a,b)
240 // == floormod(a, b)
241 // >= 0
242 //
243 // Since the argument is positive, this allows floordiv to be written as
244 // followed.
245 //
246 // floormod(a,b)
247 // == floormod(a + b*c, b)
248 // == truncmod(a + b*c, b)
249 IntImm min(op->a->dtype, const_int_bound->min_value);
250 PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
251 PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
252 return truncmod(offset_numerator, op->b);
253 }
254
255 DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
256 // NOTE:condition on b >= 0.
257 // mod(a, b) < 0 will imply we are doing ceildiv,
258 // So we need to correct these cases.
259 PrimExpr rmod = truncmod(op->a, op->b);
260 if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
261 // (rmod >> shift) & b
262 // -> (rmod >= 0 ? 0: -1) & b
263 // -> rmod >= 0 ? 0 : b
264 return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
265 } else {
266 return tir::Select(rmod >= 0, rmod, rmod + op->b);
267 }
268
269 } else {
270 if (dtype.is_float()) {
271 // a - floor(a / b) * b
272 return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
273 } else {
274 // uncommon case
275 DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
276 auto rmod = tir::Var("rmod", dtype);
277 // b > 0 && rmod >= 0 -> rmod
278 // b > 0 && rmod < 0 -> rmod + b
279 // b < 0 && rmod < 0 -> rmod
280 // b < 0 && rmod > 0 -> rmod + b
281 return Let(
282 rmod, truncmod(op->a, op->b),
283 Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b));
284 }
285 }
286 }
287
288 PrimExpr VisitExpr_(const MaxNode* op) final {
289 using namespace arith;
290 PVar<PrimExpr> x, y;
291 PVar<IntImm> c;
292 auto e = GetRef<PrimExpr>(op);
293 if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 &&
294 analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
295 return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
296 }
297 return IRMutatorWithAnalyzer::VisitExpr_(op);
298 }
299
300 PrimExpr VisitExpr_(const EQNode* op) final {
301 using namespace arith;
302 PVar<PrimExpr> x, y;
303 auto e = GetRef<PrimExpr>(op);
304 if ((floormod(x, y) == 0).Match(e)) {
305 return VisitExpr((truncmod(x, y) == 0).Eval());
306 }
307 return IRMutatorWithAnalyzer::VisitExpr_(op);
308 }
309
310 PrimExpr VisitExpr_(const NENode* op) final {
311 using namespace arith;
312 PVar<PrimExpr> x, y;
313 auto e = GetRef<PrimExpr>(op);
314 if ((floormod(x, y) != 0).Match(e)) {
315 return VisitExpr((truncmod(x, y) != 0).Eval());
316 }
317 return IRMutatorWithAnalyzer::VisitExpr_(op);
318 }
319
320 private:
321 PrimExpr SwapBroadcastCast(const PrimExpr& e) {
322 // Try to change broadcast(cast(x)) to cast(broadcast(x))
323 // For some targets, LLVM will generate more efficient FMA
324 // instruction with the latter. For example, vmla vs. vmlal
325 // on ARM.
326 if (const BroadcastNode* bcast = e.as<BroadcastNode>()) {
327 if (const CastNode* cast = bcast->value.as<CastNode>()) {
328 auto should_swap = [&]() {
329 // Maintain behaviour (int8 -> int16, fp16 -> fp32).
330 if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
331 return true;
332 }
333 // Check both operands are integer-like.
334 if (!cast->dtype.is_uint() && !cast->dtype.is_int()) {
335 return false;
336 }
337 if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) {
338 return false;
339 }
340 // If both are integer-like, swap if we have a widening cast.
341 return cast->dtype.bits() > cast->value.dtype().bits();
342 };
343
344 if (should_swap()) {
345 PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes);
346 return Cast(bcast->dtype, new_bcast);
347 }
348 }
349 }
350 return e;
351 }
352
353 PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) {
354 // emit fma instruction: a * b + c
355 PrimExpr lhs = SwapBroadcastCast(a);
356 PrimExpr rhs = SwapBroadcastCast(b);
357
358 if (fma_ != nullptr && op->dtype.is_float()) {
359 PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
360 if (r.defined()) return this->VisitExpr(r);
361 } else {
362 if (!lhs.same_as(a) || !rhs.same_as(b)) {
363 PrimExpr mul = this->VisitExpr(Mul(lhs, rhs));
364 return Add(mul, this->VisitExpr(c));
365 }
366 }
367 return IRMutatorWithAnalyzer::VisitExpr_(op);
368 }
369
370 // attribute maps, shared only when FLegalize == FLowerIntrinsic
371 std::vector<OpAttrMap<FLowerGeneral>> attr_maps_;
372 FLowerGeneral fma_{nullptr};
373 bool support_bitwise_op_{true};
374};
375
376Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
377 arith::Analyzer analyzer;
378 return IntrinInjecter(&analyzer, target)(std::move(stmt));
379}
380
381namespace transform {
382
383Pass LowerIntrin() {
384 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
385 auto* n = f.CopyOnWrite();
386 auto target = f->GetAttr<Target>(tvm::attr::kTarget);
387 ICHECK(target.defined()) << "LowerIntrin: Require the target attribute";
388 arith::Analyzer analyzer;
389 auto mtriple = target.value()->GetAttr<runtime::String>("mtriple", "");
390 n->body =
391 IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body));
392 return f;
393 };
394 return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
395}
396
397TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin);
398
399} // namespace transform
400
401} // namespace tir
402} // namespace tvm
403