1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/attrs/vision.h |
22 | * \brief Auxiliary attributes for vision operators. |
23 | */ |
24 | #ifndef TVM_RELAY_ATTRS_ALGORITHM_H_ |
25 | #define TVM_RELAY_ATTRS_ALGORITHM_H_ |
26 | |
27 | #include <tvm/ir/attrs.h> |
28 | #include <tvm/relay/base.h> |
29 | #include <tvm/relay/expr.h> |
30 | |
31 | #include <string> |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | |
36 | /*! \brief Attributes used in argsort operators */ |
37 | struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> { |
38 | int axis; |
39 | bool is_ascend; |
40 | DataType dtype; |
41 | |
42 | TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs" ) { |
43 | TVM_ATTR_FIELD(axis).set_default(-1).describe( |
44 | "Axis along which to sort the input tensor." |
45 | "If not given, the flattened array is used." ); |
46 | TVM_ATTR_FIELD(is_ascend).set_default(true).describe( |
47 | "Whether to sort in ascending or descending order." |
48 | "By default, sort in ascending order" ); |
49 | TVM_ATTR_FIELD(dtype) |
50 | .set_default(NullValue<DataType>()) |
51 | .describe("DType of the output indices." ); |
52 | } |
53 | }; |
54 | |
55 | struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> { |
56 | Optional<Integer> k; |
57 | int axis; |
58 | bool is_ascend; |
59 | std::string ret_type; |
60 | DataType dtype; |
61 | |
62 | TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs" ) { |
63 | TVM_ATTR_FIELD(k).describe("Number of top elements to select" ); |
64 | TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor." ); |
65 | TVM_ATTR_FIELD(ret_type).set_default("both" ).describe( |
66 | "The return type [both, values, indices]." |
67 | "both - return both top k data and indices." |
68 | "values - return top k data only." |
69 | "indices - return top k indices only." ); |
70 | TVM_ATTR_FIELD(is_ascend).set_default(false).describe( |
71 | "Whether to sort in ascending or descending order." |
72 | "By default, sort in descending order" ); |
73 | TVM_ATTR_FIELD(dtype) |
74 | .set_default(NullValue<DataType>()) |
75 | .describe("Data type of the output indices." ); |
76 | } |
77 | }; |
78 | |
79 | struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> { |
80 | bool right; |
81 | DataType dtype; |
82 | |
83 | TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs" ) { |
84 | TVM_ATTR_FIELD(right).set_default(false).describe( |
85 | "Controls which index is returned if a value lands exactly on one of sorted values. If " |
86 | " false, the index of the first suitable location found is given. If true, return the " |
87 | "last such index. If there is no suitable index, return either 0 or N (where N is the " |
88 | "size of the innermost dimension)." ); |
89 | TVM_ATTR_FIELD(dtype) |
90 | .set_default(DataType::Int(32)) |
91 | .describe("Data type of the output indices." ); |
92 | } |
93 | }; |
94 | |
95 | } // namespace relay |
96 | } // namespace tvm |
97 | #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ |
98 | |