1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Registration of transform operators
22 * \file transform.cc
23 */
24#include <tvm/runtime/packed_func.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/topi/einsum.h>
27#include <tvm/topi/transform.h>
28#include <tvm/topi/utils.h>
29
30namespace tvm {
31namespace topi {
32
33using namespace tvm;
34using namespace tvm::runtime;
35
36TVM_REGISTER_GLOBAL("topi.expand_dims").set_body([](TVMArgs args, TVMRetValue* rv) {
37 *rv = expand_dims(args[0], args[1], args[2]);
38});
39
40TVM_REGISTER_GLOBAL("topi.transpose").set_body([](TVMArgs args, TVMRetValue* rv) {
41 *rv = transpose(args[0], args[1]);
42});
43
44TVM_REGISTER_GLOBAL("topi.flip").set_body([](TVMArgs args, TVMRetValue* rv) {
45 // pass empty seq_lengths tensor to reverse_sequence
46 *rv = reverse_sequence(args[0], Tensor(), args[1]);
47});
48
49TVM_REGISTER_GLOBAL("topi.reverse_sequence").set_body([](TVMArgs args, TVMRetValue* rv) {
50 *rv = reverse_sequence(args[0], args[1], args[2], args[3]);
51});
52
53TVM_REGISTER_GLOBAL("topi.reshape").set_body([](TVMArgs args, TVMRetValue* rv) {
54 *rv = reshape(args[0], args[1]);
55});
56
57TVM_REGISTER_GLOBAL("topi.sliding_window").set_body([](TVMArgs args, TVMRetValue* rv) {
58 *rv = sliding_window(args[0], args[1], args[2], args[3]);
59});
60
61TVM_REGISTER_GLOBAL("topi.squeeze").set_body([](TVMArgs args, TVMRetValue* rv) {
62 *rv = squeeze(args[0], ArrayOrInt(args[1]));
63});
64
65TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* rv) {
66 *rv = concatenate(args[0], args[1]);
67});
68
69TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) {
70 *rv = stack(args[0], args[1]);
71});
72
73TVM_REGISTER_GLOBAL("topi.shape").set_body([](TVMArgs args, TVMRetValue* rv) {
74 *rv = shape(args[0], args[1]);
75});
76
77TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* rv) {
78 *rv = ndarray_size(args[0], args[1]);
79});
80
81TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) {
82 if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
83 *rv = split_sections(args[0], args[1], args[2]);
84 } else {
85 *rv = split(args[0], args[1], args[2]);
86 }
87});
88
89TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetValue* rv) {
90 *rv = layout_transform(args[0], args[1], args[2]);
91});
92
93TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) {
94 if (args.size() == 4) {
95 std::string mode = args[3];
96 int batch_dims = args[2];
97 *rv = take(args[0], args[1], batch_dims, mode);
98 } else {
99 ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments";
100 int batch_dims = args[2];
101 int axis = args[3];
102 std::string mode = args[4];
103 *rv = take(args[0], args[1], batch_dims, axis, mode);
104 }
105});
106
107TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body([](TVMArgs args, TVMRetValue* rv) {
108 double pad_val = args[2];
109 int axis = args[3];
110 *rv = sequence_mask(args[0], args[1], pad_val, axis);
111});
112
113TVM_REGISTER_GLOBAL("topi.where").set_body([](TVMArgs args, TVMRetValue* rv) {
114 *rv = where(args[0], args[1], args[2]);
115});
116
117TVM_REGISTER_GLOBAL("topi.arange").set_body([](TVMArgs args, TVMRetValue* rv) {
118 *rv = arange(args[0], args[1], args[2], args[3]);
119});
120
121TVM_REGISTER_GLOBAL("topi.meshgrid").set_body([](TVMArgs args, TVMRetValue* rv) {
122 *rv = meshgrid(args[0], args[1]);
123});
124
125TVM_REGISTER_GLOBAL("topi.repeat").set_body([](TVMArgs args, TVMRetValue* rv) {
126 *rv = repeat(args[0], args[1], args[2]);
127});
128
129TVM_REGISTER_GLOBAL("topi.tile").set_body([](TVMArgs args, TVMRetValue* rv) {
130 *rv = tile(args[0], args[1]);
131});
132
133TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) {
134 *rv = gather(args[0], args[1], args[2]);
135});
136
137TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
138 *rv = gather_nd(args[0], args[1]);
139});
140
141TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) {
142 *rv = unravel_index(args[0], args[1]);
143});
144
145TVM_REGISTER_GLOBAL("topi.sparse_to_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
146 *rv = sparse_to_dense(args[0], args[1], args[2], args[3]);
147});
148
149TVM_REGISTER_GLOBAL("topi.matmul").set_body([](TVMArgs args, TVMRetValue* rv) {
150 switch (args.size()) {
151 case 2:
152 *rv = matmul(args[0], args[1]);
153 break;
154 case 3:
155 *rv = matmul(args[0], args[1], args[2]);
156 break;
157 case 4:
158 *rv = matmul(args[0], args[1], args[2], args[3]);
159 break;
160 default:
161 ICHECK(0) << "topi.matmul expects 2, 3 or 4 arguments";
162 }
163});
164
165TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs args, TVMRetValue* rv) {
166 if (args.size() == 2) {
167 *rv = tensordot(args[0], args[1]);
168 } else if (args.size() == 3) {
169 *rv = tensordot(args[0], args[1], args[2]);
170 } else {
171 Array<PrimExpr> axes = args[3];
172 *rv = tensordot(args[0], args[1], args[2], axes);
173 }
174});
175
176TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {
177 Tensor x = args[0];
178 Array<PrimExpr> begin = args[1];
179 Array<PrimExpr> end = args[2];
180 Array<PrimExpr> strides = args[3];
181 if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) {
182 Array<Integer> begin_static = args[1];
183 Array<Integer> end_static = args[2];
184 Array<Integer> strides_static = args[3];
185 Array<Integer> axes = args[4];
186 std::string slice_mode = args[5];
187 if (axes.size() > 0) {
188 *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode);
189 } else {
190 *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode);
191 }
192 } else {
193 *rv = dynamic_strided_slice(x, begin, end, strides);
194 }
195});
196
197TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {
198 te::Tensor begin = args[1];
199 te::Tensor end = args[2];
200 te::Tensor strides = args[3];
201 *rv = dynamic_strided_slice(args[0], begin, end, strides);
202});
203
204TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {
205 int depth = args[3];
206 int axis = args[4];
207 DataType dtype = args[5];
208 *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype);
209});
210
211TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) {
212 int k1 = args[2];
213 int k2 = args[3];
214 bool super_diag_right_align = args[4];
215 bool sub_diag_right_align = args[5];
216 *rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align);
217});
218
219TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) {
220 *rv = adv_index(args[0], args[1]);
221});
222
223} // namespace topi
224} // namespace tvm
225