1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file topi/transform.h
22 * \brief Transform op constructors
23 */
24#ifndef TVM_TOPI_TRANSFORM_H_
25#define TVM_TOPI_TRANSFORM_H_
26
27#include <tvm/te/operation.h>
28#include <tvm/tir/data_layout.h>
29#include <tvm/tir/index_map.h>
30#include <tvm/topi/broadcast.h>
31#include <tvm/topi/detail/broadcast.h>
32#include <tvm/topi/detail/constant_utils.h>
33#include <tvm/topi/detail/ravel_unravel.h>
34#include <tvm/topi/detail/strided_slice.h>
35#include <tvm/topi/detail/tensor_utils.h>
36#include <tvm/topi/tags.h>
37
38#include <algorithm>
39#include <iterator>
40#include <limits>
41#include <string>
42#include <unordered_set>
43#include <vector>
44
45namespace tvm {
46namespace topi {
47
48using namespace tvm::te;
49using namespace topi::detail;
50
51/*!
52 * \brief Creates an operation to slide a window over the input x.
53 *
54 * \param x The input tensor.
55 * \param axis What axis the window begins sliding over. Window will be slid
56 * over this axis and all following axes. The axis value determines the window
57 * shape (and thus, the number of strides): window shape and strides must both
58 * be of length `data.ndim-axis`.
59 * \param window_shape The window shape to form over the input. Window shape
60 * must be of length `data.ndim-axis`.
61 * \param strides How to stride the window along each dimension. Strides must be
62 * of length `data.ndim-axis`.
63 * \param name The name of the operation
64 * \param tag The tag to mark the operation
65 *
66 * \return A Tensor whose op member is the sliding_window operation
67 */
68inline Tensor sliding_window(const Tensor& x, int axis, Array<Integer> window_shape,
69 Array<Integer> strides, std::string name = "T_sliding_window",
70 std::string tag = "") {
71 CHECK_GE(axis, 0);
72 auto _axis = size_t(axis);
73 CHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x.";
74 CHECK_EQ(x->shape.size() - _axis, window_shape.size())
75 << "There must be a window shape for every dimension of x "
76 << "over which we are sliding the window.";
77 CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length.";
78
79 // Compute the new shape.
80 Array<PrimExpr> new_shape;
81 // Dimensions up until `axis` remain the same.
82 for (size_t i = 0; i < _axis; ++i) {
83 new_shape.push_back(x->shape[i]);
84 }
85
86 // New dimensions which result from sliding the window in each dimension. One new dimension per
87 // window dimension.
88 for (size_t i = 0; i < window_shape.size(); ++i) {
89 // Length of the shape along this dimension.
90 auto dim_len = x->shape[_axis + i];
91 // Length of the window along this dimension.
92 auto window_len = window_shape[i];
93 // Strides along this dimension.
94 auto stride = strides[i];
95
96 new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
97 }
98
99 // Dimensions comprising the window.
100 for (size_t i = 0; i < window_shape.size(); ++i) {
101 new_shape.push_back(window_shape[i]);
102 }
103
104 ICHECK(new_shape.size() == _axis + 2 * window_shape.size());
105
106 return compute(
107 new_shape,
108 [&](const Array<Var>& indices) {
109 // The index at which to index the old tensor x.
110 Array<PrimExpr> idx;
111
112 // Dimensions up until `axis` remain the same.
113 for (size_t i = 0; i < _axis; ++i) {
114 idx.push_back(indices[i]);
115 }
116
117 for (size_t i = 0; i < window_shape.size(); ++i) {
118 // Which window in this dimension we are indexing.
119 auto window_idx = indices[_axis + i];
120 // Which index within the window we are indexing.
121 auto idx_within_window = indices[_axis + window_shape.size() + i];
122 // Stride value for this dimension.
123 auto stride = strides[i];
124
125 idx.push_back(window_idx * stride + idx_within_window);
126 }
127
128 ICHECK(idx.size() == x->shape.size());
129
130 return x(idx);
131 },
132 name, tag);
133}
134
135/*!
136 * \brief Creates an operation to insert new dimensions of length 1
137 *
138 * \param x The input tensor
139 * \param axis The index of the first new dimension (allows negative
140 * indices as offsets from the last dimension)
141 * \param num_newaxis The number of new dimensions to insert
142 * \param name The name of the operation
143 * \param tag The tag to mark the operation
144 *
145 * \return A Tensor whose op member is the dim expansion operation
146 */
147inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1,
148 std::string name = "T_expand_dims", std::string tag = kBroadcast) {
149 int ndim = static_cast<int>(x->shape.size());
150 ICHECK(-ndim - 1 <= axis && axis <= ndim)
151 << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
152 << ", but got axis = " << axis << ", and data.ndim = " << ndim;
153 ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
154 << ", but got num_newaxis = " << num_newaxis;
155 if (axis < 0) {
156 // Calculate offset from last dimension
157 axis = ndim + axis + 1;
158 }
159 Array<PrimExpr> new_shape;
160 for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
161 new_shape.push_back(x->shape[i]);
162 }
163 for (size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
164 new_shape.push_back(1);
165 }
166 for (size_t i = axis; i < x->shape.size(); ++i) {
167 new_shape.push_back(x->shape[i]);
168 }
169
170 return compute(
171 new_shape,
172 [&](const Array<Var>& indices) {
173 Array<PrimExpr> idx;
174 for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
175 idx.push_back(indices[i]);
176 }
177 for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
178 idx.push_back(indices[i]);
179 }
180 return x(idx);
181 },
182 name, tag);
183}
184
185/*!
186 * \brief Permute the dimensions of an array
187 *
188 * \param x The input tensor
189 * \param axes The indices of the permutation. If this is empty,
190 * the dimensions will be reversed.
191 * \param name The name of the operation
192 * \param tag The tag to mark the operation
193 *
194 * \return A Tensor whose op member is the transpose operation
195 */
196inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name = "T_transpose",
197 std::string tag = kInjective) {
198 if (!axes.defined() || axes.size() == 0) {
199 axes = Array<Integer>();
200 for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
201 axes.push_back(i);
202 }
203 }
204
205 Array<PrimExpr> new_shape;
206 for (size_t i = 0; i < axes.size(); ++i) {
207 int axis = static_cast<int>(axes[i]->value);
208 int new_axis = axis;
209 if (axis < 0) {
210 new_axis = static_cast<int>(x->shape.size()) + axis;
211 axes.Set(i, new_axis);
212 }
213 ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
214 << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size())
215 << "-dimensional input tensor";
216
217 for (size_t j = 0; j < axes.size(); ++j) {
218 if (i != j) {
219 ICHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
220 }
221 }
222 new_shape.push_back(x->shape[new_axis]);
223 }
224
225 return compute(
226 new_shape,
227 [&](const Array<Var>& indices) {
228 std::vector<PrimExpr> idx;
229 for (size_t i = 0; i < axes.size(); ++i) {
230 idx.push_back(1);
231 }
232 for (size_t i = 0; i < axes.size(); ++i) {
233 int axis = static_cast<int>(axes[i]->value);
234 idx[axis] = indices[i];
235 }
236 return x(idx);
237 },
238 name, tag);
239}
240
241/*!
242 * \brief Reverse the tensor for variable length slices.
243 * Input is first sliced along batch axis and then elements are reversed along seq axis.
244 *
245 * \param x The input tensor
246 * \param seq_lengths A 1D Tensor with length x.dims[batch_axis]. Optional Tensor() can be passed.
247 * If not defined batch axis is ignored and tensor is reversed along seq_axis.
248 * \param seq_axis The axis along which the elements will be reveresed
249 * \param batch_axis The axis along which the tensor will be sliced
250 * \param name The name of the operation
251 * \param tag The tag to mark the operation
252 *
253 * \return A Tensor whose op member is the reverse_sequence operation
254 */
255inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1,
256 int batch_axis = 0, std::string name = "T_reverse_sequence",
257 std::string tag = kInjective) {
258 size_t src_tensor_dim = x->shape.size();
259 int seq_axis_inp = seq_axis;
260
261 if (seq_lengths.defined()) {
262 size_t seq_lengths_dim = seq_lengths->shape.size();
263 int batch_axis_inp = batch_axis;
264 if (batch_axis < 0) {
265 batch_axis = static_cast<int>(x->shape.size()) + batch_axis;
266 }
267
268 ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";
269
270 ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
271 << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
272 << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
273 << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
274
275 ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
276 << "batch_axis=" << batch_axis_inp << " is invalid for the "
277 << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
278 }
279
280 if (seq_axis < 0) {
281 seq_axis = static_cast<int>(x->shape.size()) + seq_axis;
282 }
283 ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
284 << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
285 << "-dimensional input tensor";
286
287 auto func = [&](const Array<Var>& indices) {
288 Array<PrimExpr> real_indices;
289 for (size_t i = 0; i < src_tensor_dim; ++i) {
290 if (i == static_cast<size_t>(seq_axis)) {
291 if (seq_lengths.defined()) {
292 auto len = seq_lengths(indices[batch_axis]);
293 auto idx = if_then_else(
294 len <= 1 || len <= indices[i], indices[i],
295 if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
296 real_indices.push_back(idx);
297 } else {
298 real_indices.push_back(x->shape[i] - 1 - indices[i]);
299 }
300 } else {
301 real_indices.push_back(indices[i]);
302 }
303 }
304 return x(real_indices);
305 };
306
307 return compute(x->shape, func, name, tag);
308}
309
310/*!
311 * \brief Reshape a tensor
312 *
313 * \param x The input tensor
314 * \param newshape The new shape
315 * \param name The name of the operation
316 * \param tag The tag to mark the operation
317 *
318 * \return A Tensor whose op member is the reshape operation
319 */
320inline Tensor reshape(const Tensor& x, Array<PrimExpr> newshape, std::string name = "T_reshape",
321 std::string tag = kInjective) {
322 auto x_shape = x->shape;
323 Array<PrimExpr> target_shape;
324
325 for (const auto& ele : newshape) {
326 if (ele.as<IntImmNode>()) {
327 target_shape.push_back(cast(DataType::Int(32), ele));
328 } else {
329 target_shape.push_back(ele);
330 }
331 }
332
333 // If either the input shape or the target shape contains a zero, return an empty tensor.
334 if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) {
335 return compute(
336 target_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
337 } else {
338 return compute(
339 target_shape,
340 [&](const Array<Var>& indices) {
341 return x(UnravelIndex(
342 RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape), x_shape));
343 },
344 name, tag);
345 }
346}
347
348/*!
349 * \brief Converts a flat index or array of flat indices into a tuple of coordinate arrays
350 *
351 * \param x The input tensor having indices.
352 * \param shape The shape tensor
353 * \param name The name of the operation
354 * \param tag The tag to mark the operation
355 *
356 * \return A Tensor of coordinate arrays.
357 */
358
359inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel",
360 std::string tag = kInjective) {
361 auto x_shape = x->shape;
362 auto shape_shape = shape->shape;
363
364 Array<PrimExpr> oshape;
365 oshape.push_back(shape_shape[0]);
366 if (x_shape.size() != 0) {
367 oshape.push_back(x_shape[0]);
368 }
369
370 auto func = [&](const Array<Var>& indices) {
371 auto i = indices[0];
372 std::vector<PrimExpr> indices_divs;
373 PrimExpr ret = 0;
374 PrimExpr cur_val = 0;
375 PrimExpr index_val = 0;
376
377 if (x_shape.size() != 0) {
378 index_val = x[indices[1]];
379 } else {
380 index_val = x();
381 }
382 indices_divs.push_back(index_val);
383 for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
384 ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
385 cur_val = indexdiv(indices_divs.back(), shape[v]);
386 indices_divs.push_back(cur_val);
387 }
388 return ret;
389 };
390
391 return compute(oshape, func, name, tag);
392}
393
394/*!
395 * \brief Remove size 1 dimensions from the shape of a tensor.
396 * The removed dimensions must have a constant size of 1.
397 *
398 * \param x The input tensor
399 * \param axis Indices of the dimensions to remove. If this is None,
400 * all entries with a constant size of 1 will be removed.
401 * \param atleast1d Whether the output need to be atleast1d.
402 * \param name The name of the operation
403 * \param tag The tag to mark the operation
404 *
405 * \return A Tensor whose op member is the squeeze operation
406 */
407inline Tensor squeeze(const Tensor& x, Array<Integer> axis, bool atleast1d = false,
408 std::string name = "T_squeeze", std::string tag = kInjective) {
409 auto ndim = x->shape.size();
410 std::vector<int> axis_val;
411 if (!axis.defined()) {
412 for (size_t i = 0; i < ndim; ++i) {
413 if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
414 axis_val.push_back(static_cast<int>(i));
415 }
416 }
417 } else {
418 for (size_t i = 0; i < axis.size(); ++i) {
419 int64_t val = axis[i]->value;
420 if (val < 0) {
421 val += static_cast<int>(x->shape.size());
422 }
423 if (IsConstInt(x->shape[val])) {
424 ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
425 }
426 axis_val.push_back(val);
427 }
428 }
429
430 std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
431
432 Array<PrimExpr> out_shape;
433 for (size_t i = 0; i < ndim; ++i) {
434 if (axis_set.count(static_cast<int>(i)) == 0) {
435 out_shape.push_back(x->shape[i]);
436 }
437 }
438 if (out_shape.size() == 0 && atleast1d) {
439 out_shape.push_back(1);
440 }
441
442 return compute(
443 out_shape,
444 [&](const Array<Var>& indices) {
445 Array<PrimExpr> real_indices;
446 int flag = 0;
447 for (size_t i = 0; i < ndim; ++i) {
448 if (axis_set.count(static_cast<int>(i)) == 0) {
449 real_indices.push_back(indices[i - flag]);
450 } else {
451 real_indices.push_back(0);
452 flag += 1;
453 }
454 }
455 return x(real_indices);
456 },
457 name, tag);
458}
459
460/*!
461 * \brief Join a sequence of tensors along an existing axis
462 *
463 * \param inputs The input tensors
464 * \param axis The axis along which the tensors will be joined
465 * \param name The name of the operation
466 * \param tag The tag to mark the operation
467 *
468 * \return A Tensor whose op member is the concatenate operation
469 */
470inline Tensor concatenate(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_concat",
471 std::string tag = kInjective) {
472 int ndim = static_cast<int>(inputs[0]->shape.size());
473 ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)"
474 << ", but got axis = " << axis << ", and ndim = " << ndim;
475 if (axis < 0) {
476 axis += ndim;
477 }
478 ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds";
479
480 Array<PrimExpr> axis_sizes;
481 for (auto t : inputs) {
482 axis_sizes.push_back(t->shape[axis]);
483 }
484 arith::Analyzer analyzer;
485 PrimExpr join_size = axis_sizes[0];
486 for (size_t i = 1; i < axis_sizes.size(); ++i) {
487 join_size += axis_sizes[i];
488 }
489 join_size = analyzer.Simplify(join_size);
490 Array<PrimExpr> out_shape;
491 for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
492 out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
493 }
494
495 return compute(
496 out_shape,
497 [&](const Array<Var>& indices) {
498 auto ret = inputs[0](indices);
499 auto ind = indices[axis];
500 for (size_t i = 0; i < inputs.size() - 1; ++i) {
501 ind -= axis_sizes[i];
502
503 Array<PrimExpr> idx;
504 for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
505 idx.push_back(indices[i]);
506 }
507 idx.push_back(ind);
508 for (size_t i = axis + 1; i < indices.size(); ++i) {
509 idx.push_back(indices[i]);
510 }
511
512 ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret);
513 }
514 return ret;
515 },
516 name, tag);
517}
518
519/*!
520 * \brief Join a sequence of tensors along a new axis.
521 *
522 * \param inputs The input tensors
523 * \param axis The axis along which the tensors will be stacked
524 * \param name The name of the operation
525 * \param tag The tag to mark the operation
526 *
527 * \return A Tensor whose op member is the stack operation
528 */
529inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack",
530 std::string tag = kInjective) {
531 int ndim = static_cast<int>(inputs[0]->shape.size());
532 ICHECK(-ndim - 1 <= axis && axis <= ndim)
533 << "stack only accepts `axis` in [-ndim, ndim)"
534 << ", but got axis = " << axis << ", and ndim = " << ndim;
535 if (axis < 0) {
536 axis += ndim + 1;
537 }
538 ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds";
539
540 const int stack_size = static_cast<int>(inputs.size());
541 Array<PrimExpr> out_shape;
542 for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
543 out_shape.push_back(stack_size);
544 for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
545 out_shape.push_back(inputs[0]->shape[i]);
546
547 return compute(
548 out_shape,
549 [&](const Array<Var>& indices) {
550 Array<PrimExpr> idx;
551 for (size_t i = 0; i < indices.size(); ++i)
552 if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]);
553 auto ind = indices[axis];
554 auto ret = inputs[0](idx);
555 for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
556 ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret);
557 }
558 return ret;
559 },
560 name, tag);
561}
562
563/*!
564 * \brief Split a tensor into multiple sub-tensors
565 *
566 * \param x The input tensor
567 * \param split_indices The indices to split the input at. This must be in ascending
568 * order.
569 * \param axis The axis to split along.
570 * \param name The name of the operation
571 * \param tag The tag to mark the operation
572 *
573 * \return A Tensor whose op member is the split operation
574 */
575inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis,
576 std::string name = "T_split", std::string tag = kInjective) {
577 if (axis < 0) {
578 axis += static_cast<int>(x->shape.size());
579 }
580 ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
581
582 auto src_axis_size = x->shape[axis];
583 std::vector<PrimExpr> begin_ids;
584 begin_ids.push_back(0);
585
586 for (auto idx : split_indices) {
587 auto idx_node = idx.as<IntImmNode>();
588 auto back_node = begin_ids.back().as<IntImmNode>();
589 if (idx_node && back_node) {
590 ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
591 }
592 begin_ids.push_back(idx);
593 }
594
595 Array<Array<PrimExpr>> out_shapes;
596 for (size_t i = 0; i < begin_ids.size(); ++i) {
597 PrimExpr out_axis_size;
598 if (i == begin_ids.size() - 1) {
599 out_axis_size = src_axis_size - begin_ids[i];
600 } else {
601 out_axis_size = begin_ids[i + 1] - begin_ids[i];
602 }
603
604 Array<PrimExpr> shape;
605 for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
606 shape.push_back(x->shape[i]);
607 }
608 shape.push_back(out_axis_size);
609 for (size_t i = axis + 1; i < x->shape.size(); ++i) {
610 shape.push_back(x->shape[i]);
611 }
612
613 out_shapes.push_back(shape);
614 }
615
616 Array<Tensor> result;
617 for (size_t i = 0; i < begin_ids.size(); ++i) {
618 result.push_back(compute(
619 out_shapes[i],
620 [&](const Array<Var>& indices) {
621 auto begin = begin_ids[i];
622 Array<PrimExpr> real_indices;
623 for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
624 real_indices.push_back(indices[j]);
625 }
626 real_indices.push_back(indices[axis] + begin);
627 for (size_t j = axis + 1; j < indices.size(); ++j) {
628 real_indices.push_back(indices[j]);
629 }
630
631 return x(real_indices);
632 },
633 name, tag));
634 }
635
636 return result;
637}
638
639/*!
640 * \brief strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
641 *
642 * \param x The input tensor
643 * \param begin The indices to begin with in the slicing
644 * \param end Indices indicating end of the slice
645 * \param strides Specifies the stride values, it can be negative
646 * in that case, the input tensor will be reversed in that particular axis
647 * \param name The name of the operation
648 * \param tag The tag to mark the operation
649 *
650 * \return A Tensor whose op member is the dynamic_strided_slice operation
651 */
652inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
653 const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
654 std::string name = "T_dynamic_strided_slice",
655 std::string tag = kInjective) {
656 const size_t src_tensor_dim = x->shape.size();
657 ICHECK_LE(begin.size(), src_tensor_dim);
658 ICHECK_LE(end.size(), src_tensor_dim);
659 ICHECK_LE(strides.size(), src_tensor_dim);
660 ICHECK_EQ(begin.size(), end.size());
661 ICHECK_EQ(begin.size(), strides.size());
662
663 const size_t num_slice_axes = begin.size();
664 Array<PrimExpr> out_shape;
665
666 for (size_t i = 0; i < num_slice_axes; ++i) {
667 auto d = indexdiv(end[i] - begin[i], strides[i]);
668 if (d->IsInstance<tvm::IntImmNode>()) {
669 // Preserve static dimension if possible
670 out_shape.push_back(d);
671 } else {
672 out_shape.push_back(tvm::tir::Var("dim"));
673 }
674 }
675
676 for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
677 out_shape.push_back(x->shape[i]);
678 }
679
680 return te::compute(
681 out_shape,
682 [&](const Array<tvm::tir::Var>& indices) {
683 Array<PrimExpr> real_indices;
684 for (size_t i = 0; i < num_slice_axes; ++i) {
685 real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1));
686 }
687 // keep input dim
688 for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
689 real_indices.push_back(indices[i]);
690 }
691 return x(real_indices);
692 },
693 name, tag);
694}
695
696/*!
697 * \brief strided_slice of a tensor with dynamic begin/end/stride
698 *
699 * \param x The input tensor
700 * \param begin The indices to begin with in the slicing
701 * \param end Indices indicating end of the slice
702 * \param strides Specifies the stride values, it can be negative
703 * in that case, the input tensor will be reversed in that particular axis
704 * \param name The name of the operation
705 * \param tag The tag to mark the operation
706 *
707 * \return A Tensor whose op member is the dynamic_strided_slice operation
708 */
709inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin,
710 const te::Tensor& end, const te::Tensor& strides,
711 std::string name = "T_strided_slice_dynamic",
712 std::string tag = topi::kInjective) {
713 const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
714 ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
715 ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
716
717 Array<PrimExpr> begin_expr, end_expr, strides_expr;
718 for (int64_t i = 0; i < num_dynamic_axes; ++i) {
719 auto i64_ind = IntImm(DataType::Int(64), i);
720 begin_expr.push_back(begin(i64_ind));
721 end_expr.push_back(end(i64_ind));
722 strides_expr.push_back(strides(i64_ind));
723 }
724 return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
725}
726
727/*!
728 * \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation
729 *
730 * \param ishape The input tensor shape
731 * \param begin The indices to begin with in the slicing
732 * \param end Indices indicating end of the slice
733 * \param strides Specifies the stride values, it can be negative
734 * in that case, the input tensor will be reversed in that particular axis
735 * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end,
736 * strides, and axes argument must be equal
737 * \param slice_mode Specifies the slice mode
738 *
739 * \return The output shape of strided_slice using the arguments above
740 */
741inline Array<PrimExpr> StridedSliceOutputShape(
742 const Array<PrimExpr>& ishape, const Array<Integer>& begin, const Array<Integer>& end,
743 const Array<Integer>& strides, const Array<Integer>& axes, const std::string& slice_mode) {
744 ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
745 std::vector<int64_t> begin_vec, end_vec, strides_vec;
746 std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
747 auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
748 begin[0]->dtype, slice_mode);
749 return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
750 begin_canonicalized, true);
751}
752
753/*!
754 * \brief strided_slice of a tensor
755 *
756 * \param x The input tensor
757 * \param begin The indices to begin with in the slicing
758 * \param end Indices indicating end of the slice
759 * \param strides Specifies the stride values, it can be negative
760 * in that case, the input tensor will be reversed in that particular axis
761 * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end,
762 * strides, and axes argument must be equal
763 * \param slice_mode Specifies the slice mode
764 * \param name The name of the operation
765 * \param tag The tag to mark the operation
766 *
767 * \return A Tensor whose op member is the sstrided_slice operation
768 */
769inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
770 const Array<Integer>& end, const Array<Integer>& strides,
771 const Array<Integer>& axes, std::string slice_mode = "end",
772 std::string name = "T_strided_slice_with_axes",
773 std::string tag = kInjective) {
774 const size_t src_tensor_dim = x->shape.size();
775 ICHECK(axes.size() <= src_tensor_dim);
776 ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
777
778 std::vector<int64_t> begin_vec, end_vec, strides_vec;
779 std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
780
781 auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes,
782 begin[0]->dtype, slice_mode);
783 auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes,
784 slice_mode, begin_expr);
785
786 return te::compute(
787 out_shape,
788 [&](const Array<tir::Var>& indices) {
789 Array<PrimExpr> real_indices;
790 for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
791 for (size_t i = 0; i < axes.size(); ++i) {
792 auto stride = make_const(strides[i].dtype(), strides_vec[i]);
793 PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
794 real_indices.Set(axes[i].IntValue(), ind);
795 }
796 return x(real_indices);
797 },
798 name, tag);
799}
800
801/*!
802 * \brief strided_slice of a tensor
803 *
804 * \param x The input tensor
805 * \param begin The indices to begin with in the slicing
806 * \param end Indices indicating end of the slice
807 * \param strides Specifies the stride values, it can be negative
808 * in that case, the input tensor will be reversed in that particular axis
809 * \param slice_mode Specifies the slice mode
810 * \param name The name of the operation
811 * \param tag The tag to mark the operation
812 *
813 * \return A Tensor whose op member is the strided_slice operation
814 */
815inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
816 const Array<Integer>& strides, std::string slice_mode = "end",
817 std::string name = "T_strided_slice", std::string tag = kInjective) {
818 size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
819 Array<Integer> axes;
820 for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
821 Array<Integer> begin_full(begin);
822 Array<Integer> end_full(end);
823 Array<Integer> strides_full(strides);
824
825 const IntImm one = IntImm(DataType::Int(64), 1);
826 const IntImm zero = IntImm(DataType::Int(64), 0);
827 const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max());
828
829 for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
830 strides_full.push_back(one);
831 }
832 for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
833 begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
834 }
835 for (size_t i = end.size(); i < src_tensor_dim; ++i) {
836 end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
837 }
838
839 return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name,
840 tag);
841}
842
843/*!
844 * \brief Split a tensor into a number of sub-tensors
845 *
846 * \param x The input tensor
847 * \param num_sections The number of sections to split the tensor into.
848 * this must be an integer factor of the size of the axis being split.
849 * \param axis The axis to split along.
850 * \param name The name of the operation
851 * \param tag The tag to mark the operation
852 *
853 * \return A Tensor whose op member is the split operation
854 */
855inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
856 std::string name = "T_split_sections",
857 std::string tag = kInjective) {
858 if (axis < 0) {
859 axis += static_cast<int>(x->shape.size());
860 }
861 ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
862
863 auto src_axis_size = x->shape[axis];
864
865 ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
866
867 if (auto node = src_axis_size.as<IntImmNode>()) {
868 ICHECK_EQ(node->value % num_sections, 0)
869 << "num_sections must be an integer factor of the size of axis " << axis << " ("
870 << node->value << ")";
871 }
872
873 Array<PrimExpr> split_indices;
874 auto seg_size = indexdiv(src_axis_size, num_sections);
875 for (int i = 0; i < num_sections; ++i) {
876 // region at index 0 is added by split()
877 if (i != 0) {
878 split_indices.push_back(seg_size * i);
879 }
880 }
881
882 return split(x, split_indices, axis, name, tag);
883}
884
885/*!
886 * \brief Take elements from an flattened input array when axis is None.
887 *
888 * \param a The source array.
889 * \param indices The indices of the values to extract.
890 * \param batch_dims The number of batch dimensions.
891 * \param mode The mode of the operation.
892 * \param name The name of the operation.
893 * \param mode The mode of to handle out of bound indices.
894 * \param tag The tag to mark the operation.
895 *
896 * \return A Tensor whose op member is the take operation
897 */
898inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
899 std::string mode = "clip", std::string name = "T_take",
900 std::string tag = kInjective) {
901 Array<PrimExpr> a_shape = a->shape;
902 Array<PrimExpr> out_shape = indices->shape;
903 PrimExpr a_size = 1;
904 for (size_t i = 0; i < a_shape.size(); ++i) {
905 a_size = a_size * a_shape[i];
906 }
907
908 if (mode == "clip") {
909 return compute(
910 out_shape,
911 [&](const Array<Var>& out_index) {
912 auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
913 return a(UnravelIndex(idx, a_shape));
914 },
915 name, tag);
916 } else if (mode == "fast") {
917 LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
918 "Make sure input indices are in bound";
919 return compute(
920 out_shape,
921 [&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); },
922 name, tag);
923 } else { // mode == "wrap"
924 return compute(
925 out_shape,
926 [&](const Array<Var>& out_index) {
927 auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
928 return a(UnravelIndex(idx, a_shape));
929 },
930 name, tag);
931 }
932}
933
934/*!
935 * \brief Mask the out-of-boundary elements of each sequence.
936 *
937 * \param data The source array.
938 * \param valid_length The real length of each sequence.
939 * \param mask_value The masking value.
940 * \param axis The axis of the temporal dimension of the sequence
941 * \param name The name of the operation.
942 * \param tag The tag to mark the operation.
943 *
944 * \return A Tensor whose op member is the sequence_mask operation
945 */
946inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value,
947 int axis, std::string name = "T_sequence_mask",
948 std::string tag = kInjective) {
949 ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
950 ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
951 auto length_dim = data->shape[axis];
952 auto batch_dim = data->shape[1 - axis];
953 Array<PrimExpr> out_shape = data->shape;
954 Tensor out = compute(
955 out_shape,
956 [&](const Array<Var>& out_index) {
957 Array<PrimExpr> len_index;
958 auto tid = out_index[axis];
959 auto bid = out_index[1 - axis];
960 len_index.push_back(bid);
961 PrimExpr ret =
962 tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
963 tvm::tir::make_const(data->dtype, mask_value), data(out_index));
964 return ret;
965 },
966 name, tag);
967 return out;
968}
969
970/*!
971 * \brief Take elements from an array along an axis.
972 *
973 * \param a The source array.
974 * \param indices The indices of the values to extract.
975 * \param batch_dims The number of batch dimensions. By default is 0.
976 * \param axis The axis over which to select values. By default,
977 * the flattened input array is used.
978 * \param mode The mode for handling out of bound indices.
979 * \param name The name of the operation.
980 * \param tag The tag to mark the operation.
981 *
982 * \return A Tensor whose op member is the take operation
983 */
984inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis,
985 std::string mode = "clip", std::string name = "T_take",
986 std::string tag = kInjective) {
987 if (axis < 0) {
988 axis += static_cast<int>(a->shape.size());
989 }
990 ICHECK_GE(axis, 0) << "axis out of bounds";
991 ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
992 auto axis_dim = a->shape[axis];
993 int indices_len = static_cast<int>(indices->shape.size());
994
995 int batch_dims_ = batch_dims;
996 if (batch_dims_ != 0) {
997 ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds";
998 ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds";
999
1000 if (batch_dims_ < 0) {
1001 batch_dims_ = indices->shape.size() + batch_dims_;
1002 }
1003
1004 ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
1005 ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
1006 for (int i = 0; i < batch_dims_; ++i) {
1007 auto addr1 = a->shape[i];
1008 auto addr2 = indices->shape[i];
1009 auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
1010 auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
1011 ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
1012 }
1013 }
1014
1015 // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
1016 // a.shape[axis + 1:].
1017
1018 Array<PrimExpr> out_shape;
1019 for (int i = 0; i < batch_dims_; ++i) {
1020 out_shape.push_back(a->shape[i]);
1021 }
1022 for (int i = batch_dims_; i < axis; ++i) {
1023 out_shape.push_back(a->shape[i]);
1024 }
1025 for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
1026 out_shape.push_back(indices->shape[i]);
1027 }
1028 for (size_t i = axis + 1; i < a->shape.size(); ++i) {
1029 out_shape.push_back(a->shape[i]);
1030 }
1031
1032 if (mode == "clip") {
1033 if (batch_dims_ == 0) {
1034 return compute(
1035 out_shape,
1036 [&](const Array<Var>& out_index) {
1037 Array<PrimExpr> indices_position;
1038 for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1039 indices_position.push_back(out_index[j]);
1040 }
1041 Array<PrimExpr> real_indices;
1042 for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1043 real_indices.push_back(out_index[j]);
1044 }
1045 auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
1046 real_indices.push_back(idx);
1047 for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1048 real_indices.push_back(out_index[j]);
1049 }
1050 return a(real_indices);
1051 },
1052 name, tag);
1053 } else {
1054 return compute(
1055 out_shape,
1056 [&](const Array<Var>& out_index) {
1057 Array<PrimExpr> indices_position;
1058 for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
1059 indices_position.push_back(out_index[j]);
1060 }
1061 for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
1062 indices_position.push_back(out_index[j]);
1063 }
1064 Array<PrimExpr> real_indices;
1065 for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1066 real_indices.push_back(out_index[j]);
1067 }
1068 auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
1069 real_indices.push_back(idx);
1070 for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
1071 real_indices.push_back(out_index[j]);
1072 }
1073 return a(real_indices);
1074 },
1075 name, tag);
1076 }
1077 } else if (mode == "fast") {
1078 LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1079 "Make sure input indices are in bound";
1080 return compute(
1081 out_shape,
1082 [&](const Array<Var>& out_index) {
1083 Array<PrimExpr> indices_position;
1084 for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1085 indices_position.push_back(out_index[j]);
1086 }
1087 Array<PrimExpr> real_indices;
1088 for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1089 real_indices.push_back(out_index[j]);
1090 }
1091 real_indices.push_back(indices(indices_position));
1092 for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1093 real_indices.push_back(out_index[j]);
1094 }
1095 return a(real_indices);
1096 },
1097 name, tag);
1098 } else { // mode == "wrap"
1099 return compute(
1100 out_shape,
1101 [&](const Array<Var>& out_index) {
1102 Array<PrimExpr> indices_position;
1103 for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1104 indices_position.push_back(out_index[j]);
1105 }
1106 Array<PrimExpr> real_indices;
1107 for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1108 real_indices.push_back(out_index[j]);
1109 }
1110 auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
1111 real_indices.push_back(idx);
1112 for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1113 real_indices.push_back(out_index[j]);
1114 }
1115 return a(real_indices);
1116 },
1117 name, tag);
1118 }
1119}
1120
1121/*!
1122 * \brief Return the elements, either from x or y, depending on the condition.
1123 *
1124 * \param condition The condition array.
1125 * \param x First array to be selected.
1126 * \param y Second array to be selected.
1127 * \param name The name of the operation.
1128 * \param tag The tag to mark the operation.
1129 *
1130 * \return A Tensor selected from x or y depending on condition.
1131 */
1132inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
1133 std::string name = "T_where", std::string tag = kBroadcast) {
1134 ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
1135 << y->dtype;
1136 auto get_out_shape = [&]() {
1137 auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1138 Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1139 auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1140 Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
1141 return common_shape2;
1142 };
1143
1144 auto oshape = get_out_shape();
1145
1146 auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1147 auto x_bh = detail::BroadcastShape(x->shape, oshape);
1148 auto y_bh = detail::BroadcastShape(y->shape, oshape);
1149
1150 auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
1151 auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1152 auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1153 auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1154 return tvm::tir::Select(c != 0, true_val, false_val);
1155 };
1156
1157 return compute(oshape, select, name, tag);
1158}
1159
1160/*!
1161 * \brief Creates an operation to repeat elements of an array
1162 *
1163 * \param x The input tensor
1164 * \param repeats The number of repetitions for each element
1165 * \param axis The axis along which to repeat values (allows
1166 * negative indices as offsets from the last dimension)
1167 * \param name The name of the operation
1168 * \param tag The tag to mark the operation
1169 *
1170 * \return A Tensor whose op member is the repeat operation
1171 */
1172inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat",
1173 std::string tag = kBroadcast) {
1174 int ndim = static_cast<int>(x->shape.size());
1175 ICHECK(-ndim - 1 <= axis && axis <= ndim)
1176 << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1177 << ", but got axis = " << axis << ", and data.ndim = " << ndim;
1178 ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
1179 << ", but got repeats = " << repeats;
1180 if (axis < 0) {
1181 // Calculate offset from last dimension
1182 axis += ndim;
1183 }
1184 Array<PrimExpr> new_shape;
1185 for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1186 new_shape.push_back(x->shape[i]);
1187 }
1188 new_shape.push_back(repeats * x->shape[axis]);
1189 for (size_t i = axis + 1; i < x->shape.size(); ++i) {
1190 new_shape.push_back(x->shape[i]);
1191 }
1192
1193 return compute(
1194 new_shape,
1195 [&](const Array<Var>& indices) {
1196 Array<PrimExpr> idx;
1197 for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1198 idx.push_back(indices[i]);
1199 }
1200 idx.push_back(indexdiv(indices[axis], repeats));
1201 for (size_t i = axis + 1; i < indices.size(); ++i) {
1202 idx.push_back(indices[i]);
1203 }
1204 return x(idx);
1205 },
1206 name, tag);
1207}
1208
1209/*!
1210 * \brief Creates an operation to tile elements of an array
1211 *
1212 * \param x The input tensor
1213 * \param reps The number of times for repeating the tensor
1214 * \param name The name of the operation
1215 * \param tag The tag to mark the operation
1216 *
1217 * \return A Tensor whose op member is the tile operation
1218 */
1219inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_tile",
1220 std::string tag = kBroadcast) {
1221 size_t ndim = x->shape.size();
1222 size_t rdim = reps.size();
1223 size_t tdim = (ndim > rdim) ? ndim : rdim;
1224 Array<PrimExpr> data_shape;
1225 Array<PrimExpr> reps_shape;
1226 Array<PrimExpr> new_shape;
1227 if (ndim == rdim) {
1228 for (size_t i = 0; i < ndim; ++i) {
1229 data_shape.push_back(x->shape[i]);
1230 reps_shape.push_back(reps[i]);
1231 }
1232 } else if (ndim > rdim) {
1233 for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1234 for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
1235 for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1236 } else {
1237 for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
1238 for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1239 for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1240 }
1241 for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
1242
1243 if (is_empty_shape(new_shape)) {
1244 return compute(
1245 new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1246 } else {
1247 return compute(
1248 new_shape,
1249 [&](const Array<Var>& indices) {
1250 Array<PrimExpr> idx;
1251 if (ndim >= rdim) {
1252 for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i]));
1253 } else {
1254 for (size_t i = 0; i < ndim; ++i)
1255 idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1256 }
1257 return x(idx);
1258 },
1259 name, tag);
1260 }
1261}
1262
1263/*!
1264 * \brief Creates an operation to tile elements of an array
1265 *
1266 * \param x The input tensor
1267 * \param new_shape The shape of the output after tiling
1268 * \param rdim The rank of the reps, provided by caller
1269 * \param name The name of the operation
1270 * \param tag The tag to mark the operation
1271 *
1272 * \return A Tensor whose op member is the tile operation
1273 */
1274inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
1275 std::string name = "T_tile", std::string tag = kBroadcast) {
1276 size_t ndim = x->shape.size();
1277 if (is_empty_shape(new_shape)) {
1278 return compute(
1279 new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1280 } else {
1281 return compute(
1282 new_shape,
1283 [&](const Array<Var>& indices) {
1284 Array<PrimExpr> idx;
1285 if (ndim >= rdim) {
1286 for (size_t i = 0; i < ndim; ++i) {
1287 idx.push_back(indexmod(indices[i], x->shape[i]));
1288 }
1289 } else {
1290 for (size_t i = 0; i < ndim; ++i) {
1291 idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1292 }
1293 }
1294 return x(idx);
1295 },
1296 name, tag);
1297 }
1298}
1299
1300/*!
1301 * \brief Gather values along given axis from given indices.
1302 *
1303 * \param data The input data to the operator.
1304 * \param axis The axis along which to index.
1305 * \param indices The indices of values to gather.
1306 * \param name The name of the operation.
1307 * \param tag The tag to mark the operation.
1308 *
1309 * \return A Tensor whose op member is the gather operation
1310 */
1311inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
1312 std::string name = "T_gather", std::string tag = kInjective) {
1313 size_t ndim_d = data->shape.size();
1314 size_t ndim_i = indices->shape.size();
1315 ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
1316 ICHECK_EQ(ndim_d, ndim_i);
1317 if (axis < 0) {
1318 axis += ndim_d;
1319 }
1320 ICHECK_GE(axis, 0);
1321 ICHECK_LT(axis, ndim_d);
1322 if (indices->shape[axis].as<IntImmNode>()) {
1323 size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
1324 ICHECK_GE(indices_dim_i, 1);
1325 }
1326 ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
1327
1328 Array<PrimExpr> out_shape;
1329 for (size_t i = 0; i < ndim_i; ++i) {
1330 out_shape.push_back(indices->shape[i]);
1331 }
1332
1333 return compute(
1334 out_shape,
1335 [&](const Array<Var>& out_index) {
1336 Array<PrimExpr> indices_position;
1337 for (size_t i = 0; i < ndim_i; ++i) {
1338 indices_position.push_back(out_index[i]);
1339 }
1340 Array<PrimExpr> real_indices;
1341 for (size_t i = 0; i < ndim_i; ++i) {
1342 if (i == static_cast<size_t>(axis)) {
1343 real_indices.push_back(indices(indices_position));
1344 } else {
1345 real_indices.push_back(indices_position[i]);
1346 }
1347 }
1348 return data(real_indices);
1349 },
1350 name, tag);
1351}
1352
1353/*!
1354 * \brief Gather elements from a n-dimension array.
1355 *
1356 * \param data The source array.
1357 * \param indices The indices of the values to extract.
1358 * \param batch_dims The number of batch dimensions.
1359 * \param name The name of the operation.
1360 * \param tag The tag to mark the operation.
1361 *
1362 * \return A Tensor whose op member is the gather_nd operation
1363 */
1364inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
1365 std::string name = "T_gather_nd", std::string tag = kInjective) {
1366 size_t ndim_d = data->shape.size();
1367 size_t ndim_i = indices->shape.size();
1368 ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
1369 size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
1370 ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
1371 << "than dimensions of data tensor";
1372 Array<PrimExpr> out_shape;
1373 for (size_t i = 1; i < ndim_i; ++i) {
1374 out_shape.push_back(indices->shape[i]);
1375 }
1376 for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1377 out_shape.push_back(data->shape[i]);
1378 }
1379 return compute(
1380 out_shape,
1381 [&](const Array<Var>& out_index) {
1382 Array<PrimExpr> indices_position;
1383 indices_position.push_back(0);
1384 for (size_t i = 0; i < ndim_i - 1; ++i) {
1385 indices_position.push_back(out_index[i]);
1386 }
1387 Array<PrimExpr> real_indices;
1388 for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1389 real_indices.push_back(out_index[i]);
1390 }
1391 for (size_t i = 0; i < indices_dim0; ++i) {
1392 indices_position.Set(0, make_const(DataType::Int(32), i));
1393 if (indices->dtype.is_int() || indices->dtype.is_uint()) {
1394 real_indices.push_back(indices(indices_position));
1395 } else {
1396 real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
1397 }
1398 }
1399 if (real_indices.size() == ndim_d) {
1400 return data(real_indices);
1401 }
1402 for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
1403 real_indices.push_back(out_index[i]);
1404 }
1405 return data(real_indices);
1406 },
1407 name, tag);
1408}
1409
1410/*!
1411 * \brief Creates an operation that calculates a matrix multiplication
1412 * (row-major notation):
1413 * A(i, k) * B(k, j), if trans_a == trans_b
1414 * the usual transposed combinations, otherwise
1415 *
1416 * \param A The matrix A
1417 * \param B The matrix B
1418 * \param trans_a Is A's layout transposed?
1419 * \param trans_b Is B's layout transposed?
1420 * \param name The name of the operation
1421 * \param tag The tag to mark the operation
1422 *
1423 * \return A Tensor whose op member is the matmul operation
1424 */
1425inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B,
1426 bool trans_a = false, bool trans_b = false,
1427 std::string name = "T_matmul", std::string tag = kMatMul) {
1428 tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
1429 auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
1430 auto l = [&](tvm::tir::Var i, tvm::tir::Var j) {
1431 return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1432 };
1433 return tvm::te::compute(output_shape, l, name, tag);
1434}
1435
1436/*!
1437 * \brief A generalization of matrix multiplication to tensors.
1438 *
1439 * \param A The tensor A
1440 * \param B The tensor B
1441 * \param axes The number of the dimensions to reduce over
1442 * \param name The name of the operation
1443 * \param tag The tag to mark the operation
1444 *
1445 * \return A Tensor computing the result
1446 */
1447inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2,
1448 std::string name = "T_tensordot", std::string tag = kMatMul) {
1449 ICHECK_GE(A->shape.size(), axes);
1450 ICHECK_GE(B->shape.size(), axes);
1451
1452 Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1453 for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1454
1455 Array<IterVar> iter_vars;
1456 for (int i = 0; i < axes; ++i)
1457 iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
1458
1459 auto func = [&A, &B, &iter_vars, axes](const Array<Var>& input_indices) {
1460 Array<PrimExpr> A_indices(input_indices.begin(),
1461 input_indices.begin() + (A->shape.size() - axes));
1462 for (auto& v : iter_vars) A_indices.push_back(v);
1463
1464 Array<PrimExpr> B_indices;
1465 for (auto& v : iter_vars) B_indices.push_back(v);
1466
1467 auto it = input_indices.begin() + (A->shape.size() - axes);
1468 for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
1469
1470 // Some passes don't like reductions with empty axis, so avoid it here
1471 if (iter_vars.empty()) {
1472 return A(A_indices) * B(B_indices);
1473 } else {
1474 return sum(A(A_indices) * B(B_indices), iter_vars);
1475 }
1476 };
1477
1478 return compute(output_shape, func, name, tag);
1479}
1480
1481/*!
1482 * \brief A generalization of matrix multiplication to tensors.
1483 *
1484 * \param A The tensor A
1485 * \param B The tensor B
1486 * \param A_axes The indices of the dimensions of tensor A to reduce over
1487 * \param B_axes The indices of the dimensions of tensor B to reduce over
1488 * \param name The name of the operation
1489 * \param tag The tag to mark the operation
1490 *
1491 * \return A Tensor computing the result
1492 */
1493inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExpr> A_axes,
1494 Array<PrimExpr> B_axes, std::string name = "T_tensordot",
1495 std::string tag = kMatMul) {
1496 ICHECK_EQ(A_axes.size(), B_axes.size());
1497
1498 auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
1499 auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
1500
1501 Array<PrimExpr> output_shape;
1502 for (unsigned i = 0; i < A->shape.size(); ++i)
1503 if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1504 output_shape.push_back(A->shape[i]);
1505 for (unsigned i = 0; i < B->shape.size(); ++i)
1506 if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1507 output_shape.push_back(B->shape[i]);
1508
1509 Array<IterVar> iter_vars;
1510 for (unsigned i = 0; i < B_axes_val.size(); ++i)
1511 iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
1512
1513 auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array<Var>& input_indices) {
1514 int idx_input = 0;
1515 Array<PrimExpr> A_indices;
1516 for (unsigned i = 0; i < A->shape.size(); ++i) {
1517 auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1518 if (axes_pos == A_axes_val.end()) {
1519 A_indices.push_back(input_indices[idx_input++]);
1520 } else {
1521 A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1522 }
1523 }
1524
1525 Array<PrimExpr> B_indices;
1526 for (unsigned i = 0; i < B->shape.size(); ++i) {
1527 auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1528 if (axes_pos == B_axes_val.end()) {
1529 B_indices.push_back(input_indices[idx_input++]);
1530 } else {
1531 B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1532 }
1533 }
1534 return sum(A(A_indices) * B(B_indices), iter_vars);
1535 };
1536 return compute(output_shape, func, name, tag);
1537}
1538
1539inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
1540 DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
1541 PrimExpr num_elem = tvm::cast(
1542 tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
1543 Array<PrimExpr> shape;
1544 return compute(
1545 {num_elem},
1546 [&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name,
1547 tag);
1548}
1549
1550/*!
1551 * \brief Produce grids by expanding input over dimensions defined by other inputs
1552 *
1553 * \param inputs The input tensors
1554 * \param indexing The indexing mode, either "xy" or "ij"
1555 * \param name The name of the operation
1556 * \param tag The tag to mark the operation
1557 *
1558 * \return A Tensor whose op member is the meshgrid operation
1559 */
1560inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& indexing,
1561 std::string name = "T_meshgrid", std::string tag = kInjective) {
1562 const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2;
1563 Array<PrimExpr> out_shape;
1564 for (size_t i = 0; i < inputs.size(); ++i) {
1565 const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1566 out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1567 }
1568 Array<Tensor> result;
1569 for (size_t i = 0; i < inputs.size(); ++i) {
1570 result.push_back(compute(
1571 out_shape,
1572 [&](const Array<Var>& indices) {
1573 const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1574 auto ndim = inputs[i]->GetShape().size();
1575 Array<PrimExpr> real_indices = {};
1576 if (ndim > 0) {
1577 real_indices = {indices[src_index]};
1578 }
1579 return inputs[i](real_indices);
1580 },
1581 name, tag));
1582 }
1583 return result;
1584}
1585
1586/*!
1587 * \brief Transform the layout according to \p src_layout and \p dst_layout
1588 * \param src the source input.
1589 * \param src_layout the source layout.
1590 * \param dst_layout the destination layout.
1591 * \param name output tensor name.
1592 * \param tag output tensor tag.
1593 * \return A tensor with shape in \p dst_layout
1594 */
1595inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
1596 const std::string& dst_layout,
1597 const std::string name = "T_layout_trans",
1598 const std::string tag = kInjective) {
1599 Layout src_layout_struct(src_layout);
1600 Layout dst_layout_struct(dst_layout);
1601
1602 if (src_layout_struct.Equals(dst_layout_struct)) {
1603 return src;
1604 }
1605
1606 ICHECK(src_layout_struct.defined() && dst_layout_struct.defined())
1607 << "cannot convert from/to undefined layout";
1608
1609 auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct);
1610 ICHECK(layout_converter.defined())
1611 << "cannot convert from " << src_layout << " to " << dst_layout;
1612
1613 Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1614
1615 return compute(
1616 dst_shape,
1617 [&](const Array<Var>& dst_indices) {
1618 Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1619 Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1620 PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true
1621 for (size_t i = 0; i < src.ndim(); ++i) {
1622 in_range = in_range && (src_indices[i] < src->shape[i]);
1623 }
1624 return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
1625 },
1626 name, tag);
1627}
1628
1629/*! \brief Utility function for auto_scheduler_layout_transform */
1630inline void parse_auto_scheduler_layout(const String& layout, Array<PrimExpr>* shape,
1631 std::vector<std::string>* axes) {
1632 int32_t factor = 0;
1633 std::string axis = "";
1634 for (char c : std::string(layout)) {
1635 if (c >= 'A' && c <= 'z') {
1636 axis += c;
1637 if (factor != 0) {
1638 shape->push_back(factor);
1639 factor = 0;
1640 }
1641 } else if (c >= '0' && c <= '9') {
1642 factor = factor * 10 + c - '0';
1643 if (!axis.empty()) {
1644 axes->push_back(axis);
1645 axis = "";
1646 }
1647 } else {
1648 LOG(FATAL) << "Invalid layout " << layout;
1649 }
1650 }
1651 if (!axis.empty()) {
1652 axes->push_back(axis);
1653 }
1654}
1655
1656/*!
1657 * \brief Transform the auto-scheduler generated layout according to
1658 * \p src_layout and \p dst_layout
1659 * \param src the source input.
1660 * \param src_layout the source layout.
1661 * \param dst_layout the destination layout.
1662 * \param name output tensor name.
1663 * \param tag output tensor tag.
1664 * \return A tensor with shape in \p dst_layout
1665 */
1666inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout,
1667 const String& dst_layout,
1668 const String name = "T_auto_scheduler_layout_trans",
1669 const String tag = kInjective) {
1670 Array<PrimExpr> src_shape;
1671 std::vector<std::string> src_axes;
1672 Array<PrimExpr> dst_shape;
1673 std::vector<std::string> dst_axes;
1674
1675 parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
1676 parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
1677 return compute(
1678 dst_shape,
1679 [&](const Array<Var>& dst_indices) {
1680 Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1681 Array<PrimExpr> src_indices;
1682 for (const std::string& src_axis : src_axes) {
1683 PrimExpr src_index = 0;
1684 CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1685 for (size_t i = 0; i < dst_axes.size(); ++i) {
1686 if (dst_axes[i] == src_axis) {
1687 src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1688 }
1689 }
1690 src_indices.push_back(src_index);
1691 }
1692 return src(src_indices);
1693 },
1694 name, tag);
1695}
1696
1697/*!
1698 * \brief Transform the meta-schedule generated layout according to TIR's IndexMap
1699 * \param src the source input.
1700 * \param index_map The TIR IndexMap
1701 * \param name output tensor name.
1702 * \param tag output tensor tag.
1703 * \return A tensor. The layout transformation method
1704 * \note Example:
1705 *
1706 * For the indexing pattern below:
1707 *
1708 * for i in range(32):
1709 * for j in range(64):
1710 * load A[
1711 * i / 16 * 4 + j / 16,
1712 * i % 16 * 16 + j % 16,
1713 * ]
1714 *
1715 * The corresponding indexing pattern in TIR is:
1716 *
1717 * A[i, j] => A'[i / 4, j / 16, i % 4, j % 16]
1718 *
1719 * which converts the pattern to:
1720 *
1721 * for i in range(32):
1722 * for j in range(64):
1723 * load A'[
1724 * i / 16 + j / 64,
1725 * i % 16,
1726 * j % 64 / 16,
1727 * j % 16,
1728 * ]
1729 *
1730 * In this case, the transformation pattern is:
1731 * A'[a, b, c, d] = A[a * 4 + c, b * 16 + d]
1732 */
1733inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map,
1734 const String name = "T_meta_schedule_layout_trans",
1735 const String tag = kInjective) {
1736 Array<Range> iter_domain;
1737 iter_domain.reserve(src->shape.size());
1738 for (const PrimExpr& e : src->shape) {
1739 iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
1740 }
1741 Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape);
1742 return compute(
1743 post_transform_shape,
1744 [src, inv = index_map.Inverse(iter_domain)](const Array<Var>& indices) -> PrimExpr {
1745 return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()}));
1746 },
1747 name, tag);
1748}
1749
1750/*!
1751 * \brief Get the shape of input tensor.
1752 * \param src the input tensor.
1753 * \param dtype the type of the elements in the tensor.
1754 * \param name output tensor name.
1755 * \param tag output tensor tag.
1756 * \return Tensor of input shape.
1757 */
1758inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape",
1759 const std::string tag = kInjective) {
1760 int ndim = static_cast<int>(src->shape.size());
1761 Array<PrimExpr> out_shape{ndim};
1762 return compute(
1763 out_shape,
1764 [&](const Array<Var>& indices) {
1765 auto idx = indices[0];
1766 PrimExpr ret = 0;
1767 for (int i = 0; i < ndim; ++i) {
1768 ret = tvm::if_then_else(idx == i, src->shape[i], ret);
1769 }
1770 return tvm::cast(dtype, ret);
1771 },
1772 name, tag);
1773}
1774
1775/*!
1776 * \brief Get the size of input tensor.
1777 * \param src the input tensor.
1778 * \param dtype the type of the elements in the tensor.
1779 * \param name output tensor name.
1780 * \param tag output tensor tag.
1781 * \return Tensor of input shape.
1782 */
1783inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
1784 const std::string& name = "ndarray_size",
1785 const std::string& tag = kInjective) {
1786 int ndim = static_cast<int>(src->shape.size());
1787 Array<PrimExpr> out_ndarray_size = {};
1788 return compute(
1789 out_ndarray_size,
1790 [&](const Array<Var>& indices) {
1791 PrimExpr ret = 1;
1792 for (int i = 0; i < ndim; ++i) {
1793 ret *= src->shape[i];
1794 }
1795 return tvm::cast(dtype, ret);
1796 },
1797 name, tag);
1798}
1799
1800/*!
1801 * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value,
1802 other locations take value off_value.
1803 * \param indices locations to set to on_value.
1804 * \param on_value value that locations represented by indices take on.
1805 * \param off_value value that other locations take on.
1806 * \param depth depth of the one-hot dimension.
1807 * \param axis axis to fill.
1808 * \param dtype data type of the output tensor.
1809 * \param oshape shape of the output tensor.
1810 * \param name output tensor name.
1811 * \param tag output tensor tag.
1812 * \return one-hot tensor.
1813 */
1814inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
1815 int depth, int axis, const DataType& dtype,
1816 Array<PrimExpr> oshape = Array<PrimExpr>(),
1817 const std::string name = "T_one_hot", const std::string tag = kInjective) {
1818 int true_axis = (axis == -1) ? indices->shape.size() : axis;
1819 if (oshape.size() == 0) {
1820 int ndim = indices->shape.size() + 1;
1821 int indices_index = 0;
1822 for (int i = 0; i < ndim; i++) {
1823 if (i == true_axis) {
1824 oshape.push_back(Integer(depth));
1825 } else {
1826 oshape.push_back(indices->shape[indices_index++]);
1827 }
1828 }
1829 }
1830
1831 PrimExpr on_value_cast = cast(dtype, on_value);
1832 PrimExpr off_value_cast = cast(dtype, off_value);
1833 return compute(
1834 oshape,
1835 [&](const Array<Var>& iter_vars) {
1836 Array<Var> indices_indices;
1837 for (size_t i = 0; i < iter_vars.size(); i++) {
1838 if (static_cast<int>(i) == true_axis) {
1839 continue;
1840 }
1841
1842 indices_indices.push_back(iter_vars[i]);
1843 }
1844
1845 auto idx = iter_vars[true_axis];
1846 return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
1847 },
1848 name, tag);
1849}
1850
1851/*!
1852 * \brief Get a dense tensor.
1853 * \param sparse_indices sparse_indices[i] contains sparse_values[i] will be placed.
1854 * \param output_shape is the shape of the dense output tensor .
1855 * \param sparse_values is a 0-D or 1-D tensor. Values for each row of sparse_indices.
1856 * \param default_value is a 0-D tensor. Defaults to zero.
1857 * \param name output tensor name.
1858 * \param tag output tensor tag.
1859 * \return Tensor of output_shape.
1860 */
1861inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr>& output_shape,
1862 const Tensor& sparse_values, const PrimExpr& default_value,
1863 const std::string name = "T_sparse_to_dense",
1864 const std::string tag = kInjective) {
1865 ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values";
1866 ICHECK_LE(sparse_indices->shape.size(), 3)
1867 << "sparse_indices tensor should be 0D, 1D, or 2D only";
1868 ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only";
1869
1870 const auto rank_sparse_indices = static_cast<int>(sparse_indices->shape.size());
1871 Array<PrimExpr> oshape;
1872 for (auto l : output_shape) {
1873 oshape.push_back(l);
1874 }
1875 return compute(
1876 oshape,
1877 [&](const Array<Var>& indices) {
1878 PrimExpr ret = default_value;
1879 if (0 == rank_sparse_indices) {
1880 ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret);
1881 } else if (1 == rank_sparse_indices) {
1882 for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1883 ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
1884 }
1885 } else {
1886 for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1887 PrimExpr aggregate_condition;
1888 for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
1889 PrimExpr comparision = indices[k] == sparse_indices[j][k];
1890 aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
1891 }
1892 ret = if_then_else(aggregate_condition, sparse_values[j], ret);
1893 }
1894 }
1895 return ret;
1896 },
1897 name, tag);
1898}
1899
1900/*!
1901 * \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
1902 * \param input input tensor.
1903 * \param diagonal values to be filled in the diagonals.
1904 * \param k1 lower limit (included) of the range of diagonals.
1905 * \param k2 upper limit (included) of the range of diagonals.
1906 * \param super_diag_right_align bool, true iff super-diagonal is right aligned (left-padded).
1907 * \param sub_diag_right_align bool, true iff sub-diagonal is right aligned (left-padded).
1908 * \param name output tensor name.
1909 * \param tag output tensor tag.
1910 * \return new tensor with given diagonal values.
1911 */
1912inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
1913 bool super_diag_right_align, bool sub_diag_right_align,
1914 const std::string name = "T_matrix_set_diag",
1915 const std::string tag = kInjective) {
1916 size_t ndim = input->shape.size() - 1;
1917
1918 bool only_one_diagonal = k1 == k2;
1919
1920 return compute(
1921 input->shape,
1922 [&](const Array<Var>& iter_vars) {
1923 auto get_diag = [&]() {
1924 Array<PrimExpr> diagonal_indices;
1925 PrimExpr k, offset = 0;
1926 for (size_t i = 0; i < ndim - 1; i++) {
1927 diagonal_indices.push_back(iter_vars[i]);
1928 }
1929 if (only_one_diagonal) {
1930 k = k1;
1931 } else {
1932 // Determining which diagonal/sub-diagonal/super-diagonal it is
1933 k = iter_vars[ndim] - iter_vars[ndim - 1];
1934 diagonal_indices.push_back(k2 - k);
1935
1936 // Calculating the offset in diagonal tensor for this diagonal
1937 auto get_offset = [&](PrimExpr M, PrimExpr N) {
1938 // offset = max_diagonal_length - diagonal_length
1939 return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
1940 };
1941 offset = if_then_else(
1942 k >= 0,
1943 super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
1944 : 0,
1945 sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
1946 : 0);
1947 }
1948 diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
1949 offset);
1950 return diagonal(diagonal_indices);
1951 };
1952 return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
1953 if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
1954 get_diag(), input(iter_vars)),
1955 input(iter_vars));
1956 },
1957 name, tag);
1958}
1959
1960/*!
1961 * \brief Numpy style advanced indexing with tensor.
1962 * \param data is input data.
1963 * \param indices is list of indexing tensors.
1964 * \param name output tensor name.
1965 * \param tag output tensor tag.
1966 * \return Output tensor.
1967 */
1968inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
1969 const std::string name = "advanced_index",
1970 const std::string tag = kInjective) {
1971 ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!";
1972 Array<PrimExpr> oshape;
1973 Array<PrimExpr> broadcast_shape;
1974 Array<Tensor> bindices;
1975
1976 broadcast_shape = indices[0]->shape;
1977 for (size_t i = 1; i < indices.size(); ++i) {
1978 auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape);
1979 broadcast_shape = Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
1980 }
1981 if (indices.size() == 1) {
1982 // quick path
1983 bindices = indices;
1984 } else {
1985 // Do broadcast for indices
1986 for (size_t i = 0; i < indices.size(); ++i) {
1987 bindices.push_back(broadcast_to(indices[i], broadcast_shape));
1988 }
1989 }
1990
1991 for (const auto& dim : broadcast_shape) {
1992 oshape.push_back(dim);
1993 }
1994 for (size_t i = indices.size(); i < data->shape.size(); ++i) {
1995 oshape.push_back(data->shape[i]);
1996 }
1997
1998 return compute(
1999 oshape,
2000 [&](const Array<Var>& iter_var) {
2001 Array<PrimExpr> tensor_indices;
2002 for (size_t i = 0; i < broadcast_shape.size(); ++i) {
2003 tensor_indices.push_back(iter_var[i]);
2004 }
2005
2006 Array<PrimExpr> real_indices;
2007 for (size_t i = 0; i < bindices.size(); ++i) {
2008 real_indices.push_back(bindices[i](tensor_indices));
2009 }
2010 for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
2011 real_indices.push_back(iter_var[i]);
2012 }
2013
2014 return data(real_indices);
2015 },
2016 name, tag);
2017}
2018
2019} // namespace topi
2020} // namespace tvm
2021#endif // TVM_TOPI_TRANSFORM_H_
2022