1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file topi/transform.h |
22 | * \brief Transform op constructors |
23 | */ |
24 | #ifndef TVM_TOPI_TRANSFORM_H_ |
25 | #define TVM_TOPI_TRANSFORM_H_ |
26 | |
27 | #include <tvm/te/operation.h> |
28 | #include <tvm/tir/data_layout.h> |
29 | #include <tvm/tir/index_map.h> |
30 | #include <tvm/topi/broadcast.h> |
31 | #include <tvm/topi/detail/broadcast.h> |
32 | #include <tvm/topi/detail/constant_utils.h> |
33 | #include <tvm/topi/detail/ravel_unravel.h> |
34 | #include <tvm/topi/detail/strided_slice.h> |
35 | #include <tvm/topi/detail/tensor_utils.h> |
36 | #include <tvm/topi/tags.h> |
37 | |
38 | #include <algorithm> |
39 | #include <iterator> |
40 | #include <limits> |
41 | #include <string> |
42 | #include <unordered_set> |
43 | #include <vector> |
44 | |
45 | namespace tvm { |
46 | namespace topi { |
47 | |
48 | using namespace tvm::te; |
49 | using namespace topi::detail; |
50 | |
51 | /*! |
52 | * \brief Creates an operation to slide a window over the input x. |
53 | * |
54 | * \param x The input tensor. |
55 | * \param axis What axis the window begins sliding over. Window will be slid |
56 | * over this axis and all following axes. The axis value determines the window |
57 | * shape (and thus, the number of strides): window shape and strides must both |
58 | * be of length `data.ndim-axis`. |
59 | * \param window_shape The window shape to form over the input. Window shape |
60 | * must be of length `data.ndim-axis`. |
61 | * \param strides How to stride the window along each dimension. Strides must be |
62 | * of length `data.ndim-axis`. |
63 | * \param name The name of the operation |
64 | * \param tag The tag to mark the operation |
65 | * |
66 | * \return A Tensor whose op member is the sliding_window operation |
67 | */ |
68 | inline Tensor sliding_window(const Tensor& x, int axis, Array<Integer> window_shape, |
69 | Array<Integer> strides, std::string name = "T_sliding_window" , |
70 | std::string tag = "" ) { |
71 | CHECK_GE(axis, 0); |
72 | auto _axis = size_t(axis); |
73 | CHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x." ; |
74 | CHECK_EQ(x->shape.size() - _axis, window_shape.size()) |
75 | << "There must be a window shape for every dimension of x " |
76 | << "over which we are sliding the window." ; |
77 | CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length." ; |
78 | |
79 | // Compute the new shape. |
80 | Array<PrimExpr> new_shape; |
81 | // Dimensions up until `axis` remain the same. |
82 | for (size_t i = 0; i < _axis; ++i) { |
83 | new_shape.push_back(x->shape[i]); |
84 | } |
85 | |
86 | // New dimensions which result from sliding the window in each dimension. One new dimension per |
87 | // window dimension. |
88 | for (size_t i = 0; i < window_shape.size(); ++i) { |
89 | // Length of the shape along this dimension. |
90 | auto dim_len = x->shape[_axis + i]; |
91 | // Length of the window along this dimension. |
92 | auto window_len = window_shape[i]; |
93 | // Strides along this dimension. |
94 | auto stride = strides[i]; |
95 | |
96 | new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride)); |
97 | } |
98 | |
99 | // Dimensions comprising the window. |
100 | for (size_t i = 0; i < window_shape.size(); ++i) { |
101 | new_shape.push_back(window_shape[i]); |
102 | } |
103 | |
104 | ICHECK(new_shape.size() == _axis + 2 * window_shape.size()); |
105 | |
106 | return compute( |
107 | new_shape, |
108 | [&](const Array<Var>& indices) { |
109 | // The index at which to index the old tensor x. |
110 | Array<PrimExpr> idx; |
111 | |
112 | // Dimensions up until `axis` remain the same. |
113 | for (size_t i = 0; i < _axis; ++i) { |
114 | idx.push_back(indices[i]); |
115 | } |
116 | |
117 | for (size_t i = 0; i < window_shape.size(); ++i) { |
118 | // Which window in this dimension we are indexing. |
119 | auto window_idx = indices[_axis + i]; |
120 | // Which index within the window we are indexing. |
121 | auto idx_within_window = indices[_axis + window_shape.size() + i]; |
122 | // Stride value for this dimension. |
123 | auto stride = strides[i]; |
124 | |
125 | idx.push_back(window_idx * stride + idx_within_window); |
126 | } |
127 | |
128 | ICHECK(idx.size() == x->shape.size()); |
129 | |
130 | return x(idx); |
131 | }, |
132 | name, tag); |
133 | } |
134 | |
135 | /*! |
136 | * \brief Creates an operation to insert new dimensions of length 1 |
137 | * |
138 | * \param x The input tensor |
139 | * \param axis The index of the first new dimension (allows negative |
140 | * indices as offsets from the last dimension) |
141 | * \param num_newaxis The number of new dimensions to insert |
142 | * \param name The name of the operation |
143 | * \param tag The tag to mark the operation |
144 | * |
145 | * \return A Tensor whose op member is the dim expansion operation |
146 | */ |
147 | inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, |
148 | std::string name = "T_expand_dims" , std::string tag = kBroadcast) { |
149 | int ndim = static_cast<int>(x->shape.size()); |
150 | ICHECK(-ndim - 1 <= axis && axis <= ndim) |
151 | << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" |
152 | << ", but got axis = " << axis << ", and data.ndim = " << ndim; |
153 | ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" |
154 | << ", but got num_newaxis = " << num_newaxis; |
155 | if (axis < 0) { |
156 | // Calculate offset from last dimension |
157 | axis = ndim + axis + 1; |
158 | } |
159 | Array<PrimExpr> new_shape; |
160 | for (size_t i = 0; i < static_cast<size_t>(axis); ++i) { |
161 | new_shape.push_back(x->shape[i]); |
162 | } |
163 | for (size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) { |
164 | new_shape.push_back(1); |
165 | } |
166 | for (size_t i = axis; i < x->shape.size(); ++i) { |
167 | new_shape.push_back(x->shape[i]); |
168 | } |
169 | |
170 | return compute( |
171 | new_shape, |
172 | [&](const Array<Var>& indices) { |
173 | Array<PrimExpr> idx; |
174 | for (size_t i = 0; i < static_cast<size_t>(axis); ++i) { |
175 | idx.push_back(indices[i]); |
176 | } |
177 | for (size_t i = axis + num_newaxis; i < indices.size(); ++i) { |
178 | idx.push_back(indices[i]); |
179 | } |
180 | return x(idx); |
181 | }, |
182 | name, tag); |
183 | } |
184 | |
185 | /*! |
186 | * \brief Permute the dimensions of an array |
187 | * |
188 | * \param x The input tensor |
189 | * \param axes The indices of the permutation. If this is empty, |
190 | * the dimensions will be reversed. |
191 | * \param name The name of the operation |
192 | * \param tag The tag to mark the operation |
193 | * |
194 | * \return A Tensor whose op member is the transpose operation |
195 | */ |
196 | inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name = "T_transpose" , |
197 | std::string tag = kInjective) { |
198 | if (!axes.defined() || axes.size() == 0) { |
199 | axes = Array<Integer>(); |
200 | for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) { |
201 | axes.push_back(i); |
202 | } |
203 | } |
204 | |
205 | Array<PrimExpr> new_shape; |
206 | for (size_t i = 0; i < axes.size(); ++i) { |
207 | int axis = static_cast<int>(axes[i]->value); |
208 | int new_axis = axis; |
209 | if (axis < 0) { |
210 | new_axis = static_cast<int>(x->shape.size()) + axis; |
211 | axes.Set(i, new_axis); |
212 | } |
213 | ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size()))) |
214 | << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size()) |
215 | << "-dimensional input tensor" ; |
216 | |
217 | for (size_t j = 0; j < axes.size(); ++j) { |
218 | if (i != j) { |
219 | ICHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose" ; |
220 | } |
221 | } |
222 | new_shape.push_back(x->shape[new_axis]); |
223 | } |
224 | |
225 | return compute( |
226 | new_shape, |
227 | [&](const Array<Var>& indices) { |
228 | std::vector<PrimExpr> idx; |
229 | for (size_t i = 0; i < axes.size(); ++i) { |
230 | idx.push_back(1); |
231 | } |
232 | for (size_t i = 0; i < axes.size(); ++i) { |
233 | int axis = static_cast<int>(axes[i]->value); |
234 | idx[axis] = indices[i]; |
235 | } |
236 | return x(idx); |
237 | }, |
238 | name, tag); |
239 | } |
240 | |
241 | /*! |
242 | * \brief Reverse the tensor for variable length slices. |
243 | * Input is first sliced along batch axis and then elements are reversed along seq axis. |
244 | * |
245 | * \param x The input tensor |
246 | * \param seq_lengths A 1D Tensor with length x.dims[batch_axis]. Optional Tensor() can be passed. |
247 | * If not defined batch axis is ignored and tensor is reversed along seq_axis. |
248 | * \param seq_axis The axis along which the elements will be reveresed |
249 | * \param batch_axis The axis along which the tensor will be sliced |
250 | * \param name The name of the operation |
251 | * \param tag The tag to mark the operation |
252 | * |
253 | * \return A Tensor whose op member is the reverse_sequence operation |
254 | */ |
255 | inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1, |
256 | int batch_axis = 0, std::string name = "T_reverse_sequence" , |
257 | std::string tag = kInjective) { |
258 | size_t src_tensor_dim = x->shape.size(); |
259 | int seq_axis_inp = seq_axis; |
260 | |
261 | if (seq_lengths.defined()) { |
262 | size_t seq_lengths_dim = seq_lengths->shape.size(); |
263 | int batch_axis_inp = batch_axis; |
264 | if (batch_axis < 0) { |
265 | batch_axis = static_cast<int>(x->shape.size()) + batch_axis; |
266 | } |
267 | |
268 | ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector" ; |
269 | |
270 | ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis])) |
271 | << "For reverse_sequnece seq_lengths size should match with dimension of batch axis" |
272 | << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis]) |
273 | << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]); |
274 | |
275 | ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size()))) |
276 | << "batch_axis=" << batch_axis_inp << " is invalid for the " |
277 | << static_cast<int>(x->shape.size()) << "-dimensional input tensor" ; |
278 | } |
279 | |
280 | if (seq_axis < 0) { |
281 | seq_axis = static_cast<int>(x->shape.size()) + seq_axis; |
282 | } |
283 | ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size()))) |
284 | << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size()) |
285 | << "-dimensional input tensor" ; |
286 | |
287 | auto func = [&](const Array<Var>& indices) { |
288 | Array<PrimExpr> real_indices; |
289 | for (size_t i = 0; i < src_tensor_dim; ++i) { |
290 | if (i == static_cast<size_t>(seq_axis)) { |
291 | if (seq_lengths.defined()) { |
292 | auto len = seq_lengths(indices[batch_axis]); |
293 | auto idx = if_then_else( |
294 | len <= 1 || len <= indices[i], indices[i], |
295 | if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i])); |
296 | real_indices.push_back(idx); |
297 | } else { |
298 | real_indices.push_back(x->shape[i] - 1 - indices[i]); |
299 | } |
300 | } else { |
301 | real_indices.push_back(indices[i]); |
302 | } |
303 | } |
304 | return x(real_indices); |
305 | }; |
306 | |
307 | return compute(x->shape, func, name, tag); |
308 | } |
309 | |
310 | /*! |
311 | * \brief Reshape a tensor |
312 | * |
313 | * \param x The input tensor |
314 | * \param newshape The new shape |
315 | * \param name The name of the operation |
316 | * \param tag The tag to mark the operation |
317 | * |
318 | * \return A Tensor whose op member is the reshape operation |
319 | */ |
320 | inline Tensor reshape(const Tensor& x, Array<PrimExpr> newshape, std::string name = "T_reshape" , |
321 | std::string tag = kInjective) { |
322 | auto x_shape = x->shape; |
323 | Array<PrimExpr> target_shape; |
324 | |
325 | for (const auto& ele : newshape) { |
326 | if (ele.as<IntImmNode>()) { |
327 | target_shape.push_back(cast(DataType::Int(32), ele)); |
328 | } else { |
329 | target_shape.push_back(ele); |
330 | } |
331 | } |
332 | |
333 | // If either the input shape or the target shape contains a zero, return an empty tensor. |
334 | if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) { |
335 | return compute( |
336 | target_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag); |
337 | } else { |
338 | return compute( |
339 | target_shape, |
340 | [&](const Array<Var>& indices) { |
341 | return x(UnravelIndex( |
342 | RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape), x_shape)); |
343 | }, |
344 | name, tag); |
345 | } |
346 | } |
347 | |
348 | /*! |
349 | * \brief Converts a flat index or array of flat indices into a tuple of coordinate arrays |
350 | * |
351 | * \param x The input tensor having indices. |
352 | * \param shape The shape tensor |
353 | * \param name The name of the operation |
354 | * \param tag The tag to mark the operation |
355 | * |
356 | * \return A Tensor of coordinate arrays. |
357 | */ |
358 | |
359 | inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel" , |
360 | std::string tag = kInjective) { |
361 | auto x_shape = x->shape; |
362 | auto shape_shape = shape->shape; |
363 | |
364 | Array<PrimExpr> oshape; |
365 | oshape.push_back(shape_shape[0]); |
366 | if (x_shape.size() != 0) { |
367 | oshape.push_back(x_shape[0]); |
368 | } |
369 | |
370 | auto func = [&](const Array<Var>& indices) { |
371 | auto i = indices[0]; |
372 | std::vector<PrimExpr> indices_divs; |
373 | PrimExpr ret = 0; |
374 | PrimExpr cur_val = 0; |
375 | PrimExpr index_val = 0; |
376 | |
377 | if (x_shape.size() != 0) { |
378 | index_val = x[indices[1]]; |
379 | } else { |
380 | index_val = x(); |
381 | } |
382 | indices_divs.push_back(index_val); |
383 | for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) { |
384 | ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret); |
385 | cur_val = indexdiv(indices_divs.back(), shape[v]); |
386 | indices_divs.push_back(cur_val); |
387 | } |
388 | return ret; |
389 | }; |
390 | |
391 | return compute(oshape, func, name, tag); |
392 | } |
393 | |
394 | /*! |
395 | * \brief Remove size 1 dimensions from the shape of a tensor. |
396 | * The removed dimensions must have a constant size of 1. |
397 | * |
398 | * \param x The input tensor |
399 | * \param axis Indices of the dimensions to remove. If this is None, |
400 | * all entries with a constant size of 1 will be removed. |
401 | * \param atleast1d Whether the output need to be atleast1d. |
402 | * \param name The name of the operation |
403 | * \param tag The tag to mark the operation |
404 | * |
405 | * \return A Tensor whose op member is the squeeze operation |
406 | */ |
407 | inline Tensor squeeze(const Tensor& x, Array<Integer> axis, bool atleast1d = false, |
408 | std::string name = "T_squeeze" , std::string tag = kInjective) { |
409 | auto ndim = x->shape.size(); |
410 | std::vector<int> axis_val; |
411 | if (!axis.defined()) { |
412 | for (size_t i = 0; i < ndim; ++i) { |
413 | if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) { |
414 | axis_val.push_back(static_cast<int>(i)); |
415 | } |
416 | } |
417 | } else { |
418 | for (size_t i = 0; i < axis.size(); ++i) { |
419 | int64_t val = axis[i]->value; |
420 | if (val < 0) { |
421 | val += static_cast<int>(x->shape.size()); |
422 | } |
423 | if (IsConstInt(x->shape[val])) { |
424 | ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1" ; |
425 | } |
426 | axis_val.push_back(val); |
427 | } |
428 | } |
429 | |
430 | std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end()); |
431 | |
432 | Array<PrimExpr> out_shape; |
433 | for (size_t i = 0; i < ndim; ++i) { |
434 | if (axis_set.count(static_cast<int>(i)) == 0) { |
435 | out_shape.push_back(x->shape[i]); |
436 | } |
437 | } |
438 | if (out_shape.size() == 0 && atleast1d) { |
439 | out_shape.push_back(1); |
440 | } |
441 | |
442 | return compute( |
443 | out_shape, |
444 | [&](const Array<Var>& indices) { |
445 | Array<PrimExpr> real_indices; |
446 | int flag = 0; |
447 | for (size_t i = 0; i < ndim; ++i) { |
448 | if (axis_set.count(static_cast<int>(i)) == 0) { |
449 | real_indices.push_back(indices[i - flag]); |
450 | } else { |
451 | real_indices.push_back(0); |
452 | flag += 1; |
453 | } |
454 | } |
455 | return x(real_indices); |
456 | }, |
457 | name, tag); |
458 | } |
459 | |
460 | /*! |
461 | * \brief Join a sequence of tensors along an existing axis |
462 | * |
463 | * \param inputs The input tensors |
464 | * \param axis The axis along which the tensors will be joined |
465 | * \param name The name of the operation |
466 | * \param tag The tag to mark the operation |
467 | * |
468 | * \return A Tensor whose op member is the concatenate operation |
469 | */ |
470 | inline Tensor concatenate(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_concat" , |
471 | std::string tag = kInjective) { |
472 | int ndim = static_cast<int>(inputs[0]->shape.size()); |
473 | ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" |
474 | << ", but got axis = " << axis << ", and ndim = " << ndim; |
475 | if (axis < 0) { |
476 | axis += ndim; |
477 | } |
478 | ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds" ; |
479 | |
480 | Array<PrimExpr> axis_sizes; |
481 | for (auto t : inputs) { |
482 | axis_sizes.push_back(t->shape[axis]); |
483 | } |
484 | arith::Analyzer analyzer; |
485 | PrimExpr join_size = axis_sizes[0]; |
486 | for (size_t i = 1; i < axis_sizes.size(); ++i) { |
487 | join_size += axis_sizes[i]; |
488 | } |
489 | join_size = analyzer.Simplify(join_size); |
490 | Array<PrimExpr> out_shape; |
491 | for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { |
492 | out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]); |
493 | } |
494 | |
495 | return compute( |
496 | out_shape, |
497 | [&](const Array<Var>& indices) { |
498 | auto ret = inputs[0](indices); |
499 | auto ind = indices[axis]; |
500 | for (size_t i = 0; i < inputs.size() - 1; ++i) { |
501 | ind -= axis_sizes[i]; |
502 | |
503 | Array<PrimExpr> idx; |
504 | for (size_t i = 0; i < static_cast<size_t>(axis); ++i) { |
505 | idx.push_back(indices[i]); |
506 | } |
507 | idx.push_back(ind); |
508 | for (size_t i = axis + 1; i < indices.size(); ++i) { |
509 | idx.push_back(indices[i]); |
510 | } |
511 | |
512 | ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret); |
513 | } |
514 | return ret; |
515 | }, |
516 | name, tag); |
517 | } |
518 | |
519 | /*! |
520 | * \brief Join a sequence of tensors along a new axis. |
521 | * |
522 | * \param inputs The input tensors |
523 | * \param axis The axis along which the tensors will be stacked |
524 | * \param name The name of the operation |
525 | * \param tag The tag to mark the operation |
526 | * |
527 | * \return A Tensor whose op member is the stack operation |
528 | */ |
529 | inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack" , |
530 | std::string tag = kInjective) { |
531 | int ndim = static_cast<int>(inputs[0]->shape.size()); |
532 | ICHECK(-ndim - 1 <= axis && axis <= ndim) |
533 | << "stack only accepts `axis` in [-ndim, ndim)" |
534 | << ", but got axis = " << axis << ", and ndim = " << ndim; |
535 | if (axis < 0) { |
536 | axis += ndim + 1; |
537 | } |
538 | ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds" ; |
539 | |
540 | const int stack_size = static_cast<int>(inputs.size()); |
541 | Array<PrimExpr> out_shape; |
542 | for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]); |
543 | out_shape.push_back(stack_size); |
544 | for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i) |
545 | out_shape.push_back(inputs[0]->shape[i]); |
546 | |
547 | return compute( |
548 | out_shape, |
549 | [&](const Array<Var>& indices) { |
550 | Array<PrimExpr> idx; |
551 | for (size_t i = 0; i < indices.size(); ++i) |
552 | if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]); |
553 | auto ind = indices[axis]; |
554 | auto ret = inputs[0](idx); |
555 | for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) { |
556 | ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret); |
557 | } |
558 | return ret; |
559 | }, |
560 | name, tag); |
561 | } |
562 | |
563 | /*! |
564 | * \brief Split a tensor into multiple sub-tensors |
565 | * |
566 | * \param x The input tensor |
567 | * \param split_indices The indices to split the input at. This must be in ascending |
568 | * order. |
569 | * \param axis The axis to split along. |
570 | * \param name The name of the operation |
571 | * \param tag The tag to mark the operation |
572 | * |
573 | * \return A Tensor whose op member is the split operation |
574 | */ |
575 | inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis, |
576 | std::string name = "T_split" , std::string tag = kInjective) { |
577 | if (axis < 0) { |
578 | axis += static_cast<int>(x->shape.size()); |
579 | } |
580 | ICHECK_LT(axis, x->shape.size()) << "axis out of bounds" ; |
581 | |
582 | auto src_axis_size = x->shape[axis]; |
583 | std::vector<PrimExpr> begin_ids; |
584 | begin_ids.push_back(0); |
585 | |
586 | for (auto idx : split_indices) { |
587 | auto idx_node = idx.as<IntImmNode>(); |
588 | auto back_node = begin_ids.back().as<IntImmNode>(); |
589 | if (idx_node && back_node) { |
590 | ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted" ; |
591 | } |
592 | begin_ids.push_back(idx); |
593 | } |
594 | |
595 | Array<Array<PrimExpr>> out_shapes; |
596 | for (size_t i = 0; i < begin_ids.size(); ++i) { |
597 | PrimExpr out_axis_size; |
598 | if (i == begin_ids.size() - 1) { |
599 | out_axis_size = src_axis_size - begin_ids[i]; |
600 | } else { |
601 | out_axis_size = begin_ids[i + 1] - begin_ids[i]; |
602 | } |
603 | |
604 | Array<PrimExpr> shape; |
605 | for (size_t i = 0; i < static_cast<size_t>(axis); ++i) { |
606 | shape.push_back(x->shape[i]); |
607 | } |
608 | shape.push_back(out_axis_size); |
609 | for (size_t i = axis + 1; i < x->shape.size(); ++i) { |
610 | shape.push_back(x->shape[i]); |
611 | } |
612 | |
613 | out_shapes.push_back(shape); |
614 | } |
615 | |
616 | Array<Tensor> result; |
617 | for (size_t i = 0; i < begin_ids.size(); ++i) { |
618 | result.push_back(compute( |
619 | out_shapes[i], |
620 | [&](const Array<Var>& indices) { |
621 | auto begin = begin_ids[i]; |
622 | Array<PrimExpr> real_indices; |
623 | for (size_t j = 0; j < static_cast<size_t>(axis); ++j) { |
624 | real_indices.push_back(indices[j]); |
625 | } |
626 | real_indices.push_back(indices[axis] + begin); |
627 | for (size_t j = axis + 1; j < indices.size(); ++j) { |
628 | real_indices.push_back(indices[j]); |
629 | } |
630 | |
631 | return x(real_indices); |
632 | }, |
633 | name, tag)); |
634 | } |
635 | |
636 | return result; |
637 | } |
638 | |
639 | /*! |
640 | * \brief strided_slice of a tensor where begin/end/stride can be mixed static and dynamic |
641 | * |
642 | * \param x The input tensor |
643 | * \param begin The indices to begin with in the slicing |
644 | * \param end Indices indicating end of the slice |
645 | * \param strides Specifies the stride values, it can be negative |
646 | * in that case, the input tensor will be reversed in that particular axis |
647 | * \param name The name of the operation |
648 | * \param tag The tag to mark the operation |
649 | * |
650 | * \return A Tensor whose op member is the dynamic_strided_slice operation |
651 | */ |
652 | inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin, |
653 | const Array<PrimExpr>& end, const Array<PrimExpr>& strides, |
654 | std::string name = "T_dynamic_strided_slice" , |
655 | std::string tag = kInjective) { |
656 | const size_t src_tensor_dim = x->shape.size(); |
657 | ICHECK_LE(begin.size(), src_tensor_dim); |
658 | ICHECK_LE(end.size(), src_tensor_dim); |
659 | ICHECK_LE(strides.size(), src_tensor_dim); |
660 | ICHECK_EQ(begin.size(), end.size()); |
661 | ICHECK_EQ(begin.size(), strides.size()); |
662 | |
663 | const size_t num_slice_axes = begin.size(); |
664 | Array<PrimExpr> out_shape; |
665 | |
666 | for (size_t i = 0; i < num_slice_axes; ++i) { |
667 | auto d = indexdiv(end[i] - begin[i], strides[i]); |
668 | if (d->IsInstance<tvm::IntImmNode>()) { |
669 | // Preserve static dimension if possible |
670 | out_shape.push_back(d); |
671 | } else { |
672 | out_shape.push_back(tvm::tir::Var("dim" )); |
673 | } |
674 | } |
675 | |
676 | for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) { |
677 | out_shape.push_back(x->shape[i]); |
678 | } |
679 | |
680 | return te::compute( |
681 | out_shape, |
682 | [&](const Array<tvm::tir::Var>& indices) { |
683 | Array<PrimExpr> real_indices; |
684 | for (size_t i = 0; i < num_slice_axes; ++i) { |
685 | real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1)); |
686 | } |
687 | // keep input dim |
688 | for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) { |
689 | real_indices.push_back(indices[i]); |
690 | } |
691 | return x(real_indices); |
692 | }, |
693 | name, tag); |
694 | } |
695 | |
696 | /*! |
697 | * \brief strided_slice of a tensor with dynamic begin/end/stride |
698 | * |
699 | * \param x The input tensor |
700 | * \param begin The indices to begin with in the slicing |
701 | * \param end Indices indicating end of the slice |
702 | * \param strides Specifies the stride values, it can be negative |
703 | * in that case, the input tensor will be reversed in that particular axis |
704 | * \param name The name of the operation |
705 | * \param tag The tag to mark the operation |
706 | * |
707 | * \return A Tensor whose op member is the dynamic_strided_slice operation |
708 | */ |
709 | inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, |
710 | const te::Tensor& end, const te::Tensor& strides, |
711 | std::string name = "T_strided_slice_dynamic" , |
712 | std::string tag = topi::kInjective) { |
713 | const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value; |
714 | ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes); |
715 | ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes); |
716 | |
717 | Array<PrimExpr> begin_expr, end_expr, strides_expr; |
718 | for (int64_t i = 0; i < num_dynamic_axes; ++i) { |
719 | auto i64_ind = IntImm(DataType::Int(64), i); |
720 | begin_expr.push_back(begin(i64_ind)); |
721 | end_expr.push_back(end(i64_ind)); |
722 | strides_expr.push_back(strides(i64_ind)); |
723 | } |
724 | return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag); |
725 | } |
726 | |
727 | /*! |
728 | * \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation |
729 | * |
730 | * \param ishape The input tensor shape |
731 | * \param begin The indices to begin with in the slicing |
732 | * \param end Indices indicating end of the slice |
733 | * \param strides Specifies the stride values, it can be negative |
734 | * in that case, the input tensor will be reversed in that particular axis |
735 | * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end, |
736 | * strides, and axes argument must be equal |
737 | * \param slice_mode Specifies the slice mode |
738 | * |
739 | * \return The output shape of strided_slice using the arguments above |
740 | */ |
741 | inline Array<PrimExpr> StridedSliceOutputShape( |
742 | const Array<PrimExpr>& ishape, const Array<Integer>& begin, const Array<Integer>& end, |
743 | const Array<Integer>& strides, const Array<Integer>& axes, const std::string& slice_mode) { |
744 | ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); |
745 | std::vector<int64_t> begin_vec, end_vec, strides_vec; |
746 | std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); |
747 | auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, |
748 | begin[0]->dtype, slice_mode); |
749 | return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode, |
750 | begin_canonicalized, true); |
751 | } |
752 | |
753 | /*! |
754 | * \brief strided_slice of a tensor |
755 | * |
756 | * \param x The input tensor |
757 | * \param begin The indices to begin with in the slicing |
758 | * \param end Indices indicating end of the slice |
759 | * \param strides Specifies the stride values, it can be negative |
760 | * in that case, the input tensor will be reversed in that particular axis |
761 | * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end, |
762 | * strides, and axes argument must be equal |
763 | * \param slice_mode Specifies the slice mode |
764 | * \param name The name of the operation |
765 | * \param tag The tag to mark the operation |
766 | * |
767 | * \return A Tensor whose op member is the sstrided_slice operation |
768 | */ |
769 | inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin, |
770 | const Array<Integer>& end, const Array<Integer>& strides, |
771 | const Array<Integer>& axes, std::string slice_mode = "end" , |
772 | std::string name = "T_strided_slice_with_axes" , |
773 | std::string tag = kInjective) { |
774 | const size_t src_tensor_dim = x->shape.size(); |
775 | ICHECK(axes.size() <= src_tensor_dim); |
776 | ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); |
777 | |
778 | std::vector<int64_t> begin_vec, end_vec, strides_vec; |
779 | std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); |
780 | |
781 | auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes, |
782 | begin[0]->dtype, slice_mode); |
783 | auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes, |
784 | slice_mode, begin_expr); |
785 | |
786 | return te::compute( |
787 | out_shape, |
788 | [&](const Array<tir::Var>& indices) { |
789 | Array<PrimExpr> real_indices; |
790 | for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); |
791 | for (size_t i = 0; i < axes.size(); ++i) { |
792 | auto stride = make_const(strides[i].dtype(), strides_vec[i]); |
793 | PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i]; |
794 | real_indices.Set(axes[i].IntValue(), ind); |
795 | } |
796 | return x(real_indices); |
797 | }, |
798 | name, tag); |
799 | } |
800 | |
801 | /*! |
802 | * \brief strided_slice of a tensor |
803 | * |
804 | * \param x The input tensor |
805 | * \param begin The indices to begin with in the slicing |
806 | * \param end Indices indicating end of the slice |
807 | * \param strides Specifies the stride values, it can be negative |
808 | * in that case, the input tensor will be reversed in that particular axis |
809 | * \param slice_mode Specifies the slice mode |
810 | * \param name The name of the operation |
811 | * \param tag The tag to mark the operation |
812 | * |
813 | * \return A Tensor whose op member is the strided_slice operation |
814 | */ |
815 | inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end, |
816 | const Array<Integer>& strides, std::string slice_mode = "end" , |
817 | std::string name = "T_strided_slice" , std::string tag = kInjective) { |
818 | size_t src_tensor_dim = static_cast<size_t>(x->shape.size()); |
819 | Array<Integer> axes; |
820 | for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); |
821 | Array<Integer> begin_full(begin); |
822 | Array<Integer> end_full(end); |
823 | Array<Integer> strides_full(strides); |
824 | |
825 | const IntImm one = IntImm(DataType::Int(64), 1); |
826 | const IntImm zero = IntImm(DataType::Int(64), 0); |
827 | const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max()); |
828 | |
829 | for (size_t i = strides.size(); i < src_tensor_dim; ++i) { |
830 | strides_full.push_back(one); |
831 | } |
832 | for (size_t i = begin.size(); i < src_tensor_dim; ++i) { |
833 | begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range); |
834 | } |
835 | for (size_t i = end.size(); i < src_tensor_dim; ++i) { |
836 | end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range); |
837 | } |
838 | |
839 | return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name, |
840 | tag); |
841 | } |
842 | |
843 | /*! |
844 | * \brief Split a tensor into a number of sub-tensors |
845 | * |
846 | * \param x The input tensor |
847 | * \param num_sections The number of sections to split the tensor into. |
848 | * this must be an integer factor of the size of the axis being split. |
849 | * \param axis The axis to split along. |
850 | * \param name The name of the operation |
851 | * \param tag The tag to mark the operation |
852 | * |
853 | * \return A Tensor whose op member is the split operation |
854 | */ |
855 | inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis, |
856 | std::string name = "T_split_sections" , |
857 | std::string tag = kInjective) { |
858 | if (axis < 0) { |
859 | axis += static_cast<int>(x->shape.size()); |
860 | } |
861 | ICHECK_LT(axis, x->shape.size()) << "axis out of bounds" ; |
862 | |
863 | auto src_axis_size = x->shape[axis]; |
864 | |
865 | ICHECK_GT(num_sections, 0) << "Slice count must be > 0" ; |
866 | |
867 | if (auto node = src_axis_size.as<IntImmNode>()) { |
868 | ICHECK_EQ(node->value % num_sections, 0) |
869 | << "num_sections must be an integer factor of the size of axis " << axis << " (" |
870 | << node->value << ")" ; |
871 | } |
872 | |
873 | Array<PrimExpr> split_indices; |
874 | auto seg_size = indexdiv(src_axis_size, num_sections); |
875 | for (int i = 0; i < num_sections; ++i) { |
876 | // region at index 0 is added by split() |
877 | if (i != 0) { |
878 | split_indices.push_back(seg_size * i); |
879 | } |
880 | } |
881 | |
882 | return split(x, split_indices, axis, name, tag); |
883 | } |
884 | |
885 | /*! |
886 | * \brief Take elements from an flattened input array when axis is None. |
887 | * |
888 | * \param a The source array. |
889 | * \param indices The indices of the values to extract. |
890 | * \param batch_dims The number of batch dimensions. |
891 | * \param mode The mode of the operation. |
892 | * \param name The name of the operation. |
893 | * \param mode The mode of to handle out of bound indices. |
894 | * \param tag The tag to mark the operation. |
895 | * |
896 | * \return A Tensor whose op member is the take operation |
897 | */ |
898 | inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, |
899 | std::string mode = "clip" , std::string name = "T_take" , |
900 | std::string tag = kInjective) { |
901 | Array<PrimExpr> a_shape = a->shape; |
902 | Array<PrimExpr> out_shape = indices->shape; |
903 | PrimExpr a_size = 1; |
904 | for (size_t i = 0; i < a_shape.size(); ++i) { |
905 | a_size = a_size * a_shape[i]; |
906 | } |
907 | |
908 | if (mode == "clip" ) { |
909 | return compute( |
910 | out_shape, |
911 | [&](const Array<Var>& out_index) { |
912 | auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); |
913 | return a(UnravelIndex(idx, a_shape)); |
914 | }, |
915 | name, tag); |
916 | } else if (mode == "fast" ) { |
917 | LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " |
918 | "Make sure input indices are in bound" ; |
919 | return compute( |
920 | out_shape, |
921 | [&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); }, |
922 | name, tag); |
923 | } else { // mode == "wrap" |
924 | return compute( |
925 | out_shape, |
926 | [&](const Array<Var>& out_index) { |
927 | auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size); |
928 | return a(UnravelIndex(idx, a_shape)); |
929 | }, |
930 | name, tag); |
931 | } |
932 | } |
933 | |
934 | /*! |
935 | * \brief Mask the out-of-boundary elements of each sequence. |
936 | * |
937 | * \param data The source array. |
938 | * \param valid_length The real length of each sequence. |
939 | * \param mask_value The masking value. |
940 | * \param axis The axis of the temporal dimension of the sequence |
941 | * \param name The name of the operation. |
942 | * \param tag The tag to mark the operation. |
943 | * |
944 | * \return A Tensor whose op member is the sequence_mask operation |
945 | */ |
946 | inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value, |
947 | int axis, std::string name = "T_sequence_mask" , |
948 | std::string tag = kInjective) { |
949 | ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1" ; |
950 | ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,)." ; |
951 | auto length_dim = data->shape[axis]; |
952 | auto batch_dim = data->shape[1 - axis]; |
953 | Array<PrimExpr> out_shape = data->shape; |
954 | Tensor out = compute( |
955 | out_shape, |
956 | [&](const Array<Var>& out_index) { |
957 | Array<PrimExpr> len_index; |
958 | auto tid = out_index[axis]; |
959 | auto bid = out_index[1 - axis]; |
960 | len_index.push_back(bid); |
961 | PrimExpr ret = |
962 | tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), |
963 | tvm::tir::make_const(data->dtype, mask_value), data(out_index)); |
964 | return ret; |
965 | }, |
966 | name, tag); |
967 | return out; |
968 | } |
969 | |
970 | /*! |
971 | * \brief Take elements from an array along an axis. |
972 | * |
973 | * \param a The source array. |
974 | * \param indices The indices of the values to extract. |
975 | * \param batch_dims The number of batch dimensions. By default is 0. |
976 | * \param axis The axis over which to select values. By default, |
977 | * the flattened input array is used. |
978 | * \param mode The mode for handling out of bound indices. |
979 | * \param name The name of the operation. |
980 | * \param tag The tag to mark the operation. |
981 | * |
982 | * \return A Tensor whose op member is the take operation |
983 | */ |
984 | inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis, |
985 | std::string mode = "clip" , std::string name = "T_take" , |
986 | std::string tag = kInjective) { |
987 | if (axis < 0) { |
988 | axis += static_cast<int>(a->shape.size()); |
989 | } |
990 | ICHECK_GE(axis, 0) << "axis out of bounds" ; |
991 | ICHECK_LT(axis, a->shape.size()) << "axis out of bounds" ; |
992 | auto axis_dim = a->shape[axis]; |
993 | int indices_len = static_cast<int>(indices->shape.size()); |
994 | |
995 | int batch_dims_ = batch_dims; |
996 | if (batch_dims_ != 0) { |
997 | ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds" ; |
998 | ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds" ; |
999 | |
1000 | if (batch_dims_ < 0) { |
1001 | batch_dims_ = indices->shape.size() + batch_dims_; |
1002 | } |
1003 | |
1004 | ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds" ; |
1005 | ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis" ; |
1006 | for (int i = 0; i < batch_dims_; ++i) { |
1007 | auto addr1 = a->shape[i]; |
1008 | auto addr2 = indices->shape[i]; |
1009 | auto v1 = static_cast<IntImm*>(&addr1)->get()->value; |
1010 | auto v2 = static_cast<IntImm*>(&addr2)->get()->value; |
1011 | ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]" ; |
1012 | } |
1013 | } |
1014 | |
1015 | // The result shape is a.shape[:axis] + indices.shape[batch_dims:] + |
1016 | // a.shape[axis + 1:]. |
1017 | |
1018 | Array<PrimExpr> out_shape; |
1019 | for (int i = 0; i < batch_dims_; ++i) { |
1020 | out_shape.push_back(a->shape[i]); |
1021 | } |
1022 | for (int i = batch_dims_; i < axis; ++i) { |
1023 | out_shape.push_back(a->shape[i]); |
1024 | } |
1025 | for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) { |
1026 | out_shape.push_back(indices->shape[i]); |
1027 | } |
1028 | for (size_t i = axis + 1; i < a->shape.size(); ++i) { |
1029 | out_shape.push_back(a->shape[i]); |
1030 | } |
1031 | |
1032 | if (mode == "clip" ) { |
1033 | if (batch_dims_ == 0) { |
1034 | return compute( |
1035 | out_shape, |
1036 | [&](const Array<Var>& out_index) { |
1037 | Array<PrimExpr> indices_position; |
1038 | for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) { |
1039 | indices_position.push_back(out_index[j]); |
1040 | } |
1041 | Array<PrimExpr> real_indices; |
1042 | for (size_t j = 0; j < static_cast<size_t>(axis); ++j) { |
1043 | real_indices.push_back(out_index[j]); |
1044 | } |
1045 | auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); |
1046 | real_indices.push_back(idx); |
1047 | for (size_t j = axis + indices_len; j < out_index.size(); ++j) { |
1048 | real_indices.push_back(out_index[j]); |
1049 | } |
1050 | return a(real_indices); |
1051 | }, |
1052 | name, tag); |
1053 | } else { |
1054 | return compute( |
1055 | out_shape, |
1056 | [&](const Array<Var>& out_index) { |
1057 | Array<PrimExpr> indices_position; |
1058 | for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) { |
1059 | indices_position.push_back(out_index[j]); |
1060 | } |
1061 | for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) { |
1062 | indices_position.push_back(out_index[j]); |
1063 | } |
1064 | Array<PrimExpr> real_indices; |
1065 | for (size_t j = 0; j < static_cast<size_t>(axis); ++j) { |
1066 | real_indices.push_back(out_index[j]); |
1067 | } |
1068 | auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); |
1069 | real_indices.push_back(idx); |
1070 | for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) { |
1071 | real_indices.push_back(out_index[j]); |
1072 | } |
1073 | return a(real_indices); |
1074 | }, |
1075 | name, tag); |
1076 | } |
1077 | } else if (mode == "fast" ) { |
1078 | LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " |
1079 | "Make sure input indices are in bound" ; |
1080 | return compute( |
1081 | out_shape, |
1082 | [&](const Array<Var>& out_index) { |
1083 | Array<PrimExpr> indices_position; |
1084 | for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) { |
1085 | indices_position.push_back(out_index[j]); |
1086 | } |
1087 | Array<PrimExpr> real_indices; |
1088 | for (size_t j = 0; j < static_cast<size_t>(axis); ++j) { |
1089 | real_indices.push_back(out_index[j]); |
1090 | } |
1091 | real_indices.push_back(indices(indices_position)); |
1092 | for (size_t j = axis + indices_len; j < out_index.size(); ++j) { |
1093 | real_indices.push_back(out_index[j]); |
1094 | } |
1095 | return a(real_indices); |
1096 | }, |
1097 | name, tag); |
1098 | } else { // mode == "wrap" |
1099 | return compute( |
1100 | out_shape, |
1101 | [&](const Array<Var>& out_index) { |
1102 | Array<PrimExpr> indices_position; |
1103 | for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) { |
1104 | indices_position.push_back(out_index[j]); |
1105 | } |
1106 | Array<PrimExpr> real_indices; |
1107 | for (size_t j = 0; j < static_cast<size_t>(axis); ++j) { |
1108 | real_indices.push_back(out_index[j]); |
1109 | } |
1110 | auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim); |
1111 | real_indices.push_back(idx); |
1112 | for (size_t j = axis + indices_len; j < out_index.size(); ++j) { |
1113 | real_indices.push_back(out_index[j]); |
1114 | } |
1115 | return a(real_indices); |
1116 | }, |
1117 | name, tag); |
1118 | } |
1119 | } |
1120 | |
1121 | /*! |
1122 | * \brief Return the elements, either from x or y, depending on the condition. |
1123 | * |
1124 | * \param condition The condition array. |
1125 | * \param x First array to be selected. |
1126 | * \param y Second array to be selected. |
1127 | * \param name The name of the operation. |
1128 | * \param tag The tag to mark the operation. |
1129 | * |
1130 | * \return A Tensor selected from x or y depending on condition. |
1131 | */ |
1132 | inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, |
1133 | std::string name = "T_where" , std::string tag = kBroadcast) { |
1134 | ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " |
1135 | << y->dtype; |
1136 | auto get_out_shape = [&]() { |
1137 | auto bh1 = detail::BroadcastShape(x->shape, y->shape); |
1138 | Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); |
1139 | auto bh2 = detail::BroadcastShape(condition->shape, common_shape1); |
1140 | Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end()); |
1141 | return common_shape2; |
1142 | }; |
1143 | |
1144 | auto oshape = get_out_shape(); |
1145 | |
1146 | auto c_bh = detail::BroadcastShape(condition->shape, oshape); |
1147 | auto x_bh = detail::BroadcastShape(x->shape, oshape); |
1148 | auto y_bh = detail::BroadcastShape(y->shape, oshape); |
1149 | |
1150 | auto select = [&](tvm::Array<tvm::tir::Var> ovars) { |
1151 | auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars)); |
1152 | auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars)); |
1153 | auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars)); |
1154 | return tvm::tir::Select(c != 0, true_val, false_val); |
1155 | }; |
1156 | |
1157 | return compute(oshape, select, name, tag); |
1158 | } |
1159 | |
1160 | /*! |
1161 | * \brief Creates an operation to repeat elements of an array |
1162 | * |
1163 | * \param x The input tensor |
1164 | * \param repeats The number of repetitions for each element |
1165 | * \param axis The axis along which to repeat values (allows |
1166 | * negative indices as offsets from the last dimension) |
1167 | * \param name The name of the operation |
1168 | * \param tag The tag to mark the operation |
1169 | * |
1170 | * \return A Tensor whose op member is the repeat operation |
1171 | */ |
1172 | inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat" , |
1173 | std::string tag = kBroadcast) { |
1174 | int ndim = static_cast<int>(x->shape.size()); |
1175 | ICHECK(-ndim - 1 <= axis && axis <= ndim) |
1176 | << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" |
1177 | << ", but got axis = " << axis << ", and data.ndim = " << ndim; |
1178 | ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" |
1179 | << ", but got repeats = " << repeats; |
1180 | if (axis < 0) { |
1181 | // Calculate offset from last dimension |
1182 | axis += ndim; |
1183 | } |
1184 | Array<PrimExpr> new_shape; |
1185 | for (size_t i = 0; i < static_cast<size_t>(axis); ++i) { |
1186 | new_shape.push_back(x->shape[i]); |
1187 | } |
1188 | new_shape.push_back(repeats * x->shape[axis]); |
1189 | for (size_t i = axis + 1; i < x->shape.size(); ++i) { |
1190 | new_shape.push_back(x->shape[i]); |
1191 | } |
1192 | |
1193 | return compute( |
1194 | new_shape, |
1195 | [&](const Array<Var>& indices) { |
1196 | Array<PrimExpr> idx; |
1197 | for (size_t i = 0; i < static_cast<size_t>(axis); ++i) { |
1198 | idx.push_back(indices[i]); |
1199 | } |
1200 | idx.push_back(indexdiv(indices[axis], repeats)); |
1201 | for (size_t i = axis + 1; i < indices.size(); ++i) { |
1202 | idx.push_back(indices[i]); |
1203 | } |
1204 | return x(idx); |
1205 | }, |
1206 | name, tag); |
1207 | } |
1208 | |
1209 | /*! |
1210 | * \brief Creates an operation to tile elements of an array |
1211 | * |
1212 | * \param x The input tensor |
1213 | * \param reps The number of times for repeating the tensor |
1214 | * \param name The name of the operation |
1215 | * \param tag The tag to mark the operation |
1216 | * |
1217 | * \return A Tensor whose op member is the tile operation |
1218 | */ |
1219 | inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_tile" , |
1220 | std::string tag = kBroadcast) { |
1221 | size_t ndim = x->shape.size(); |
1222 | size_t rdim = reps.size(); |
1223 | size_t tdim = (ndim > rdim) ? ndim : rdim; |
1224 | Array<PrimExpr> data_shape; |
1225 | Array<PrimExpr> reps_shape; |
1226 | Array<PrimExpr> new_shape; |
1227 | if (ndim == rdim) { |
1228 | for (size_t i = 0; i < ndim; ++i) { |
1229 | data_shape.push_back(x->shape[i]); |
1230 | reps_shape.push_back(reps[i]); |
1231 | } |
1232 | } else if (ndim > rdim) { |
1233 | for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); |
1234 | for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1); |
1235 | for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); |
1236 | } else { |
1237 | for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1); |
1238 | for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); |
1239 | for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); |
1240 | } |
1241 | for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); |
1242 | |
1243 | if (is_empty_shape(new_shape)) { |
1244 | return compute( |
1245 | new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag); |
1246 | } else { |
1247 | return compute( |
1248 | new_shape, |
1249 | [&](const Array<Var>& indices) { |
1250 | Array<PrimExpr> idx; |
1251 | if (ndim >= rdim) { |
1252 | for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i])); |
1253 | } else { |
1254 | for (size_t i = 0; i < ndim; ++i) |
1255 | idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); |
1256 | } |
1257 | return x(idx); |
1258 | }, |
1259 | name, tag); |
1260 | } |
1261 | } |
1262 | |
1263 | /*! |
1264 | * \brief Creates an operation to tile elements of an array |
1265 | * |
1266 | * \param x The input tensor |
1267 | * \param new_shape The shape of the output after tiling |
1268 | * \param rdim The rank of the reps, provided by caller |
1269 | * \param name The name of the operation |
1270 | * \param tag The tag to mark the operation |
1271 | * |
1272 | * \return A Tensor whose op member is the tile operation |
1273 | */ |
1274 | inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim, |
1275 | std::string name = "T_tile" , std::string tag = kBroadcast) { |
1276 | size_t ndim = x->shape.size(); |
1277 | if (is_empty_shape(new_shape)) { |
1278 | return compute( |
1279 | new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag); |
1280 | } else { |
1281 | return compute( |
1282 | new_shape, |
1283 | [&](const Array<Var>& indices) { |
1284 | Array<PrimExpr> idx; |
1285 | if (ndim >= rdim) { |
1286 | for (size_t i = 0; i < ndim; ++i) { |
1287 | idx.push_back(indexmod(indices[i], x->shape[i])); |
1288 | } |
1289 | } else { |
1290 | for (size_t i = 0; i < ndim; ++i) { |
1291 | idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); |
1292 | } |
1293 | } |
1294 | return x(idx); |
1295 | }, |
1296 | name, tag); |
1297 | } |
1298 | } |
1299 | |
1300 | /*! |
1301 | * \brief Gather values along given axis from given indices. |
1302 | * |
1303 | * \param data The input data to the operator. |
1304 | * \param axis The axis along which to index. |
1305 | * \param indices The indices of values to gather. |
1306 | * \param name The name of the operation. |
1307 | * \param tag The tag to mark the operation. |
1308 | * |
1309 | * \return A Tensor whose op member is the gather operation |
1310 | */ |
1311 | inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, |
1312 | std::string name = "T_gather" , std::string tag = kInjective) { |
1313 | size_t ndim_d = data->shape.size(); |
1314 | size_t ndim_i = indices->shape.size(); |
1315 | ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar." ; |
1316 | ICHECK_EQ(ndim_d, ndim_i); |
1317 | if (axis < 0) { |
1318 | axis += ndim_d; |
1319 | } |
1320 | ICHECK_GE(axis, 0); |
1321 | ICHECK_LT(axis, ndim_d); |
1322 | if (indices->shape[axis].as<IntImmNode>()) { |
1323 | size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis])); |
1324 | ICHECK_GE(indices_dim_i, 1); |
1325 | } |
1326 | ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()); |
1327 | |
1328 | Array<PrimExpr> out_shape; |
1329 | for (size_t i = 0; i < ndim_i; ++i) { |
1330 | out_shape.push_back(indices->shape[i]); |
1331 | } |
1332 | |
1333 | return compute( |
1334 | out_shape, |
1335 | [&](const Array<Var>& out_index) { |
1336 | Array<PrimExpr> indices_position; |
1337 | for (size_t i = 0; i < ndim_i; ++i) { |
1338 | indices_position.push_back(out_index[i]); |
1339 | } |
1340 | Array<PrimExpr> real_indices; |
1341 | for (size_t i = 0; i < ndim_i; ++i) { |
1342 | if (i == static_cast<size_t>(axis)) { |
1343 | real_indices.push_back(indices(indices_position)); |
1344 | } else { |
1345 | real_indices.push_back(indices_position[i]); |
1346 | } |
1347 | } |
1348 | return data(real_indices); |
1349 | }, |
1350 | name, tag); |
1351 | } |
1352 | |
1353 | /*! |
1354 | * \brief Gather elements from a n-dimension array. |
1355 | * |
1356 | * \param data The source array. |
1357 | * \param indices The indices of the values to extract. |
1358 | * \param batch_dims The number of batch dimensions. |
1359 | * \param name The name of the operation. |
1360 | * \param tag The tag to mark the operation. |
1361 | * |
1362 | * \return A Tensor whose op member is the gather_nd operation |
1363 | */ |
1364 | inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0, |
1365 | std::string name = "T_gather_nd" , std::string tag = kInjective) { |
1366 | size_t ndim_d = data->shape.size(); |
1367 | size_t ndim_i = indices->shape.size(); |
1368 | ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions" ; |
1369 | size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0])); |
1370 | ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more " |
1371 | << "than dimensions of data tensor" ; |
1372 | Array<PrimExpr> out_shape; |
1373 | for (size_t i = 1; i < ndim_i; ++i) { |
1374 | out_shape.push_back(indices->shape[i]); |
1375 | } |
1376 | for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) { |
1377 | out_shape.push_back(data->shape[i]); |
1378 | } |
1379 | return compute( |
1380 | out_shape, |
1381 | [&](const Array<Var>& out_index) { |
1382 | Array<PrimExpr> indices_position; |
1383 | indices_position.push_back(0); |
1384 | for (size_t i = 0; i < ndim_i - 1; ++i) { |
1385 | indices_position.push_back(out_index[i]); |
1386 | } |
1387 | Array<PrimExpr> real_indices; |
1388 | for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) { |
1389 | real_indices.push_back(out_index[i]); |
1390 | } |
1391 | for (size_t i = 0; i < indices_dim0; ++i) { |
1392 | indices_position.Set(0, make_const(DataType::Int(32), i)); |
1393 | if (indices->dtype.is_int() || indices->dtype.is_uint()) { |
1394 | real_indices.push_back(indices(indices_position)); |
1395 | } else { |
1396 | real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); |
1397 | } |
1398 | } |
1399 | if (real_indices.size() == ndim_d) { |
1400 | return data(real_indices); |
1401 | } |
1402 | for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { |
1403 | real_indices.push_back(out_index[i]); |
1404 | } |
1405 | return data(real_indices); |
1406 | }, |
1407 | name, tag); |
1408 | } |
1409 | |
1410 | /*! |
1411 | * \brief Creates an operation that calculates a matrix multiplication |
1412 | * (row-major notation): |
1413 | * A(i, k) * B(k, j), if trans_a == trans_b |
1414 | * the usual transposed combinations, otherwise |
1415 | * |
1416 | * \param A The matrix A |
1417 | * \param B The matrix B |
1418 | * \param trans_a Is A's layout transposed? |
1419 | * \param trans_b Is B's layout transposed? |
1420 | * \param name The name of the operation |
1421 | * \param tag The tag to mark the operation |
1422 | * |
1423 | * \return A Tensor whose op member is the matmul operation |
1424 | */ |
1425 | inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B, |
1426 | bool trans_a = false, bool trans_b = false, |
1427 | std::string name = "T_matmul" , std::string tag = kMatMul) { |
1428 | tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; |
1429 | auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k" ); |
1430 | auto l = [&](tvm::tir::Var i, tvm::tir::Var j) { |
1431 | return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k}); |
1432 | }; |
1433 | return tvm::te::compute(output_shape, l, name, tag); |
1434 | } |
1435 | |
1436 | /*! |
1437 | * \brief A generalization of matrix multiplication to tensors. |
1438 | * |
1439 | * \param A The tensor A |
1440 | * \param B The tensor B |
1441 | * \param axes The number of the dimensions to reduce over |
1442 | * \param name The name of the operation |
1443 | * \param tag The tag to mark the operation |
1444 | * |
1445 | * \return A Tensor computing the result |
1446 | */ |
1447 | inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, |
1448 | std::string name = "T_tensordot" , std::string tag = kMatMul) { |
1449 | ICHECK_GE(A->shape.size(), axes); |
1450 | ICHECK_GE(B->shape.size(), axes); |
1451 | |
1452 | Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes)); |
1453 | for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it); |
1454 | |
1455 | Array<IterVar> iter_vars; |
1456 | for (int i = 0; i < axes; ++i) |
1457 | iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i))); |
1458 | |
1459 | auto func = [&A, &B, &iter_vars, axes](const Array<Var>& input_indices) { |
1460 | Array<PrimExpr> A_indices(input_indices.begin(), |
1461 | input_indices.begin() + (A->shape.size() - axes)); |
1462 | for (auto& v : iter_vars) A_indices.push_back(v); |
1463 | |
1464 | Array<PrimExpr> B_indices; |
1465 | for (auto& v : iter_vars) B_indices.push_back(v); |
1466 | |
1467 | auto it = input_indices.begin() + (A->shape.size() - axes); |
1468 | for (; it != input_indices.end(); ++it) B_indices.push_back(*it); |
1469 | |
1470 | // Some passes don't like reductions with empty axis, so avoid it here |
1471 | if (iter_vars.empty()) { |
1472 | return A(A_indices) * B(B_indices); |
1473 | } else { |
1474 | return sum(A(A_indices) * B(B_indices), iter_vars); |
1475 | } |
1476 | }; |
1477 | |
1478 | return compute(output_shape, func, name, tag); |
1479 | } |
1480 | |
1481 | /*! |
1482 | * \brief A generalization of matrix multiplication to tensors. |
1483 | * |
1484 | * \param A The tensor A |
1485 | * \param B The tensor B |
1486 | * \param A_axes The indices of the dimensions of tensor A to reduce over |
1487 | * \param B_axes The indices of the dimensions of tensor B to reduce over |
1488 | * \param name The name of the operation |
1489 | * \param tag The tag to mark the operation |
1490 | * |
1491 | * \return A Tensor computing the result |
1492 | */ |
1493 | inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExpr> A_axes, |
1494 | Array<PrimExpr> B_axes, std::string name = "T_tensordot" , |
1495 | std::string tag = kMatMul) { |
1496 | ICHECK_EQ(A_axes.size(), B_axes.size()); |
1497 | |
1498 | auto A_axes_val = GetConstIntValues(A_axes, "A_axes" ); |
1499 | auto B_axes_val = GetConstIntValues(B_axes, "B_axes" ); |
1500 | |
1501 | Array<PrimExpr> output_shape; |
1502 | for (unsigned i = 0; i < A->shape.size(); ++i) |
1503 | if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end()) |
1504 | output_shape.push_back(A->shape[i]); |
1505 | for (unsigned i = 0; i < B->shape.size(); ++i) |
1506 | if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end()) |
1507 | output_shape.push_back(B->shape[i]); |
1508 | |
1509 | Array<IterVar> iter_vars; |
1510 | for (unsigned i = 0; i < B_axes_val.size(); ++i) |
1511 | iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); |
1512 | |
1513 | auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array<Var>& input_indices) { |
1514 | int idx_input = 0; |
1515 | Array<PrimExpr> A_indices; |
1516 | for (unsigned i = 0; i < A->shape.size(); ++i) { |
1517 | auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); |
1518 | if (axes_pos == A_axes_val.end()) { |
1519 | A_indices.push_back(input_indices[idx_input++]); |
1520 | } else { |
1521 | A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]); |
1522 | } |
1523 | } |
1524 | |
1525 | Array<PrimExpr> B_indices; |
1526 | for (unsigned i = 0; i < B->shape.size(); ++i) { |
1527 | auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); |
1528 | if (axes_pos == B_axes_val.end()) { |
1529 | B_indices.push_back(input_indices[idx_input++]); |
1530 | } else { |
1531 | B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]); |
1532 | } |
1533 | } |
1534 | return sum(A(A_indices) * B(B_indices), iter_vars); |
1535 | }; |
1536 | return compute(output_shape, func, name, tag); |
1537 | } |
1538 | |
1539 | inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step, |
1540 | DataType dtype, std::string name = "T_arange" , std::string tag = kInjective) { |
1541 | PrimExpr num_elem = tvm::cast( |
1542 | tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step)); |
1543 | Array<PrimExpr> shape; |
1544 | return compute( |
1545 | {num_elem}, |
1546 | [&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name, |
1547 | tag); |
1548 | } |
1549 | |
1550 | /*! |
1551 | * \brief Produce grids by expanding input over dimensions defined by other inputs |
1552 | * |
1553 | * \param inputs The input tensors |
1554 | * \param indexing The indexing mode, either "xy" or "ij" |
1555 | * \param name The name of the operation |
1556 | * \param tag The tag to mark the operation |
1557 | * |
1558 | * \return A Tensor whose op member is the meshgrid operation |
1559 | */ |
1560 | inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& indexing, |
1561 | std::string name = "T_meshgrid" , std::string tag = kInjective) { |
1562 | const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2; |
1563 | Array<PrimExpr> out_shape; |
1564 | for (size_t i = 0; i < inputs.size(); ++i) { |
1565 | const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; |
1566 | out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]); |
1567 | } |
1568 | Array<Tensor> result; |
1569 | for (size_t i = 0; i < inputs.size(); ++i) { |
1570 | result.push_back(compute( |
1571 | out_shape, |
1572 | [&](const Array<Var>& indices) { |
1573 | const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; |
1574 | auto ndim = inputs[i]->GetShape().size(); |
1575 | Array<PrimExpr> real_indices = {}; |
1576 | if (ndim > 0) { |
1577 | real_indices = {indices[src_index]}; |
1578 | } |
1579 | return inputs[i](real_indices); |
1580 | }, |
1581 | name, tag)); |
1582 | } |
1583 | return result; |
1584 | } |
1585 | |
1586 | /*! |
1587 | * \brief Transform the layout according to \p src_layout and \p dst_layout |
1588 | * \param src the source input. |
1589 | * \param src_layout the source layout. |
1590 | * \param dst_layout the destination layout. |
1591 | * \param name output tensor name. |
1592 | * \param tag output tensor tag. |
1593 | * \return A tensor with shape in \p dst_layout |
1594 | */ |
1595 | inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, |
1596 | const std::string& dst_layout, |
1597 | const std::string name = "T_layout_trans" , |
1598 | const std::string tag = kInjective) { |
1599 | Layout src_layout_struct(src_layout); |
1600 | Layout dst_layout_struct(dst_layout); |
1601 | |
1602 | if (src_layout_struct.Equals(dst_layout_struct)) { |
1603 | return src; |
1604 | } |
1605 | |
1606 | ICHECK(src_layout_struct.defined() && dst_layout_struct.defined()) |
1607 | << "cannot convert from/to undefined layout" ; |
1608 | |
1609 | auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct); |
1610 | ICHECK(layout_converter.defined()) |
1611 | << "cannot convert from " << src_layout << " to " << dst_layout; |
1612 | |
1613 | Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape); |
1614 | |
1615 | return compute( |
1616 | dst_shape, |
1617 | [&](const Array<Var>& dst_indices) { |
1618 | Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end()); |
1619 | Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr); |
1620 | PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true |
1621 | for (size_t i = 0; i < src.ndim(); ++i) { |
1622 | in_range = in_range && (src_indices[i] < src->shape[i]); |
1623 | } |
1624 | return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0))); |
1625 | }, |
1626 | name, tag); |
1627 | } |
1628 | |
1629 | /*! \brief Utility function for auto_scheduler_layout_transform */ |
1630 | inline void parse_auto_scheduler_layout(const String& layout, Array<PrimExpr>* shape, |
1631 | std::vector<std::string>* axes) { |
1632 | int32_t factor = 0; |
1633 | std::string axis = "" ; |
1634 | for (char c : std::string(layout)) { |
1635 | if (c >= 'A' && c <= 'z') { |
1636 | axis += c; |
1637 | if (factor != 0) { |
1638 | shape->push_back(factor); |
1639 | factor = 0; |
1640 | } |
1641 | } else if (c >= '0' && c <= '9') { |
1642 | factor = factor * 10 + c - '0'; |
1643 | if (!axis.empty()) { |
1644 | axes->push_back(axis); |
1645 | axis = "" ; |
1646 | } |
1647 | } else { |
1648 | LOG(FATAL) << "Invalid layout " << layout; |
1649 | } |
1650 | } |
1651 | if (!axis.empty()) { |
1652 | axes->push_back(axis); |
1653 | } |
1654 | } |
1655 | |
1656 | /*! |
1657 | * \brief Transform the auto-scheduler generated layout according to |
1658 | * \p src_layout and \p dst_layout |
1659 | * \param src the source input. |
1660 | * \param src_layout the source layout. |
1661 | * \param dst_layout the destination layout. |
1662 | * \param name output tensor name. |
1663 | * \param tag output tensor tag. |
1664 | * \return A tensor with shape in \p dst_layout |
1665 | */ |
1666 | inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout, |
1667 | const String& dst_layout, |
1668 | const String name = "T_auto_scheduler_layout_trans" , |
1669 | const String tag = kInjective) { |
1670 | Array<PrimExpr> src_shape; |
1671 | std::vector<std::string> src_axes; |
1672 | Array<PrimExpr> dst_shape; |
1673 | std::vector<std::string> dst_axes; |
1674 | |
1675 | parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes); |
1676 | parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes); |
1677 | return compute( |
1678 | dst_shape, |
1679 | [&](const Array<Var>& dst_indices) { |
1680 | Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end()); |
1681 | Array<PrimExpr> src_indices; |
1682 | for (const std::string& src_axis : src_axes) { |
1683 | PrimExpr src_index = 0; |
1684 | CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); |
1685 | for (size_t i = 0; i < dst_axes.size(); ++i) { |
1686 | if (dst_axes[i] == src_axis) { |
1687 | src_index = src_index * dst_shape[i] + dst_indices_expr[i]; |
1688 | } |
1689 | } |
1690 | src_indices.push_back(src_index); |
1691 | } |
1692 | return src(src_indices); |
1693 | }, |
1694 | name, tag); |
1695 | } |
1696 | |
1697 | /*! |
1698 | * \brief Transform the meta-schedule generated layout according to TIR's IndexMap |
1699 | * \param src the source input. |
1700 | * \param index_map The TIR IndexMap |
1701 | * \param name output tensor name. |
1702 | * \param tag output tensor tag. |
1703 | * \return A tensor. The layout transformation method |
1704 | * \note Example: |
1705 | * |
1706 | * For the indexing pattern below: |
1707 | * |
1708 | * for i in range(32): |
1709 | * for j in range(64): |
1710 | * load A[ |
1711 | * i / 16 * 4 + j / 16, |
1712 | * i % 16 * 16 + j % 16, |
1713 | * ] |
1714 | * |
1715 | * The corresponding indexing pattern in TIR is: |
1716 | * |
1717 | * A[i, j] => A'[i / 4, j / 16, i % 4, j % 16] |
1718 | * |
1719 | * which converts the pattern to: |
1720 | * |
1721 | * for i in range(32): |
1722 | * for j in range(64): |
1723 | * load A'[ |
1724 | * i / 16 + j / 64, |
1725 | * i % 16, |
1726 | * j % 64 / 16, |
1727 | * j % 16, |
1728 | * ] |
1729 | * |
1730 | * In this case, the transformation pattern is: |
1731 | * A'[a, b, c, d] = A[a * 4 + c, b * 16 + d] |
1732 | */ |
1733 | inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map, |
1734 | const String name = "T_meta_schedule_layout_trans" , |
1735 | const String tag = kInjective) { |
1736 | Array<Range> iter_domain; |
1737 | iter_domain.reserve(src->shape.size()); |
1738 | for (const PrimExpr& e : src->shape) { |
1739 | iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e)); |
1740 | } |
1741 | Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape); |
1742 | return compute( |
1743 | post_transform_shape, |
1744 | [src, inv = index_map.Inverse(iter_domain)](const Array<Var>& indices) -> PrimExpr { |
1745 | return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()})); |
1746 | }, |
1747 | name, tag); |
1748 | } |
1749 | |
1750 | /*! |
1751 | * \brief Get the shape of input tensor. |
1752 | * \param src the input tensor. |
1753 | * \param dtype the type of the elements in the tensor. |
1754 | * \param name output tensor name. |
1755 | * \param tag output tensor tag. |
1756 | * \return Tensor of input shape. |
1757 | */ |
1758 | inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape" , |
1759 | const std::string tag = kInjective) { |
1760 | int ndim = static_cast<int>(src->shape.size()); |
1761 | Array<PrimExpr> out_shape{ndim}; |
1762 | return compute( |
1763 | out_shape, |
1764 | [&](const Array<Var>& indices) { |
1765 | auto idx = indices[0]; |
1766 | PrimExpr ret = 0; |
1767 | for (int i = 0; i < ndim; ++i) { |
1768 | ret = tvm::if_then_else(idx == i, src->shape[i], ret); |
1769 | } |
1770 | return tvm::cast(dtype, ret); |
1771 | }, |
1772 | name, tag); |
1773 | } |
1774 | |
1775 | /*! |
1776 | * \brief Get the size of input tensor. |
1777 | * \param src the input tensor. |
1778 | * \param dtype the type of the elements in the tensor. |
1779 | * \param name output tensor name. |
1780 | * \param tag output tensor tag. |
1781 | * \return Tensor of input shape. |
1782 | */ |
1783 | inline Tensor ndarray_size(const Tensor& src, const DataType& dtype, |
1784 | const std::string& name = "ndarray_size" , |
1785 | const std::string& tag = kInjective) { |
1786 | int ndim = static_cast<int>(src->shape.size()); |
1787 | Array<PrimExpr> out_ndarray_size = {}; |
1788 | return compute( |
1789 | out_ndarray_size, |
1790 | [&](const Array<Var>& indices) { |
1791 | PrimExpr ret = 1; |
1792 | for (int i = 0; i < ndim; ++i) { |
1793 | ret *= src->shape[i]; |
1794 | } |
1795 | return tvm::cast(dtype, ret); |
1796 | }, |
1797 | name, tag); |
1798 | } |
1799 | |
1800 | /*! |
1801 | * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value, |
1802 | other locations take value off_value. |
1803 | * \param indices locations to set to on_value. |
1804 | * \param on_value value that locations represented by indices take on. |
1805 | * \param off_value value that other locations take on. |
1806 | * \param depth depth of the one-hot dimension. |
1807 | * \param axis axis to fill. |
1808 | * \param dtype data type of the output tensor. |
1809 | * \param oshape shape of the output tensor. |
1810 | * \param name output tensor name. |
1811 | * \param tag output tensor tag. |
1812 | * \return one-hot tensor. |
1813 | */ |
1814 | inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, |
1815 | int depth, int axis, const DataType& dtype, |
1816 | Array<PrimExpr> oshape = Array<PrimExpr>(), |
1817 | const std::string name = "T_one_hot" , const std::string tag = kInjective) { |
1818 | int true_axis = (axis == -1) ? indices->shape.size() : axis; |
1819 | if (oshape.size() == 0) { |
1820 | int ndim = indices->shape.size() + 1; |
1821 | int indices_index = 0; |
1822 | for (int i = 0; i < ndim; i++) { |
1823 | if (i == true_axis) { |
1824 | oshape.push_back(Integer(depth)); |
1825 | } else { |
1826 | oshape.push_back(indices->shape[indices_index++]); |
1827 | } |
1828 | } |
1829 | } |
1830 | |
1831 | PrimExpr on_value_cast = cast(dtype, on_value); |
1832 | PrimExpr off_value_cast = cast(dtype, off_value); |
1833 | return compute( |
1834 | oshape, |
1835 | [&](const Array<Var>& iter_vars) { |
1836 | Array<Var> indices_indices; |
1837 | for (size_t i = 0; i < iter_vars.size(); i++) { |
1838 | if (static_cast<int>(i) == true_axis) { |
1839 | continue; |
1840 | } |
1841 | |
1842 | indices_indices.push_back(iter_vars[i]); |
1843 | } |
1844 | |
1845 | auto idx = iter_vars[true_axis]; |
1846 | return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast); |
1847 | }, |
1848 | name, tag); |
1849 | } |
1850 | |
1851 | /*! |
1852 | * \brief Get a dense tensor. |
1853 | * \param sparse_indices sparse_indices[i] contains sparse_values[i] will be placed. |
1854 | * \param output_shape is the shape of the dense output tensor . |
1855 | * \param sparse_values is a 0-D or 1-D tensor. Values for each row of sparse_indices. |
1856 | * \param default_value is a 0-D tensor. Defaults to zero. |
1857 | * \param name output tensor name. |
1858 | * \param tag output tensor tag. |
1859 | * \return Tensor of output_shape. |
1860 | */ |
1861 | inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr>& output_shape, |
1862 | const Tensor& sparse_values, const PrimExpr& default_value, |
1863 | const std::string name = "T_sparse_to_dense" , |
1864 | const std::string tag = kInjective) { |
1865 | ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values" ; |
1866 | ICHECK_LE(sparse_indices->shape.size(), 3) |
1867 | << "sparse_indices tensor should be 0D, 1D, or 2D only" ; |
1868 | ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only" ; |
1869 | |
1870 | const auto rank_sparse_indices = static_cast<int>(sparse_indices->shape.size()); |
1871 | Array<PrimExpr> oshape; |
1872 | for (auto l : output_shape) { |
1873 | oshape.push_back(l); |
1874 | } |
1875 | return compute( |
1876 | oshape, |
1877 | [&](const Array<Var>& indices) { |
1878 | PrimExpr ret = default_value; |
1879 | if (0 == rank_sparse_indices) { |
1880 | ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret); |
1881 | } else if (1 == rank_sparse_indices) { |
1882 | for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { |
1883 | ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret); |
1884 | } |
1885 | } else { |
1886 | for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { |
1887 | PrimExpr aggregate_condition; |
1888 | for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) { |
1889 | PrimExpr comparision = indices[k] == sparse_indices[j][k]; |
1890 | aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision; |
1891 | } |
1892 | ret = if_then_else(aggregate_condition, sparse_values[j], ret); |
1893 | } |
1894 | } |
1895 | return ret; |
1896 | }, |
1897 | name, tag); |
1898 | } |
1899 | |
1900 | /*! |
1901 | * \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonals. |
1902 | * \param input input tensor. |
1903 | * \param diagonal values to be filled in the diagonals. |
1904 | * \param k1 lower limit (included) of the range of diagonals. |
1905 | * \param k2 upper limit (included) of the range of diagonals. |
1906 | * \param super_diag_right_align bool, true iff super-diagonal is right aligned (left-padded). |
1907 | * \param sub_diag_right_align bool, true iff sub-diagonal is right aligned (left-padded). |
1908 | * \param name output tensor name. |
1909 | * \param tag output tensor tag. |
1910 | * \return new tensor with given diagonal values. |
1911 | */ |
1912 | inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2, |
1913 | bool super_diag_right_align, bool sub_diag_right_align, |
1914 | const std::string name = "T_matrix_set_diag" , |
1915 | const std::string tag = kInjective) { |
1916 | size_t ndim = input->shape.size() - 1; |
1917 | |
1918 | bool only_one_diagonal = k1 == k2; |
1919 | |
1920 | return compute( |
1921 | input->shape, |
1922 | [&](const Array<Var>& iter_vars) { |
1923 | auto get_diag = [&]() { |
1924 | Array<PrimExpr> diagonal_indices; |
1925 | PrimExpr k, offset = 0; |
1926 | for (size_t i = 0; i < ndim - 1; i++) { |
1927 | diagonal_indices.push_back(iter_vars[i]); |
1928 | } |
1929 | if (only_one_diagonal) { |
1930 | k = k1; |
1931 | } else { |
1932 | // Determining which diagonal/sub-diagonal/super-diagonal it is |
1933 | k = iter_vars[ndim] - iter_vars[ndim - 1]; |
1934 | diagonal_indices.push_back(k2 - k); |
1935 | |
1936 | // Calculating the offset in diagonal tensor for this diagonal |
1937 | auto get_offset = [&](PrimExpr M, PrimExpr N) { |
1938 | // offset = max_diagonal_length - diagonal_length |
1939 | return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N); |
1940 | }; |
1941 | offset = if_then_else( |
1942 | k >= 0, |
1943 | super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1]) |
1944 | : 0, |
1945 | sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k) |
1946 | : 0); |
1947 | } |
1948 | diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) + |
1949 | offset); |
1950 | return diagonal(diagonal_indices); |
1951 | }; |
1952 | return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1, |
1953 | if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2, |
1954 | get_diag(), input(iter_vars)), |
1955 | input(iter_vars)); |
1956 | }, |
1957 | name, tag); |
1958 | } |
1959 | |
1960 | /*! |
1961 | * \brief Numpy style advanced indexing with tensor. |
1962 | * \param data is input data. |
1963 | * \param indices is list of indexing tensors. |
1964 | * \param name output tensor name. |
1965 | * \param tag output tensor tag. |
1966 | * \return Output tensor. |
1967 | */ |
1968 | inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices, |
1969 | const std::string name = "advanced_index" , |
1970 | const std::string tag = kInjective) { |
1971 | ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!" ; |
1972 | Array<PrimExpr> oshape; |
1973 | Array<PrimExpr> broadcast_shape; |
1974 | Array<Tensor> bindices; |
1975 | |
1976 | broadcast_shape = indices[0]->shape; |
1977 | for (size_t i = 1; i < indices.size(); ++i) { |
1978 | auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape); |
1979 | broadcast_shape = Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()); |
1980 | } |
1981 | if (indices.size() == 1) { |
1982 | // quick path |
1983 | bindices = indices; |
1984 | } else { |
1985 | // Do broadcast for indices |
1986 | for (size_t i = 0; i < indices.size(); ++i) { |
1987 | bindices.push_back(broadcast_to(indices[i], broadcast_shape)); |
1988 | } |
1989 | } |
1990 | |
1991 | for (const auto& dim : broadcast_shape) { |
1992 | oshape.push_back(dim); |
1993 | } |
1994 | for (size_t i = indices.size(); i < data->shape.size(); ++i) { |
1995 | oshape.push_back(data->shape[i]); |
1996 | } |
1997 | |
1998 | return compute( |
1999 | oshape, |
2000 | [&](const Array<Var>& iter_var) { |
2001 | Array<PrimExpr> tensor_indices; |
2002 | for (size_t i = 0; i < broadcast_shape.size(); ++i) { |
2003 | tensor_indices.push_back(iter_var[i]); |
2004 | } |
2005 | |
2006 | Array<PrimExpr> real_indices; |
2007 | for (size_t i = 0; i < bindices.size(); ++i) { |
2008 | real_indices.push_back(bindices[i](tensor_indices)); |
2009 | } |
2010 | for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) { |
2011 | real_indices.push_back(iter_var[i]); |
2012 | } |
2013 | |
2014 | return data(real_indices); |
2015 | }, |
2016 | name, tag); |
2017 | } |
2018 | |
2019 | } // namespace topi |
2020 | } // namespace tvm |
2021 | #endif // TVM_TOPI_TRANSFORM_H_ |
2022 | |