1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file searchsorted.cc |
22 | * \brief SearchSorted operators |
23 | */ |
24 | #include <tvm/relay/attrs/algorithm.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/op.h> |
27 | |
28 | namespace tvm { |
29 | namespace relay { |
30 | |
31 | TVM_REGISTER_NODE_TYPE(SearchSortedAttrs); |
32 | |
33 | bool SearchSortedRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
34 | const TypeReporter& reporter) { |
35 | const SearchSortedAttrs* param = attrs.as<SearchSortedAttrs>(); |
36 | ICHECK_EQ(types.size(), 3); |
37 | const auto* sorted_sequence = types[0].as<TensorTypeNode>(); |
38 | const auto* values = types[1].as<TensorTypeNode>(); |
39 | ICHECK(sorted_sequence) << "Expects TensorType in the first input" ; |
40 | ICHECK(values) << "Expects TensorType in the second input" ; |
41 | ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one" ; |
42 | |
43 | if (sorted_sequence->shape.size() > 1) { |
44 | ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) |
45 | << "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is " |
46 | "multi-dimensional." ; |
47 | |
48 | for (size_t i = 0; i < values->shape.size() - 1; ++i) { |
49 | if (sorted_sequence->shape[i].as<IntImmNode>() && values->shape[i].as<IntImmNode>()) { |
50 | ICHECK_EQ(sorted_sequence->shape[i].as<IntImmNode>()->value, |
51 | values->shape[i].as<IntImmNode>()->value) |
52 | << "`sorted_sequence and `values` do not have the same shape along outer axes" ; |
53 | } |
54 | } |
55 | } |
56 | |
57 | reporter->Assign(types[2], TensorType(values->shape, param->dtype)); |
58 | return true; |
59 | } |
60 | |
61 | Expr MakeSearchSorted(Expr sorted_sequence, Expr values, Bool right, DataType dtype) { |
62 | auto attrs = make_object<SearchSortedAttrs>(); |
63 | static const Op& op = Op::Get("searchsorted" ); |
64 | attrs->dtype = dtype; |
65 | attrs->right = right; |
66 | return Call(op, {sorted_sequence, values}, Attrs(attrs), {}); |
67 | } |
68 | |
69 | TVM_REGISTER_GLOBAL("relay.op._make.searchsorted" ).set_body_typed(MakeSearchSorted); |
70 | |
71 | RELAY_REGISTER_OP("searchsorted" ) |
72 | .describe( |
73 | R"doc(Find indices where elements should be inserted to maintain order. |
74 | If `sorted_sequence` is N-dimensional, the innermost dimension of |
75 | `values` are searched in the corresponding dimension of `sorted_sequence`. |
76 | )doc" TVM_ADD_FILELINE) |
77 | .set_num_inputs(2) |
78 | .set_attrs_type<SearchSortedAttrs>() |
79 | .add_argument("sorted_sequence" , "Tensor" , |
80 | "Monotonically increasing sequence on the innermost dimension." ) |
81 | .add_argument("values" , "Tensor" , "Values to search for." ) |
82 | .set_support_level(6) |
83 | .add_type_rel("SearchSorted" , SearchSortedRel); |
84 | |
85 | } // namespace relay |
86 | } // namespace tvm |
87 | |