1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file topk.cc
22 * \brief TopK operators
23 */
24#include <tvm/relay/attrs/algorithm.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/data_layout.h>
27#include <tvm/tir/op.h>
28
29#include "../../transforms/infer_layout_utils.h"
30
31namespace tvm {
32namespace relay {
33
34TVM_REGISTER_NODE_TYPE(TopKAttrs);
35
36InferCorrectLayoutOutput TopKInferCorrectLayout(const Attrs& attrs,
37 const Array<Layout>& new_in_layouts,
38 const Array<Layout>& old_in_layouts,
39 const Array<tvm::relay::Type>& old_in_types) {
40 const auto* attrs_ptr = attrs.as<TopKAttrs>();
41 ICHECK(attrs_ptr);
42 ObjectPtr<TopKAttrs> param = make_object<TopKAttrs>(*attrs_ptr);
43
44 Array<Array<IndexExpr>> old_in_shapes;
45 for (auto old_in_t : old_in_types) {
46 ICHECK(old_in_t.as<TensorTypeNode>());
47 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
48 }
49
50 size_t axis =
51 param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
52
53 Layout ret = Layout::Undef();
54
55 // If new_in_layouts are defined, this code tries to modify the layout.
56 if (new_in_layouts.defined() && old_in_layouts.defined()) {
57 const auto& sp_dim = old_in_layouts[0][axis];
58 auto new_index = new_in_layouts[0].IndexOf(sp_dim);
59 param->axis = new_index;
60 ret = new_in_layouts[0];
61 } else if (old_in_layouts.defined()) {
62 ret = old_in_layouts[0];
63 }
64
65 // TopK has 2 outputs, Values and Indices
66 return InferCorrectLayoutOutput({ret}, {ret, ret}, Attrs(param));
67}
68
69bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
70 const TypeReporter& reporter) {
71 // `types` contains: [data, result]
72 const TopKAttrs* param = attrs.as<TopKAttrs>();
73 ICHECK_EQ(types.size(), 2);
74 const auto* data = types[0].as<TensorTypeNode>();
75 if (data == nullptr) return false;
76 int ndim = data->shape.size();
77 int axis = param->axis;
78 if (axis < 0) {
79 axis += ndim;
80 }
81 ICHECK(axis >= 0 && axis < ndim);
82 Array<IndexExpr> out_shape;
83 for (int i = 0; i < ndim; ++i) {
84 if (i != axis) {
85 out_shape.push_back(data->shape[i]);
86 } else {
87 const Integer& ck = param->k.value();
88 if (ck->value < 1) {
89 out_shape.push_back(data->shape[i]);
90 } else {
91 out_shape.push_back(ck);
92 }
93 }
94 }
95 auto values_ty = TensorType(out_shape, data->dtype);
96 auto indices_ty = TensorType(out_shape, param->dtype);
97 if (param->ret_type == "both") {
98 reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
99 } else if (param->ret_type == "values") {
100 reporter->Assign(types[1], values_ty);
101 } else if (param->ret_type == "indices") {
102 reporter->Assign(types[1], indices_ty);
103 } else {
104 LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
105 }
106 return true;
107}
108
109Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) {
110 auto attrs = make_object<TopKAttrs>();
111 attrs->k = Integer(k);
112 attrs->axis = axis;
113 attrs->ret_type = ret_type;
114 attrs->is_ascend = is_ascend;
115 attrs->dtype = dtype;
116 static const Op& op = Op::Get("topk");
117 return Call(op, {data}, Attrs(attrs), {});
118}
119
120TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
121
122RELAY_REGISTER_OP("topk")
123 .describe(R"doc(Get the top k elements in an input tensor along the given axis.
124)doc" TVM_ADD_FILELINE)
125 .set_num_inputs(1)
126 .set_attrs_type<TopKAttrs>()
127 .add_argument("data", "Tensor", "Input data.")
128 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", TopKInferCorrectLayout)
129 .set_support_level(6)
130 .add_type_rel("TopK", TopKRel);
131
132} // namespace relay
133} // namespace tvm
134