1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifdef TVM_LLVM_VERSION |
21 | |
22 | #include <llvm/IR/Intrinsics.h> |
23 | #include <tvm/tir/op.h> |
24 | #include <tvm/tir/op_attr_types.h> |
25 | |
26 | #include "intrin_rule_llvm.h" |
27 | |
28 | #define TVM_REGISTER_QHL_OP_FP16(INTRIN_FUNC, WRAPPER_FUNC, NUM_SIGN) \ |
29 | std::string tvm_qhl_ahf_##INTRIN_FUNC = WRAPPER_FUNC; \ |
30 | TVM_REGISTER_OP("tir." #INTRIN_FUNC) \ |
31 | .set_attr<FLowerIntrinsic>( \ |
32 | "hexagon.FLowerIntrinsic", \ |
33 | DispatchTVMQHLWrapperFp16<tvm_qhl_ahf_##INTRIN_FUNC, ::llvm::Intrinsic::INTRIN_FUNC, \ |
34 | NUM_SIGN>); |
35 | |
36 | namespace tvm { |
37 | namespace codegen { |
38 | namespace llvm { |
39 | using tir::FLowerIntrinsic; |
40 | |
41 | inline PrimExpr TVMExternCall(const tir::CallNode* call, const std::string& fname) { |
42 | Array<PrimExpr> new_args = {tir::StringImm(fname)}; |
43 | for (PrimExpr arg : call->args) { |
44 | new_args.push_back(arg); |
45 | } |
46 | return tir::Call(call->dtype, tir::builtin::call_pure_extern(), new_args); |
47 | } |
48 | |
49 | template <std::string& tvm_wrapper, unsigned id, int num_sign> |
50 | inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { |
51 | using namespace tir; |
52 | const CallNode* call = e.as<CallNode>(); |
53 | ICHECK(call != nullptr); |
54 | Array<PrimExpr> new_args; |
55 | #if ENABLE_QHL |
56 | // Check target for qfloat enablement |
57 | const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent" ); |
58 | ICHECK(f != nullptr); |
59 | const auto ret = (*f)(true); |
60 | const Target t = ret.AsObjectRef<Target>(); |
61 | bool useqhl = true; |
62 | if (t.defined()) { |
63 | const std::string tstring = t->str(); |
64 | useqhl = tstring.find("+hvx-qfloat" ) != std::string::npos; |
65 | } |
66 | |
67 | // Enable QHL library for FP16 data type |
68 | const PrimExpr& x = call->args[0]; |
69 | if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { |
70 | return TVMExternCall(call, tvm_wrapper); |
71 | } |
72 | #endif |
73 | new_args.push_back(IntImm(DataType::UInt(32), id)); |
74 | new_args.push_back(IntImm(DataType::UInt(32), num_sign)); |
75 | new_args.insert(new_args.end(), call->args.begin(), call->args.end()); |
76 | return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), new_args); |
77 | } |
78 | |
79 | TVM_REGISTER_OP("tir.fma" ).set_attr<FLowerIntrinsic>( |
80 | "hexagon.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); |
81 | |
82 | TVM_REGISTER_OP("tir.log" ).set_attr<FLowerIntrinsic>( |
83 | "hexagon.FLowerIntrinsic" , DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); |
84 | |
85 | TVM_REGISTER_OP("tir.trunc" ) |
86 | .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic" , |
87 | DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); |
88 | |
89 | TVM_REGISTER_OP("tir.fabs" ) |
90 | .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic" , |
91 | DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); |
92 | |
93 | TVM_REGISTER_OP("tir.round" ) |
94 | .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic" , |
95 | DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); |
96 | |
97 | TVM_REGISTER_OP("tir.ctpop" ) |
98 | .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic" , |
99 | DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); |
100 | TVM_REGISTER_OP("tir.tanh" ) |
101 | .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic" , [](const PrimExpr& e) { |
102 | const tir::CallNode* call = e.as<tir::CallNode>(); |
103 | ICHECK(call != nullptr); |
104 | const PrimExpr& x = call->args[0]; |
105 | |
106 | #if ENABLE_QHL |
107 | // Check target for qfloat enablement |
108 | const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent" ); |
109 | ICHECK(f != nullptr); |
110 | const auto ret = (*f)(true); |
111 | const Target t = ret.AsObjectRef<Target>(); |
112 | bool useqhl = true; |
113 | if (t.defined()) { |
114 | const std::string tstring = t->str(); |
115 | useqhl = tstring.find("+hvx-qfloat" ) != std::string::npos; |
116 | } |
117 | |
118 | // Enable QHL library for FP16 data type |
119 | if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { |
120 | std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf" ); |
121 | return TVMExternCall(call, tvm_wrapper); |
122 | } |
123 | #endif |
124 | PrimExpr one = tir::make_const(x.dtype(), 1); |
125 | PrimExpr two = tir::make_const(x.dtype(), 2); |
126 | PrimExpr neg_two = tir::make_const(x.dtype(), -2); |
127 | |
128 | PrimExpr exp_neg2x = exp(neg_two * x); |
129 | PrimExpr exp_pos2x = exp(two * x); |
130 | |
131 | PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); |
132 | PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); |
133 | PrimExpr tanh_x = tir::Select(x >= tir::make_zero(x.dtype()), tanh_pos, tanh_neg); |
134 | return tanh_x; |
135 | }); |
136 | |
137 | TVM_REGISTER_OP("tir.tan" ).set_attr<FLowerIntrinsic>( |
138 | "hexagon.FLowerIntrinsic" , [](const PrimExpr& e) { |
139 | const tir::CallNode* call = e.as<tir::CallNode>(); |
140 | ICHECK(call != nullptr); |
141 | const PrimExpr& x = call->args[0]; |
142 | #if ENABLE_QHL |
143 | // Check target for qfloat enablement |
144 | const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent" ); |
145 | ICHECK(f != nullptr); |
146 | const auto ret = (*f)(true); |
147 | const Target t = ret.AsObjectRef<Target>(); |
148 | bool useqhl = true; |
149 | if (t.defined()) { |
150 | const std::string tstring = t->str(); |
151 | useqhl = tstring.find("+hvx-qfloat" ) != std::string::npos; |
152 | } |
153 | |
154 | // Enable QHL library for FP16 data type |
155 | if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { |
156 | std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf" ); |
157 | return TVMExternCall(call, tvm_wrapper); |
158 | } |
159 | #endif |
160 | PrimExpr tan_x = sin(x) / cos(x); |
161 | return tan_x; |
162 | }); |
163 | |
164 | TVM_REGISTER_OP("tir.nearbyint" ) |
165 | .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic" , |
166 | DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); |
167 | |
168 | TVM_REGISTER_OP("tir.sigmoid" ) |
169 | .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic" , [](const PrimExpr& e) { |
170 | const tir::CallNode* call = e.as<tir::CallNode>(); |
171 | ICHECK(call != nullptr); |
172 | const PrimExpr& x = call->args[0]; |
173 | #if ENABLE_QHL |
174 | // Check target for qfloat enablement |
175 | const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent" ); |
176 | ICHECK(f != nullptr); |
177 | const auto ret = (*f)(true); |
178 | const Target t = ret.AsObjectRef<Target>(); |
179 | bool useqhl = true; |
180 | if (t.defined()) { |
181 | const std::string tstring = t->str(); |
182 | useqhl = tstring.find("+hvx-qfloat" ) != std::string::npos; |
183 | } |
184 | |
185 | PrimExpr MinBound = tir::make_const(x.dtype(), -8); |
186 | PrimExpr MaxBound = tir::make_const(x.dtype(), 8); |
187 | const PrimExpr v1 = tir::Max(x, MinBound); |
188 | const PrimExpr v2 = tir::Min(v1, MaxBound); |
189 | |
190 | Array<tvm::PrimExpr> new_args = {v2}; |
191 | const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); |
192 | |
193 | // Enable QHL library for FP16 data type |
194 | if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { |
195 | std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf" ); |
196 | return TVMExternCall(new_call.get(), tvm_wrapper); |
197 | } |
198 | #endif |
199 | PrimExpr one = tir::make_const(x.dtype(), 1); |
200 | return one / (one + exp(-x)); |
201 | }); |
202 | |
203 | TVM_REGISTER_QHL_OP_FP16(ceil, "tvm_vect_qhmath_hvx_ceil_ahf" , 1) |
204 | |
205 | TVM_REGISTER_QHL_OP_FP16(cos, "tvm_vect_qhmath_hvx_cos_ahf" , 1) |
206 | |
207 | TVM_REGISTER_QHL_OP_FP16(exp, "tvm_vect_qhmath_hvx_exp_ahf" , 1) |
208 | |
209 | TVM_REGISTER_QHL_OP_FP16(floor, "tvm_vect_qhmath_hvx_floor_ahf" , 1) |
210 | |
211 | TVM_REGISTER_QHL_OP_FP16(sin, "tvm_vect_qhmath_hvx_sin_ahf" , 1) |
212 | |
213 | TVM_REGISTER_QHL_OP_FP16(pow, "tvm_vect_qhmath_hvx_pow_ahf" , 2) |
214 | |
215 | TVM_REGISTER_QHL_OP_FP16(sqrt, "tvm_vect_qhmath_hvx_sqrt_ahf" , 1) |
216 | |
217 | } // namespace llvm |
218 | } // namespace codegen |
219 | } // namespace tvm |
220 | |
221 | #endif // TVM_LLVM_VERSION |
222 | |