1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Lower intrinsic calls and ops to device specific ir when possible. |
22 | * \file lower_intrin.cc |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/target/target.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <limits> |
31 | #include <unordered_set> |
32 | |
33 | #include "../../arith/ir_mutator_with_analyzer.h" |
34 | #include "../../arith/pattern_match.h" |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { |
40 | public: |
41 | using IRMutatorWithAnalyzer::VisitExpr_; |
42 | using IRMutatorWithAnalyzer::VisitStmt_; |
43 | using FLowerGeneral = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>; |
44 | |
45 | IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "" ) |
46 | : IRMutatorWithAnalyzer(analyzer) { |
47 | std::vector<std::string> patterns; |
48 | patterns.push_back(target + ".FLowerIntrinsic" ); |
49 | patterns.push_back(target + ".FLegalize" ); |
50 | bool is_llvm_aarch64 = (mtriple.find("aarch64" ) != std::string::npos); |
51 | if (is_llvm_aarch64) { |
52 | patterns.push_back(target + ".aarch64.FLowerIntrinsic" ); |
53 | patterns.push_back(target + ".aarch64.FLegalize" ); |
54 | } |
55 | patterns.push_back("default.FLowerIntrinsic" ); |
56 | patterns.push_back("default.FLegalize" ); |
57 | |
58 | for (const std::string& pattern : patterns) |
59 | if (Op::HasAttrMap(pattern)) { |
60 | attr_maps_.push_back(Op::GetAttrMap<FLowerGeneral>(pattern)); |
61 | if (fma_ == nullptr) { |
62 | fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma" ), nullptr); |
63 | } |
64 | } |
65 | } |
66 | |
67 | PrimExpr VisitExpr_(const CallNode* op) final { |
68 | if (auto* ptr_op = op->op.as<OpNode>()) { |
69 | for (const auto& f_attr_map : attr_maps_) { |
70 | FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr); |
71 | if (f != nullptr) { |
72 | PrimExpr e = GetRef<PrimExpr>(op); |
73 | PrimExpr r = f(e); |
74 | ICHECK(r.defined()) << "intrinsic rule must always return valid Expr" ; |
75 | if (!r.same_as(e)) { |
76 | r = this->VisitExpr(r); |
77 | if (r.defined()) { |
78 | return r; |
79 | } |
80 | } |
81 | } |
82 | } |
83 | } |
84 | return IRMutatorWithAnalyzer::VisitExpr_(op); |
85 | } |
86 | |
87 | PrimExpr VisitExpr_(const AddNode* op) final { |
88 | if (const MulNode* mb = op->b.as<MulNode>()) { |
89 | return MakeFMA(mb->a, mb->b, op->a, op); |
90 | } else if (const MulNode* ma = op->a.as<MulNode>()) { |
91 | return MakeFMA(ma->a, ma->b, op->b, op); |
92 | } |
93 | return IRMutatorWithAnalyzer::VisitExpr_(op); |
94 | } |
95 | |
96 | // We use floordiv for integer analysis, |
97 | // but will need to lower them to native truncdiv instructions |
98 | PrimExpr VisitExpr_(const FloorDivNode* op) final { |
99 | auto e = GetRef<PrimExpr>(op); |
100 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
101 | op = ret.as<FloorDivNode>(); |
102 | if (op == nullptr) return ret; |
103 | int shift; |
104 | const DataType& dtype = op->dtype; |
105 | ICHECK(dtype.is_int() || dtype.is_uint()); |
106 | |
107 | if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { |
108 | // lower to right shift if possible. |
109 | return op->a >> make_const(dtype, shift); |
110 | } |
111 | |
112 | if (analyzer_->CanProveGreaterEqual(op->b, 0)) { |
113 | // Common path, positive divisor |
114 | if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) { |
115 | return truncdiv(op->a, op->b); |
116 | } |
117 | |
118 | // If the numerator's lower bound is known, express the floordiv |
119 | // in terms of truncdiv using only positive operands. |
120 | arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); |
121 | if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && |
122 | const_int_bound->min_value < 0 && |
123 | const_int_bound->min_value > Downcast<IntImm>(tvm::min_value(op->a->dtype))->value) { |
124 | // The goal is to write floordiv(a,b) in terms of truncdiv, without using |
125 | // negative operands. |
126 | // |
127 | // For any integer c |
128 | // |
129 | // floordiv(a,b) == floordiv(a + b*c - b*c, b) |
130 | // == floordiv(a + b*c, b) - c |
131 | // |
132 | // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of |
133 | // truncdiv as follows. |
134 | // |
135 | // c == ceildiv(-a_min,b) |
136 | // == floordiv(-a_min + (b-1), b) |
137 | // == truncdiv(-a_min + (b-1), b) |
138 | // |
139 | // When substituted into `a + b*c`, this results in a positive argument. |
140 | // |
141 | // a + b*c |
142 | // == a + b*ceildiv(-a_min,b) |
143 | // == a - b*floordiv(a_min,b) |
144 | // >= a - b*floordiv(a,b) |
145 | // == floormod(a, b) |
146 | // >= 0 |
147 | // |
148 | // Since the argument is positive, this allows floordiv to be written as |
149 | // followed. |
150 | // |
151 | // floordiv(a,b) |
152 | // == floordiv(a + b*c, b) - c |
153 | // == truncdiv(a + b*c, b) - c |
154 | IntImm min(op->a->dtype, const_int_bound->min_value); |
155 | PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); |
156 | PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); |
157 | return truncdiv(offset_numerator, op->b) - ceildiv; |
158 | } |
159 | |
160 | DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident" ; |
161 | PrimExpr rdiv = truncdiv(op->a, op->b); |
162 | PrimExpr rmod = truncmod(op->a, op->b); |
163 | // condition on b >= 0. |
164 | // truncmod(a, b) < 0 will implies ceildiv, |
165 | // So we need to correct these cases. |
166 | if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { |
167 | // equivalent to rdiv + (rmod >= 0 ? 0: -1); |
168 | return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); |
169 | } else { |
170 | return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); |
171 | } |
172 | |
173 | } else { |
174 | if (dtype.is_float()) { |
175 | // floor(a / b) |
176 | return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()); |
177 | } else { |
178 | // uncommon case |
179 | DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor" ; |
180 | auto rmod = tir::Var("rmod" , dtype); |
181 | auto rdiv = tir::Var("rdiv" , dtype); |
182 | // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) |
183 | // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) |
184 | PrimExpr let_rdiv = |
185 | tir::Let(rdiv, truncdiv(op->a, op->b), |
186 | tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, |
187 | rdiv - make_const(dtype, 1))); |
188 | return Let(rmod, truncmod(op->a, op->b), let_rdiv); |
189 | } |
190 | } |
191 | } |
192 | |
193 | PrimExpr VisitExpr_(const FloorModNode* op) final { |
194 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
195 | op = ret.as<FloorModNode>(); |
196 | if (op == nullptr) return ret; |
197 | // Lower floordiv to native truncdiv. |
198 | int shift; |
199 | const DataType& dtype = op->dtype; |
200 | ICHECK(dtype.is_int() || dtype.is_uint()); |
201 | |
202 | if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { |
203 | // lower to masking if possible. |
204 | int64_t mask = (static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1; |
205 | return op->a & make_const(dtype, mask); |
206 | } |
207 | |
208 | if (analyzer_->CanProveGreaterEqual(op->b, 0)) { |
209 | // Common pass, positive divisor |
210 | if (analyzer_->CanProveGreaterEqual(op->a, 0)) { |
211 | return truncmod(op->a, op->b); |
212 | } |
213 | |
214 | // If the numerator's lower bound is known, express the floormod |
215 | // in terms of truncmod using only positive operands. |
216 | arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); |
217 | if (const_int_bound->min_value != arith::ConstIntBound::kNegInf && |
218 | const_int_bound->min_value < 0 && |
219 | const_int_bound->min_value > Downcast<IntImm>(tvm::min_value(op->a->dtype))->value) { |
220 | // The goal is to write floormod(a,b) in terms of truncdiv and truncmod, |
221 | // without using negative operands. |
222 | // |
223 | // For any integer c |
224 | // |
225 | // floormod(a, b) == floormod(a + b*c, b) |
226 | // |
227 | // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of |
228 | // truncdiv as follows. |
229 | // |
230 | // c == ceildiv(-a_min,b) |
231 | // == floordiv(-a_min + (b-1), b) |
232 | // == truncdiv(-a_min + (b-1), b) |
233 | // |
234 | // When substituted into `a + b*c`, this results in a positive argument. |
235 | // |
236 | // a + b*c |
237 | // == a + b*ceildiv(-a_min,b) |
238 | // == a - b*floordiv(a_min,b) |
239 | // >= a - b*floordiv(a,b) |
240 | // == floormod(a, b) |
241 | // >= 0 |
242 | // |
243 | // Since the argument is positive, this allows floordiv to be written as |
244 | // followed. |
245 | // |
246 | // floormod(a,b) |
247 | // == floormod(a + b*c, b) |
248 | // == truncmod(a + b*c, b) |
249 | IntImm min(op->a->dtype, const_int_bound->min_value); |
250 | PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); |
251 | PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); |
252 | return truncmod(offset_numerator, op->b); |
253 | } |
254 | |
255 | DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident" ; |
256 | // NOTE:condition on b >= 0. |
257 | // mod(a, b) < 0 will imply we are doing ceildiv, |
258 | // So we need to correct these cases. |
259 | PrimExpr rmod = truncmod(op->a, op->b); |
260 | if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { |
261 | // (rmod >> shift) & b |
262 | // -> (rmod >= 0 ? 0: -1) & b |
263 | // -> rmod >= 0 ? 0 : b |
264 | return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); |
265 | } else { |
266 | return tir::Select(rmod >= 0, rmod, rmod + op->b); |
267 | } |
268 | |
269 | } else { |
270 | if (dtype.is_float()) { |
271 | // a - floor(a / b) * b |
272 | return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b); |
273 | } else { |
274 | // uncommon case |
275 | DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident" ; |
276 | auto rmod = tir::Var("rmod" , dtype); |
277 | // b > 0 && rmod >= 0 -> rmod |
278 | // b > 0 && rmod < 0 -> rmod + b |
279 | // b < 0 && rmod < 0 -> rmod |
280 | // b < 0 && rmod > 0 -> rmod + b |
281 | return Let( |
282 | rmod, truncmod(op->a, op->b), |
283 | Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b)); |
284 | } |
285 | } |
286 | } |
287 | |
288 | PrimExpr VisitExpr_(const MaxNode* op) final { |
289 | using namespace arith; |
290 | PVar<PrimExpr> x, y; |
291 | PVar<IntImm> c; |
292 | auto e = GetRef<PrimExpr>(op); |
293 | if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && |
294 | analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { |
295 | return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); |
296 | } |
297 | return IRMutatorWithAnalyzer::VisitExpr_(op); |
298 | } |
299 | |
300 | PrimExpr VisitExpr_(const EQNode* op) final { |
301 | using namespace arith; |
302 | PVar<PrimExpr> x, y; |
303 | auto e = GetRef<PrimExpr>(op); |
304 | if ((floormod(x, y) == 0).Match(e)) { |
305 | return VisitExpr((truncmod(x, y) == 0).Eval()); |
306 | } |
307 | return IRMutatorWithAnalyzer::VisitExpr_(op); |
308 | } |
309 | |
310 | PrimExpr VisitExpr_(const NENode* op) final { |
311 | using namespace arith; |
312 | PVar<PrimExpr> x, y; |
313 | auto e = GetRef<PrimExpr>(op); |
314 | if ((floormod(x, y) != 0).Match(e)) { |
315 | return VisitExpr((truncmod(x, y) != 0).Eval()); |
316 | } |
317 | return IRMutatorWithAnalyzer::VisitExpr_(op); |
318 | } |
319 | |
320 | private: |
321 | PrimExpr SwapBroadcastCast(const PrimExpr& e) { |
322 | // Try to change broadcast(cast(x)) to cast(broadcast(x)) |
323 | // For some targets, LLVM will generate more efficient FMA |
324 | // instruction with the latter. For example, vmla vs. vmlal |
325 | // on ARM. |
326 | if (const BroadcastNode* bcast = e.as<BroadcastNode>()) { |
327 | if (const CastNode* cast = bcast->value.as<CastNode>()) { |
328 | auto should_swap = [&]() { |
329 | // Maintain behaviour (int8 -> int16, fp16 -> fp32). |
330 | if (cast->dtype.bits() == cast->value.dtype().bits() * 2) { |
331 | return true; |
332 | } |
333 | // Check both operands are integer-like. |
334 | if (!cast->dtype.is_uint() && !cast->dtype.is_int()) { |
335 | return false; |
336 | } |
337 | if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) { |
338 | return false; |
339 | } |
340 | // If both are integer-like, swap if we have a widening cast. |
341 | return cast->dtype.bits() > cast->value.dtype().bits(); |
342 | }; |
343 | |
344 | if (should_swap()) { |
345 | PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes); |
346 | return Cast(bcast->dtype, new_bcast); |
347 | } |
348 | } |
349 | } |
350 | return e; |
351 | } |
352 | |
353 | PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) { |
354 | // emit fma instruction: a * b + c |
355 | PrimExpr lhs = SwapBroadcastCast(a); |
356 | PrimExpr rhs = SwapBroadcastCast(b); |
357 | |
358 | if (fma_ != nullptr && op->dtype.is_float()) { |
359 | PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); |
360 | if (r.defined()) return this->VisitExpr(r); |
361 | } else { |
362 | if (!lhs.same_as(a) || !rhs.same_as(b)) { |
363 | PrimExpr mul = this->VisitExpr(Mul(lhs, rhs)); |
364 | return Add(mul, this->VisitExpr(c)); |
365 | } |
366 | } |
367 | return IRMutatorWithAnalyzer::VisitExpr_(op); |
368 | } |
369 | |
370 | // attribute maps, shared only when FLegalize == FLowerIntrinsic |
371 | std::vector<OpAttrMap<FLowerGeneral>> attr_maps_; |
372 | FLowerGeneral fma_{nullptr}; |
373 | bool support_bitwise_op_{true}; |
374 | }; |
375 | |
376 | Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { |
377 | arith::Analyzer analyzer; |
378 | return IntrinInjecter(&analyzer, target)(std::move(stmt)); |
379 | } |
380 | |
381 | namespace transform { |
382 | |
383 | Pass LowerIntrin() { |
384 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
385 | auto* n = f.CopyOnWrite(); |
386 | auto target = f->GetAttr<Target>(tvm::attr::kTarget); |
387 | ICHECK(target.defined()) << "LowerIntrin: Require the target attribute" ; |
388 | arith::Analyzer analyzer; |
389 | auto mtriple = target.value()->GetAttr<runtime::String>("mtriple" , "" ); |
390 | n->body = |
391 | IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body)); |
392 | return f; |
393 | }; |
394 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin" , {}); |
395 | } |
396 | |
397 | TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin" ).set_body_typed(LowerIntrin); |
398 | |
399 | } // namespace transform |
400 | |
401 | } // namespace tir |
402 | } // namespace tvm |
403 | |