1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file intrin_rule_default.cc |
22 | * \brief Default intrinsic rules. |
23 | */ |
24 | #include "intrin_rule.h" |
25 | |
26 | #include <tvm/tir/op.h> |
27 | #include <tvm/tir/op_attr_types.h> |
28 | |
29 | namespace tvm { |
30 | namespace codegen { |
31 | namespace intrin { |
32 | using tir::FLowerIntrinsic; |
33 | |
34 | TVM_REGISTER_OP("tir.exp" ).set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , |
35 | DispatchPureExtern<FloatSuffix>); |
36 | |
37 | TVM_REGISTER_OP("tir.erf" ).set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , |
38 | DispatchPureExtern<FloatSuffix>); |
39 | |
40 | TVM_REGISTER_OP("tir.log" ).set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , |
41 | DispatchPureExtern<FloatSuffix>); |
42 | |
43 | TVM_REGISTER_OP("tir.log2" ) |
44 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
45 | |
46 | TVM_REGISTER_OP("tir.log10" ) |
47 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
48 | |
49 | TVM_REGISTER_OP("tir.log1p" ) |
50 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
51 | |
52 | TVM_REGISTER_OP("tir.tanh" ) |
53 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
54 | |
55 | TVM_REGISTER_OP("tir.tan" ).set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , |
56 | DispatchPureExtern<FloatSuffix>); |
57 | |
58 | TVM_REGISTER_OP("tir.atan" ) |
59 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
60 | |
61 | TVM_REGISTER_OP("tir.atanh" ) |
62 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
63 | |
64 | TVM_REGISTER_OP("tir.atan2" ) |
65 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
66 | |
67 | TVM_REGISTER_OP("tir.cos" ).set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , |
68 | DispatchPureExtern<FloatSuffix>); |
69 | |
70 | TVM_REGISTER_OP("tir.acos" ) |
71 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
72 | |
73 | TVM_REGISTER_OP("tir.cosh" ) |
74 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
75 | |
76 | TVM_REGISTER_OP("tir.acosh" ) |
77 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
78 | |
79 | TVM_REGISTER_OP("tir.sin" ).set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , |
80 | DispatchPureExtern<FloatSuffix>); |
81 | |
82 | TVM_REGISTER_OP("tir.asin" ) |
83 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
84 | |
85 | TVM_REGISTER_OP("tir.sinh" ) |
86 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
87 | |
88 | TVM_REGISTER_OP("tir.asinh" ) |
89 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
90 | |
91 | TVM_REGISTER_OP("tir.hypot" ) |
92 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
93 | |
94 | TVM_REGISTER_OP("tir.nextafter" ) |
95 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
96 | |
97 | TVM_REGISTER_OP("tir.copysign" ) |
98 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
99 | |
100 | TVM_REGISTER_OP("tir.ldexp" ) |
101 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
102 | |
103 | TVM_REGISTER_OP("tir.sqrt" ) |
104 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
105 | |
106 | TVM_REGISTER_OP("tir.floor" ) |
107 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
108 | |
109 | TVM_REGISTER_OP("tir.ceil" ) |
110 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
111 | |
112 | TVM_REGISTER_OP("tir.round" ) |
113 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
114 | |
115 | TVM_REGISTER_OP("tir.nearbyint" ) |
116 | .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , DispatchPureExtern<FloatSuffix>); |
117 | |
118 | TVM_REGISTER_OP("tir.pow" ).set_attr<FLowerIntrinsic>("default.FLowerIntrinsic" , |
119 | DispatchPureExtern<FloatSuffix>); |
120 | |
121 | } // namespace intrin |
122 | |
123 | namespace legalize { |
124 | |
125 | using namespace tir; |
126 | |
127 | TVM_REGISTER_OP("tir.rsqrt" ) |
128 | .set_attr<FLegalize>("default.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
129 | const CallNode* call = e.as<CallNode>(); |
130 | ICHECK(call != nullptr); |
131 | auto one = make_const(call->args[0].dtype(), 1); |
132 | return one / sqrt(call->args[0]); |
133 | }); |
134 | |
135 | TVM_REGISTER_OP("tir.sigmoid" ) |
136 | .set_attr<FLegalize>("default.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
137 | const CallNode* call = e.as<CallNode>(); |
138 | ICHECK(call != nullptr); |
139 | auto one = make_const(call->args[0].dtype(), 1); |
140 | return one / (one + exp(-call->args[0])); |
141 | }); |
142 | |
143 | TVM_REGISTER_OP("tir.isfinite" ) |
144 | .set_attr<FLegalize>("default.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
145 | const CallNode* call = e.as<CallNode>(); |
146 | ICHECK(call != nullptr); |
147 | return isfinite(call->args[0]); |
148 | }); |
149 | |
150 | TVM_REGISTER_OP("tir.isinf" ) |
151 | .set_attr<FLegalize>("default.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
152 | const CallNode* call = e.as<CallNode>(); |
153 | ICHECK(call != nullptr); |
154 | return isinf(call->args[0]); |
155 | }); |
156 | |
157 | /*! |
158 | * \brief Makes fixed point multiplication. |
159 | * \param x Input tensor. |
160 | * \param y Integer multiplier. |
161 | * \param left_shift Integer left shift. |
162 | * \param right_shift Integer right shift. |
163 | * \param is_left_shift_required Flag whether we need to do left shift or not. |
164 | * \return Calculated expression. |
165 | */ |
166 | static PrimExpr QMultiplyShift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr left_shift, |
167 | PrimExpr right_shift, PrimExpr is_left_shift_required) { |
168 | // Only int32 types are supported (any number of lanes is allowed) |
169 | ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); |
170 | ICHECK(left_shift.dtype().code() == DLDataTypeCode::kDLInt && left_shift.dtype().bits() == 32); |
171 | ICHECK(right_shift.dtype().code() == DLDataTypeCode::kDLInt && right_shift.dtype().bits() == 32); |
172 | |
173 | DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); |
174 | DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); |
175 | |
176 | // 1) Cast and Multiply the integer multiplier |
177 | PrimExpr one = make_const(hp_dtype, 1); |
178 | x = cast(hp_dtype, x); |
179 | y = cast(hp_dtype, y); |
180 | x = tir::Select(is_left_shift_required, x << left_shift, x); |
181 | |
182 | // 2) Perform the multiplication in higher precision. |
183 | x = x * y; |
184 | |
185 | // 3) Find the rounding scalar |
186 | PrimExpr total_right_shift = right_shift + q; |
187 | PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); |
188 | x = x + pos_rounding_value; |
189 | |
190 | // 4) Simply right shift the result to get the final output. |
191 | x = x >> total_right_shift; |
192 | |
193 | // 5) The fixed point multiplication keeps the value in int32 range. Casting back to int32. |
194 | return cast(lp_dtype, x); |
195 | } |
196 | |
197 | TVM_REGISTER_OP("tir.q_multiply_shift" ) |
198 | .set_attr<FLegalize>("default.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
199 | using tir::make_const; |
200 | |
201 | const tir::CallNode* call = e.as<tir::CallNode>(); |
202 | ICHECK(call != nullptr); |
203 | |
204 | PrimExpr x = call->args[0]; |
205 | PrimExpr y = call->args[1]; |
206 | PrimExpr q = call->args[2]; |
207 | PrimExpr s = call->args[3]; |
208 | |
209 | // Lambda function to extract the int value from PrimExpr |
210 | auto get_int_value = [](const PrimExpr node) { |
211 | if (auto int_node = node.as<IntImmNode>()) { |
212 | return int_node->value; |
213 | } |
214 | auto broadcast_node = node.as<BroadcastNode>(); |
215 | CHECK(broadcast_node != nullptr); |
216 | auto int_node = broadcast_node->value.as<IntImmNode>(); |
217 | CHECK(int_node != nullptr); |
218 | return int_node->value; |
219 | }; |
220 | // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of |
221 | // 2, fixed point multiplier will represent a float value of 0.5. In fixed point, this is |
222 | // represented by 1 << 30. |
223 | if (get_int_value(y) == (1 << 30)) { |
224 | PrimExpr exp = s - 1; |
225 | int exp_val = get_int_value(s) - 1; |
226 | if (exp_val > 0) { |
227 | // power of 2 is greater than 0, apply left shift. |
228 | return x << exp; |
229 | } else { |
230 | // power of 2 is less than 0, round and then apply right shift. |
231 | DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); |
232 | PrimExpr one = make_const(lp_dtype, 1); |
233 | exp = -exp; |
234 | PrimExpr rounding_factor = one << (exp - 1); |
235 | PrimExpr rounded_t = x + rounding_factor; |
236 | return rounded_t >> exp; |
237 | } |
238 | } else { |
239 | // Only int32 types are supported (any number of lanes is allowed) |
240 | ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); |
241 | |
242 | // Calculating integer shifts |
243 | PrimExpr zero = make_const(s.dtype(), 0); |
244 | PrimExpr left_shift = tir::Select(s > zero, s, zero); |
245 | PrimExpr right_shift = tir::Select(s > zero, zero, -s); |
246 | PrimExpr is_left_shift_required = (left_shift != zero); |
247 | |
248 | return QMultiplyShift(x, y, q, left_shift, right_shift, is_left_shift_required); |
249 | } |
250 | }); |
251 | |
252 | TVM_REGISTER_OP("tir.q_multiply_shift_per_axis" ) |
253 | .set_attr<FLegalize>("default.FLegalize" , [](const PrimExpr& e) -> PrimExpr { |
254 | const tir::CallNode* call = e.as<tir::CallNode>(); |
255 | ICHECK(call != nullptr); |
256 | |
257 | PrimExpr x = call->args[0]; |
258 | PrimExpr y = call->args[1]; |
259 | PrimExpr left_shift = call->args[2]; |
260 | PrimExpr right_shift = call->args[3]; |
261 | PrimExpr q = call->args[4]; |
262 | PrimExpr is_lshift_required = call->args[5]; |
263 | // Note, 7th argument is "is_rshift_required" flag, but we don't need that here. |
264 | // PrimExpr is_rshift_required = call->args[6]; |
265 | |
266 | return QMultiplyShift(x, y, q, left_shift, right_shift, is_lshift_required); |
267 | }); |
268 | } // namespace legalize |
269 | } // namespace codegen |
270 | } // namespace tvm |
271 | |