1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file Use standard C library call. |
22 | */ |
23 | |
24 | #include <dlpack/dlpack.h> |
25 | #include <tvm/runtime/registry.h> |
26 | |
27 | #include <algorithm> |
28 | #include <vector> |
29 | |
30 | #include "../../../../3rdparty/compiler-rt/builtin_fp16.h" |
31 | |
32 | namespace tvm { |
33 | namespace contrib { |
34 | |
35 | using namespace runtime; |
36 | |
37 | template <typename DType, bool stable_comparison = false> |
38 | bool CompareAscend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) { |
39 | if constexpr (stable_comparison) { |
40 | if (lhs.second == rhs.second) { |
41 | return lhs.first < rhs.first; |
42 | } |
43 | } |
44 | |
45 | return lhs.second < rhs.second; |
46 | } |
47 | |
48 | template <typename DType, bool stable_comparison = false> |
49 | bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) { |
50 | if constexpr (stable_comparison) { |
51 | if (lhs.second == rhs.second) { |
52 | return lhs.first < rhs.first; |
53 | } |
54 | } |
55 | |
56 | return lhs.second > rhs.second; |
57 | } |
58 | |
59 | struct float16 { |
60 | uint16_t bits; |
61 | float to_float() const { |
62 | return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(bits); |
63 | } |
64 | |
65 | inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); } |
66 | inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); } |
67 | inline bool operator<(const float16& rhs) const { return to_float() < rhs.to_float(); } |
68 | inline bool operator>(const float16& rhs) const { return to_float() > rhs.to_float(); } |
69 | inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); } |
70 | inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); } |
71 | }; |
72 | |
73 | // Argsort implemented C library sort for nms. |
74 | // Return indices of sorted tensor. |
75 | // By default, the last axis will be used to sort. |
76 | // sort_num specify the number of elements to be sorted. |
77 | // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) |
78 | // and sort axis is dk. sort_num should have dimension of |
79 | // (d1, d2, ..., d(k-1), d(k+1), ..., dn). |
80 | TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
81 | DLTensor* input = args[0]; |
82 | DLTensor* sort_num = args[1]; |
83 | DLTensor* output = args[2]; |
84 | int32_t axis = args[3]; |
85 | bool is_ascend = args[4]; |
86 | |
87 | auto dtype = input->dtype; |
88 | auto data_ptr = static_cast<float*>(input->data); |
89 | auto sort_num_ptr = static_cast<int32_t*>(sort_num->data); |
90 | std::vector<std::pair<int32_t, float>> sorter; |
91 | int64_t axis_mul_before = 1; |
92 | int64_t axis_mul_after = 1; |
93 | |
94 | if (axis < 0) { |
95 | axis = input->ndim + axis; |
96 | } |
97 | |
98 | // Currently only supports input dtype to be float32. |
99 | ICHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " |
100 | "to be float." ; |
101 | #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) |
102 | ICHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " |
103 | "to be float32." ; |
104 | #endif |
105 | ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " |
106 | "input ndim " |
107 | << input->ndim; |
108 | |
109 | for (int i = 0; i < input->ndim; ++i) { |
110 | if (i < axis) { |
111 | axis_mul_before *= input->shape[i]; |
112 | } else if (i > axis) { |
113 | axis_mul_after *= input->shape[i]; |
114 | } |
115 | } |
116 | |
117 | for (int64_t i = 0; i < axis_mul_before; ++i) { |
118 | for (int64_t j = 0; j < axis_mul_after; ++j) { |
119 | sorter.clear(); |
120 | int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j); |
121 | int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; |
122 | for (int64_t k = 0; k < current_sort_num; ++k) { |
123 | int64_t full_idx = base_idx + k * axis_mul_after; |
124 | sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); |
125 | } |
126 | if (is_ascend) { |
127 | #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
128 | if (dtype.bits == 16) { |
129 | std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); |
130 | } else { |
131 | #endif |
132 | std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>); |
133 | #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
134 | } |
135 | #endif |
136 | } else { |
137 | #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
138 | if (dtype.bits == 16) { |
139 | std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); |
140 | } else { |
141 | #endif |
142 | std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>); |
143 | #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
144 | } |
145 | #endif |
146 | } |
147 | for (int32_t k = 0; k < input->shape[axis]; ++k) { |
148 | *(static_cast<int32_t*>(output->data) + base_idx + k * axis_mul_after) = |
149 | k < static_cast<int32_t>(sorter.size()) ? sorter[k].first : k; |
150 | } |
151 | } |
152 | } |
153 | }); |
154 | |
155 | template <typename DataType, typename OutType> |
156 | void sort_impl( |
157 | DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, |
158 | std::function<void(OutType*, size_t, const std::pair<int64_t, DataType>&)> epilogue) { |
159 | auto data_ptr = static_cast<DataType*>(input->data); |
160 | auto out_ptr = static_cast<OutType*>(output->data); |
161 | std::vector<std::pair<int64_t, DataType>> sorter; |
162 | |
163 | int axis_mul_before = 1; |
164 | int axis_mul_after = 1; |
165 | for (int i = 0; i < input->ndim; ++i) { |
166 | if (i < axis) { |
167 | axis_mul_before *= input->shape[i]; |
168 | } else if (i > axis) { |
169 | axis_mul_after *= input->shape[i]; |
170 | } |
171 | } |
172 | |
173 | for (int i = 0; i < axis_mul_before; ++i) { |
174 | for (int j = 0; j < axis_mul_after; ++j) { |
175 | sorter.clear(); |
176 | int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; |
177 | for (int64_t k = 0; k < input->shape[axis]; ++k) { |
178 | int64_t full_idx = base_idx + k * axis_mul_after; |
179 | sorter.emplace_back(std::make_pair(k, data_ptr[full_idx])); |
180 | } |
181 | if (is_ascend) { |
182 | std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>); |
183 | } else { |
184 | std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>); |
185 | } |
186 | for (int64_t k = 0; k < input->shape[axis]; ++k) { |
187 | epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]); |
188 | } |
189 | } |
190 | } |
191 | } |
192 | |
193 | template <typename DataType, typename OutType> |
194 | void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { |
195 | return sort_impl<DataType, OutType>( |
196 | input, output, axis, is_ascend, |
197 | [](OutType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) { |
198 | out_ptr[index] = static_cast<OutType>(sort_pair.first); |
199 | }); |
200 | } |
201 | |
202 | template <typename DataType> |
203 | void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { |
204 | return sort_impl<DataType, DataType>( |
205 | input, output, axis, is_ascend, |
206 | [](DataType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) { |
207 | out_ptr[index] = sort_pair.second; |
208 | }); |
209 | } |
210 | |
211 | // Argsort implemented C library sort. |
212 | // Return indices of sorted tensor. |
213 | // By default, the last axis will be used to sort. |
214 | // sort_num specify the number of elements to be sorted. |
215 | // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) |
216 | // and sort axis is dk. sort_num should have dimension of |
217 | // (d1, d2, ..., d(k-1), d(k+1), ..., dn). |
218 | TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
219 | DLTensor* input = args[0]; |
220 | DLTensor* output = args[1]; |
221 | int32_t axis = args[2]; |
222 | bool is_ascend = args[3]; |
223 | if (axis < 0) { |
224 | axis = input->ndim + axis; |
225 | } |
226 | ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " |
227 | "input ndim " |
228 | << input->ndim; |
229 | |
230 | auto data_dtype = DLDataType2String(input->dtype); |
231 | auto out_dtype = DLDataType2String(output->dtype); |
232 | |
233 | if (data_dtype == "float32" ) { |
234 | if (out_dtype == "int32" ) { |
235 | argsort<float, int32_t>(input, output, axis, is_ascend); |
236 | } else if (out_dtype == "int64" ) { |
237 | argsort<float, int64_t>(input, output, axis, is_ascend); |
238 | } else if (out_dtype == "float32" ) { |
239 | argsort<float, float>(input, output, axis, is_ascend); |
240 | } else if (out_dtype == "float64" ) { |
241 | argsort<float, double>(input, output, axis, is_ascend); |
242 | } else { |
243 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
244 | } |
245 | } else if (data_dtype == "float64" ) { |
246 | if (out_dtype == "int32" ) { |
247 | argsort<double, int32_t>(input, output, axis, is_ascend); |
248 | } else if (out_dtype == "int64" ) { |
249 | argsort<double, int64_t>(input, output, axis, is_ascend); |
250 | } else if (out_dtype == "float32" ) { |
251 | argsort<double, float>(input, output, axis, is_ascend); |
252 | } else if (out_dtype == "float64" ) { |
253 | argsort<double, double>(input, output, axis, is_ascend); |
254 | } else { |
255 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
256 | } |
257 | #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
258 | } else if (data_dtype == "float16" ) { |
259 | if (out_dtype == "float16" ) { |
260 | argsort<__fp16, __fp16>(input, output, axis, is_ascend); |
261 | } else { |
262 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
263 | } |
264 | #endif |
265 | } else if (data_dtype == "int32" ) { |
266 | if (out_dtype == "int32" ) { |
267 | argsort<int32_t, int32_t>(input, output, axis, is_ascend); |
268 | } else if (out_dtype == "int64" ) { |
269 | argsort<int32_t, int64_t>(input, output, axis, is_ascend); |
270 | } else if (out_dtype == "float32" ) { |
271 | argsort<int32_t, float>(input, output, axis, is_ascend); |
272 | } else if (out_dtype == "float64" ) { |
273 | argsort<int32_t, double>(input, output, axis, is_ascend); |
274 | } else { |
275 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
276 | } |
277 | } else if (data_dtype == "int64" ) { |
278 | if (out_dtype == "int32" ) { |
279 | argsort<int64_t, int32_t>(input, output, axis, is_ascend); |
280 | } else if (out_dtype == "int64" ) { |
281 | argsort<int64_t, int64_t>(input, output, axis, is_ascend); |
282 | } else if (out_dtype == "float32" ) { |
283 | argsort<int64_t, float>(input, output, axis, is_ascend); |
284 | } else if (out_dtype == "float64" ) { |
285 | argsort<int64_t, double>(input, output, axis, is_ascend); |
286 | } else { |
287 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
288 | } |
289 | } else if (data_dtype == "float16" ) { |
290 | if (out_dtype == "int32" ) { |
291 | argsort<float16, int32_t>(input, output, axis, is_ascend); |
292 | } else if (out_dtype == "int64" ) { |
293 | argsort<float16, int64_t>(input, output, axis, is_ascend); |
294 | } else if (out_dtype == "float32" ) { |
295 | argsort<float16, float>(input, output, axis, is_ascend); |
296 | } else if (out_dtype == "float64" ) { |
297 | argsort<float16, double>(input, output, axis, is_ascend); |
298 | } else { |
299 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
300 | } |
301 | } else { |
302 | LOG(FATAL) << "Unsupported input dtype: " << data_dtype; |
303 | } |
304 | }); |
305 | |
306 | // Sort implemented C library sort. |
307 | // Return sorted tensor. |
308 | // By default, the last axis will be used to sort. |
309 | // sort_num specify the number of elements to be sorted. |
310 | // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) |
311 | // and sort axis is dk. sort_num should have dimension of |
312 | // (d1, d2, ..., d(k-1), d(k+1), ..., dn). |
313 | TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
314 | DLTensor* input = args[0]; |
315 | DLTensor* output = args[1]; |
316 | int32_t axis = args[2]; |
317 | bool is_ascend = args[3]; |
318 | if (axis < 0) { |
319 | axis = input->ndim + axis; |
320 | } |
321 | ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " |
322 | "input ndim " |
323 | << input->ndim; |
324 | |
325 | auto data_dtype = DLDataType2String(input->dtype); |
326 | auto out_dtype = DLDataType2String(output->dtype); |
327 | |
328 | ICHECK_EQ(data_dtype, out_dtype); |
329 | |
330 | if (data_dtype == "float32" ) { |
331 | sort<float>(input, output, axis, is_ascend); |
332 | } else if (data_dtype == "float64" ) { |
333 | sort<double>(input, output, axis, is_ascend); |
334 | #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
335 | } else if (data_dtype == "float16" ) { |
336 | sort<__fp16>(input, output, axis, is_ascend); |
337 | #endif |
338 | } else if (data_dtype == "int32" ) { |
339 | sort<int32_t>(input, output, axis, is_ascend); |
340 | } else if (data_dtype == "int64" ) { |
341 | sort<int64_t>(input, output, axis, is_ascend); |
342 | } else if (data_dtype == "float16" ) { |
343 | sort<float16>(input, output, axis, is_ascend); |
344 | } else { |
345 | LOG(FATAL) << "Unsupported input dtype: " << data_dtype; |
346 | } |
347 | }); |
348 | |
349 | template <typename DataType, typename IndicesType> |
350 | void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis, |
351 | bool is_ascend) { |
352 | DataType* data_ptr = static_cast<DataType*>(input->data); |
353 | DataType* values_ptr = |
354 | (out_values == nullptr) ? nullptr : static_cast<DataType*>(out_values->data); |
355 | IndicesType* indices_ptr = |
356 | (out_indices == nullptr) ? nullptr : static_cast<IndicesType*>(out_indices->data); |
357 | |
358 | // Maintain a min/max containing the top-k elements |
359 | std::vector<std::pair<int64_t, DataType>> running_heap; |
360 | |
361 | // Need +1 when inserting new element before maintaining heap invariant |
362 | running_heap.reserve(k + 1); |
363 | |
364 | int axis_mul_before = 1; |
365 | int axis_mul_after = 1; |
366 | for (int i = 0; i < input->ndim; ++i) { |
367 | if (i < axis) { |
368 | axis_mul_before *= input->shape[i]; |
369 | } else if (i > axis) { |
370 | axis_mul_after *= input->shape[i]; |
371 | } |
372 | } |
373 | if (k < 1) { |
374 | k = input->shape[axis]; |
375 | } |
376 | |
377 | for (int i = 0; i < axis_mul_before; ++i) { |
378 | for (int j = 0; j < axis_mul_after; ++j) { |
379 | running_heap.clear(); |
380 | int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; |
381 | int64_t dst_base_idx = i * k * axis_mul_after + j; |
382 | |
383 | // Start by creating min/max heap with fixed-k elements |
384 | int cur_axis_index = 0; |
385 | for (; cur_axis_index < k && cur_axis_index < input->shape[axis]; cur_axis_index++) { |
386 | int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; |
387 | running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx])); |
388 | } |
389 | if (!is_ascend) { |
390 | std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>); |
391 | } else { |
392 | std::make_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>); |
393 | } |
394 | |
395 | // Iterate through all elements, adding to heap along the way |
396 | for (; cur_axis_index < input->shape[axis]; cur_axis_index++) { |
397 | int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; |
398 | std::pair<int64_t, DataType> cur_val = {cur_axis_index, data_ptr[full_idx]}; |
399 | |
400 | // Eq. to cur_val.second > running_heap.second |
401 | if (!is_ascend && CompareDescend<DataType, true>(cur_val, running_heap[0])) { |
402 | running_heap.push_back(cur_val); |
403 | std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>); |
404 | std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>); |
405 | running_heap.pop_back(); |
406 | } else if (is_ascend && CompareAscend<DataType, true>(cur_val, running_heap[0])) { |
407 | running_heap.push_back(cur_val); |
408 | std::push_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>); |
409 | std::pop_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>); |
410 | running_heap.pop_back(); |
411 | } |
412 | } |
413 | |
414 | // finally sort heap and deliver results |
415 | if (is_ascend) { |
416 | std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>); |
417 | } else { |
418 | std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>); |
419 | } |
420 | |
421 | for (uint32_t kk = 0; kk < running_heap.size(); ++kk) { |
422 | if (indices_ptr != nullptr) { |
423 | indices_ptr[dst_base_idx + kk * axis_mul_after] = |
424 | static_cast<IndicesType>(running_heap[kk].first); |
425 | } |
426 | if (values_ptr != nullptr) { |
427 | values_ptr[dst_base_idx + kk * axis_mul_after] = |
428 | static_cast<DataType>(running_heap[kk].second); |
429 | } |
430 | } |
431 | } |
432 | } |
433 | } |
434 | |
435 | // Argsort implemented C library sort. |
436 | // Return indices of sorted tensor. |
437 | // By default, the last axis will be used to sort. |
438 | // sort_num specify the number of elements to be sorted. |
439 | // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) |
440 | // and sort axis is dk. sort_num should have dimension of |
441 | // (d1, d2, ..., d(k-1), d(k+1), ..., dn). |
442 | TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
443 | DLTensor* input = args[0]; |
444 | DLTensor* values_out = nullptr; |
445 | DLTensor* indices_out = nullptr; |
446 | int k = args[args.num_args - 4]; |
447 | int axis = args[args.num_args - 3]; |
448 | std::string ret_type = args[args.num_args - 2]; |
449 | bool is_ascend = args[args.num_args - 1]; |
450 | if (ret_type == "both" ) { |
451 | values_out = args[1]; |
452 | indices_out = args[2]; |
453 | } else if (ret_type == "values" ) { |
454 | values_out = args[1]; |
455 | } else if (ret_type == "indices" ) { |
456 | indices_out = args[1]; |
457 | } else { |
458 | LOG(FATAL) << "Unsupported ret type: " << ret_type; |
459 | } |
460 | if (axis < 0) { |
461 | axis = input->ndim + axis; |
462 | } |
463 | ICHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; |
464 | |
465 | auto data_dtype = DLDataType2String(input->dtype); |
466 | auto out_dtype = (indices_out == nullptr) ? "int64" : DLDataType2String(indices_out->dtype); |
467 | |
468 | if (data_dtype == "float32" ) { |
469 | if (out_dtype == "int32" ) { |
470 | topk<float, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
471 | } else if (out_dtype == "int64" ) { |
472 | topk<float, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
473 | } else if (out_dtype == "float32" ) { |
474 | topk<float, float>(input, values_out, indices_out, k, axis, is_ascend); |
475 | } else if (out_dtype == "float64" ) { |
476 | topk<float, double>(input, values_out, indices_out, k, axis, is_ascend); |
477 | } else { |
478 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
479 | } |
480 | } else if (data_dtype == "float64" ) { |
481 | if (out_dtype == "int32" ) { |
482 | topk<double, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
483 | } else if (out_dtype == "int64" ) { |
484 | topk<double, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
485 | } else if (out_dtype == "float32" ) { |
486 | topk<double, float>(input, values_out, indices_out, k, axis, is_ascend); |
487 | } else if (out_dtype == "float64" ) { |
488 | topk<double, double>(input, values_out, indices_out, k, axis, is_ascend); |
489 | } else { |
490 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
491 | } |
492 | } else if (data_dtype == "uint8" ) { |
493 | if (out_dtype == "uint8" ) { |
494 | topk<uint8_t, uint8_t>(input, values_out, indices_out, k, axis, is_ascend); |
495 | } else if (out_dtype == "int32" ) { |
496 | topk<uint8_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
497 | } else if (out_dtype == "int64" ) { |
498 | topk<uint8_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
499 | } else if (out_dtype == "float32" ) { |
500 | topk<uint8_t, float>(input, values_out, indices_out, k, axis, is_ascend); |
501 | } else if (out_dtype == "float64" ) { |
502 | topk<uint8_t, double>(input, values_out, indices_out, k, axis, is_ascend); |
503 | } else { |
504 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
505 | } |
506 | } else if (data_dtype == "int8" ) { |
507 | if (out_dtype == "int8" ) { |
508 | topk<int8_t, int8_t>(input, values_out, indices_out, k, axis, is_ascend); |
509 | } else if (out_dtype == "int32" ) { |
510 | topk<int8_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
511 | } else if (out_dtype == "int64" ) { |
512 | topk<int8_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
513 | } else if (out_dtype == "float32" ) { |
514 | topk<int8_t, float>(input, values_out, indices_out, k, axis, is_ascend); |
515 | } else if (out_dtype == "float64" ) { |
516 | topk<int8_t, double>(input, values_out, indices_out, k, axis, is_ascend); |
517 | } else { |
518 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
519 | } |
520 | } else if (data_dtype == "int32" ) { |
521 | if (out_dtype == "int32" ) { |
522 | topk<int32_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
523 | } else if (out_dtype == "int64" ) { |
524 | topk<int32_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
525 | } else if (out_dtype == "float32" ) { |
526 | topk<int32_t, float>(input, values_out, indices_out, k, axis, is_ascend); |
527 | } else if (out_dtype == "float64" ) { |
528 | topk<int32_t, double>(input, values_out, indices_out, k, axis, is_ascend); |
529 | } else { |
530 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
531 | } |
532 | } else if (data_dtype == "int64" ) { |
533 | if (out_dtype == "int32" ) { |
534 | topk<int64_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
535 | } else if (out_dtype == "int64" ) { |
536 | topk<int64_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
537 | } else if (out_dtype == "float32" ) { |
538 | topk<int64_t, float>(input, values_out, indices_out, k, axis, is_ascend); |
539 | } else if (out_dtype == "float64" ) { |
540 | topk<int64_t, double>(input, values_out, indices_out, k, axis, is_ascend); |
541 | } else { |
542 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
543 | } |
544 | } else if (data_dtype == "float16" ) { |
545 | if (out_dtype == "int32" ) { |
546 | topk<float16, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
547 | } else if (out_dtype == "int64" ) { |
548 | topk<float16, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
549 | } else if (out_dtype == "float32" ) { |
550 | topk<float16, float>(input, values_out, indices_out, k, axis, is_ascend); |
551 | } else if (out_dtype == "float64" ) { |
552 | topk<float16, double>(input, values_out, indices_out, k, axis, is_ascend); |
553 | } else { |
554 | LOG(FATAL) << "Unsupported output dtype: " << out_dtype; |
555 | } |
556 | } else { |
557 | LOG(FATAL) << "Unsupported input dtype: " << data_dtype; |
558 | } |
559 | }); |
560 | |
561 | } // namespace contrib |
562 | } // namespace tvm |
563 | |