1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifdef TVM_LLVM_VERSION
21
22#include <llvm/IR/Intrinsics.h>
23#include <tvm/tir/op.h>
24#include <tvm/tir/op_attr_types.h>
25
26#include "intrin_rule_llvm.h"
27
28#define TVM_REGISTER_QHL_OP_FP16(INTRIN_FUNC, WRAPPER_FUNC, NUM_SIGN) \
29 std::string tvm_qhl_ahf_##INTRIN_FUNC = WRAPPER_FUNC; \
30 TVM_REGISTER_OP("tir." #INTRIN_FUNC) \
31 .set_attr<FLowerIntrinsic>( \
32 "hexagon.FLowerIntrinsic", \
33 DispatchTVMQHLWrapperFp16<tvm_qhl_ahf_##INTRIN_FUNC, ::llvm::Intrinsic::INTRIN_FUNC, \
34 NUM_SIGN>);
35
36namespace tvm {
37namespace codegen {
38namespace llvm {
39using tir::FLowerIntrinsic;
40
41inline PrimExpr TVMExternCall(const tir::CallNode* call, const std::string& fname) {
42 Array<PrimExpr> new_args = {tir::StringImm(fname)};
43 for (PrimExpr arg : call->args) {
44 new_args.push_back(arg);
45 }
46 return tir::Call(call->dtype, tir::builtin::call_pure_extern(), new_args);
47}
48
49template <std::string& tvm_wrapper, unsigned id, int num_sign>
50inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) {
51 using namespace tir;
52 const CallNode* call = e.as<CallNode>();
53 ICHECK(call != nullptr);
54 Array<PrimExpr> new_args;
55#if ENABLE_QHL
56 // Check target for qfloat enablement
57 const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent");
58 ICHECK(f != nullptr);
59 const auto ret = (*f)(true);
60 const Target t = ret.AsObjectRef<Target>();
61 bool useqhl = true;
62 if (t.defined()) {
63 const std::string tstring = t->str();
64 useqhl = tstring.find("+hvx-qfloat") != std::string::npos;
65 }
66
67 // Enable QHL library for FP16 data type
68 const PrimExpr& x = call->args[0];
69 if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
70 return TVMExternCall(call, tvm_wrapper);
71 }
72#endif
73 new_args.push_back(IntImm(DataType::UInt(32), id));
74 new_args.push_back(IntImm(DataType::UInt(32), num_sign));
75 new_args.insert(new_args.end(), call->args.begin(), call->args.end());
76 return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), new_args);
77}
78
79TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
80 "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
81
82TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>(
83 "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
84
85TVM_REGISTER_OP("tir.trunc")
86 .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
87 DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
88
89TVM_REGISTER_OP("tir.fabs")
90 .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
91 DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
92
93TVM_REGISTER_OP("tir.round")
94 .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
95 DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
96
97TVM_REGISTER_OP("tir.ctpop")
98 .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
99 DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
100TVM_REGISTER_OP("tir.tanh")
101 .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic", [](const PrimExpr& e) {
102 const tir::CallNode* call = e.as<tir::CallNode>();
103 ICHECK(call != nullptr);
104 const PrimExpr& x = call->args[0];
105
106#if ENABLE_QHL
107 // Check target for qfloat enablement
108 const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent");
109 ICHECK(f != nullptr);
110 const auto ret = (*f)(true);
111 const Target t = ret.AsObjectRef<Target>();
112 bool useqhl = true;
113 if (t.defined()) {
114 const std::string tstring = t->str();
115 useqhl = tstring.find("+hvx-qfloat") != std::string::npos;
116 }
117
118 // Enable QHL library for FP16 data type
119 if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
120 std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf");
121 return TVMExternCall(call, tvm_wrapper);
122 }
123#endif
124 PrimExpr one = tir::make_const(x.dtype(), 1);
125 PrimExpr two = tir::make_const(x.dtype(), 2);
126 PrimExpr neg_two = tir::make_const(x.dtype(), -2);
127
128 PrimExpr exp_neg2x = exp(neg_two * x);
129 PrimExpr exp_pos2x = exp(two * x);
130
131 PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
132 PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
133 PrimExpr tanh_x = tir::Select(x >= tir::make_zero(x.dtype()), tanh_pos, tanh_neg);
134 return tanh_x;
135 });
136
137TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
138 "hexagon.FLowerIntrinsic", [](const PrimExpr& e) {
139 const tir::CallNode* call = e.as<tir::CallNode>();
140 ICHECK(call != nullptr);
141 const PrimExpr& x = call->args[0];
142#if ENABLE_QHL
143 // Check target for qfloat enablement
144 const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent");
145 ICHECK(f != nullptr);
146 const auto ret = (*f)(true);
147 const Target t = ret.AsObjectRef<Target>();
148 bool useqhl = true;
149 if (t.defined()) {
150 const std::string tstring = t->str();
151 useqhl = tstring.find("+hvx-qfloat") != std::string::npos;
152 }
153
154 // Enable QHL library for FP16 data type
155 if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
156 std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf");
157 return TVMExternCall(call, tvm_wrapper);
158 }
159#endif
160 PrimExpr tan_x = sin(x) / cos(x);
161 return tan_x;
162 });
163
164TVM_REGISTER_OP("tir.nearbyint")
165 .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
166 DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
167
168TVM_REGISTER_OP("tir.sigmoid")
169 .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic", [](const PrimExpr& e) {
170 const tir::CallNode* call = e.as<tir::CallNode>();
171 ICHECK(call != nullptr);
172 const PrimExpr& x = call->args[0];
173#if ENABLE_QHL
174 // Check target for qfloat enablement
175 const auto* f = tvm::runtime::Registry::Get("target.TargetCurrent");
176 ICHECK(f != nullptr);
177 const auto ret = (*f)(true);
178 const Target t = ret.AsObjectRef<Target>();
179 bool useqhl = true;
180 if (t.defined()) {
181 const std::string tstring = t->str();
182 useqhl = tstring.find("+hvx-qfloat") != std::string::npos;
183 }
184
185 PrimExpr MinBound = tir::make_const(x.dtype(), -8);
186 PrimExpr MaxBound = tir::make_const(x.dtype(), 8);
187 const PrimExpr v1 = tir::Max(x, MinBound);
188 const PrimExpr v2 = tir::Min(v1, MaxBound);
189
190 Array<tvm::PrimExpr> new_args = {v2};
191 const tir::Call new_call = tir::Call(call->dtype, call->op, new_args);
192
193 // Enable QHL library for FP16 data type
194 if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
195 std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf");
196 return TVMExternCall(new_call.get(), tvm_wrapper);
197 }
198#endif
199 PrimExpr one = tir::make_const(x.dtype(), 1);
200 return one / (one + exp(-x));
201 });
202
203TVM_REGISTER_QHL_OP_FP16(ceil, "tvm_vect_qhmath_hvx_ceil_ahf", 1)
204
205TVM_REGISTER_QHL_OP_FP16(cos, "tvm_vect_qhmath_hvx_cos_ahf", 1)
206
207TVM_REGISTER_QHL_OP_FP16(exp, "tvm_vect_qhmath_hvx_exp_ahf", 1)
208
209TVM_REGISTER_QHL_OP_FP16(floor, "tvm_vect_qhmath_hvx_floor_ahf", 1)
210
211TVM_REGISTER_QHL_OP_FP16(sin, "tvm_vect_qhmath_hvx_sin_ahf", 1)
212
213TVM_REGISTER_QHL_OP_FP16(pow, "tvm_vect_qhmath_hvx_pow_ahf", 2)
214
215TVM_REGISTER_QHL_OP_FP16(sqrt, "tvm_vect_qhmath_hvx_sqrt_ahf", 1)
216
217} // namespace llvm
218} // namespace codegen
219} // namespace tvm
220
221#endif // TVM_LLVM_VERSION
222