1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file intrin_rule_llvm.cc
22 */
23#ifdef TVM_LLVM_VERSION
24
25#include "intrin_rule_llvm.h"
26
27#include <llvm/IR/Intrinsics.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/op_attr_types.h>
30
31namespace tvm {
32namespace codegen {
33namespace llvm {
34namespace intrin {
35using tir::FLowerIntrinsic;
36
37TVM_REGISTER_OP("tir.prefetch")
38 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
39 DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>);
40
41TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>(
42 "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
43
44TVM_REGISTER_OP("tir.exp2")
45 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
46 DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
47
48TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
49 "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
50
51TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>(
52 "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
53
54TVM_REGISTER_OP("tir.log2")
55 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
56 DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
57
58TVM_REGISTER_OP("tir.log10")
59 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
60 DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
61
62TVM_REGISTER_OP("tir.sqrt")
63 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
64 DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
65
66TVM_REGISTER_OP("tir.floor")
67 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
68 DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
69
70TVM_REGISTER_OP("tir.ceil")
71 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
72 DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
73
74TVM_REGISTER_OP("tir.trunc")
75 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
76 DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
77
78TVM_REGISTER_OP("tir.fabs")
79 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
80 DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
81
82TVM_REGISTER_OP("tir.round")
83 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
84 DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
85
86TVM_REGISTER_OP("tir.nearbyint")
87 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
88 DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
89
90TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
91 "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
92
93TVM_REGISTER_OP("tir.popcount")
94 .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
95 DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
96
97TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
98 "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
99
100TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
101 "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
102} // namespace intrin
103
104namespace legalize {
105using tir::FLegalize;
106
107TVM_REGISTER_OP("tir.exp10")
108 .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
109 using tir::make_const;
110 using tir::make_zero;
111 const tir::CallNode* call = e.as<tir::CallNode>();
112 ICHECK(call != nullptr);
113 const PrimExpr& x = call->args[0];
114 PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
115 PrimExpr ret = exp(x * ln10);
116 return ret;
117 });
118
119TVM_REGISTER_OP("tir.tanh")
120 .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
121 using tir::make_const;
122 using tir::make_zero;
123 const tir::CallNode* call = e.as<tir::CallNode>();
124 ICHECK(call != nullptr);
125 const PrimExpr& x = call->args[0];
126 PrimExpr one = make_const(x.dtype(), 1);
127 PrimExpr two = make_const(x.dtype(), 2);
128 PrimExpr neg_two = make_const(x.dtype(), -2);
129
130 PrimExpr exp_neg2x = exp(neg_two * x);
131 PrimExpr exp_pos2x = exp(two * x);
132
133 PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
134 PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
135 return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
136 });
137
138TVM_REGISTER_OP("tir.tan").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
139 const tir::CallNode* call = e.as<tir::CallNode>();
140 ICHECK(call != nullptr);
141 const PrimExpr& x = call->args[0];
142 PrimExpr tan_x = sin(x) / cos(x);
143 return tan_x;
144});
145
146TVM_REGISTER_OP("tir.cosh")
147 .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
148 using tir::make_const;
149 using tir::make_zero;
150 const tir::CallNode* call = e.as<tir::CallNode>();
151 ICHECK(call != nullptr);
152 const PrimExpr& x = call->args[0];
153 PrimExpr two = make_const(x.dtype(), 2);
154 PrimExpr neg_one = make_const(x.dtype(), -1);
155 PrimExpr exp_negx = exp(neg_one * x);
156 PrimExpr exp_posx = exp(x);
157 PrimExpr ret = (exp_posx + exp_negx) / two;
158 return ret;
159 });
160
161TVM_REGISTER_OP("tir.sinh")
162 .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
163 using tir::make_const;
164 using tir::make_zero;
165 const tir::CallNode* call = e.as<tir::CallNode>();
166 ICHECK(call != nullptr);
167 const PrimExpr& x = call->args[0];
168 PrimExpr two = make_const(x.dtype(), 2);
169 PrimExpr neg_one = make_const(x.dtype(), -1);
170 PrimExpr exp_negx = exp(neg_one * x);
171 PrimExpr exp_posx = exp(x);
172 PrimExpr ret = (exp_posx - exp_negx) / two;
173 return ret;
174 });
175
176TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
177 const tir::CallNode* call = e.as<tir::CallNode>();
178 ICHECK(call != nullptr);
179 ICHECK_EQ(call->args.size(), 1);
180 Array<PrimExpr> cargs;
181 cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
182 cargs.push_back(IntImm(DataType::UInt(32), 2));
183 cargs.push_back(call->args[0]);
184 cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
185 // LLVM requires that the return type must match the first argument type
186 auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
187 return cast(call->dtype, clz);
188});
189
190} // namespace legalize
191} // namespace llvm
192} // namespace codegen
193} // namespace tvm
194
195#endif // LLVM_VERSION
196