1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file intrin_rule_default.cc
22 * \brief Default intrinsic rules.
23 */
24#include "intrin_rule.h"
25
26#include <tvm/tir/op.h>
27#include <tvm/tir/op_attr_types.h>
28
29namespace tvm {
30namespace codegen {
31namespace intrin {
32using tir::FLowerIntrinsic;
33
34TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
35 DispatchPureExtern<FloatSuffix>);
36
37TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
38 DispatchPureExtern<FloatSuffix>);
39
40TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
41 DispatchPureExtern<FloatSuffix>);
42
43TVM_REGISTER_OP("tir.log2")
44 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
45
46TVM_REGISTER_OP("tir.log10")
47 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
48
49TVM_REGISTER_OP("tir.log1p")
50 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
51
52TVM_REGISTER_OP("tir.tanh")
53 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
54
55TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
56 DispatchPureExtern<FloatSuffix>);
57
58TVM_REGISTER_OP("tir.atan")
59 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
60
61TVM_REGISTER_OP("tir.atanh")
62 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
63
64TVM_REGISTER_OP("tir.atan2")
65 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
66
67TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
68 DispatchPureExtern<FloatSuffix>);
69
70TVM_REGISTER_OP("tir.acos")
71 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
72
73TVM_REGISTER_OP("tir.cosh")
74 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
75
76TVM_REGISTER_OP("tir.acosh")
77 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
78
79TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
80 DispatchPureExtern<FloatSuffix>);
81
82TVM_REGISTER_OP("tir.asin")
83 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
84
85TVM_REGISTER_OP("tir.sinh")
86 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
87
88TVM_REGISTER_OP("tir.asinh")
89 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
90
91TVM_REGISTER_OP("tir.hypot")
92 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
93
94TVM_REGISTER_OP("tir.nextafter")
95 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
96
97TVM_REGISTER_OP("tir.copysign")
98 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
99
100TVM_REGISTER_OP("tir.ldexp")
101 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
102
103TVM_REGISTER_OP("tir.sqrt")
104 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
105
106TVM_REGISTER_OP("tir.floor")
107 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
108
109TVM_REGISTER_OP("tir.ceil")
110 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
111
112TVM_REGISTER_OP("tir.round")
113 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
114
115TVM_REGISTER_OP("tir.nearbyint")
116 .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
117
118TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
119 DispatchPureExtern<FloatSuffix>);
120
121} // namespace intrin
122
123namespace legalize {
124
125using namespace tir;
126
127TVM_REGISTER_OP("tir.rsqrt")
128 .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
129 const CallNode* call = e.as<CallNode>();
130 ICHECK(call != nullptr);
131 auto one = make_const(call->args[0].dtype(), 1);
132 return one / sqrt(call->args[0]);
133 });
134
135TVM_REGISTER_OP("tir.sigmoid")
136 .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
137 const CallNode* call = e.as<CallNode>();
138 ICHECK(call != nullptr);
139 auto one = make_const(call->args[0].dtype(), 1);
140 return one / (one + exp(-call->args[0]));
141 });
142
143TVM_REGISTER_OP("tir.isfinite")
144 .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
145 const CallNode* call = e.as<CallNode>();
146 ICHECK(call != nullptr);
147 return isfinite(call->args[0]);
148 });
149
150TVM_REGISTER_OP("tir.isinf")
151 .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
152 const CallNode* call = e.as<CallNode>();
153 ICHECK(call != nullptr);
154 return isinf(call->args[0]);
155 });
156
157/*!
158 * \brief Makes fixed point multiplication.
159 * \param x Input tensor.
160 * \param y Integer multiplier.
161 * \param left_shift Integer left shift.
162 * \param right_shift Integer right shift.
163 * \param is_left_shift_required Flag whether we need to do left shift or not.
164 * \return Calculated expression.
165 */
166static PrimExpr QMultiplyShift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr left_shift,
167 PrimExpr right_shift, PrimExpr is_left_shift_required) {
168 // Only int32 types are supported (any number of lanes is allowed)
169 ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
170 ICHECK(left_shift.dtype().code() == DLDataTypeCode::kDLInt && left_shift.dtype().bits() == 32);
171 ICHECK(right_shift.dtype().code() == DLDataTypeCode::kDLInt && right_shift.dtype().bits() == 32);
172
173 DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
174 DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
175
176 // 1) Cast and Multiply the integer multiplier
177 PrimExpr one = make_const(hp_dtype, 1);
178 x = cast(hp_dtype, x);
179 y = cast(hp_dtype, y);
180 x = tir::Select(is_left_shift_required, x << left_shift, x);
181
182 // 2) Perform the multiplication in higher precision.
183 x = x * y;
184
185 // 3) Find the rounding scalar
186 PrimExpr total_right_shift = right_shift + q;
187 PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
188 x = x + pos_rounding_value;
189
190 // 4) Simply right shift the result to get the final output.
191 x = x >> total_right_shift;
192
193 // 5) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
194 return cast(lp_dtype, x);
195}
196
197TVM_REGISTER_OP("tir.q_multiply_shift")
198 .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
199 using tir::make_const;
200
201 const tir::CallNode* call = e.as<tir::CallNode>();
202 ICHECK(call != nullptr);
203
204 PrimExpr x = call->args[0];
205 PrimExpr y = call->args[1];
206 PrimExpr q = call->args[2];
207 PrimExpr s = call->args[3];
208
209 // Lambda function to extract the int value from PrimExpr
210 auto get_int_value = [](const PrimExpr node) {
211 if (auto int_node = node.as<IntImmNode>()) {
212 return int_node->value;
213 }
214 auto broadcast_node = node.as<BroadcastNode>();
215 CHECK(broadcast_node != nullptr);
216 auto int_node = broadcast_node->value.as<IntImmNode>();
217 CHECK(int_node != nullptr);
218 return int_node->value;
219 };
220 // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of
221 // 2, fixed point multiplier will represent a float value of 0.5. In fixed point, this is
222 // represented by 1 << 30.
223 if (get_int_value(y) == (1 << 30)) {
224 PrimExpr exp = s - 1;
225 int exp_val = get_int_value(s) - 1;
226 if (exp_val > 0) {
227 // power of 2 is greater than 0, apply left shift.
228 return x << exp;
229 } else {
230 // power of 2 is less than 0, round and then apply right shift.
231 DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
232 PrimExpr one = make_const(lp_dtype, 1);
233 exp = -exp;
234 PrimExpr rounding_factor = one << (exp - 1);
235 PrimExpr rounded_t = x + rounding_factor;
236 return rounded_t >> exp;
237 }
238 } else {
239 // Only int32 types are supported (any number of lanes is allowed)
240 ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);
241
242 // Calculating integer shifts
243 PrimExpr zero = make_const(s.dtype(), 0);
244 PrimExpr left_shift = tir::Select(s > zero, s, zero);
245 PrimExpr right_shift = tir::Select(s > zero, zero, -s);
246 PrimExpr is_left_shift_required = (left_shift != zero);
247
248 return QMultiplyShift(x, y, q, left_shift, right_shift, is_left_shift_required);
249 }
250 });
251
252TVM_REGISTER_OP("tir.q_multiply_shift_per_axis")
253 .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
254 const tir::CallNode* call = e.as<tir::CallNode>();
255 ICHECK(call != nullptr);
256
257 PrimExpr x = call->args[0];
258 PrimExpr y = call->args[1];
259 PrimExpr left_shift = call->args[2];
260 PrimExpr right_shift = call->args[3];
261 PrimExpr q = call->args[4];
262 PrimExpr is_lshift_required = call->args[5];
263 // Note, 7th argument is "is_rshift_required" flag, but we don't need that here.
264 // PrimExpr is_rshift_required = call->args[6];
265
266 return QMultiplyShift(x, y, q, left_shift, right_shift, is_lshift_required);
267 });
268} // namespace legalize
269} // namespace codegen
270} // namespace tvm
271