1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file searchsorted.cc
22 * \brief SearchSorted operators
23 */
24#include <tvm/relay/attrs/algorithm.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/op.h>
27
28namespace tvm {
29namespace relay {
30
31TVM_REGISTER_NODE_TYPE(SearchSortedAttrs);
32
33bool SearchSortedRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
34 const TypeReporter& reporter) {
35 const SearchSortedAttrs* param = attrs.as<SearchSortedAttrs>();
36 ICHECK_EQ(types.size(), 3);
37 const auto* sorted_sequence = types[0].as<TensorTypeNode>();
38 const auto* values = types[1].as<TensorTypeNode>();
39 ICHECK(sorted_sequence) << "Expects TensorType in the first input";
40 ICHECK(values) << "Expects TensorType in the second input";
41 ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one";
42
43 if (sorted_sequence->shape.size() > 1) {
44 ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size())
45 << "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is "
46 "multi-dimensional.";
47
48 for (size_t i = 0; i < values->shape.size() - 1; ++i) {
49 if (sorted_sequence->shape[i].as<IntImmNode>() && values->shape[i].as<IntImmNode>()) {
50 ICHECK_EQ(sorted_sequence->shape[i].as<IntImmNode>()->value,
51 values->shape[i].as<IntImmNode>()->value)
52 << "`sorted_sequence and `values` do not have the same shape along outer axes";
53 }
54 }
55 }
56
57 reporter->Assign(types[2], TensorType(values->shape, param->dtype));
58 return true;
59}
60
61Expr MakeSearchSorted(Expr sorted_sequence, Expr values, Bool right, DataType dtype) {
62 auto attrs = make_object<SearchSortedAttrs>();
63 static const Op& op = Op::Get("searchsorted");
64 attrs->dtype = dtype;
65 attrs->right = right;
66 return Call(op, {sorted_sequence, values}, Attrs(attrs), {});
67}
68
69TVM_REGISTER_GLOBAL("relay.op._make.searchsorted").set_body_typed(MakeSearchSorted);
70
71RELAY_REGISTER_OP("searchsorted")
72 .describe(
73 R"doc(Find indices where elements should be inserted to maintain order.
74If `sorted_sequence` is N-dimensional, the innermost dimension of
75`values` are searched in the corresponding dimension of `sorted_sequence`.
76)doc" TVM_ADD_FILELINE)
77 .set_num_inputs(2)
78 .set_attrs_type<SearchSortedAttrs>()
79 .add_argument("sorted_sequence", "Tensor",
80 "Monotonically increasing sequence on the innermost dimension.")
81 .add_argument("values", "Tensor", "Values to search for.")
82 .set_support_level(6)
83 .add_type_rel("SearchSorted", SearchSortedRel);
84
85} // namespace relay
86} // namespace tvm
87