1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/attrs/transform.h |
22 | * \brief Transform operators. |
23 | */ |
24 | #ifndef TVM_RELAY_ATTRS_TRANSFORM_H_ |
25 | #define TVM_RELAY_ATTRS_TRANSFORM_H_ |
26 | |
27 | #include <tvm/ir/attrs.h> |
28 | #include <tvm/relay/base.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/tir/index_map.h> |
31 | |
32 | #include <string> |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | |
37 | /*! \brief Attributes used for the sliding_window operator */ |
38 | struct SlidingWindowAttrs : public tvm::AttrsNode<SlidingWindowAttrs> { |
39 | int axis; |
40 | Array<Integer> window_shape; |
41 | Array<Integer> strides; |
42 | TVM_DECLARE_ATTRS(SlidingWindowAttrs, "relay.attrs.SlidingWindowAttrs" ) { |
43 | TVM_ATTR_FIELD(axis).describe( |
44 | "What axis the sliding window begin forming over." |
45 | "Window will be slid over this axis and all following axes." |
46 | "The axis value determines the window shape (and thus, the" |
47 | "number of strides):" |
48 | "window shape and strides must both be of length" |
49 | "`data.ndim-axis`." ); |
50 | TVM_ATTR_FIELD(window_shape) |
51 | .describe( |
52 | "The window shape to form over the input." |
53 | "Window shape must be of length `data.ndim-axis`." ); |
54 | TVM_ATTR_FIELD(strides).describe( |
55 | "How to stride the window along each dimension." |
56 | "Strides must be of length `data.ndim-axis`." ); |
57 | } |
58 | }; // struct SlidingWindowAttrs |
59 | |
60 | /*! \brief data type cast */ |
61 | struct CastAttrs : public tvm::AttrsNode<CastAttrs> { |
62 | DataType dtype; |
63 | |
64 | TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs" ) { |
65 | TVM_ATTR_FIELD(dtype).describe("Target data type" ); |
66 | } |
67 | }; // struct CastAttrs. |
68 | |
69 | /*! \brief Attributes used in expand_dims operators */ |
70 | struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> { |
71 | int axis; |
72 | int num_newaxis; |
73 | |
74 | TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs" ) { |
75 | TVM_ATTR_FIELD(axis).describe( |
76 | "The axis at which the input array is expanded." |
77 | "Should lie in range `[-data.ndim - 1, data.ndim]`." |
78 | "If `axis < 0`, it is the first axis inserted;" |
79 | "If `axis >= 0`, it is the last axis inserted in Python's negative indexing." ); |
80 | TVM_ATTR_FIELD(num_newaxis) |
81 | .describe("Number of axes to be inserted. Should be >= 0." ) |
82 | .set_lower_bound(0) |
83 | .set_default(1); |
84 | } |
85 | }; // struct ExpandDimsAttrs |
86 | |
87 | /*! \brief Attributes used in dynamic expand_dims operators */ |
88 | struct DynExpandDimsAttrs : public tvm::AttrsNode<DynExpandDimsAttrs> { |
89 | int num_newaxis; |
90 | |
91 | TVM_DECLARE_ATTRS(DynExpandDimsAttrs, "relay.attrs.DynExpandDimsAttrs" ) { |
92 | TVM_ATTR_FIELD(num_newaxis) |
93 | .describe("Number of axes to be inserted. Should be >= 0." ) |
94 | .set_lower_bound(0) |
95 | .set_default(1); |
96 | } |
97 | }; // struct ExpandDimsAttrs |
98 | |
99 | /*! \brief Attributes used in concatenate operators */ |
100 | struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> { |
101 | int axis; |
102 | TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs" ) { |
103 | TVM_ATTR_FIELD(axis) |
104 | .describe( |
105 | "The axis at which the input arrays are concatenated." |
106 | "Should lie in range `[-ndim, ndim)`." ) |
107 | .set_default(0); |
108 | } |
109 | }; // struct ConcatenateAttrs |
110 | |
111 | /*! \brief Attributes used in transpose operators */ |
112 | struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> { |
113 | Array<Integer> axes; |
114 | TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs" ) { |
115 | TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified." ); |
116 | } |
117 | }; // struct TransposeAttrs |
118 | |
119 | /*! \brief Attributes used in reshape operators */ |
120 | struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> { |
121 | Array<Integer> newshape; |
122 | bool allowzero; |
123 | TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs" ) { |
124 | TVM_ATTR_FIELD(newshape).describe( |
125 | "The new shape. Should be compatible with the original shape." ); |
126 | TVM_ATTR_FIELD(allowzero).set_default(0).describe( |
127 | "Whether to honor the value of zero in newshape." ); |
128 | } |
129 | }; // struct ReshapeAttrs |
130 | |
131 | /*! \brief Attributes used in MXNet-style reshape_like operators */ |
132 | struct ReshapeLikeAttrs : public tvm::AttrsNode<ReshapeLikeAttrs> { |
133 | int lhs_begin; |
134 | Integer lhs_end; // can be None |
135 | int rhs_begin; |
136 | Integer rhs_end; // can be None |
137 | TVM_DECLARE_ATTRS(ReshapeLikeAttrs, "relay.attrs.ReshapeLikeAttrs" ) { |
138 | TVM_ATTR_FIELD(lhs_begin).set_default(0).describe( |
139 | "The axis of the input where reshaping should begin." ); |
140 | TVM_ATTR_FIELD(lhs_end) |
141 | .set_default(NullValue<Integer>()) |
142 | .describe("The axis of the input where reshaping should end, exclusive." ); |
143 | TVM_ATTR_FIELD(rhs_begin).set_default(0).describe( |
144 | "The axis of the shape_like tensor to begin taking dimensions from." ); |
145 | TVM_ATTR_FIELD(rhs_end) |
146 | .set_default(NullValue<Integer>()) |
147 | .describe("The axis of the shape_like tensor to end taking dimensions from, exclusive." ); |
148 | } |
149 | }; // struct ReshapeLikeAttrs |
150 | |
151 | struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> { |
152 | Integer axis; |
153 | |
154 | TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs" ) { |
155 | TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values." ); |
156 | } |
157 | }; |
158 | |
159 | struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> { |
160 | Integer axis; |
161 | |
162 | TVM_DECLARE_ATTRS(ScatterAddAttrs, "relay.attrs.ScatterAddAttrs" ) { |
163 | TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values." ); |
164 | } |
165 | }; |
166 | |
167 | struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> { |
168 | String mode; |
169 | |
170 | TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs" ) { |
171 | TVM_ATTR_FIELD(mode).describe( |
172 | "Accumulation mode of the scatter, either \"update\" or \"add\"." ); |
173 | } |
174 | }; |
175 | |
176 | struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> { |
177 | Integer axis; |
178 | |
179 | TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs" ) { |
180 | TVM_ATTR_FIELD(axis) |
181 | .set_default(NullValue<Integer>()) |
182 | .describe("The axis over which to select values." ); |
183 | } |
184 | }; |
185 | |
186 | struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> { |
187 | Integer batch_dims; |
188 | Optional<Integer> index_rank; |
189 | |
190 | TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs" ) { |
191 | TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions." ); |
192 | TVM_ATTR_FIELD(index_rank) |
193 | .set_default(NullValue<Integer>()) |
194 | .describe( |
195 | "The size of an indexing tuple, which is a fixed value. Only needed when the number of " |
196 | "indexting tuples is dynamic." ); |
197 | } |
198 | }; |
199 | |
200 | struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> { |
201 | Integer batch_dims; |
202 | Integer axis; |
203 | tvm::String mode; |
204 | |
205 | TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs" ) { |
206 | TVM_ATTR_FIELD(batch_dims) |
207 | .set_default(0) |
208 | .describe("The batch_dims over which to select values." ); |
209 | TVM_ATTR_FIELD(axis) |
210 | .set_default(NullValue<Integer>()) |
211 | .describe("The axis over which to select values." ); |
212 | TVM_ATTR_FIELD(mode).set_default("clip" ).describe( |
213 | "Specify how out-of-bound indices will behave." |
214 | "clip - clip to the range (default)" |
215 | "wrap - wrap around the indices" |
216 | "fast - no clip or wrap around (user must make sure indices are in-bound)" ); |
217 | } |
218 | }; |
219 | |
220 | /*! \brief Attributes that specify a tensor */ |
221 | struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { |
222 | Optional<Array<Integer>> shape; |
223 | DataType dtype; |
224 | |
225 | TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs" ) { |
226 | TVM_ATTR_FIELD(shape).describe("Target shape." ); |
227 | TVM_ATTR_FIELD(dtype).describe("Target data type." ).set_default(NullValue<DataType>()); |
228 | } |
229 | }; // struct InitOpAttrs |
230 | |
231 | /*! \brief Attributes used in arange operators */ |
232 | struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> { |
233 | Expr start; |
234 | Expr stop; |
235 | Expr step; |
236 | DataType dtype; |
237 | |
238 | TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs" ) { |
239 | TVM_ATTR_FIELD(start).describe("Start of interval. The interval includes this value." ); |
240 | TVM_ATTR_FIELD(stop).describe("Stop of interval. The interval does not include this value." ); |
241 | TVM_ATTR_FIELD(step).describe("Spacing between values." ); |
242 | TVM_ATTR_FIELD(dtype).describe("Target data type." ); |
243 | } |
244 | }; // struct ArangeAttrs |
245 | |
246 | /*! \brief Attributes used in meshgrid operators */ |
247 | struct MeshgridAttrs : public tvm::AttrsNode<MeshgridAttrs> { |
248 | std::string indexing; |
249 | |
250 | TVM_DECLARE_ATTRS(MeshgridAttrs, "relay.attrs.MeshgridAttrs" ) { |
251 | TVM_ATTR_FIELD(indexing) |
252 | .describe( |
253 | "Indexing mode, either \"ij\" for matrix or \"xy\" for cartesian in which first two" |
254 | "dimensions are swapped." ) |
255 | .set_default("ij" ); |
256 | } |
257 | }; // struct MeshgridAttrs |
258 | |
259 | /*! \brief Attributes used in stack operators */ |
260 | struct StackAttrs : public tvm::AttrsNode<StackAttrs> { |
261 | Integer axis; |
262 | TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs" ) { |
263 | TVM_ATTR_FIELD(axis).set_default(0).describe( |
264 | "The axis in the result array along which the input arrays are stacked." ); |
265 | } |
266 | }; // struct StackAttrs |
267 | |
268 | /*! \brief Attributes used in repeat operators */ |
269 | struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> { |
270 | Integer repeats; |
271 | Integer axis; |
272 | TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs" ) { |
273 | TVM_ATTR_FIELD(repeats).describe("The number of repetitions for each element." ); |
274 | TVM_ATTR_FIELD(axis) |
275 | .set_default(NullValue<Integer>()) |
276 | .describe(" The axis along which to repeat values." ); |
277 | } |
278 | }; // struct RepeatAttrs |
279 | |
280 | /*! \brief Attributes used in tile operators */ |
281 | struct TileAttrs : public tvm::AttrsNode<TileAttrs> { |
282 | Array<Integer> reps; |
283 | TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs" ) { |
284 | TVM_ATTR_FIELD(reps).describe( |
285 | "The number of times for repeating the tensor a." |
286 | "Each dim sizeof reps must be a positive integer." ); |
287 | } |
288 | }; // struct TileAttrs |
289 | |
290 | /*! \brief Attributes used in reverse operators */ |
291 | struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> { |
292 | Integer axis; |
293 | TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs" ) { |
294 | TVM_ATTR_FIELD(axis) |
295 | .set_default(NullValue<Integer>()) |
296 | .describe("The axis along which to reverse elements." ); |
297 | } |
298 | }; // struct ReverseAttrs |
299 | |
300 | /*! \brief Attributes used in reverse_sequence operators */ |
301 | struct ReverseSequenceAttrs : public tvm::AttrsNode<ReverseSequenceAttrs> { |
302 | Integer seq_axis; |
303 | Integer batch_axis; |
304 | |
305 | TVM_DECLARE_ATTRS(ReverseSequenceAttrs, "relay.attrs.ReverseSequenceAttrs" ) { |
306 | TVM_ATTR_FIELD(seq_axis).set_default(1).describe( |
307 | "The seq axis along which to reverse elements." ); |
308 | TVM_ATTR_FIELD(batch_axis) |
309 | .set_default(0) |
310 | .describe("The batch axis along which to slice the tensor." ); |
311 | } |
312 | }; // struct ReverseSequenceAttrs |
313 | |
314 | /*! \brief Attributes used in squeeze operators */ |
315 | struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> { |
316 | // use axis to make the name numpy compatible. |
317 | Array<Integer> axis; |
318 | |
319 | TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs" ) { |
320 | TVM_ATTR_FIELD(axis) |
321 | .describe( |
322 | "The axis to squeeze in the input tensor." |
323 | "If `axis = None`, all axis of dimension 1 get squeezed;" |
324 | "Else, the dimension in axes get squeezed." |
325 | "It is an error if an axis does not has dimension 1." ) |
326 | .set_default(NullValue<Array<Integer>>()); |
327 | } |
328 | }; // struct SqueezeAttrs |
329 | |
330 | struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> { |
331 | ObjectRef indices_or_sections; |
332 | int axis; |
333 | |
334 | TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs" ) { |
335 | TVM_ATTR_FIELD(indices_or_sections) |
336 | .describe( |
337 | "Indices or sections to split into. Accepts an int or a tuple" |
338 | "If indices_or_sections is an integer, the input will be divided equally" |
339 | "along given axis. If such a split is not possible, an error is raised." |
340 | "If indices_or_sections is a tuple of sorted integers," |
341 | "the entries indicate where along axis the array is split." ); |
342 | TVM_ATTR_FIELD(axis).set_default(0).describe("the axis to be splitted." ); |
343 | } |
344 | }; |
345 | |
346 | /*! \brief Attributes for StridedSlice operator */ |
347 | struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { |
348 | Optional<Array<Integer>> begin; |
349 | Optional<Array<Integer>> end; |
350 | Optional<Array<Integer>> strides; |
351 | tvm::String slice_mode; |
352 | Optional<Array<Integer>> axes; |
353 | |
354 | TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs" ) { |
355 | TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive" ); |
356 | TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive" ); |
357 | TVM_ATTR_FIELD(strides).describe( |
358 | "Stride values of the slice, a stride can be negative, which causes a reverse slice." ); |
359 | TVM_ATTR_FIELD(slice_mode) |
360 | .set_default("end" ) |
361 | .describe( |
362 | "The slice mode [end, size]." |
363 | "end - The default slice mode, ending indices for the slice." |
364 | "size - The input strides will be ignored, input end in this mode indicates the size" |
365 | "of a slice starting at the location specified by begin. If end[i] is -1," |
366 | "all remaining elements in that dimension are included in the slice" ); |
367 | TVM_ATTR_FIELD(axes).describe( |
368 | "Axes along which slicing is applied. When it is specified, the length of begin, end, " |
369 | "strides, and axes must be equal." ); |
370 | } |
371 | }; |
372 | |
373 | struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> { |
374 | Array<Integer> axes; |
375 | |
376 | TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs" ) { |
377 | TVM_ATTR_FIELD(axes).describe( |
378 | "List of axes on which input data will be sliced according to the " |
379 | "corresponding size of the second input. By default will slice " |
380 | "on all axes. Negative axes mean counting in reverse." ); |
381 | } |
382 | }; |
383 | |
384 | /*! \brief Attributes for Clip operator */ |
385 | struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { |
386 | double a_min; |
387 | double a_max; |
388 | |
389 | TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs" ) { |
390 | TVM_ATTR_FIELD(a_min).describe("The minimum clip value." ); |
391 | TVM_ATTR_FIELD(a_max).describe("The maximum clip value." ); |
392 | } |
393 | }; |
394 | |
395 | /*! \brief Attributes for FixedPointMultiply operator */ |
396 | struct FixedPointMultiplyAttrs : public tvm::AttrsNode<FixedPointMultiplyAttrs> { |
397 | int32_t multiplier; |
398 | int32_t shift; |
399 | |
400 | TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs" ) { |
401 | TVM_ATTR_FIELD(multiplier) |
402 | .describe("Multiplier of a fixed floating point number described as multiplier*2^(shift)" ); |
403 | TVM_ATTR_FIELD(shift).describe( |
404 | "Shift of a fixed floating point number described as multiplier*2^(shift)" ); |
405 | } |
406 | }; |
407 | |
408 | /*! \brief Attributes for per channel/per axes FixedPointMultiply operator */ |
409 | struct FixedPointMultiplyPerAxisAttrs : public tvm::AttrsNode<FixedPointMultiplyPerAxisAttrs> { |
410 | bool is_lshift_required; |
411 | bool is_rshift_required; |
412 | Array<Integer> axes; |
413 | |
414 | TVM_DECLARE_ATTRS(FixedPointMultiplyPerAxisAttrs, "relay.attrs.FixedPointMultiplyPerAxisAttrs" ) { |
415 | TVM_ATTR_FIELD(is_lshift_required) |
416 | .describe("Whether left shift is required in fixed point multiplication." ) |
417 | .set_default(false); |
418 | TVM_ATTR_FIELD(is_rshift_required) |
419 | .describe("Whether right shift is required in fixed point multiplication." ) |
420 | .set_default(false); |
421 | TVM_ATTR_FIELD(axes).describe("List of axes on which input data was quantized." ); |
422 | } |
423 | }; |
424 | |
425 | /*! \brief Attributes for LayoutTransform operator */ |
426 | struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> { |
427 | std::string src_layout; |
428 | std::string dst_layout; |
429 | |
430 | TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs" ) { |
431 | TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. NCHW)" ); |
432 | TVM_ATTR_FIELD(dst_layout).describe("The destination layout of the tensor. (e.g. NCHW16c)" ); |
433 | } |
434 | }; |
435 | |
436 | /*! \brief Attributes for AutoSchedulerLayoutTransform operator */ |
437 | struct AutoSchedulerLayoutTransformAttrs |
438 | : public tvm::AttrsNode<AutoSchedulerLayoutTransformAttrs> { |
439 | std::string src_layout; |
440 | std::string dst_layout; |
441 | |
442 | TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs, |
443 | "relay.attrs.AutoSchedulerLayoutTransformAttrs" ) { |
444 | TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. 1N32C112H112W)" ); |
445 | TVM_ATTR_FIELD(dst_layout) |
446 | .describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)" ); |
447 | } |
448 | }; |
449 | |
450 | /*! \brief Attributes for MetaScheduleLayoutTransform operator */ |
451 | struct MetaScheduleLayoutTransformAttrs : public tvm::AttrsNode<MetaScheduleLayoutTransformAttrs> { |
452 | tir::IndexMap index_map; |
453 | |
454 | TVM_DECLARE_ATTRS(MetaScheduleLayoutTransformAttrs, |
455 | "relay.attrs.MetaScheduleLayoutTransformAttrs" ) { |
456 | TVM_ATTR_FIELD(index_map).describe( |
457 | "The order of the extents, for example, " |
458 | "let extents = [2, 3, 4], reorder = [0, 2, 1], and the shape of buffer A is (4, 6)" |
459 | "then A[i, j] will be first rewritten to " |
460 | "A[(6 * i + j) / 12, (6 * i + j) / 4 % 3 , (6 * i + j) % 4] according to the `extents`," |
461 | "and then reordered to A[(6 * i + j) / 12, (6 * i + j) % 4 , (6 * i + j) / 4 % 3]" |
462 | "according to `reorder`" ); |
463 | } |
464 | }; |
465 | |
466 | /*! \brief Attributes for ShapeOf operator */ |
467 | struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> { |
468 | DataType dtype; |
469 | |
470 | TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs" ) { |
471 | TVM_ATTR_FIELD(dtype).describe("Target data type" ).set_default(NullValue<DataType>()); |
472 | } |
473 | }; |
474 | |
475 | struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> { |
476 | double mask_value; |
477 | int axis; |
478 | |
479 | TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs" ) { |
480 | TVM_ATTR_FIELD(mask_value).set_default(0).describe("The masking value." ); |
481 | TVM_ATTR_FIELD(axis).set_default(0).describe( |
482 | "The axis of the length dimension. Can only be 0 or 1." ); |
483 | } |
484 | }; // struct SequenceMaskAttrs. |
485 | |
486 | /*! \brief Attributes used in sparse_to_dense operator */ |
487 | struct SparseToDenseAttrs : public tvm::AttrsNode<SparseToDenseAttrs> { |
488 | Array<Integer> output_shape; |
489 | |
490 | TVM_DECLARE_ATTRS(SparseToDenseAttrs, "relay.attrs.SparseToDenseAttrs" ) { |
491 | TVM_ATTR_FIELD(output_shape).describe("Shape of the dense output tensor" ); |
492 | } |
493 | }; // struct SparseToDenseAttrs |
494 | |
495 | /*! \brief Attributes for ndarray_size operator */ |
496 | struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> { |
497 | DataType dtype; |
498 | |
499 | TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs" ) { |
500 | TVM_ATTR_FIELD(dtype).describe("Target data type" ).set_default(NullValue<DataType>()); |
501 | } |
502 | }; |
503 | |
504 | /*! \brief Attributes used in one-hot operator */ |
505 | struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> { |
506 | int depth; |
507 | int axis; |
508 | DataType dtype; |
509 | |
510 | TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs" ) { |
511 | TVM_ATTR_FIELD(depth).set_default(1).describe("Depth of the one hot dimension." ); |
512 | TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill." ); |
513 | TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>()).describe("Output data type." ); |
514 | } |
515 | }; // struct OneHotAttrs |
516 | |
517 | /*! \brief Attributes used in matrix_set_diag operator */ |
518 | struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> { |
519 | int k1; |
520 | int k2; |
521 | bool super_diag_right_align; |
522 | bool sub_diag_right_align; |
523 | |
524 | TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs" ) { |
525 | TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals." ); |
526 | TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals." ); |
527 | TVM_ATTR_FIELD(super_diag_right_align) |
528 | .set_default(true) |
529 | .describe("Bool, true iff super-diagonal is right aligned (left-padded)." ); |
530 | TVM_ATTR_FIELD(sub_diag_right_align) |
531 | .set_default(false) |
532 | .describe("Bool, true iff sub-diagonal is right aligned (left-padded)." ); |
533 | } |
534 | }; // struct MatrixSetDiagAttrs |
535 | |
536 | /*! \brief Attributes used in cumsum and cumprod operator */ |
537 | struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> { |
538 | Integer axis; |
539 | DataType dtype; |
540 | Bool exclusive = Bool(false); |
541 | TVM_DECLARE_ATTRS(ScanopAttrs, "relay.attrs.ScanopAttrs" ) { |
542 | TVM_ATTR_FIELD(axis).describe("The axis to operate over" ).set_default(NullValue<Integer>()); |
543 | TVM_ATTR_FIELD(dtype).describe("Output data type" ).set_default(NullValue<DataType>()); |
544 | |
545 | // Default is 0 which is "false" |
546 | TVM_ATTR_FIELD(exclusive) |
547 | .describe("The first element is not included" ) |
548 | .set_default(Bool(false)); |
549 | } |
550 | }; // struct ScanopAttrs |
551 | |
552 | /*! \brief Attributes used in unique operator */ |
553 | struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> { |
554 | bool sorted; |
555 | bool return_counts; |
556 | TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs" ) { |
557 | TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted" ).set_default(true); |
558 | TVM_ATTR_FIELD(return_counts) |
559 | .describe("Whether to return an additional tensor with counts of each unique elements" ) |
560 | .set_default(false); |
561 | } |
562 | }; // struct UniqueAttrs |
563 | |
564 | /*! \brief Attributes used in einsum operator */ |
565 | struct EinsumAttrs : public tvm::AttrsNode<EinsumAttrs> { |
566 | String equation; |
567 | |
568 | TVM_DECLARE_ATTRS(EinsumAttrs, "relay.attrs.EinsumAttrs" ) { |
569 | TVM_ATTR_FIELD(equation).describe("The einsum expression string" ); |
570 | } |
571 | }; // struct EinsumAttrs |
572 | |
573 | /*! \brief Attributes used in stft operator */ |
574 | struct StftAttrs : public tvm::AttrsNode<StftAttrs> { |
575 | int n_fft; |
576 | int hop_length; |
577 | int win_length; |
578 | bool normalized; |
579 | bool onesided; |
580 | |
581 | TVM_DECLARE_ATTRS(StftAttrs, "relay.attrs.StftAttrs" ) { |
582 | TVM_ATTR_FIELD(n_fft).set_default(-1).describe("The size of Fourier transform" ); |
583 | TVM_ATTR_FIELD(hop_length) |
584 | .set_default(-1) |
585 | .describe("The distance between neighboring sliding window frames" ); |
586 | TVM_ATTR_FIELD(win_length).set_default(-1).describe("The size of window frame and STFT filter" ); |
587 | TVM_ATTR_FIELD(normalized) |
588 | .set_default(false) |
589 | .describe("Whether to return the normalized STFT results" ); |
590 | TVM_ATTR_FIELD(onesided).set_default(true).describe( |
591 | "Whether to return onesided result or fill with conjugate symmetry" ); |
592 | } |
593 | }; // struct StftAttrs |
594 | |
595 | struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> { |
596 | bool upper; |
597 | |
598 | TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs" ) { |
599 | TVM_ATTR_FIELD(upper).set_default(true).describe( |
600 | "Whether to keep the upper or lower half of the diagonal." ); |
601 | } |
602 | }; // struct TriluAttrs |
603 | |
604 | } // namespace relay |
605 | } // namespace tvm |
606 | #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ |
607 | |