1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file intrin_rule_llvm.cc |
22 | */ |
23 | #ifdef TVM_LLVM_VERSION |
24 | |
25 | #include "intrin_rule_llvm.h" |
26 | |
27 | #include <llvm/IR/Intrinsics.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/op_attr_types.h> |
30 | |
31 | namespace tvm { |
32 | namespace codegen { |
33 | namespace llvm { |
34 | namespace intrin { |
35 | using tir::FLowerIntrinsic; |
36 | |
37 | TVM_REGISTER_OP("tir.prefetch" ) |
38 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
39 | DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); |
40 | |
41 | TVM_REGISTER_OP("tir.exp" ).set_attr<FLowerIntrinsic>( |
42 | "llvm.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); |
43 | |
44 | TVM_REGISTER_OP("tir.exp2" ) |
45 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
46 | DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); |
47 | |
48 | TVM_REGISTER_OP("tir.fma" ).set_attr<FLowerIntrinsic>( |
49 | "llvm.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); |
50 | |
51 | TVM_REGISTER_OP("tir.log" ).set_attr<FLowerIntrinsic>( |
52 | "llvm.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); |
53 | |
54 | TVM_REGISTER_OP("tir.log2" ) |
55 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
56 | DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); |
57 | |
58 | TVM_REGISTER_OP("tir.log10" ) |
59 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
60 | DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); |
61 | |
62 | TVM_REGISTER_OP("tir.sqrt" ) |
63 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
64 | DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); |
65 | |
66 | TVM_REGISTER_OP("tir.floor" ) |
67 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
68 | DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); |
69 | |
70 | TVM_REGISTER_OP("tir.ceil" ) |
71 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
72 | DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); |
73 | |
74 | TVM_REGISTER_OP("tir.trunc" ) |
75 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
76 | DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); |
77 | |
78 | TVM_REGISTER_OP("tir.fabs" ) |
79 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
80 | DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); |
81 | |
82 | TVM_REGISTER_OP("tir.round" ) |
83 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
84 | DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); |
85 | |
86 | TVM_REGISTER_OP("tir.nearbyint" ) |
87 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
88 | DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); |
89 | |
90 | TVM_REGISTER_OP("tir.pow" ).set_attr<FLowerIntrinsic>( |
91 | "llvm.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); |
92 | |
93 | TVM_REGISTER_OP("tir.popcount" ) |
94 | .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic" , |
95 | DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); |
96 | |
97 | TVM_REGISTER_OP("tir.cos" ).set_attr<FLowerIntrinsic>( |
98 | "llvm.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); |
99 | |
100 | TVM_REGISTER_OP("tir.sin" ).set_attr<FLowerIntrinsic>( |
101 | "llvm.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); |
102 | } // namespace intrin |
103 | |
104 | namespace legalize { |
105 | using tir::FLegalize; |
106 | |
107 | TVM_REGISTER_OP("tir.exp10" ) |
108 | .set_attr<FLegalize>("llvm.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
109 | using tir::make_const; |
110 | using tir::make_zero; |
111 | const tir::CallNode* call = e.as<tir::CallNode>(); |
112 | ICHECK(call != nullptr); |
113 | const PrimExpr& x = call->args[0]; |
114 | PrimExpr ln10 = make_const(x.dtype(), 2.302585093); |
115 | PrimExpr ret = exp(x * ln10); |
116 | return ret; |
117 | }); |
118 | |
119 | TVM_REGISTER_OP("tir.tanh" ) |
120 | .set_attr<FLegalize>("llvm.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
121 | using tir::make_const; |
122 | using tir::make_zero; |
123 | const tir::CallNode* call = e.as<tir::CallNode>(); |
124 | ICHECK(call != nullptr); |
125 | const PrimExpr& x = call->args[0]; |
126 | PrimExpr one = make_const(x.dtype(), 1); |
127 | PrimExpr two = make_const(x.dtype(), 2); |
128 | PrimExpr neg_two = make_const(x.dtype(), -2); |
129 | |
130 | PrimExpr exp_neg2x = exp(neg_two * x); |
131 | PrimExpr exp_pos2x = exp(two * x); |
132 | |
133 | PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); |
134 | PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); |
135 | return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); |
136 | }); |
137 | |
138 | TVM_REGISTER_OP("tir.tan" ).set_attr<FLegalize>("llvm.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
139 | const tir::CallNode* call = e.as<tir::CallNode>(); |
140 | ICHECK(call != nullptr); |
141 | const PrimExpr& x = call->args[0]; |
142 | PrimExpr tan_x = sin(x) / cos(x); |
143 | return tan_x; |
144 | }); |
145 | |
146 | TVM_REGISTER_OP("tir.cosh" ) |
147 | .set_attr<FLegalize>("llvm.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
148 | using tir::make_const; |
149 | using tir::make_zero; |
150 | const tir::CallNode* call = e.as<tir::CallNode>(); |
151 | ICHECK(call != nullptr); |
152 | const PrimExpr& x = call->args[0]; |
153 | PrimExpr two = make_const(x.dtype(), 2); |
154 | PrimExpr neg_one = make_const(x.dtype(), -1); |
155 | PrimExpr exp_negx = exp(neg_one * x); |
156 | PrimExpr exp_posx = exp(x); |
157 | PrimExpr ret = (exp_posx + exp_negx) / two; |
158 | return ret; |
159 | }); |
160 | |
161 | TVM_REGISTER_OP("tir.sinh" ) |
162 | .set_attr<FLegalize>("llvm.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
163 | using tir::make_const; |
164 | using tir::make_zero; |
165 | const tir::CallNode* call = e.as<tir::CallNode>(); |
166 | ICHECK(call != nullptr); |
167 | const PrimExpr& x = call->args[0]; |
168 | PrimExpr two = make_const(x.dtype(), 2); |
169 | PrimExpr neg_one = make_const(x.dtype(), -1); |
170 | PrimExpr exp_negx = exp(neg_one * x); |
171 | PrimExpr exp_posx = exp(x); |
172 | PrimExpr ret = (exp_posx - exp_negx) / two; |
173 | return ret; |
174 | }); |
175 | |
176 | TVM_REGISTER_OP("tir.clz" ).set_attr<FLegalize>("llvm.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
177 | const tir::CallNode* call = e.as<tir::CallNode>(); |
178 | ICHECK(call != nullptr); |
179 | ICHECK_EQ(call->args.size(), 1); |
180 | Array<PrimExpr> cargs; |
181 | cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); |
182 | cargs.push_back(IntImm(DataType::UInt(32), 2)); |
183 | cargs.push_back(call->args[0]); |
184 | cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef |
185 | // LLVM requires that the return type must match the first argument type |
186 | auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); |
187 | return cast(call->dtype, clz); |
188 | }); |
189 | |
190 | } // namespace legalize |
191 | } // namespace llvm |
192 | } // namespace codegen |
193 | } // namespace tvm |
194 | |
195 | #endif // LLVM_VERSION |
196 | |