1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_BASE_TYPE_H
17#define GLOW_BASE_TYPE_H
18
19#include "DimType.h"
20
21#include "glow/Support/BFloat16.h"
22#include "glow/Support/Compiler.h"
23#include "glow/Support/Float16.h"
24#include "glow/Support/Memory.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/StringRef.h"
28
29#include <glog/logging.h>
30
31#include <cstddef>
32#include <cstdint>
33#include <type_traits>
34#include <utility>
35
36namespace llvm {
37class raw_ostream;
38}
39
40namespace glow {
41
42// UINT8_MIN is not defined in standard headers.
43// Define it here for using these definitions consistently.
44#define UINT8_MIN 0
45
46struct Type;
47
48using TypeRef = const Type *;
49
50constexpr unsigned max_tensor_dimensions = 6;
51
52/// This type is used to implement the Node and Instruction builder's
53/// MemberType::Unsigned and MemberType::VectorUnsigned. Thus it should be used
54/// when handling members of these classes, e.g. a convolution Node/Instr's
55/// getGroup() (Unsigned), or getKernels() (UnsignedVector).
56using unsigned_t = uint32_t;
57
58using float16_t = float16;
59static_assert(sizeof(float16_t) == 2, "Half precision should be 16-bit");
60
61using bfloat16_t = bfloat16;
62static_assert(sizeof(bfloat16_t) == 2, "bfloat16 should be 16-bit");
63
64using ShapeVector = llvm::SmallVector<dim_t, max_tensor_dimensions>;
65
66struct ShapeNHWC {
67
68 enum {
69 DimN,
70 DimH,
71 DimW,
72 DimC,
73 };
74
75 dim_t n; // Number of samples
76 dim_t h; // Height
77 dim_t w; // Width
78 dim_t c; // Number of Channels
79
80 template <typename T> explicit ShapeNHWC(llvm::ArrayRef<T> shape) {
81 assert(shape.size() == 4 && "Invalid shape");
82 n = shape[DimN];
83 h = shape[DimH];
84 w = shape[DimW];
85 c = shape[DimC];
86 }
87
88 ShapeNHWC(dim_t samples, dim_t height, dim_t width, dim_t channels)
89 : n(samples), h(height), w(width), c(channels) {}
90
91 bool equals(const ShapeNHWC &other) const {
92 return n == other.n && h == other.h && w == other.w && c == other.c;
93 }
94};
95
96struct ShapeNTHWC {
97 dim_t n; // Number of samples
98 dim_t t; // Temporal frames
99 dim_t h; // Height
100 dim_t w; // Width
101 dim_t c; // Number of Channels
102
103 template <typename T> explicit ShapeNTHWC(llvm::ArrayRef<T> shape) {
104 assert(shape.size() == 5 && "Invalid shape");
105 n = shape[0];
106 t = shape[1];
107 h = shape[2];
108 w = shape[3];
109 c = shape[4];
110 }
111
112 ShapeNTHWC(dim_t samples, dim_t temporal_frames, dim_t height, dim_t width,
113 dim_t channels)
114 : n(samples), t(temporal_frames), h(height), w(width), c(channels) {}
115
116 bool equals(const ShapeNTHWC &other) const {
117 return n == other.n && t == other.t && h == other.h && w == other.w &&
118 c == other.c;
119 }
120};
121
122struct ShapeNHWTC {
123 dim_t n; // Number of samples
124 dim_t h; // Height
125 dim_t w; // Width
126 dim_t t; // Temporal_frames
127 dim_t c; // Number of Channels
128
129 template <typename T> explicit ShapeNHWTC(llvm::ArrayRef<T> shape) {
130 assert(shape.size() == 5 && "Invalid shape");
131 n = shape[0];
132 h = shape[1];
133 w = shape[2];
134 t = shape[3];
135 c = shape[4];
136 }
137
138 ShapeNHWTC(size_t samples, size_t height, size_t width,
139 size_t temporal_frames, size_t channels)
140 : n(samples), h(height), w(width), t(temporal_frames), c(channels) {}
141
142 bool equals(const ShapeNHWTC &other) const {
143 return n == other.n && h == other.h && w == other.w && t == other.t &&
144 c == other.c;
145 }
146};
147
148struct ShapeNCHW {
149
150 enum {
151 DimN,
152 DimC,
153 DimH,
154 DimW,
155 };
156
157 dim_t n; // Number of samples
158 dim_t c; // Number of Channels
159 dim_t h; // Height
160 dim_t w; // Width
161
162 explicit ShapeNCHW(llvm::ArrayRef<dim_t> shape) {
163 assert(shape.size() == 4 && "Invalid shape");
164 n = shape[DimN];
165 c = shape[DimC];
166 h = shape[DimH];
167 w = shape[DimW];
168 }
169
170 ShapeNCHW(dim_t samples, dim_t channels, dim_t height, dim_t width)
171 : n(samples), c(channels), h(height), w(width) {}
172
173 bool equals(const ShapeNCHW &other) const {
174 return n == other.n && h == other.h && w == other.w && c == other.c;
175 }
176};
177
178struct ShapeNCTHW {
179 dim_t n; // Number of samples
180 dim_t c; // Number of Channels
181 dim_t t; // Temporal frames
182 dim_t h; // Height
183 dim_t w; // Width
184
185 explicit ShapeNCTHW(llvm::ArrayRef<dim_t> shape) {
186 assert(shape.size() == 5 && "Invalid shape");
187 n = shape[0];
188 c = shape[1];
189 t = shape[2];
190 h = shape[3];
191 w = shape[4];
192 }
193
194 ShapeNCTHW(dim_t samples, dim_t channels, dim_t temporal_frames, dim_t height,
195 dim_t width)
196 : n(samples), c(channels), t(temporal_frames), h(height), w(width) {}
197
198 bool equals(const ShapeNCTHW &other) const {
199 return n == other.n && t == other.t && h == other.h && w == other.w &&
200 c == other.c;
201 }
202};
203
204struct PaddingTLBR {
205 dim_t top;
206 dim_t left;
207 dim_t bottom;
208 dim_t right;
209
210 template <typename T> explicit PaddingTLBR(llvm::ArrayRef<T> pads) {
211 assert(pads.size() == 4 && "Invalid padding");
212 top = pads[0];
213 left = pads[1];
214 bottom = pads[2];
215 right = pads[3];
216 }
217
218 bool equalPadding() const {
219 return top == left && top == bottom && top == right;
220 }
221};
222
223struct PaddingTLNBRF {
224 dim_t top;
225 dim_t left;
226 dim_t near;
227 dim_t bottom;
228 dim_t right;
229 dim_t far;
230
231 template <typename T> explicit PaddingTLNBRF(llvm::ArrayRef<T> pads) {
232 assert(pads.size() == 6 && "Invalid padding");
233 top = pads[0];
234 left = pads[1];
235 near = pads[2];
236 bottom = pads[3];
237 right = pads[4];
238 far = pads[5];
239 }
240
241 bool equalPadding() const {
242 return top == left && top == bottom && top == right && top == near &&
243 top == far;
244 }
245};
246
247struct PaddingNFTBLR {
248 dim_t near;
249 dim_t far;
250 dim_t top;
251 dim_t bottom;
252 dim_t left;
253 dim_t right;
254
255 template <typename T> explicit PaddingNFTBLR(llvm::ArrayRef<T> pads) {
256 assert(pads.size() == 6 && "Invalid padding");
257 near = pads[0];
258 far = pads[1];
259 top = pads[2];
260 bottom = pads[3];
261 left = pads[4];
262 right = pads[5];
263 }
264
265 bool equalPadding() const {
266 return top == left && top == bottom && top == right && top == near &&
267 top == far;
268 }
269};
270
271struct ShapeHW {
272
273 enum {
274 DimH,
275 DimW,
276 };
277
278 dim_t height;
279 dim_t width;
280
281 template <typename T> explicit ShapeHW(llvm::ArrayRef<T> shape) {
282 assert(shape.size() == 2 && "Invalid shape");
283 height = shape[DimH];
284 width = shape[DimW];
285 }
286
287 bool isSquare() const { return height == width; }
288};
289
290struct ShapeNHW {
291
292 enum {
293 DimN,
294 DimH,
295 DimW,
296 };
297
298 dim_t n; // Number of samples
299 dim_t h; // Height
300 dim_t w; // Width
301
302 template <typename T> explicit ShapeNHW(llvm::ArrayRef<T> shape) {
303 assert(shape.size() == 3 && "Invalid shape");
304 n = shape[DimN];
305 h = shape[DimH];
306 w = shape[DimW];
307 }
308
309 bool isSquare() const { return h == w; }
310};
311
312struct ShapeHWT {
313 dim_t height;
314 dim_t width;
315 dim_t temporal_frames;
316
317 template <typename T> explicit ShapeHWT(llvm::ArrayRef<T> shape) {
318 assert(shape.size() == 3 && "Invalid shape");
319 height = shape[0];
320 width = shape[1];
321 temporal_frames = shape[2];
322 }
323
324 bool isCube() const { return height == width && height == temporal_frames; }
325};
326
327struct ShapeTHW {
328 dim_t temporal_frames;
329 dim_t height;
330 dim_t width;
331
332 template <typename T> explicit ShapeTHW(llvm::ArrayRef<T> shape) {
333 assert(shape.size() == 3 && "Invalid shape");
334 temporal_frames = shape[0];
335 height = shape[1];
336 width = shape[2];
337 }
338
339 bool isCube() const { return height == width && height == temporal_frames; }
340};
341
342/// Collapse a tensor shape into two sizes: the first n dimensions and the size
343/// of the rest of the dimensions. For example, ([7, 3, 4, 2], 1) -> [7, 24]
344inline std::pair<dim_t, dim_t> flattenCdr(llvm::ArrayRef<dim_t> dims,
345 unsigned_t n = 1) {
346 assert(1 <= n && n <= dims.size());
347 size_t first = dims[0];
348 for (unsigned_t i = 1; i < n; i++) {
349 first *= dims[i];
350 }
351 size_t rest = 1;
352 for (unsigned_t i = n; i < dims.size(); i++) {
353 rest *= dims[i];
354 }
355
356 return {first, rest};
357}
358
359/// Collapse a tensor shape into two sizes: the first will be
360/// size of n without the axis dimension, and second will be
361/// size of the axis dimension. For example, ([7, 3, 4, 2], 1) -> [56, 3]
362inline std::pair<dim_t, dim_t> collapseShape(llvm::ArrayRef<dim_t> dims,
363 unsigned_t n = 1) {
364 assert(1 <= n && n <= dims.size());
365 size_t first = 1;
366 size_t second = 1;
367 for (unsigned_t i = 0; i < dims.size(); i++) {
368 if (i == n) {
369 second = dims[i];
370 } else {
371 first *= dims[i];
372 }
373 }
374 return {first, second};
375}
376
377inline bool operator==(const ShapeNHWC &LHS, const ShapeNHWC &RHS) {
378 return LHS.equals(RHS);
379}
380
381inline bool operator==(const ShapeNCHW &LHS, const ShapeNCHW &RHS) {
382 return LHS.equals(RHS);
383}
384
385inline bool operator==(const ShapeNHWTC &LHS, const ShapeNHWTC &RHS) {
386 return LHS.equals(RHS);
387}
388
389inline bool operator==(const ShapeNTHWC &LHS, const ShapeNTHWC &RHS) {
390 return LHS.equals(RHS);
391}
392
393inline bool operator==(const ShapeNCTHW &LHS, const ShapeNCTHW &RHS) {
394 return LHS.equals(RHS);
395}
396
397/// An enum representing the type used by the elements of a tensor. The types of
398/// Handles for these tensors should match the element kind.
399/// When adding new type, note that this enum definition must match with
400/// ElemKind definition in Glow/lib/Backends/CPU/libjit/libjit.cpp
401enum class ElemKind : unsigned char {
402 // 32-bit float type (float)
403 FloatTy,
404 // 16-bit float type (half, fp16)
405 Float16Ty,
406 // 16-bit float type (bfloat16)
407 BFloat16Ty,
408 // 64-bit float type (double)
409 Float64Ty,
410 // 8-bit quantized type (int8_t)
411 Int8QTy,
412 // unsigned 8-bit quantized type (uint8_t)
413 UInt8QTy,
414 // 16-bit quantized type (int16_t)
415 Int16QTy,
416 // 32-bit quantized type (int32_t)
417 Int32QTy,
418 // 8-bit index type (uint8_t)
419 UInt8ITy,
420 // 32-bit index type (int32_t)
421 Int32ITy,
422 // 64-bit index type (int64_t)
423 Int64ITy,
424 // 8-bit quantized type with fused scale/offset (uint8_t)
425 UInt8FusedQTy,
426 // 8-bit quantized type with fused FP16 scale/offset (uint8_t)
427 UInt8FusedFP16QTy,
428 // 4-bit quantized type with fused FP16 scale/offset (uint8_t, each byte
429 // represents 2 4-bit quantized data)
430 UInt4FusedFP16QTy,
431 // 4-bit quantized type with fused FP32 scale/offset (uint8_t, each byte
432 // represents 2 4-bit quantized data)
433 UInt4FusedQTy,
434 // Bool type (bool)
435 BoolTy,
436 // 64-bit quantized type (int64_t, partial support)
437 Int64QTy,
438};
439
440/// \returns whether \p e is a quantized ElemKind.
441inline bool isQuantizedElemKind(ElemKind e) {
442 return e == ElemKind::Int8QTy || e == ElemKind::UInt8QTy ||
443 e == ElemKind::Int16QTy || e == ElemKind::Int32QTy ||
444 e == ElemKind::Int64QTy || e == ElemKind::UInt8FusedQTy ||
445 e == ElemKind::UInt8FusedFP16QTy || e == ElemKind::UInt4FusedFP16QTy ||
446 e == ElemKind::UInt4FusedQTy;
447}
448
449/// \returns whether \p e is a float ElemKind.
450inline bool isFloatElemKind(ElemKind e) {
451 return e == ElemKind::FloatTy || e == ElemKind::Float16Ty ||
452 e == ElemKind::BFloat16Ty || e == ElemKind::Float64Ty;
453}
454
455/// \returns whether \p e is a non-quantized integer ElemKind.
456inline bool isNonQuantizedIntElemKind(ElemKind e) {
457 return e == ElemKind::Int32ITy || e == ElemKind::Int64ITy;
458}
459
460/// \returns whether \p e is a fused quantized ElemKind.
461inline bool isFusedQuantizedElemKind(ElemKind e) {
462 return e == ElemKind::UInt8FusedQTy || e == ElemKind::UInt8FusedFP16QTy ||
463 e == ElemKind::UInt4FusedFP16QTy || e == ElemKind::UInt4FusedQTy;
464}
465
466/// \returns the scale and offset ElemKind used by the fused ElemKind \p e.
467inline ElemKind getScaleOffsetElemKindFromFused(ElemKind e) {
468 assert(isFusedQuantizedElemKind(e) && "Must pass Fused ElemKind.");
469 if (e == ElemKind::UInt8FusedQTy || e == ElemKind::UInt4FusedQTy) {
470 return ElemKind::FloatTy;
471 }
472 return ElemKind::Float16Ty;
473}
474
475/// \returns the floating point value range that covers a quantized type (min
476/// first, max second) given \p scale, \p offset, and \p elementType.
477std::pair<float, float> getQuantizedValueRange(float scale, int32_t offset,
478 ElemKind elementType);
479
480/// A class that represents a type of a tensor.
481struct Type final {
482 /// Contains the dimensions (sizes) of the tensor. Ex: [sx, sy, sz, ...].
483 dim_t sizes_[max_tensor_dimensions] = {
484 0,
485 };
486 /// Contains the strides for each dimension (in elements). The order should be
487 /// the same as in sizes_. In more details, suppose that the tensor is laid
488 /// out flat in memory, and some dimensions are aligned. strides_[i] is the
489 /// number of elements that needs to be skipped in order to reach the next
490 /// plane in the i-th dimension. For example, if the tensor has dimensions
491 /// [3, 5, 10] and alignments [3, 32, 1], the strides will be [162, 32, 1].
492 dim_t strides_[max_tensor_dimensions] = {
493 0,
494 };
495
496 /// Contains the number of dimensions used by the tensor.
497 unsigned char numSizes_{0};
498
499 /// On quantized tensors, this represents the scale of the values.
500 float scale_{0};
501 /// On quantized tensors, this represents the offset of the values.
502 int32_t offset_{0};
503
504 /// Specifies the element type of the tensor.
505 ElemKind elementType_{ElemKind::Int64ITy};
506
507 /// Initialize a new quantized type with \p scale and \p offset.
508 Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, float scale, int32_t offset)
509 : scale_(scale), offset_(offset), elementType_(elemTy) {
510 assert(isQuantizedType() && "Only quantized types have a scale and offset");
511 ShapeVector alignments(dims.size(), 1);
512 initDims(dims, llvm::makeArrayRef(alignments));
513 }
514
515 /// Initialize a new non-quantized type.
516 Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims) : elementType_(elemTy) {
517 assert(!isQuantizedType() &&
518 "Can't initialize quantized types without scale and offset");
519 ShapeVector alignments(dims.size(), 1);
520 initDims(dims, llvm::makeArrayRef(alignments));
521 }
522
523 /// Initialize a new quantized type with \p scale and \p offset.
524 Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
525 llvm::ArrayRef<dim_t> alignments, float scale, int32_t offset)
526 : scale_(scale), offset_(offset), elementType_(elemTy) {
527 assert(isQuantizedType() && "Only quantized types have a scale and offset");
528 initDims(dims, alignments);
529 }
530
531 /// Initialize a new non-quantized type.
532 Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
533 llvm::ArrayRef<dim_t> alignments)
534 : elementType_(elemTy) {
535 assert(!isQuantizedType() &&
536 "Can't initialize quantized types without scale and offset");
537 initDims(dims, alignments);
538 }
539
540 /// Reshape existing type. This method takes care of quantized types.
541 static Type newShape(const Type &T, llvm::ArrayRef<dim_t> dims) {
542 if (T.isQuantizedType()) {
543 return Type(T.getElementType(), dims, T.getScale(), T.getOffset());
544 } else {
545 return Type(T.getElementType(), dims);
546 }
547 }
548
549 /// Reshape existing type and change alignments.
550 static Type newShape(const Type &T, llvm::ArrayRef<dim_t> dims,
551 llvm::ArrayRef<dim_t> alignments) {
552 if (T.isQuantizedType()) {
553 return Type(T.getElementType(), dims, alignments, T.getScale(),
554 T.getOffset());
555 } else {
556 return Type(T.getElementType(), dims, alignments);
557 }
558 }
559
560 /// Reshape existing type by taking shapes and strides of \p shapeType.
561 static Type newShape(const Type &T, TypeRef shapeType) {
562 Type ty;
563 if (T.isQuantizedType()) {
564 ty = Type(T.getElementType(), shapeType->dims(), T.getScale(),
565 T.getOffset());
566 } else {
567 ty = Type(T.getElementType(), shapeType->dims());
568 }
569 // Copy the stride information.
570 std::copy(&shapeType->strides_[0], &shapeType->strides_[ty.numSizes_],
571 ty.strides_);
572 return ty;
573 }
574
575 /// Reshape existing type \p T by taking shapes and using the provided \p
576 /// strides.
577 static Type newStrides(const Type &T, llvm::ArrayRef<dim_t> strides) {
578 assert(strides.size() == T.strides().size());
579 Type ty = T;
580 // Copy the stride information.
581 std::copy(&strides[0], &strides[0] + ty.numSizes_, ty.strides_);
582 return ty;
583 }
584
585 /// Reshape existing type. This method takes care of quantized types.
586 static Type newQuantparams(const Type &T, float scale, int32_t offset) {
587 assert(T.isQuantizedType() &&
588 "Can't modify scale and offset of non quantized types");
589 return Type(T.getElementType(), T.dims(), scale, offset);
590 }
591
592 /// \returns true if a type has standard strides and no special alignment
593 /// requirements.
594 bool hasStandardStrides() const {
595 if (numSizes_ > 0) {
596 // Stride of the last dimension is always 1.
597 assert(strides_[numSizes_ - 1] == 1 &&
598 "Last dimension must always be aligned.");
599 }
600 for (int i = numSizes_ - 2; i >= 0; i--) {
601 // All the strides (except for last one) depend on the previous dimension.
602 // For standard strides the following should be true:
603 // strides_[i] == sizes_[i + 1] * strides_[i + 1]
604 if (strides_[i] != sizes_[i + 1] * strides_[i + 1]) {
605 return false;
606 }
607 }
608 return true;
609 }
610
611 /// An empty type.
612 Type() = default;
613
614 /// \returns true if \p other is the same type.
615 bool isEqual(TypeRef other) const { return isEqual(*other); }
616
617 /// \returns the scale of a quantized type.
618 float getScale() const {
619 assert(isQuantizedType() && "Can't get the scale of a non-quantized type");
620 return scale_;
621 }
622
623 /// \returns the offset of a quantized type.
624 int32_t getOffset() const {
625 assert(isQuantizedType() && "Can't get the offset of a non-quantized type");
626 return offset_;
627 }
628
629 /// \returns the floating point value range that covers a quantized type (min
630 /// first, max second).
631 std::pair<float, float> getQuantizedValueRange() const {
632 return ::glow::getQuantizedValueRange(scale_, offset_, elementType_);
633 }
634
635 /// \returns the number of values associated to the quantized type (e.g. 256
636 /// for Int8QTy).
637 size_t getQuantizedValueCount() const {
638 assert(getElementSize() < sizeof(size_t) &&
639 "Cannot retrieve quantized value count with size_t!");
640 double numBits = getElementSize() * 8;
641 return static_cast<size_t>(std::exp2(numBits));
642 }
643
644 /// \returns the floating point value step associated to the quantized type.
645 /// The quantization step is actually equal to the quantization scale.
646 float getQuantizedValueStep() const {
647 assert(isQuantizedType() &&
648 "Can't get the quantized value step of a non-quantized type");
649 return scale_;
650 }
651
652 /// \returns true if \p other is the same type. If \p allowDifferentShape then
653 /// shapes will not be considered as part of the equal comparison. If \p
654 /// allowDifferentScaleOffset is true, scale and offset will not be considered
655 /// as part of the equal comparison.
656 bool isEqual(const Type &other, bool allowDifferentShape = false,
657 bool allowDifferentStrides = false,
658 bool allowDifferentScaleOffset = false) const {
659 // Element type must be the same.
660 if (elementType_ != other.elementType_) {
661 return false;
662 }
663 // Must have the same number of sizes.
664 if (numSizes_ != other.numSizes_) {
665 return false;
666 }
667 // Sizes must be the same.
668 if (!allowDifferentShape) {
669 for (size_t i = 0; i < numSizes_; i++) {
670 if (sizes_[i] != other.sizes_[i]) {
671 return false;
672 }
673 }
674 if (!allowDifferentStrides) {
675 // Strides must be the same.
676 for (size_t i = 0; i < numSizes_; i++) {
677 if (strides_[i] != other.strides_[i]) {
678 return false;
679 }
680 }
681 }
682 }
683
684 // Compare the scale and offset of integers. Fused types use dummy
685 // scale/offset, so can ignore them.
686 if (isQuantizedType() && !isFusedQuantizedType() &&
687 !allowDifferentScaleOffset) {
688 if (scale_ != other.scale_ || offset_ != other.offset_) {
689 return false;
690 }
691 }
692
693 return true;
694 }
695
696 /// \returns a hash value for this Type. Hashes for Ty1 and Ty2 are equal if
697 /// Ty1.isEqual(Ty2).
698 llvm::hash_code equals_hash() const {
699 return llvm::hash_combine(
700 elementType_, dims(),
701 // hashing floats is tricky, fall back to std::hash
702 std::hash<float>{}(scale_), offset_);
703 }
704
705 ElemKind getElementType() const { return elementType_; }
706
707 /// \returns the shape of the tensor.
708 llvm::ArrayRef<dim_t> dims() const { return {sizes_, numSizes_}; }
709
710 /// \returns the strides of the tensor.
711 llvm::ArrayRef<dim_t> strides() const { return {strides_, numSizes_}; }
712
713 /// \returns the number of elements in the tensor.
714 dim_t size() const {
715 dim_t s = 1;
716 for (unsigned char i = 0; i < numSizes_; i++) {
717 s *= dim_t(sizes_[i]);
718 }
719
720 return s;
721 }
722
723 /// \returns the number of elements in a slice in the tensor. Calculate the
724 /// size of the slice starting at \p startDim. For example, the tensor with
725 /// the shape [10, 10, 3] and startDim 1 would have the size 30, because this
726 /// is the size of the slice [10, 3] that starts at index 1.
727 dim_t getSliceSize(unsigned char startDim) const {
728 assert(startDim <= numSizes_ && "Invalid start dim");
729 dim_t s = 1;
730 for (unsigned char i = startDim; i < numSizes_; i++) {
731 s *= dim_t(sizes_[i]);
732 }
733 return s;
734 }
735
736 /// \returns true if the templated parameter \p ElemTy matches this type.
737 template <class ElemTy> bool isType() const {
738 return isType<ElemTy>(elementType_);
739 }
740
741 /// \returns true if the templated parameter \p ElemTy matches the type that's
742 /// specified by the parameter \p Ty.
743 template <class ElemTy> static bool isType(ElemKind Ty) {
744 switch (Ty) {
745 case ElemKind::FloatTy:
746 return std::is_same<ElemTy, float>::value;
747 case ElemKind::Float16Ty:
748 return std::is_same<ElemTy, float16_t>::value;
749 case ElemKind::BFloat16Ty:
750 return std::is_same<ElemTy, bfloat16_t>::value;
751 case ElemKind::Float64Ty:
752 return std::is_same<ElemTy, double>::value;
753 case ElemKind::Int8QTy:
754 return std::is_same<ElemTy, int8_t>::value;
755 case ElemKind::UInt8QTy:
756 return std::is_same<ElemTy, uint8_t>::value;
757 case ElemKind::Int16QTy:
758 return std::is_same<ElemTy, int16_t>::value;
759 case ElemKind::Int32QTy:
760 return std::is_same<ElemTy, int32_t>::value;
761 case ElemKind::Int64QTy:
762 return std::is_same<ElemTy, int64_t>::value;
763 case ElemKind::UInt8ITy:
764 return std::is_same<ElemTy, uint8_t>::value;
765 case ElemKind::Int32ITy:
766 return std::is_same<ElemTy, int32_t>::value;
767 case ElemKind::Int64ITy:
768 return std::is_same<ElemTy, int64_t>::value;
769 case ElemKind::UInt8FusedQTy:
770 return std::is_same<ElemTy, uint8_t>::value;
771 case ElemKind::UInt8FusedFP16QTy:
772 return std::is_same<ElemTy, uint8_t>::value;
773 case ElemKind::UInt4FusedFP16QTy:
774 return std::is_same<ElemTy, uint8_t>::value;
775 case ElemKind::UInt4FusedQTy:
776 return std::is_same<ElemTy, uint8_t>::value;
777 case ElemKind::BoolTy:
778 return std::is_same<ElemTy, bool>::value;
779 }
780 LOG(FATAL) << "Invalid type: " << getElementName(Ty).str();
781 return false; // Get rid of compilation warnings.
782 }
783
784 /// \returns true if the type of this Tensor is one of the quantized types.
785 bool isQuantizedType() const { return isQuantizedElemKind(elementType_); }
786
787 /// \returns true if the type of this Tensor is one of the fused quantized
788 /// types.
789 bool isFusedQuantizedType() const {
790 return isFusedQuantizedElemKind(elementType_);
791 }
792
793 /// \returns true if the type of this Tensor is one of the floating point
794 /// types.
795 bool isFPType() const { return isFloatElemKind(getElementType()); }
796
797 /// \return the size of the type element.
798 unsigned getElementSize() const { return getElementSize(elementType_); }
799
800 /// \returns the size in bytes for this Tensor.
801 size_t getSizeInBytes() const {
802 size_t s = getElementSize();
803 for (unsigned char i = 0; i < numSizes_; i++) {
804 // If any dimensions are 0 then the entire size is 0, so early return.
805 if (sizes_[i] == 0) {
806 return 0;
807 }
808 s = std::max<dim_t>(s,
809 size_t(sizes_[i]) * getElementSize() * strides_[i]);
810 }
811 return s;
812 }
813
814 /// \returns the actual number of elements in the tensor taking striding into
815 /// account. Since size() does not take striding into account, size() is
816 /// always <= actualSize().
817 size_t actualSize() const { return getSizeInBytes() / getElementSize(); }
818
819 /// \return the size of the element \p Ty.
820 static unsigned getElementSize(ElemKind Ty) {
821 switch (Ty) {
822 case ElemKind::FloatTy:
823 return sizeof(float);
824 case ElemKind::Float16Ty:
825 return sizeof(float16_t);
826 case ElemKind::BFloat16Ty:
827 return sizeof(bfloat16_t);
828 case ElemKind::Float64Ty:
829 return sizeof(double);
830 case ElemKind::Int8QTy:
831 return sizeof(int8_t);
832 case ElemKind::UInt8QTy:
833 return sizeof(uint8_t);
834 case ElemKind::Int16QTy:
835 return sizeof(int16_t);
836 case ElemKind::Int32QTy:
837 return sizeof(int32_t);
838 case ElemKind::Int64QTy:
839 return sizeof(int64_t);
840 case ElemKind::UInt8ITy:
841 return sizeof(uint8_t);
842 case ElemKind::Int32ITy:
843 return sizeof(int32_t);
844 case ElemKind::Int64ITy:
845 return sizeof(int64_t);
846 case ElemKind::UInt8FusedQTy:
847 return sizeof(uint8_t);
848 case ElemKind::UInt8FusedFP16QTy:
849 return sizeof(uint8_t);
850 case ElemKind::UInt4FusedFP16QTy:
851 return sizeof(uint8_t);
852 case ElemKind::UInt4FusedQTy:
853 return sizeof(uint8_t);
854 case ElemKind::BoolTy:
855 return sizeof(bool);
856 }
857 LOG(FATAL) << "Invalid type: " << getElementName(Ty).str();
858 }
859
860 /// \return the textual name of the element.
861 llvm::StringRef getElementName() const {
862 return getElementName(elementType_);
863 }
864
865 /// \return the textual name of the element \p Ty.
866 static llvm::StringRef getElementName(ElemKind Ty) {
867 static const char *names[] = {
868 "float", "float16", "bfloat16", "float64", "i8",
869 "ui8", "i16", "i32", "uindex8", "index32",
870 "index64", "ui8fused", "ui8fusedfp16", "ui4fusedfp16", "ui4fused",
871 "bool", "i64",
872 };
873 return names[(int)Ty];
874 }
875
876 /// Given a string \p str containing the name of an ElemKind from
877 /// Type::getElementName, returns the corresponding ElemKind or Error if a
878 /// mapping couldn't be found.
879 static ElemKind getElementKindFromName(llvm::StringRef str) {
880 if (str == Type::getElementName(ElemKind::FloatTy)) {
881 return ElemKind::FloatTy;
882 } else if (str == Type::getElementName(ElemKind::Float16Ty)) {
883 return ElemKind::Float16Ty;
884 } else if (str == Type::getElementName(ElemKind::Float64Ty)) {
885 return ElemKind::Float64Ty;
886 } else if (str == Type::getElementName(ElemKind::BFloat16Ty)) {
887 return ElemKind::BFloat16Ty;
888 } else if (str == Type::getElementName(ElemKind::Int8QTy)) {
889 return ElemKind::Int8QTy;
890 } else if (str == Type::getElementName(ElemKind::UInt8QTy)) {
891 return ElemKind::UInt8QTy;
892 } else if (str == Type::getElementName(ElemKind::Int16QTy)) {
893 return ElemKind::Int16QTy;
894 } else if (str == Type::getElementName(ElemKind::Int32QTy)) {
895 return ElemKind::Int32QTy;
896 } else if (str == Type::getElementName(ElemKind::UInt8ITy)) {
897 return ElemKind::UInt8ITy;
898 } else if (str == Type::getElementName(ElemKind::Int32ITy)) {
899 return ElemKind::Int32ITy;
900 } else if (str == Type::getElementName(ElemKind::Int64ITy)) {
901 return ElemKind::Int64ITy;
902 } else if (str == Type::getElementName(ElemKind::UInt8FusedQTy)) {
903 return ElemKind::UInt8FusedQTy;
904 } else if (str == Type::getElementName(ElemKind::UInt8FusedFP16QTy)) {
905 return ElemKind::UInt8FusedFP16QTy;
906 } else if (str == Type::getElementName(ElemKind::UInt4FusedFP16QTy)) {
907 return ElemKind::UInt4FusedFP16QTy;
908 } else if (str == Type::getElementName(ElemKind::UInt4FusedQTy)) {
909 return ElemKind::UInt4FusedQTy;
910 } else if (str == Type::getElementName(ElemKind::BoolTy)) {
911 return ElemKind::BoolTy;
912 } else if (str == Type::getElementName(ElemKind::Int64QTy)) {
913 return ElemKind::Int64QTy;
914 } else {
915 LOG(DFATAL) << "Invalid ElemKind string: " << str.str();
916 return ElemKind::FloatTy;
917 }
918 }
919
920 /// Dump a textual representation of the Type into provided output stream.
921 void dump(llvm::raw_ostream &out) const;
922
923 /// Dump a textual representation of the Type into default output stream.
924 void dump() const;
925
926 /// Dump a textual representation of the Type to std::string.
927 std::string toString() const;
928
929 /// Load a Type object from a textual representation \p str. This method is
930 /// paired and should be used together with \ref toString.
931 static Type fromString(llvm::StringRef str);
932
933private:
934 /// Setup the internals of type that store the dimensions. This method is
935 /// used by the constructor.
936 /// \param dims of the tensor (in elements).
937 /// \param alignments of the tensor (in bytes).
938 void initDims(llvm::ArrayRef<dim_t> dims, llvm::ArrayRef<dim_t> alignments) {
939 assert(dims.size() <= max_tensor_dimensions && "Too many dimensions.");
940 assert(dims.size() == alignments.size() &&
941 "The number of dimensions and alignments should be the same");
942 // Update the tensor strides and sizes based on given dims and alignments.
943 // Sizes are simply assigned to dims. And strides are computed as partial
944 // product of dims, making sure that each dimension is aligned as required.
945 numSizes_ = dims.size();
946 if (numSizes_ > 0) {
947 // Stride of the last dimension is always 1.
948 assert(alignments[numSizes_ - 1] == 1 &&
949 "Last dimension must always be aligned.");
950 strides_[numSizes_ - 1] = 1;
951 sizes_[numSizes_ - 1] = dims[numSizes_ - 1];
952 }
953 for (int i = numSizes_ - 2; i >= 0; i--) {
954 dim_t alignment = alignments[i];
955 if (alignment != 1) {
956 assert(alignment % getElementSize() == 0 &&
957 "Alignment should be a multiple of element size");
958 alignment /= getElementSize();
959 }
960 // All the strides (except for last one) depend on the previous dimension.
961 strides_[i] = alignedSize(dims[i + 1] * strides_[i + 1], alignment);
962 sizes_[i] = dims[i];
963 }
964 }
965
966 void initDims(llvm::ArrayRef<dim_t> dims) {
967 assert(dims.size() <= max_tensor_dimensions && "Too many dimensions.");
968 // Update the tensor sizes.
969 for (size_t i = 0, e = dims.size(); i < e; i++) {
970 sizes_[i] = dims[i];
971 }
972 numSizes_ = dims.size();
973 }
974};
975
976inline bool operator==(const Type &LHS, const Type &RHS) {
977 return LHS.isEqual(RHS);
978}
979
980llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Type &type);
981llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const TypeRef &type);
982
983} // namespace glow
984
985#endif // GLOW_BASE_TYPE_H
986