1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file Use standard C library call.
22 */
23
24#include <dlpack/dlpack.h>
25#include <tvm/runtime/registry.h>
26
27#include <algorithm>
28#include <vector>
29
30#include "../../../../3rdparty/compiler-rt/builtin_fp16.h"
31
32namespace tvm {
33namespace contrib {
34
35using namespace runtime;
36
37template <typename DType, bool stable_comparison = false>
38bool CompareAscend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) {
39 if constexpr (stable_comparison) {
40 if (lhs.second == rhs.second) {
41 return lhs.first < rhs.first;
42 }
43 }
44
45 return lhs.second < rhs.second;
46}
47
48template <typename DType, bool stable_comparison = false>
49bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) {
50 if constexpr (stable_comparison) {
51 if (lhs.second == rhs.second) {
52 return lhs.first < rhs.first;
53 }
54 }
55
56 return lhs.second > rhs.second;
57}
58
59struct float16 {
60 uint16_t bits;
61 float to_float() const {
62 return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(bits);
63 }
64
65 inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); }
66 inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); }
67 inline bool operator<(const float16& rhs) const { return to_float() < rhs.to_float(); }
68 inline bool operator>(const float16& rhs) const { return to_float() > rhs.to_float(); }
69 inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); }
70 inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); }
71};
72
73// Argsort implemented C library sort for nms.
74// Return indices of sorted tensor.
75// By default, the last axis will be used to sort.
76// sort_num specify the number of elements to be sorted.
77// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
78// and sort axis is dk. sort_num should have dimension of
79// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
80TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TVMRetValue* ret) {
81 DLTensor* input = args[0];
82 DLTensor* sort_num = args[1];
83 DLTensor* output = args[2];
84 int32_t axis = args[3];
85 bool is_ascend = args[4];
86
87 auto dtype = input->dtype;
88 auto data_ptr = static_cast<float*>(input->data);
89 auto sort_num_ptr = static_cast<int32_t*>(sort_num->data);
90 std::vector<std::pair<int32_t, float>> sorter;
91 int64_t axis_mul_before = 1;
92 int64_t axis_mul_after = 1;
93
94 if (axis < 0) {
95 axis = input->ndim + axis;
96 }
97
98 // Currently only supports input dtype to be float32.
99 ICHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
100 "to be float.";
101#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1)
102 ICHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
103 "to be float32.";
104#endif
105 ICHECK_LT(axis, input->ndim) << "Axis out of boundary for "
106 "input ndim "
107 << input->ndim;
108
109 for (int i = 0; i < input->ndim; ++i) {
110 if (i < axis) {
111 axis_mul_before *= input->shape[i];
112 } else if (i > axis) {
113 axis_mul_after *= input->shape[i];
114 }
115 }
116
117 for (int64_t i = 0; i < axis_mul_before; ++i) {
118 for (int64_t j = 0; j < axis_mul_after; ++j) {
119 sorter.clear();
120 int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j);
121 int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
122 for (int64_t k = 0; k < current_sort_num; ++k) {
123 int64_t full_idx = base_idx + k * axis_mul_after;
124 sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
125 }
126 if (is_ascend) {
127#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1)
128 if (dtype.bits == 16) {
129 std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>);
130 } else {
131#endif
132 std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
133#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1)
134 }
135#endif
136 } else {
137#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1)
138 if (dtype.bits == 16) {
139 std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>);
140 } else {
141#endif
142 std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
143#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1)
144 }
145#endif
146 }
147 for (int32_t k = 0; k < input->shape[axis]; ++k) {
148 *(static_cast<int32_t*>(output->data) + base_idx + k * axis_mul_after) =
149 k < static_cast<int32_t>(sorter.size()) ? sorter[k].first : k;
150 }
151 }
152 }
153});
154
155template <typename DataType, typename OutType>
156void sort_impl(
157 DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
158 std::function<void(OutType*, size_t, const std::pair<int64_t, DataType>&)> epilogue) {
159 auto data_ptr = static_cast<DataType*>(input->data);
160 auto out_ptr = static_cast<OutType*>(output->data);
161 std::vector<std::pair<int64_t, DataType>> sorter;
162
163 int axis_mul_before = 1;
164 int axis_mul_after = 1;
165 for (int i = 0; i < input->ndim; ++i) {
166 if (i < axis) {
167 axis_mul_before *= input->shape[i];
168 } else if (i > axis) {
169 axis_mul_after *= input->shape[i];
170 }
171 }
172
173 for (int i = 0; i < axis_mul_before; ++i) {
174 for (int j = 0; j < axis_mul_after; ++j) {
175 sorter.clear();
176 int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
177 for (int64_t k = 0; k < input->shape[axis]; ++k) {
178 int64_t full_idx = base_idx + k * axis_mul_after;
179 sorter.emplace_back(std::make_pair(k, data_ptr[full_idx]));
180 }
181 if (is_ascend) {
182 std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>);
183 } else {
184 std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
185 }
186 for (int64_t k = 0; k < input->shape[axis]; ++k) {
187 epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]);
188 }
189 }
190 }
191}
192
193template <typename DataType, typename OutType>
194void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
195 return sort_impl<DataType, OutType>(
196 input, output, axis, is_ascend,
197 [](OutType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) {
198 out_ptr[index] = static_cast<OutType>(sort_pair.first);
199 });
200}
201
202template <typename DataType>
203void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
204 return sort_impl<DataType, DataType>(
205 input, output, axis, is_ascend,
206 [](DataType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) {
207 out_ptr[index] = sort_pair.second;
208 });
209}
210
211// Argsort implemented C library sort.
212// Return indices of sorted tensor.
213// By default, the last axis will be used to sort.
214// sort_num specify the number of elements to be sorted.
215// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
216// and sort axis is dk. sort_num should have dimension of
217// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
218TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRetValue* ret) {
219 DLTensor* input = args[0];
220 DLTensor* output = args[1];
221 int32_t axis = args[2];
222 bool is_ascend = args[3];
223 if (axis < 0) {
224 axis = input->ndim + axis;
225 }
226 ICHECK_LT(axis, input->ndim) << "Axis out of boundary for "
227 "input ndim "
228 << input->ndim;
229
230 auto data_dtype = DLDataType2String(input->dtype);
231 auto out_dtype = DLDataType2String(output->dtype);
232
233 if (data_dtype == "float32") {
234 if (out_dtype == "int32") {
235 argsort<float, int32_t>(input, output, axis, is_ascend);
236 } else if (out_dtype == "int64") {
237 argsort<float, int64_t>(input, output, axis, is_ascend);
238 } else if (out_dtype == "float32") {
239 argsort<float, float>(input, output, axis, is_ascend);
240 } else if (out_dtype == "float64") {
241 argsort<float, double>(input, output, axis, is_ascend);
242 } else {
243 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
244 }
245 } else if (data_dtype == "float64") {
246 if (out_dtype == "int32") {
247 argsort<double, int32_t>(input, output, axis, is_ascend);
248 } else if (out_dtype == "int64") {
249 argsort<double, int64_t>(input, output, axis, is_ascend);
250 } else if (out_dtype == "float32") {
251 argsort<double, float>(input, output, axis, is_ascend);
252 } else if (out_dtype == "float64") {
253 argsort<double, double>(input, output, axis, is_ascend);
254 } else {
255 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
256 }
257#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1)
258 } else if (data_dtype == "float16") {
259 if (out_dtype == "float16") {
260 argsort<__fp16, __fp16>(input, output, axis, is_ascend);
261 } else {
262 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
263 }
264#endif
265 } else if (data_dtype == "int32") {
266 if (out_dtype == "int32") {
267 argsort<int32_t, int32_t>(input, output, axis, is_ascend);
268 } else if (out_dtype == "int64") {
269 argsort<int32_t, int64_t>(input, output, axis, is_ascend);
270 } else if (out_dtype == "float32") {
271 argsort<int32_t, float>(input, output, axis, is_ascend);
272 } else if (out_dtype == "float64") {
273 argsort<int32_t, double>(input, output, axis, is_ascend);
274 } else {
275 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
276 }
277 } else if (data_dtype == "int64") {
278 if (out_dtype == "int32") {
279 argsort<int64_t, int32_t>(input, output, axis, is_ascend);
280 } else if (out_dtype == "int64") {
281 argsort<int64_t, int64_t>(input, output, axis, is_ascend);
282 } else if (out_dtype == "float32") {
283 argsort<int64_t, float>(input, output, axis, is_ascend);
284 } else if (out_dtype == "float64") {
285 argsort<int64_t, double>(input, output, axis, is_ascend);
286 } else {
287 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
288 }
289 } else if (data_dtype == "float16") {
290 if (out_dtype == "int32") {
291 argsort<float16, int32_t>(input, output, axis, is_ascend);
292 } else if (out_dtype == "int64") {
293 argsort<float16, int64_t>(input, output, axis, is_ascend);
294 } else if (out_dtype == "float32") {
295 argsort<float16, float>(input, output, axis, is_ascend);
296 } else if (out_dtype == "float64") {
297 argsort<float16, double>(input, output, axis, is_ascend);
298 } else {
299 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
300 }
301 } else {
302 LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
303 }
304});
305
306// Sort implemented C library sort.
307// Return sorted tensor.
308// By default, the last axis will be used to sort.
309// sort_num specify the number of elements to be sorted.
310// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
311// and sort axis is dk. sort_num should have dimension of
312// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
313TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body([](TVMArgs args, TVMRetValue* ret) {
314 DLTensor* input = args[0];
315 DLTensor* output = args[1];
316 int32_t axis = args[2];
317 bool is_ascend = args[3];
318 if (axis < 0) {
319 axis = input->ndim + axis;
320 }
321 ICHECK_LT(axis, input->ndim) << "Axis out of boundary for "
322 "input ndim "
323 << input->ndim;
324
325 auto data_dtype = DLDataType2String(input->dtype);
326 auto out_dtype = DLDataType2String(output->dtype);
327
328 ICHECK_EQ(data_dtype, out_dtype);
329
330 if (data_dtype == "float32") {
331 sort<float>(input, output, axis, is_ascend);
332 } else if (data_dtype == "float64") {
333 sort<double>(input, output, axis, is_ascend);
334#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1)
335 } else if (data_dtype == "float16") {
336 sort<__fp16>(input, output, axis, is_ascend);
337#endif
338 } else if (data_dtype == "int32") {
339 sort<int32_t>(input, output, axis, is_ascend);
340 } else if (data_dtype == "int64") {
341 sort<int64_t>(input, output, axis, is_ascend);
342 } else if (data_dtype == "float16") {
343 sort<float16>(input, output, axis, is_ascend);
344 } else {
345 LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
346 }
347});
348
349template <typename DataType, typename IndicesType>
350void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis,
351 bool is_ascend) {
352 DataType* data_ptr = static_cast<DataType*>(input->data);
353 DataType* values_ptr =
354 (out_values == nullptr) ? nullptr : static_cast<DataType*>(out_values->data);
355 IndicesType* indices_ptr =
356 (out_indices == nullptr) ? nullptr : static_cast<IndicesType*>(out_indices->data);
357
358 // Maintain a min/max containing the top-k elements
359 std::vector<std::pair<int64_t, DataType>> running_heap;
360
361 // Need +1 when inserting new element before maintaining heap invariant
362 running_heap.reserve(k + 1);
363
364 int axis_mul_before = 1;
365 int axis_mul_after = 1;
366 for (int i = 0; i < input->ndim; ++i) {
367 if (i < axis) {
368 axis_mul_before *= input->shape[i];
369 } else if (i > axis) {
370 axis_mul_after *= input->shape[i];
371 }
372 }
373 if (k < 1) {
374 k = input->shape[axis];
375 }
376
377 for (int i = 0; i < axis_mul_before; ++i) {
378 for (int j = 0; j < axis_mul_after; ++j) {
379 running_heap.clear();
380 int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j;
381 int64_t dst_base_idx = i * k * axis_mul_after + j;
382
383 // Start by creating min/max heap with fixed-k elements
384 int cur_axis_index = 0;
385 for (; cur_axis_index < k && cur_axis_index < input->shape[axis]; cur_axis_index++) {
386 int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after;
387 running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx]));
388 }
389 if (!is_ascend) {
390 std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>);
391 } else {
392 std::make_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>);
393 }
394
395 // Iterate through all elements, adding to heap along the way
396 for (; cur_axis_index < input->shape[axis]; cur_axis_index++) {
397 int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after;
398 std::pair<int64_t, DataType> cur_val = {cur_axis_index, data_ptr[full_idx]};
399
400 // Eq. to cur_val.second > running_heap.second
401 if (!is_ascend && CompareDescend<DataType, true>(cur_val, running_heap[0])) {
402 running_heap.push_back(cur_val);
403 std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>);
404 std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>);
405 running_heap.pop_back();
406 } else if (is_ascend && CompareAscend<DataType, true>(cur_val, running_heap[0])) {
407 running_heap.push_back(cur_val);
408 std::push_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>);
409 std::pop_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>);
410 running_heap.pop_back();
411 }
412 }
413
414 // finally sort heap and deliver results
415 if (is_ascend) {
416 std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>);
417 } else {
418 std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>);
419 }
420
421 for (uint32_t kk = 0; kk < running_heap.size(); ++kk) {
422 if (indices_ptr != nullptr) {
423 indices_ptr[dst_base_idx + kk * axis_mul_after] =
424 static_cast<IndicesType>(running_heap[kk].first);
425 }
426 if (values_ptr != nullptr) {
427 values_ptr[dst_base_idx + kk * axis_mul_after] =
428 static_cast<DataType>(running_heap[kk].second);
429 }
430 }
431 }
432 }
433}
434
435// Argsort implemented C library sort.
436// Return indices of sorted tensor.
437// By default, the last axis will be used to sort.
438// sort_num specify the number of elements to be sorted.
439// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
440// and sort axis is dk. sort_num should have dimension of
441// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
442TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetValue* ret) {
443 DLTensor* input = args[0];
444 DLTensor* values_out = nullptr;
445 DLTensor* indices_out = nullptr;
446 int k = args[args.num_args - 4];
447 int axis = args[args.num_args - 3];
448 std::string ret_type = args[args.num_args - 2];
449 bool is_ascend = args[args.num_args - 1];
450 if (ret_type == "both") {
451 values_out = args[1];
452 indices_out = args[2];
453 } else if (ret_type == "values") {
454 values_out = args[1];
455 } else if (ret_type == "indices") {
456 indices_out = args[1];
457 } else {
458 LOG(FATAL) << "Unsupported ret type: " << ret_type;
459 }
460 if (axis < 0) {
461 axis = input->ndim + axis;
462 }
463 ICHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim;
464
465 auto data_dtype = DLDataType2String(input->dtype);
466 auto out_dtype = (indices_out == nullptr) ? "int64" : DLDataType2String(indices_out->dtype);
467
468 if (data_dtype == "float32") {
469 if (out_dtype == "int32") {
470 topk<float, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
471 } else if (out_dtype == "int64") {
472 topk<float, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
473 } else if (out_dtype == "float32") {
474 topk<float, float>(input, values_out, indices_out, k, axis, is_ascend);
475 } else if (out_dtype == "float64") {
476 topk<float, double>(input, values_out, indices_out, k, axis, is_ascend);
477 } else {
478 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
479 }
480 } else if (data_dtype == "float64") {
481 if (out_dtype == "int32") {
482 topk<double, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
483 } else if (out_dtype == "int64") {
484 topk<double, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
485 } else if (out_dtype == "float32") {
486 topk<double, float>(input, values_out, indices_out, k, axis, is_ascend);
487 } else if (out_dtype == "float64") {
488 topk<double, double>(input, values_out, indices_out, k, axis, is_ascend);
489 } else {
490 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
491 }
492 } else if (data_dtype == "uint8") {
493 if (out_dtype == "uint8") {
494 topk<uint8_t, uint8_t>(input, values_out, indices_out, k, axis, is_ascend);
495 } else if (out_dtype == "int32") {
496 topk<uint8_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
497 } else if (out_dtype == "int64") {
498 topk<uint8_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
499 } else if (out_dtype == "float32") {
500 topk<uint8_t, float>(input, values_out, indices_out, k, axis, is_ascend);
501 } else if (out_dtype == "float64") {
502 topk<uint8_t, double>(input, values_out, indices_out, k, axis, is_ascend);
503 } else {
504 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
505 }
506 } else if (data_dtype == "int8") {
507 if (out_dtype == "int8") {
508 topk<int8_t, int8_t>(input, values_out, indices_out, k, axis, is_ascend);
509 } else if (out_dtype == "int32") {
510 topk<int8_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
511 } else if (out_dtype == "int64") {
512 topk<int8_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
513 } else if (out_dtype == "float32") {
514 topk<int8_t, float>(input, values_out, indices_out, k, axis, is_ascend);
515 } else if (out_dtype == "float64") {
516 topk<int8_t, double>(input, values_out, indices_out, k, axis, is_ascend);
517 } else {
518 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
519 }
520 } else if (data_dtype == "int32") {
521 if (out_dtype == "int32") {
522 topk<int32_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
523 } else if (out_dtype == "int64") {
524 topk<int32_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
525 } else if (out_dtype == "float32") {
526 topk<int32_t, float>(input, values_out, indices_out, k, axis, is_ascend);
527 } else if (out_dtype == "float64") {
528 topk<int32_t, double>(input, values_out, indices_out, k, axis, is_ascend);
529 } else {
530 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
531 }
532 } else if (data_dtype == "int64") {
533 if (out_dtype == "int32") {
534 topk<int64_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
535 } else if (out_dtype == "int64") {
536 topk<int64_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
537 } else if (out_dtype == "float32") {
538 topk<int64_t, float>(input, values_out, indices_out, k, axis, is_ascend);
539 } else if (out_dtype == "float64") {
540 topk<int64_t, double>(input, values_out, indices_out, k, axis, is_ascend);
541 } else {
542 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
543 }
544 } else if (data_dtype == "float16") {
545 if (out_dtype == "int32") {
546 topk<float16, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
547 } else if (out_dtype == "int64") {
548 topk<float16, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
549 } else if (out_dtype == "float32") {
550 topk<float16, float>(input, values_out, indices_out, k, axis, is_ascend);
551 } else if (out_dtype == "float64") {
552 topk<float16, double>(input, values_out, indices_out, k, axis, is_ascend);
553 } else {
554 LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
555 }
556 } else {
557 LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
558 }
559});
560
561} // namespace contrib
562} // namespace tvm
563