1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Registration of transform operators |
22 | * \file transform.cc |
23 | */ |
24 | #include <tvm/runtime/packed_func.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/topi/einsum.h> |
27 | #include <tvm/topi/transform.h> |
28 | #include <tvm/topi/utils.h> |
29 | |
30 | namespace tvm { |
31 | namespace topi { |
32 | |
33 | using namespace tvm; |
34 | using namespace tvm::runtime; |
35 | |
36 | TVM_REGISTER_GLOBAL("topi.expand_dims" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
37 | *rv = expand_dims(args[0], args[1], args[2]); |
38 | }); |
39 | |
40 | TVM_REGISTER_GLOBAL("topi.transpose" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
41 | *rv = transpose(args[0], args[1]); |
42 | }); |
43 | |
44 | TVM_REGISTER_GLOBAL("topi.flip" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
45 | // pass empty seq_lengths tensor to reverse_sequence |
46 | *rv = reverse_sequence(args[0], Tensor(), args[1]); |
47 | }); |
48 | |
49 | TVM_REGISTER_GLOBAL("topi.reverse_sequence" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
50 | *rv = reverse_sequence(args[0], args[1], args[2], args[3]); |
51 | }); |
52 | |
53 | TVM_REGISTER_GLOBAL("topi.reshape" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
54 | *rv = reshape(args[0], args[1]); |
55 | }); |
56 | |
57 | TVM_REGISTER_GLOBAL("topi.sliding_window" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
58 | *rv = sliding_window(args[0], args[1], args[2], args[3]); |
59 | }); |
60 | |
61 | TVM_REGISTER_GLOBAL("topi.squeeze" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
62 | *rv = squeeze(args[0], ArrayOrInt(args[1])); |
63 | }); |
64 | |
65 | TVM_REGISTER_GLOBAL("topi.concatenate" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
66 | *rv = concatenate(args[0], args[1]); |
67 | }); |
68 | |
69 | TVM_REGISTER_GLOBAL("topi.stack" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
70 | *rv = stack(args[0], args[1]); |
71 | }); |
72 | |
73 | TVM_REGISTER_GLOBAL("topi.shape" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
74 | *rv = shape(args[0], args[1]); |
75 | }); |
76 | |
77 | TVM_REGISTER_GLOBAL("topi.ndarray_size" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
78 | *rv = ndarray_size(args[0], args[1]); |
79 | }); |
80 | |
81 | TVM_REGISTER_GLOBAL("topi.split" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
82 | if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { |
83 | *rv = split_sections(args[0], args[1], args[2]); |
84 | } else { |
85 | *rv = split(args[0], args[1], args[2]); |
86 | } |
87 | }); |
88 | |
89 | TVM_REGISTER_GLOBAL("topi.layout_transform" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
90 | *rv = layout_transform(args[0], args[1], args[2]); |
91 | }); |
92 | |
93 | TVM_REGISTER_GLOBAL("topi.take" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
94 | if (args.size() == 4) { |
95 | std::string mode = args[3]; |
96 | int batch_dims = args[2]; |
97 | *rv = take(args[0], args[1], batch_dims, mode); |
98 | } else { |
99 | ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments" ; |
100 | int batch_dims = args[2]; |
101 | int axis = args[3]; |
102 | std::string mode = args[4]; |
103 | *rv = take(args[0], args[1], batch_dims, axis, mode); |
104 | } |
105 | }); |
106 | |
107 | TVM_REGISTER_GLOBAL("topi.sequence_mask" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
108 | double pad_val = args[2]; |
109 | int axis = args[3]; |
110 | *rv = sequence_mask(args[0], args[1], pad_val, axis); |
111 | }); |
112 | |
113 | TVM_REGISTER_GLOBAL("topi.where" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
114 | *rv = where(args[0], args[1], args[2]); |
115 | }); |
116 | |
117 | TVM_REGISTER_GLOBAL("topi.arange" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
118 | *rv = arange(args[0], args[1], args[2], args[3]); |
119 | }); |
120 | |
121 | TVM_REGISTER_GLOBAL("topi.meshgrid" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
122 | *rv = meshgrid(args[0], args[1]); |
123 | }); |
124 | |
125 | TVM_REGISTER_GLOBAL("topi.repeat" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
126 | *rv = repeat(args[0], args[1], args[2]); |
127 | }); |
128 | |
129 | TVM_REGISTER_GLOBAL("topi.tile" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
130 | *rv = tile(args[0], args[1]); |
131 | }); |
132 | |
133 | TVM_REGISTER_GLOBAL("topi.gather" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
134 | *rv = gather(args[0], args[1], args[2]); |
135 | }); |
136 | |
137 | TVM_REGISTER_GLOBAL("topi.gather_nd" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
138 | *rv = gather_nd(args[0], args[1]); |
139 | }); |
140 | |
141 | TVM_REGISTER_GLOBAL("topi.unravel_index" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
142 | *rv = unravel_index(args[0], args[1]); |
143 | }); |
144 | |
145 | TVM_REGISTER_GLOBAL("topi.sparse_to_dense" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
146 | *rv = sparse_to_dense(args[0], args[1], args[2], args[3]); |
147 | }); |
148 | |
149 | TVM_REGISTER_GLOBAL("topi.matmul" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
150 | switch (args.size()) { |
151 | case 2: |
152 | *rv = matmul(args[0], args[1]); |
153 | break; |
154 | case 3: |
155 | *rv = matmul(args[0], args[1], args[2]); |
156 | break; |
157 | case 4: |
158 | *rv = matmul(args[0], args[1], args[2], args[3]); |
159 | break; |
160 | default: |
161 | ICHECK(0) << "topi.matmul expects 2, 3 or 4 arguments" ; |
162 | } |
163 | }); |
164 | |
165 | TVM_REGISTER_GLOBAL("topi.tensordot" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
166 | if (args.size() == 2) { |
167 | *rv = tensordot(args[0], args[1]); |
168 | } else if (args.size() == 3) { |
169 | *rv = tensordot(args[0], args[1], args[2]); |
170 | } else { |
171 | Array<PrimExpr> axes = args[3]; |
172 | *rv = tensordot(args[0], args[1], args[2], axes); |
173 | } |
174 | }); |
175 | |
176 | TVM_REGISTER_GLOBAL("topi.strided_slice" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
177 | Tensor x = args[0]; |
178 | Array<PrimExpr> begin = args[1]; |
179 | Array<PrimExpr> end = args[2]; |
180 | Array<PrimExpr> strides = args[3]; |
181 | if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) { |
182 | Array<Integer> begin_static = args[1]; |
183 | Array<Integer> end_static = args[2]; |
184 | Array<Integer> strides_static = args[3]; |
185 | Array<Integer> axes = args[4]; |
186 | std::string slice_mode = args[5]; |
187 | if (axes.size() > 0) { |
188 | *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode); |
189 | } else { |
190 | *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); |
191 | } |
192 | } else { |
193 | *rv = dynamic_strided_slice(x, begin, end, strides); |
194 | } |
195 | }); |
196 | |
197 | TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
198 | te::Tensor begin = args[1]; |
199 | te::Tensor end = args[2]; |
200 | te::Tensor strides = args[3]; |
201 | *rv = dynamic_strided_slice(args[0], begin, end, strides); |
202 | }); |
203 | |
204 | TVM_REGISTER_GLOBAL("topi.one_hot" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
205 | int depth = args[3]; |
206 | int axis = args[4]; |
207 | DataType dtype = args[5]; |
208 | *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); |
209 | }); |
210 | |
211 | TVM_REGISTER_GLOBAL("topi.matrix_set_diag" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
212 | int k1 = args[2]; |
213 | int k2 = args[3]; |
214 | bool super_diag_right_align = args[4]; |
215 | bool sub_diag_right_align = args[5]; |
216 | *rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align); |
217 | }); |
218 | |
219 | TVM_REGISTER_GLOBAL("topi.adv_index" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
220 | *rv = adv_index(args[0], args[1]); |
221 | }); |
222 | |
223 | } // namespace topi |
224 | } // namespace tvm |
225 | |