1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/attrs/transform.h
22 * \brief Transform operators.
23 */
24#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
25#define TVM_RELAY_ATTRS_TRANSFORM_H_
26
27#include <tvm/ir/attrs.h>
28#include <tvm/relay/base.h>
29#include <tvm/relay/expr.h>
30#include <tvm/tir/index_map.h>
31
32#include <string>
33
34namespace tvm {
35namespace relay {
36
37/*! \brief Attributes used for the sliding_window operator */
38struct SlidingWindowAttrs : public tvm::AttrsNode<SlidingWindowAttrs> {
39 int axis;
40 Array<Integer> window_shape;
41 Array<Integer> strides;
42 TVM_DECLARE_ATTRS(SlidingWindowAttrs, "relay.attrs.SlidingWindowAttrs") {
43 TVM_ATTR_FIELD(axis).describe(
44 "What axis the sliding window begin forming over."
45 "Window will be slid over this axis and all following axes."
46 "The axis value determines the window shape (and thus, the"
47 "number of strides):"
48 "window shape and strides must both be of length"
49 "`data.ndim-axis`.");
50 TVM_ATTR_FIELD(window_shape)
51 .describe(
52 "The window shape to form over the input."
53 "Window shape must be of length `data.ndim-axis`.");
54 TVM_ATTR_FIELD(strides).describe(
55 "How to stride the window along each dimension."
56 "Strides must be of length `data.ndim-axis`.");
57 }
58}; // struct SlidingWindowAttrs
59
60/*! \brief data type cast */
61struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
62 DataType dtype;
63
64 TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
65 TVM_ATTR_FIELD(dtype).describe("Target data type");
66 }
67}; // struct CastAttrs.
68
69/*! \brief Attributes used in expand_dims operators */
70struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
71 int axis;
72 int num_newaxis;
73
74 TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") {
75 TVM_ATTR_FIELD(axis).describe(
76 "The axis at which the input array is expanded."
77 "Should lie in range `[-data.ndim - 1, data.ndim]`."
78 "If `axis < 0`, it is the first axis inserted;"
79 "If `axis >= 0`, it is the last axis inserted in Python's negative indexing.");
80 TVM_ATTR_FIELD(num_newaxis)
81 .describe("Number of axes to be inserted. Should be >= 0.")
82 .set_lower_bound(0)
83 .set_default(1);
84 }
85}; // struct ExpandDimsAttrs
86
87/*! \brief Attributes used in dynamic expand_dims operators */
88struct DynExpandDimsAttrs : public tvm::AttrsNode<DynExpandDimsAttrs> {
89 int num_newaxis;
90
91 TVM_DECLARE_ATTRS(DynExpandDimsAttrs, "relay.attrs.DynExpandDimsAttrs") {
92 TVM_ATTR_FIELD(num_newaxis)
93 .describe("Number of axes to be inserted. Should be >= 0.")
94 .set_lower_bound(0)
95 .set_default(1);
96 }
97}; // struct ExpandDimsAttrs
98
99/*! \brief Attributes used in concatenate operators */
100struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {
101 int axis;
102 TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") {
103 TVM_ATTR_FIELD(axis)
104 .describe(
105 "The axis at which the input arrays are concatenated."
106 "Should lie in range `[-ndim, ndim)`.")
107 .set_default(0);
108 }
109}; // struct ConcatenateAttrs
110
111/*! \brief Attributes used in transpose operators */
112struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
113 Array<Integer> axes;
114 TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
115 TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified.");
116 }
117}; // struct TransposeAttrs
118
119/*! \brief Attributes used in reshape operators */
120struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
121 Array<Integer> newshape;
122 bool allowzero;
123 TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
124 TVM_ATTR_FIELD(newshape).describe(
125 "The new shape. Should be compatible with the original shape.");
126 TVM_ATTR_FIELD(allowzero).set_default(0).describe(
127 "Whether to honor the value of zero in newshape.");
128 }
129}; // struct ReshapeAttrs
130
131/*! \brief Attributes used in MXNet-style reshape_like operators */
132struct ReshapeLikeAttrs : public tvm::AttrsNode<ReshapeLikeAttrs> {
133 int lhs_begin;
134 Integer lhs_end; // can be None
135 int rhs_begin;
136 Integer rhs_end; // can be None
137 TVM_DECLARE_ATTRS(ReshapeLikeAttrs, "relay.attrs.ReshapeLikeAttrs") {
138 TVM_ATTR_FIELD(lhs_begin).set_default(0).describe(
139 "The axis of the input where reshaping should begin.");
140 TVM_ATTR_FIELD(lhs_end)
141 .set_default(NullValue<Integer>())
142 .describe("The axis of the input where reshaping should end, exclusive.");
143 TVM_ATTR_FIELD(rhs_begin).set_default(0).describe(
144 "The axis of the shape_like tensor to begin taking dimensions from.");
145 TVM_ATTR_FIELD(rhs_end)
146 .set_default(NullValue<Integer>())
147 .describe("The axis of the shape_like tensor to end taking dimensions from, exclusive.");
148 }
149}; // struct ReshapeLikeAttrs
150
151struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
152 Integer axis;
153
154 TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") {
155 TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
156 }
157};
158
159struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
160 Integer axis;
161
162 TVM_DECLARE_ATTRS(ScatterAddAttrs, "relay.attrs.ScatterAddAttrs") {
163 TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
164 }
165};
166
167struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
168 String mode;
169
170 TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
171 TVM_ATTR_FIELD(mode).describe(
172 "Accumulation mode of the scatter, either \"update\" or \"add\".");
173 }
174};
175
176struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
177 Integer axis;
178
179 TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") {
180 TVM_ATTR_FIELD(axis)
181 .set_default(NullValue<Integer>())
182 .describe("The axis over which to select values.");
183 }
184};
185
186struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
187 Integer batch_dims;
188 Optional<Integer> index_rank;
189
190 TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs") {
191 TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
192 TVM_ATTR_FIELD(index_rank)
193 .set_default(NullValue<Integer>())
194 .describe(
195 "The size of an indexing tuple, which is a fixed value. Only needed when the number of "
196 "indexting tuples is dynamic.");
197 }
198};
199
200struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
201 Integer batch_dims;
202 Integer axis;
203 tvm::String mode;
204
205 TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
206 TVM_ATTR_FIELD(batch_dims)
207 .set_default(0)
208 .describe("The batch_dims over which to select values.");
209 TVM_ATTR_FIELD(axis)
210 .set_default(NullValue<Integer>())
211 .describe("The axis over which to select values.");
212 TVM_ATTR_FIELD(mode).set_default("clip").describe(
213 "Specify how out-of-bound indices will behave."
214 "clip - clip to the range (default)"
215 "wrap - wrap around the indices"
216 "fast - no clip or wrap around (user must make sure indices are in-bound)");
217 }
218};
219
220/*! \brief Attributes that specify a tensor */
221struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
222 Optional<Array<Integer>> shape;
223 DataType dtype;
224
225 TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
226 TVM_ATTR_FIELD(shape).describe("Target shape.");
227 TVM_ATTR_FIELD(dtype).describe("Target data type.").set_default(NullValue<DataType>());
228 }
229}; // struct InitOpAttrs
230
231/*! \brief Attributes used in arange operators */
232struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
233 Expr start;
234 Expr stop;
235 Expr step;
236 DataType dtype;
237
238 TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
239 TVM_ATTR_FIELD(start).describe("Start of interval. The interval includes this value.");
240 TVM_ATTR_FIELD(stop).describe("Stop of interval. The interval does not include this value.");
241 TVM_ATTR_FIELD(step).describe("Spacing between values.");
242 TVM_ATTR_FIELD(dtype).describe("Target data type.");
243 }
244}; // struct ArangeAttrs
245
246/*! \brief Attributes used in meshgrid operators */
247struct MeshgridAttrs : public tvm::AttrsNode<MeshgridAttrs> {
248 std::string indexing;
249
250 TVM_DECLARE_ATTRS(MeshgridAttrs, "relay.attrs.MeshgridAttrs") {
251 TVM_ATTR_FIELD(indexing)
252 .describe(
253 "Indexing mode, either \"ij\" for matrix or \"xy\" for cartesian in which first two"
254 "dimensions are swapped.")
255 .set_default("ij");
256 }
257}; // struct MeshgridAttrs
258
259/*! \brief Attributes used in stack operators */
260struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
261 Integer axis;
262 TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs") {
263 TVM_ATTR_FIELD(axis).set_default(0).describe(
264 "The axis in the result array along which the input arrays are stacked.");
265 }
266}; // struct StackAttrs
267
268/*! \brief Attributes used in repeat operators */
269struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
270 Integer repeats;
271 Integer axis;
272 TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
273 TVM_ATTR_FIELD(repeats).describe("The number of repetitions for each element.");
274 TVM_ATTR_FIELD(axis)
275 .set_default(NullValue<Integer>())
276 .describe(" The axis along which to repeat values.");
277 }
278}; // struct RepeatAttrs
279
280/*! \brief Attributes used in tile operators */
281struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
282 Array<Integer> reps;
283 TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") {
284 TVM_ATTR_FIELD(reps).describe(
285 "The number of times for repeating the tensor a."
286 "Each dim sizeof reps must be a positive integer.");
287 }
288}; // struct TileAttrs
289
290/*! \brief Attributes used in reverse operators */
291struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
292 Integer axis;
293 TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") {
294 TVM_ATTR_FIELD(axis)
295 .set_default(NullValue<Integer>())
296 .describe("The axis along which to reverse elements.");
297 }
298}; // struct ReverseAttrs
299
300/*! \brief Attributes used in reverse_sequence operators */
301struct ReverseSequenceAttrs : public tvm::AttrsNode<ReverseSequenceAttrs> {
302 Integer seq_axis;
303 Integer batch_axis;
304
305 TVM_DECLARE_ATTRS(ReverseSequenceAttrs, "relay.attrs.ReverseSequenceAttrs") {
306 TVM_ATTR_FIELD(seq_axis).set_default(1).describe(
307 "The seq axis along which to reverse elements.");
308 TVM_ATTR_FIELD(batch_axis)
309 .set_default(0)
310 .describe("The batch axis along which to slice the tensor.");
311 }
312}; // struct ReverseSequenceAttrs
313
314/*! \brief Attributes used in squeeze operators */
315struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
316 // use axis to make the name numpy compatible.
317 Array<Integer> axis;
318
319 TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
320 TVM_ATTR_FIELD(axis)
321 .describe(
322 "The axis to squeeze in the input tensor."
323 "If `axis = None`, all axis of dimension 1 get squeezed;"
324 "Else, the dimension in axes get squeezed."
325 "It is an error if an axis does not has dimension 1.")
326 .set_default(NullValue<Array<Integer>>());
327 }
328}; // struct SqueezeAttrs
329
330struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
331 ObjectRef indices_or_sections;
332 int axis;
333
334 TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
335 TVM_ATTR_FIELD(indices_or_sections)
336 .describe(
337 "Indices or sections to split into. Accepts an int or a tuple"
338 "If indices_or_sections is an integer, the input will be divided equally"
339 "along given axis. If such a split is not possible, an error is raised."
340 "If indices_or_sections is a tuple of sorted integers,"
341 "the entries indicate where along axis the array is split.");
342 TVM_ATTR_FIELD(axis).set_default(0).describe("the axis to be splitted.");
343 }
344};
345
346/*! \brief Attributes for StridedSlice operator */
347struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
348 Optional<Array<Integer>> begin;
349 Optional<Array<Integer>> end;
350 Optional<Array<Integer>> strides;
351 tvm::String slice_mode;
352 Optional<Array<Integer>> axes;
353
354 TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
355 TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
356 TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive");
357 TVM_ATTR_FIELD(strides).describe(
358 "Stride values of the slice, a stride can be negative, which causes a reverse slice.");
359 TVM_ATTR_FIELD(slice_mode)
360 .set_default("end")
361 .describe(
362 "The slice mode [end, size]."
363 "end - The default slice mode, ending indices for the slice."
364 "size - The input strides will be ignored, input end in this mode indicates the size"
365 "of a slice starting at the location specified by begin. If end[i] is -1,"
366 "all remaining elements in that dimension are included in the slice");
367 TVM_ATTR_FIELD(axes).describe(
368 "Axes along which slicing is applied. When it is specified, the length of begin, end, "
369 "strides, and axes must be equal.");
370 }
371};
372
373struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
374 Array<Integer> axes;
375
376 TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") {
377 TVM_ATTR_FIELD(axes).describe(
378 "List of axes on which input data will be sliced according to the "
379 "corresponding size of the second input. By default will slice "
380 "on all axes. Negative axes mean counting in reverse.");
381 }
382};
383
384/*! \brief Attributes for Clip operator */
385struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
386 double a_min;
387 double a_max;
388
389 TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
390 TVM_ATTR_FIELD(a_min).describe("The minimum clip value.");
391 TVM_ATTR_FIELD(a_max).describe("The maximum clip value.");
392 }
393};
394
395/*! \brief Attributes for FixedPointMultiply operator */
396struct FixedPointMultiplyAttrs : public tvm::AttrsNode<FixedPointMultiplyAttrs> {
397 int32_t multiplier;
398 int32_t shift;
399
400 TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") {
401 TVM_ATTR_FIELD(multiplier)
402 .describe("Multiplier of a fixed floating point number described as multiplier*2^(shift)");
403 TVM_ATTR_FIELD(shift).describe(
404 "Shift of a fixed floating point number described as multiplier*2^(shift)");
405 }
406};
407
408/*! \brief Attributes for per channel/per axes FixedPointMultiply operator */
409struct FixedPointMultiplyPerAxisAttrs : public tvm::AttrsNode<FixedPointMultiplyPerAxisAttrs> {
410 bool is_lshift_required;
411 bool is_rshift_required;
412 Array<Integer> axes;
413
414 TVM_DECLARE_ATTRS(FixedPointMultiplyPerAxisAttrs, "relay.attrs.FixedPointMultiplyPerAxisAttrs") {
415 TVM_ATTR_FIELD(is_lshift_required)
416 .describe("Whether left shift is required in fixed point multiplication.")
417 .set_default(false);
418 TVM_ATTR_FIELD(is_rshift_required)
419 .describe("Whether right shift is required in fixed point multiplication.")
420 .set_default(false);
421 TVM_ATTR_FIELD(axes).describe("List of axes on which input data was quantized.");
422 }
423};
424
425/*! \brief Attributes for LayoutTransform operator */
426struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
427 std::string src_layout;
428 std::string dst_layout;
429
430 TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
431 TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. NCHW)");
432 TVM_ATTR_FIELD(dst_layout).describe("The destination layout of the tensor. (e.g. NCHW16c)");
433 }
434};
435
436/*! \brief Attributes for AutoSchedulerLayoutTransform operator */
437struct AutoSchedulerLayoutTransformAttrs
438 : public tvm::AttrsNode<AutoSchedulerLayoutTransformAttrs> {
439 std::string src_layout;
440 std::string dst_layout;
441
442 TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs,
443 "relay.attrs.AutoSchedulerLayoutTransformAttrs") {
444 TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. 1N32C112H112W)");
445 TVM_ATTR_FIELD(dst_layout)
446 .describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)");
447 }
448};
449
450/*! \brief Attributes for MetaScheduleLayoutTransform operator */
451struct MetaScheduleLayoutTransformAttrs : public tvm::AttrsNode<MetaScheduleLayoutTransformAttrs> {
452 tir::IndexMap index_map;
453
454 TVM_DECLARE_ATTRS(MetaScheduleLayoutTransformAttrs,
455 "relay.attrs.MetaScheduleLayoutTransformAttrs") {
456 TVM_ATTR_FIELD(index_map).describe(
457 "The order of the extents, for example, "
458 "let extents = [2, 3, 4], reorder = [0, 2, 1], and the shape of buffer A is (4, 6)"
459 "then A[i, j] will be first rewritten to "
460 "A[(6 * i + j) / 12, (6 * i + j) / 4 % 3 , (6 * i + j) % 4] according to the `extents`,"
461 "and then reordered to A[(6 * i + j) / 12, (6 * i + j) % 4 , (6 * i + j) / 4 % 3]"
462 "according to `reorder`");
463 }
464};
465
466/*! \brief Attributes for ShapeOf operator */
467struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
468 DataType dtype;
469
470 TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") {
471 TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue<DataType>());
472 }
473};
474
475struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> {
476 double mask_value;
477 int axis;
478
479 TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs") {
480 TVM_ATTR_FIELD(mask_value).set_default(0).describe("The masking value.");
481 TVM_ATTR_FIELD(axis).set_default(0).describe(
482 "The axis of the length dimension. Can only be 0 or 1.");
483 }
484}; // struct SequenceMaskAttrs.
485
486/*! \brief Attributes used in sparse_to_dense operator */
487struct SparseToDenseAttrs : public tvm::AttrsNode<SparseToDenseAttrs> {
488 Array<Integer> output_shape;
489
490 TVM_DECLARE_ATTRS(SparseToDenseAttrs, "relay.attrs.SparseToDenseAttrs") {
491 TVM_ATTR_FIELD(output_shape).describe("Shape of the dense output tensor");
492 }
493}; // struct SparseToDenseAttrs
494
495/*! \brief Attributes for ndarray_size operator */
496struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
497 DataType dtype;
498
499 TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") {
500 TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue<DataType>());
501 }
502};
503
504/*! \brief Attributes used in one-hot operator */
505struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
506 int depth;
507 int axis;
508 DataType dtype;
509
510 TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") {
511 TVM_ATTR_FIELD(depth).set_default(1).describe("Depth of the one hot dimension.");
512 TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill.");
513 TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>()).describe("Output data type.");
514 }
515}; // struct OneHotAttrs
516
517/*! \brief Attributes used in matrix_set_diag operator */
518struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
519 int k1;
520 int k2;
521 bool super_diag_right_align;
522 bool sub_diag_right_align;
523
524 TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") {
525 TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals.");
526 TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals.");
527 TVM_ATTR_FIELD(super_diag_right_align)
528 .set_default(true)
529 .describe("Bool, true iff super-diagonal is right aligned (left-padded).");
530 TVM_ATTR_FIELD(sub_diag_right_align)
531 .set_default(false)
532 .describe("Bool, true iff sub-diagonal is right aligned (left-padded).");
533 }
534}; // struct MatrixSetDiagAttrs
535
536/*! \brief Attributes used in cumsum and cumprod operator */
537struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
538 Integer axis;
539 DataType dtype;
540 Bool exclusive = Bool(false);
541 TVM_DECLARE_ATTRS(ScanopAttrs, "relay.attrs.ScanopAttrs") {
542 TVM_ATTR_FIELD(axis).describe("The axis to operate over").set_default(NullValue<Integer>());
543 TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
544
545 // Default is 0 which is "false"
546 TVM_ATTR_FIELD(exclusive)
547 .describe("The first element is not included")
548 .set_default(Bool(false));
549 }
550}; // struct ScanopAttrs
551
552/*! \brief Attributes used in unique operator */
553struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
554 bool sorted;
555 bool return_counts;
556 TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs") {
557 TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted").set_default(true);
558 TVM_ATTR_FIELD(return_counts)
559 .describe("Whether to return an additional tensor with counts of each unique elements")
560 .set_default(false);
561 }
562}; // struct UniqueAttrs
563
564/*! \brief Attributes used in einsum operator */
565struct EinsumAttrs : public tvm::AttrsNode<EinsumAttrs> {
566 String equation;
567
568 TVM_DECLARE_ATTRS(EinsumAttrs, "relay.attrs.EinsumAttrs") {
569 TVM_ATTR_FIELD(equation).describe("The einsum expression string");
570 }
571}; // struct EinsumAttrs
572
573/*! \brief Attributes used in stft operator */
574struct StftAttrs : public tvm::AttrsNode<StftAttrs> {
575 int n_fft;
576 int hop_length;
577 int win_length;
578 bool normalized;
579 bool onesided;
580
581 TVM_DECLARE_ATTRS(StftAttrs, "relay.attrs.StftAttrs") {
582 TVM_ATTR_FIELD(n_fft).set_default(-1).describe("The size of Fourier transform");
583 TVM_ATTR_FIELD(hop_length)
584 .set_default(-1)
585 .describe("The distance between neighboring sliding window frames");
586 TVM_ATTR_FIELD(win_length).set_default(-1).describe("The size of window frame and STFT filter");
587 TVM_ATTR_FIELD(normalized)
588 .set_default(false)
589 .describe("Whether to return the normalized STFT results");
590 TVM_ATTR_FIELD(onesided).set_default(true).describe(
591 "Whether to return onesided result or fill with conjugate symmetry");
592 }
593}; // struct StftAttrs
594
595struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
596 bool upper;
597
598 TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs") {
599 TVM_ATTR_FIELD(upper).set_default(true).describe(
600 "Whether to keep the upper or lower half of the diagonal.");
601 }
602}; // struct TriluAttrs
603
604} // namespace relay
605} // namespace tvm
606#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
607