1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file topk.cc |
22 | * \brief TopK operators |
23 | */ |
24 | #include <tvm/relay/attrs/algorithm.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/data_layout.h> |
27 | #include <tvm/tir/op.h> |
28 | |
29 | #include "../../transforms/infer_layout_utils.h" |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | |
34 | TVM_REGISTER_NODE_TYPE(TopKAttrs); |
35 | |
36 | InferCorrectLayoutOutput TopKInferCorrectLayout(const Attrs& attrs, |
37 | const Array<Layout>& new_in_layouts, |
38 | const Array<Layout>& old_in_layouts, |
39 | const Array<tvm::relay::Type>& old_in_types) { |
40 | const auto* attrs_ptr = attrs.as<TopKAttrs>(); |
41 | ICHECK(attrs_ptr); |
42 | ObjectPtr<TopKAttrs> param = make_object<TopKAttrs>(*attrs_ptr); |
43 | |
44 | Array<Array<IndexExpr>> old_in_shapes; |
45 | for (auto old_in_t : old_in_types) { |
46 | ICHECK(old_in_t.as<TensorTypeNode>()); |
47 | old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape); |
48 | } |
49 | |
50 | size_t axis = |
51 | param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis); |
52 | |
53 | Layout ret = Layout::Undef(); |
54 | |
55 | // If new_in_layouts are defined, this code tries to modify the layout. |
56 | if (new_in_layouts.defined() && old_in_layouts.defined()) { |
57 | const auto& sp_dim = old_in_layouts[0][axis]; |
58 | auto new_index = new_in_layouts[0].IndexOf(sp_dim); |
59 | param->axis = new_index; |
60 | ret = new_in_layouts[0]; |
61 | } else if (old_in_layouts.defined()) { |
62 | ret = old_in_layouts[0]; |
63 | } |
64 | |
65 | // TopK has 2 outputs, Values and Indices |
66 | return InferCorrectLayoutOutput({ret}, {ret, ret}, Attrs(param)); |
67 | } |
68 | |
69 | bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
70 | const TypeReporter& reporter) { |
71 | // `types` contains: [data, result] |
72 | const TopKAttrs* param = attrs.as<TopKAttrs>(); |
73 | ICHECK_EQ(types.size(), 2); |
74 | const auto* data = types[0].as<TensorTypeNode>(); |
75 | if (data == nullptr) return false; |
76 | int ndim = data->shape.size(); |
77 | int axis = param->axis; |
78 | if (axis < 0) { |
79 | axis += ndim; |
80 | } |
81 | ICHECK(axis >= 0 && axis < ndim); |
82 | Array<IndexExpr> out_shape; |
83 | for (int i = 0; i < ndim; ++i) { |
84 | if (i != axis) { |
85 | out_shape.push_back(data->shape[i]); |
86 | } else { |
87 | const Integer& ck = param->k.value(); |
88 | if (ck->value < 1) { |
89 | out_shape.push_back(data->shape[i]); |
90 | } else { |
91 | out_shape.push_back(ck); |
92 | } |
93 | } |
94 | } |
95 | auto values_ty = TensorType(out_shape, data->dtype); |
96 | auto indices_ty = TensorType(out_shape, param->dtype); |
97 | if (param->ret_type == "both" ) { |
98 | reporter->Assign(types[1], TupleType({values_ty, indices_ty})); |
99 | } else if (param->ret_type == "values" ) { |
100 | reporter->Assign(types[1], values_ty); |
101 | } else if (param->ret_type == "indices" ) { |
102 | reporter->Assign(types[1], indices_ty); |
103 | } else { |
104 | LOG(FATAL) << "Unsupported ret type: " << param->ret_type; |
105 | } |
106 | return true; |
107 | } |
108 | |
109 | Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) { |
110 | auto attrs = make_object<TopKAttrs>(); |
111 | attrs->k = Integer(k); |
112 | attrs->axis = axis; |
113 | attrs->ret_type = ret_type; |
114 | attrs->is_ascend = is_ascend; |
115 | attrs->dtype = dtype; |
116 | static const Op& op = Op::Get("topk" ); |
117 | return Call(op, {data}, Attrs(attrs), {}); |
118 | } |
119 | |
120 | TVM_REGISTER_GLOBAL("relay.op._make.topk" ).set_body_typed(MakeTopK); |
121 | |
122 | RELAY_REGISTER_OP("topk" ) |
123 | .describe(R"doc(Get the top k elements in an input tensor along the given axis. |
124 | )doc" TVM_ADD_FILELINE) |
125 | .set_num_inputs(1) |
126 | .set_attrs_type<TopKAttrs>() |
127 | .add_argument("data" , "Tensor" , "Input data." ) |
128 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , TopKInferCorrectLayout) |
129 | .set_support_level(6) |
130 | .add_type_rel("TopK" , TopKRel); |
131 | |
132 | } // namespace relay |
133 | } // namespace tvm |
134 | |