1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_BASE_TYPE_H |
17 | #define GLOW_BASE_TYPE_H |
18 | |
19 | #include "DimType.h" |
20 | |
21 | #include "glow/Support/BFloat16.h" |
22 | #include "glow/Support/Compiler.h" |
23 | #include "glow/Support/Float16.h" |
24 | #include "glow/Support/Memory.h" |
25 | |
26 | #include "llvm/ADT/ArrayRef.h" |
27 | #include "llvm/ADT/StringRef.h" |
28 | |
29 | #include <glog/logging.h> |
30 | |
31 | #include <cstddef> |
32 | #include <cstdint> |
33 | #include <type_traits> |
34 | #include <utility> |
35 | |
36 | namespace llvm { |
37 | class raw_ostream; |
38 | } |
39 | |
40 | namespace glow { |
41 | |
42 | // UINT8_MIN is not defined in standard headers. |
43 | // Define it here for using these definitions consistently. |
44 | #define UINT8_MIN 0 |
45 | |
46 | struct Type; |
47 | |
48 | using TypeRef = const Type *; |
49 | |
50 | constexpr unsigned max_tensor_dimensions = 6; |
51 | |
52 | /// This type is used to implement the Node and Instruction builder's |
53 | /// MemberType::Unsigned and MemberType::VectorUnsigned. Thus it should be used |
54 | /// when handling members of these classes, e.g. a convolution Node/Instr's |
55 | /// getGroup() (Unsigned), or getKernels() (UnsignedVector). |
56 | using unsigned_t = uint32_t; |
57 | |
58 | using float16_t = float16; |
59 | static_assert(sizeof(float16_t) == 2, "Half precision should be 16-bit" ); |
60 | |
61 | using bfloat16_t = bfloat16; |
62 | static_assert(sizeof(bfloat16_t) == 2, "bfloat16 should be 16-bit" ); |
63 | |
64 | using ShapeVector = llvm::SmallVector<dim_t, max_tensor_dimensions>; |
65 | |
66 | struct ShapeNHWC { |
67 | |
68 | enum { |
69 | DimN, |
70 | DimH, |
71 | DimW, |
72 | DimC, |
73 | }; |
74 | |
75 | dim_t n; // Number of samples |
76 | dim_t h; // Height |
77 | dim_t w; // Width |
78 | dim_t c; // Number of Channels |
79 | |
80 | template <typename T> explicit ShapeNHWC(llvm::ArrayRef<T> shape) { |
81 | assert(shape.size() == 4 && "Invalid shape" ); |
82 | n = shape[DimN]; |
83 | h = shape[DimH]; |
84 | w = shape[DimW]; |
85 | c = shape[DimC]; |
86 | } |
87 | |
88 | ShapeNHWC(dim_t samples, dim_t height, dim_t width, dim_t channels) |
89 | : n(samples), h(height), w(width), c(channels) {} |
90 | |
91 | bool equals(const ShapeNHWC &other) const { |
92 | return n == other.n && h == other.h && w == other.w && c == other.c; |
93 | } |
94 | }; |
95 | |
96 | struct ShapeNTHWC { |
97 | dim_t n; // Number of samples |
98 | dim_t t; // Temporal frames |
99 | dim_t h; // Height |
100 | dim_t w; // Width |
101 | dim_t c; // Number of Channels |
102 | |
103 | template <typename T> explicit ShapeNTHWC(llvm::ArrayRef<T> shape) { |
104 | assert(shape.size() == 5 && "Invalid shape" ); |
105 | n = shape[0]; |
106 | t = shape[1]; |
107 | h = shape[2]; |
108 | w = shape[3]; |
109 | c = shape[4]; |
110 | } |
111 | |
112 | ShapeNTHWC(dim_t samples, dim_t temporal_frames, dim_t height, dim_t width, |
113 | dim_t channels) |
114 | : n(samples), t(temporal_frames), h(height), w(width), c(channels) {} |
115 | |
116 | bool equals(const ShapeNTHWC &other) const { |
117 | return n == other.n && t == other.t && h == other.h && w == other.w && |
118 | c == other.c; |
119 | } |
120 | }; |
121 | |
122 | struct ShapeNHWTC { |
123 | dim_t n; // Number of samples |
124 | dim_t h; // Height |
125 | dim_t w; // Width |
126 | dim_t t; // Temporal_frames |
127 | dim_t c; // Number of Channels |
128 | |
129 | template <typename T> explicit ShapeNHWTC(llvm::ArrayRef<T> shape) { |
130 | assert(shape.size() == 5 && "Invalid shape" ); |
131 | n = shape[0]; |
132 | h = shape[1]; |
133 | w = shape[2]; |
134 | t = shape[3]; |
135 | c = shape[4]; |
136 | } |
137 | |
138 | ShapeNHWTC(size_t samples, size_t height, size_t width, |
139 | size_t temporal_frames, size_t channels) |
140 | : n(samples), h(height), w(width), t(temporal_frames), c(channels) {} |
141 | |
142 | bool equals(const ShapeNHWTC &other) const { |
143 | return n == other.n && h == other.h && w == other.w && t == other.t && |
144 | c == other.c; |
145 | } |
146 | }; |
147 | |
148 | struct ShapeNCHW { |
149 | |
150 | enum { |
151 | DimN, |
152 | DimC, |
153 | DimH, |
154 | DimW, |
155 | }; |
156 | |
157 | dim_t n; // Number of samples |
158 | dim_t c; // Number of Channels |
159 | dim_t h; // Height |
160 | dim_t w; // Width |
161 | |
162 | explicit ShapeNCHW(llvm::ArrayRef<dim_t> shape) { |
163 | assert(shape.size() == 4 && "Invalid shape" ); |
164 | n = shape[DimN]; |
165 | c = shape[DimC]; |
166 | h = shape[DimH]; |
167 | w = shape[DimW]; |
168 | } |
169 | |
170 | ShapeNCHW(dim_t samples, dim_t channels, dim_t height, dim_t width) |
171 | : n(samples), c(channels), h(height), w(width) {} |
172 | |
173 | bool equals(const ShapeNCHW &other) const { |
174 | return n == other.n && h == other.h && w == other.w && c == other.c; |
175 | } |
176 | }; |
177 | |
178 | struct ShapeNCTHW { |
179 | dim_t n; // Number of samples |
180 | dim_t c; // Number of Channels |
181 | dim_t t; // Temporal frames |
182 | dim_t h; // Height |
183 | dim_t w; // Width |
184 | |
185 | explicit ShapeNCTHW(llvm::ArrayRef<dim_t> shape) { |
186 | assert(shape.size() == 5 && "Invalid shape" ); |
187 | n = shape[0]; |
188 | c = shape[1]; |
189 | t = shape[2]; |
190 | h = shape[3]; |
191 | w = shape[4]; |
192 | } |
193 | |
194 | ShapeNCTHW(dim_t samples, dim_t channels, dim_t temporal_frames, dim_t height, |
195 | dim_t width) |
196 | : n(samples), c(channels), t(temporal_frames), h(height), w(width) {} |
197 | |
198 | bool equals(const ShapeNCTHW &other) const { |
199 | return n == other.n && t == other.t && h == other.h && w == other.w && |
200 | c == other.c; |
201 | } |
202 | }; |
203 | |
204 | struct PaddingTLBR { |
205 | dim_t top; |
206 | dim_t left; |
207 | dim_t bottom; |
208 | dim_t right; |
209 | |
210 | template <typename T> explicit PaddingTLBR(llvm::ArrayRef<T> pads) { |
211 | assert(pads.size() == 4 && "Invalid padding" ); |
212 | top = pads[0]; |
213 | left = pads[1]; |
214 | bottom = pads[2]; |
215 | right = pads[3]; |
216 | } |
217 | |
218 | bool equalPadding() const { |
219 | return top == left && top == bottom && top == right; |
220 | } |
221 | }; |
222 | |
223 | struct PaddingTLNBRF { |
224 | dim_t top; |
225 | dim_t left; |
226 | dim_t near; |
227 | dim_t bottom; |
228 | dim_t right; |
229 | dim_t far; |
230 | |
231 | template <typename T> explicit PaddingTLNBRF(llvm::ArrayRef<T> pads) { |
232 | assert(pads.size() == 6 && "Invalid padding" ); |
233 | top = pads[0]; |
234 | left = pads[1]; |
235 | near = pads[2]; |
236 | bottom = pads[3]; |
237 | right = pads[4]; |
238 | far = pads[5]; |
239 | } |
240 | |
241 | bool equalPadding() const { |
242 | return top == left && top == bottom && top == right && top == near && |
243 | top == far; |
244 | } |
245 | }; |
246 | |
247 | struct PaddingNFTBLR { |
248 | dim_t near; |
249 | dim_t far; |
250 | dim_t top; |
251 | dim_t bottom; |
252 | dim_t left; |
253 | dim_t right; |
254 | |
255 | template <typename T> explicit PaddingNFTBLR(llvm::ArrayRef<T> pads) { |
256 | assert(pads.size() == 6 && "Invalid padding" ); |
257 | near = pads[0]; |
258 | far = pads[1]; |
259 | top = pads[2]; |
260 | bottom = pads[3]; |
261 | left = pads[4]; |
262 | right = pads[5]; |
263 | } |
264 | |
265 | bool equalPadding() const { |
266 | return top == left && top == bottom && top == right && top == near && |
267 | top == far; |
268 | } |
269 | }; |
270 | |
271 | struct ShapeHW { |
272 | |
273 | enum { |
274 | DimH, |
275 | DimW, |
276 | }; |
277 | |
278 | dim_t height; |
279 | dim_t width; |
280 | |
281 | template <typename T> explicit ShapeHW(llvm::ArrayRef<T> shape) { |
282 | assert(shape.size() == 2 && "Invalid shape" ); |
283 | height = shape[DimH]; |
284 | width = shape[DimW]; |
285 | } |
286 | |
287 | bool isSquare() const { return height == width; } |
288 | }; |
289 | |
290 | struct ShapeNHW { |
291 | |
292 | enum { |
293 | DimN, |
294 | DimH, |
295 | DimW, |
296 | }; |
297 | |
298 | dim_t n; // Number of samples |
299 | dim_t h; // Height |
300 | dim_t w; // Width |
301 | |
302 | template <typename T> explicit ShapeNHW(llvm::ArrayRef<T> shape) { |
303 | assert(shape.size() == 3 && "Invalid shape" ); |
304 | n = shape[DimN]; |
305 | h = shape[DimH]; |
306 | w = shape[DimW]; |
307 | } |
308 | |
309 | bool isSquare() const { return h == w; } |
310 | }; |
311 | |
312 | struct ShapeHWT { |
313 | dim_t height; |
314 | dim_t width; |
315 | dim_t temporal_frames; |
316 | |
317 | template <typename T> explicit ShapeHWT(llvm::ArrayRef<T> shape) { |
318 | assert(shape.size() == 3 && "Invalid shape" ); |
319 | height = shape[0]; |
320 | width = shape[1]; |
321 | temporal_frames = shape[2]; |
322 | } |
323 | |
324 | bool isCube() const { return height == width && height == temporal_frames; } |
325 | }; |
326 | |
327 | struct ShapeTHW { |
328 | dim_t temporal_frames; |
329 | dim_t height; |
330 | dim_t width; |
331 | |
332 | template <typename T> explicit ShapeTHW(llvm::ArrayRef<T> shape) { |
333 | assert(shape.size() == 3 && "Invalid shape" ); |
334 | temporal_frames = shape[0]; |
335 | height = shape[1]; |
336 | width = shape[2]; |
337 | } |
338 | |
339 | bool isCube() const { return height == width && height == temporal_frames; } |
340 | }; |
341 | |
342 | /// Collapse a tensor shape into two sizes: the first n dimensions and the size |
343 | /// of the rest of the dimensions. For example, ([7, 3, 4, 2], 1) -> [7, 24] |
344 | inline std::pair<dim_t, dim_t> flattenCdr(llvm::ArrayRef<dim_t> dims, |
345 | unsigned_t n = 1) { |
346 | assert(1 <= n && n <= dims.size()); |
347 | size_t first = dims[0]; |
348 | for (unsigned_t i = 1; i < n; i++) { |
349 | first *= dims[i]; |
350 | } |
351 | size_t rest = 1; |
352 | for (unsigned_t i = n; i < dims.size(); i++) { |
353 | rest *= dims[i]; |
354 | } |
355 | |
356 | return {first, rest}; |
357 | } |
358 | |
359 | /// Collapse a tensor shape into two sizes: the first will be |
360 | /// size of n without the axis dimension, and second will be |
361 | /// size of the axis dimension. For example, ([7, 3, 4, 2], 1) -> [56, 3] |
362 | inline std::pair<dim_t, dim_t> collapseShape(llvm::ArrayRef<dim_t> dims, |
363 | unsigned_t n = 1) { |
364 | assert(1 <= n && n <= dims.size()); |
365 | size_t first = 1; |
366 | size_t second = 1; |
367 | for (unsigned_t i = 0; i < dims.size(); i++) { |
368 | if (i == n) { |
369 | second = dims[i]; |
370 | } else { |
371 | first *= dims[i]; |
372 | } |
373 | } |
374 | return {first, second}; |
375 | } |
376 | |
377 | inline bool operator==(const ShapeNHWC &LHS, const ShapeNHWC &RHS) { |
378 | return LHS.equals(RHS); |
379 | } |
380 | |
381 | inline bool operator==(const ShapeNCHW &LHS, const ShapeNCHW &RHS) { |
382 | return LHS.equals(RHS); |
383 | } |
384 | |
385 | inline bool operator==(const ShapeNHWTC &LHS, const ShapeNHWTC &RHS) { |
386 | return LHS.equals(RHS); |
387 | } |
388 | |
389 | inline bool operator==(const ShapeNTHWC &LHS, const ShapeNTHWC &RHS) { |
390 | return LHS.equals(RHS); |
391 | } |
392 | |
393 | inline bool operator==(const ShapeNCTHW &LHS, const ShapeNCTHW &RHS) { |
394 | return LHS.equals(RHS); |
395 | } |
396 | |
397 | /// An enum representing the type used by the elements of a tensor. The types of |
398 | /// Handles for these tensors should match the element kind. |
399 | /// When adding new type, note that this enum definition must match with |
400 | /// ElemKind definition in Glow/lib/Backends/CPU/libjit/libjit.cpp |
401 | enum class ElemKind : unsigned char { |
402 | // 32-bit float type (float) |
403 | FloatTy, |
404 | // 16-bit float type (half, fp16) |
405 | Float16Ty, |
406 | // 16-bit float type (bfloat16) |
407 | BFloat16Ty, |
408 | // 64-bit float type (double) |
409 | Float64Ty, |
410 | // 8-bit quantized type (int8_t) |
411 | Int8QTy, |
412 | // unsigned 8-bit quantized type (uint8_t) |
413 | UInt8QTy, |
414 | // 16-bit quantized type (int16_t) |
415 | Int16QTy, |
416 | // 32-bit quantized type (int32_t) |
417 | Int32QTy, |
418 | // 8-bit index type (uint8_t) |
419 | UInt8ITy, |
420 | // 32-bit index type (int32_t) |
421 | Int32ITy, |
422 | // 64-bit index type (int64_t) |
423 | Int64ITy, |
424 | // 8-bit quantized type with fused scale/offset (uint8_t) |
425 | UInt8FusedQTy, |
426 | // 8-bit quantized type with fused FP16 scale/offset (uint8_t) |
427 | UInt8FusedFP16QTy, |
428 | // 4-bit quantized type with fused FP16 scale/offset (uint8_t, each byte |
429 | // represents 2 4-bit quantized data) |
430 | UInt4FusedFP16QTy, |
431 | // 4-bit quantized type with fused FP32 scale/offset (uint8_t, each byte |
432 | // represents 2 4-bit quantized data) |
433 | UInt4FusedQTy, |
434 | // Bool type (bool) |
435 | BoolTy, |
436 | // 64-bit quantized type (int64_t, partial support) |
437 | Int64QTy, |
438 | }; |
439 | |
440 | /// \returns whether \p e is a quantized ElemKind. |
441 | inline bool isQuantizedElemKind(ElemKind e) { |
442 | return e == ElemKind::Int8QTy || e == ElemKind::UInt8QTy || |
443 | e == ElemKind::Int16QTy || e == ElemKind::Int32QTy || |
444 | e == ElemKind::Int64QTy || e == ElemKind::UInt8FusedQTy || |
445 | e == ElemKind::UInt8FusedFP16QTy || e == ElemKind::UInt4FusedFP16QTy || |
446 | e == ElemKind::UInt4FusedQTy; |
447 | } |
448 | |
449 | /// \returns whether \p e is a float ElemKind. |
450 | inline bool isFloatElemKind(ElemKind e) { |
451 | return e == ElemKind::FloatTy || e == ElemKind::Float16Ty || |
452 | e == ElemKind::BFloat16Ty || e == ElemKind::Float64Ty; |
453 | } |
454 | |
455 | /// \returns whether \p e is a non-quantized integer ElemKind. |
456 | inline bool isNonQuantizedIntElemKind(ElemKind e) { |
457 | return e == ElemKind::Int32ITy || e == ElemKind::Int64ITy; |
458 | } |
459 | |
460 | /// \returns whether \p e is a fused quantized ElemKind. |
461 | inline bool isFusedQuantizedElemKind(ElemKind e) { |
462 | return e == ElemKind::UInt8FusedQTy || e == ElemKind::UInt8FusedFP16QTy || |
463 | e == ElemKind::UInt4FusedFP16QTy || e == ElemKind::UInt4FusedQTy; |
464 | } |
465 | |
466 | /// \returns the scale and offset ElemKind used by the fused ElemKind \p e. |
467 | inline ElemKind getScaleOffsetElemKindFromFused(ElemKind e) { |
468 | assert(isFusedQuantizedElemKind(e) && "Must pass Fused ElemKind." ); |
469 | if (e == ElemKind::UInt8FusedQTy || e == ElemKind::UInt4FusedQTy) { |
470 | return ElemKind::FloatTy; |
471 | } |
472 | return ElemKind::Float16Ty; |
473 | } |
474 | |
475 | /// \returns the floating point value range that covers a quantized type (min |
476 | /// first, max second) given \p scale, \p offset, and \p elementType. |
477 | std::pair<float, float> getQuantizedValueRange(float scale, int32_t offset, |
478 | ElemKind elementType); |
479 | |
480 | /// A class that represents a type of a tensor. |
481 | struct Type final { |
482 | /// Contains the dimensions (sizes) of the tensor. Ex: [sx, sy, sz, ...]. |
483 | dim_t sizes_[max_tensor_dimensions] = { |
484 | 0, |
485 | }; |
486 | /// Contains the strides for each dimension (in elements). The order should be |
487 | /// the same as in sizes_. In more details, suppose that the tensor is laid |
488 | /// out flat in memory, and some dimensions are aligned. strides_[i] is the |
489 | /// number of elements that needs to be skipped in order to reach the next |
490 | /// plane in the i-th dimension. For example, if the tensor has dimensions |
491 | /// [3, 5, 10] and alignments [3, 32, 1], the strides will be [162, 32, 1]. |
492 | dim_t strides_[max_tensor_dimensions] = { |
493 | 0, |
494 | }; |
495 | |
496 | /// Contains the number of dimensions used by the tensor. |
497 | unsigned char numSizes_{0}; |
498 | |
499 | /// On quantized tensors, this represents the scale of the values. |
500 | float scale_{0}; |
501 | /// On quantized tensors, this represents the offset of the values. |
502 | int32_t offset_{0}; |
503 | |
504 | /// Specifies the element type of the tensor. |
505 | ElemKind elementType_{ElemKind::Int64ITy}; |
506 | |
507 | /// Initialize a new quantized type with \p scale and \p offset. |
508 | Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, float scale, int32_t offset) |
509 | : scale_(scale), offset_(offset), elementType_(elemTy) { |
510 | assert(isQuantizedType() && "Only quantized types have a scale and offset" ); |
511 | ShapeVector alignments(dims.size(), 1); |
512 | initDims(dims, llvm::makeArrayRef(alignments)); |
513 | } |
514 | |
515 | /// Initialize a new non-quantized type. |
516 | Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims) : elementType_(elemTy) { |
517 | assert(!isQuantizedType() && |
518 | "Can't initialize quantized types without scale and offset" ); |
519 | ShapeVector alignments(dims.size(), 1); |
520 | initDims(dims, llvm::makeArrayRef(alignments)); |
521 | } |
522 | |
523 | /// Initialize a new quantized type with \p scale and \p offset. |
524 | Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, |
525 | llvm::ArrayRef<dim_t> alignments, float scale, int32_t offset) |
526 | : scale_(scale), offset_(offset), elementType_(elemTy) { |
527 | assert(isQuantizedType() && "Only quantized types have a scale and offset" ); |
528 | initDims(dims, alignments); |
529 | } |
530 | |
531 | /// Initialize a new non-quantized type. |
532 | Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, |
533 | llvm::ArrayRef<dim_t> alignments) |
534 | : elementType_(elemTy) { |
535 | assert(!isQuantizedType() && |
536 | "Can't initialize quantized types without scale and offset" ); |
537 | initDims(dims, alignments); |
538 | } |
539 | |
540 | /// Reshape existing type. This method takes care of quantized types. |
541 | static Type newShape(const Type &T, llvm::ArrayRef<dim_t> dims) { |
542 | if (T.isQuantizedType()) { |
543 | return Type(T.getElementType(), dims, T.getScale(), T.getOffset()); |
544 | } else { |
545 | return Type(T.getElementType(), dims); |
546 | } |
547 | } |
548 | |
549 | /// Reshape existing type and change alignments. |
550 | static Type newShape(const Type &T, llvm::ArrayRef<dim_t> dims, |
551 | llvm::ArrayRef<dim_t> alignments) { |
552 | if (T.isQuantizedType()) { |
553 | return Type(T.getElementType(), dims, alignments, T.getScale(), |
554 | T.getOffset()); |
555 | } else { |
556 | return Type(T.getElementType(), dims, alignments); |
557 | } |
558 | } |
559 | |
560 | /// Reshape existing type by taking shapes and strides of \p shapeType. |
561 | static Type newShape(const Type &T, TypeRef shapeType) { |
562 | Type ty; |
563 | if (T.isQuantizedType()) { |
564 | ty = Type(T.getElementType(), shapeType->dims(), T.getScale(), |
565 | T.getOffset()); |
566 | } else { |
567 | ty = Type(T.getElementType(), shapeType->dims()); |
568 | } |
569 | // Copy the stride information. |
570 | std::copy(&shapeType->strides_[0], &shapeType->strides_[ty.numSizes_], |
571 | ty.strides_); |
572 | return ty; |
573 | } |
574 | |
575 | /// Reshape existing type \p T by taking shapes and using the provided \p |
576 | /// strides. |
577 | static Type newStrides(const Type &T, llvm::ArrayRef<dim_t> strides) { |
578 | assert(strides.size() == T.strides().size()); |
579 | Type ty = T; |
580 | // Copy the stride information. |
581 | std::copy(&strides[0], &strides[0] + ty.numSizes_, ty.strides_); |
582 | return ty; |
583 | } |
584 | |
585 | /// Reshape existing type. This method takes care of quantized types. |
586 | static Type newQuantparams(const Type &T, float scale, int32_t offset) { |
587 | assert(T.isQuantizedType() && |
588 | "Can't modify scale and offset of non quantized types" ); |
589 | return Type(T.getElementType(), T.dims(), scale, offset); |
590 | } |
591 | |
592 | /// \returns true if a type has standard strides and no special alignment |
593 | /// requirements. |
594 | bool hasStandardStrides() const { |
595 | if (numSizes_ > 0) { |
596 | // Stride of the last dimension is always 1. |
597 | assert(strides_[numSizes_ - 1] == 1 && |
598 | "Last dimension must always be aligned." ); |
599 | } |
600 | for (int i = numSizes_ - 2; i >= 0; i--) { |
601 | // All the strides (except for last one) depend on the previous dimension. |
602 | // For standard strides the following should be true: |
603 | // strides_[i] == sizes_[i + 1] * strides_[i + 1] |
604 | if (strides_[i] != sizes_[i + 1] * strides_[i + 1]) { |
605 | return false; |
606 | } |
607 | } |
608 | return true; |
609 | } |
610 | |
611 | /// An empty type. |
612 | Type() = default; |
613 | |
614 | /// \returns true if \p other is the same type. |
615 | bool isEqual(TypeRef other) const { return isEqual(*other); } |
616 | |
617 | /// \returns the scale of a quantized type. |
618 | float getScale() const { |
619 | assert(isQuantizedType() && "Can't get the scale of a non-quantized type" ); |
620 | return scale_; |
621 | } |
622 | |
623 | /// \returns the offset of a quantized type. |
624 | int32_t getOffset() const { |
625 | assert(isQuantizedType() && "Can't get the offset of a non-quantized type" ); |
626 | return offset_; |
627 | } |
628 | |
629 | /// \returns the floating point value range that covers a quantized type (min |
630 | /// first, max second). |
631 | std::pair<float, float> getQuantizedValueRange() const { |
632 | return ::glow::getQuantizedValueRange(scale_, offset_, elementType_); |
633 | } |
634 | |
635 | /// \returns the number of values associated to the quantized type (e.g. 256 |
636 | /// for Int8QTy). |
637 | size_t getQuantizedValueCount() const { |
638 | assert(getElementSize() < sizeof(size_t) && |
639 | "Cannot retrieve quantized value count with size_t!" ); |
640 | double numBits = getElementSize() * 8; |
641 | return static_cast<size_t>(std::exp2(numBits)); |
642 | } |
643 | |
644 | /// \returns the floating point value step associated to the quantized type. |
645 | /// The quantization step is actually equal to the quantization scale. |
646 | float getQuantizedValueStep() const { |
647 | assert(isQuantizedType() && |
648 | "Can't get the quantized value step of a non-quantized type" ); |
649 | return scale_; |
650 | } |
651 | |
652 | /// \returns true if \p other is the same type. If \p allowDifferentShape then |
653 | /// shapes will not be considered as part of the equal comparison. If \p |
654 | /// allowDifferentScaleOffset is true, scale and offset will not be considered |
655 | /// as part of the equal comparison. |
656 | bool isEqual(const Type &other, bool allowDifferentShape = false, |
657 | bool allowDifferentStrides = false, |
658 | bool allowDifferentScaleOffset = false) const { |
659 | // Element type must be the same. |
660 | if (elementType_ != other.elementType_) { |
661 | return false; |
662 | } |
663 | // Must have the same number of sizes. |
664 | if (numSizes_ != other.numSizes_) { |
665 | return false; |
666 | } |
667 | // Sizes must be the same. |
668 | if (!allowDifferentShape) { |
669 | for (size_t i = 0; i < numSizes_; i++) { |
670 | if (sizes_[i] != other.sizes_[i]) { |
671 | return false; |
672 | } |
673 | } |
674 | if (!allowDifferentStrides) { |
675 | // Strides must be the same. |
676 | for (size_t i = 0; i < numSizes_; i++) { |
677 | if (strides_[i] != other.strides_[i]) { |
678 | return false; |
679 | } |
680 | } |
681 | } |
682 | } |
683 | |
684 | // Compare the scale and offset of integers. Fused types use dummy |
685 | // scale/offset, so can ignore them. |
686 | if (isQuantizedType() && !isFusedQuantizedType() && |
687 | !allowDifferentScaleOffset) { |
688 | if (scale_ != other.scale_ || offset_ != other.offset_) { |
689 | return false; |
690 | } |
691 | } |
692 | |
693 | return true; |
694 | } |
695 | |
696 | /// \returns a hash value for this Type. Hashes for Ty1 and Ty2 are equal if |
697 | /// Ty1.isEqual(Ty2). |
698 | llvm::hash_code equals_hash() const { |
699 | return llvm::hash_combine( |
700 | elementType_, dims(), |
701 | // hashing floats is tricky, fall back to std::hash |
702 | std::hash<float>{}(scale_), offset_); |
703 | } |
704 | |
705 | ElemKind getElementType() const { return elementType_; } |
706 | |
707 | /// \returns the shape of the tensor. |
708 | llvm::ArrayRef<dim_t> dims() const { return {sizes_, numSizes_}; } |
709 | |
710 | /// \returns the strides of the tensor. |
711 | llvm::ArrayRef<dim_t> strides() const { return {strides_, numSizes_}; } |
712 | |
713 | /// \returns the number of elements in the tensor. |
714 | dim_t size() const { |
715 | dim_t s = 1; |
716 | for (unsigned char i = 0; i < numSizes_; i++) { |
717 | s *= dim_t(sizes_[i]); |
718 | } |
719 | |
720 | return s; |
721 | } |
722 | |
723 | /// \returns the number of elements in a slice in the tensor. Calculate the |
724 | /// size of the slice starting at \p startDim. For example, the tensor with |
725 | /// the shape [10, 10, 3] and startDim 1 would have the size 30, because this |
726 | /// is the size of the slice [10, 3] that starts at index 1. |
727 | dim_t getSliceSize(unsigned char startDim) const { |
728 | assert(startDim <= numSizes_ && "Invalid start dim" ); |
729 | dim_t s = 1; |
730 | for (unsigned char i = startDim; i < numSizes_; i++) { |
731 | s *= dim_t(sizes_[i]); |
732 | } |
733 | return s; |
734 | } |
735 | |
736 | /// \returns true if the templated parameter \p ElemTy matches this type. |
737 | template <class ElemTy> bool isType() const { |
738 | return isType<ElemTy>(elementType_); |
739 | } |
740 | |
741 | /// \returns true if the templated parameter \p ElemTy matches the type that's |
742 | /// specified by the parameter \p Ty. |
743 | template <class ElemTy> static bool isType(ElemKind Ty) { |
744 | switch (Ty) { |
745 | case ElemKind::FloatTy: |
746 | return std::is_same<ElemTy, float>::value; |
747 | case ElemKind::Float16Ty: |
748 | return std::is_same<ElemTy, float16_t>::value; |
749 | case ElemKind::BFloat16Ty: |
750 | return std::is_same<ElemTy, bfloat16_t>::value; |
751 | case ElemKind::Float64Ty: |
752 | return std::is_same<ElemTy, double>::value; |
753 | case ElemKind::Int8QTy: |
754 | return std::is_same<ElemTy, int8_t>::value; |
755 | case ElemKind::UInt8QTy: |
756 | return std::is_same<ElemTy, uint8_t>::value; |
757 | case ElemKind::Int16QTy: |
758 | return std::is_same<ElemTy, int16_t>::value; |
759 | case ElemKind::Int32QTy: |
760 | return std::is_same<ElemTy, int32_t>::value; |
761 | case ElemKind::Int64QTy: |
762 | return std::is_same<ElemTy, int64_t>::value; |
763 | case ElemKind::UInt8ITy: |
764 | return std::is_same<ElemTy, uint8_t>::value; |
765 | case ElemKind::Int32ITy: |
766 | return std::is_same<ElemTy, int32_t>::value; |
767 | case ElemKind::Int64ITy: |
768 | return std::is_same<ElemTy, int64_t>::value; |
769 | case ElemKind::UInt8FusedQTy: |
770 | return std::is_same<ElemTy, uint8_t>::value; |
771 | case ElemKind::UInt8FusedFP16QTy: |
772 | return std::is_same<ElemTy, uint8_t>::value; |
773 | case ElemKind::UInt4FusedFP16QTy: |
774 | return std::is_same<ElemTy, uint8_t>::value; |
775 | case ElemKind::UInt4FusedQTy: |
776 | return std::is_same<ElemTy, uint8_t>::value; |
777 | case ElemKind::BoolTy: |
778 | return std::is_same<ElemTy, bool>::value; |
779 | } |
780 | LOG(FATAL) << "Invalid type: " << getElementName(Ty).str(); |
781 | return false; // Get rid of compilation warnings. |
782 | } |
783 | |
784 | /// \returns true if the type of this Tensor is one of the quantized types. |
785 | bool isQuantizedType() const { return isQuantizedElemKind(elementType_); } |
786 | |
787 | /// \returns true if the type of this Tensor is one of the fused quantized |
788 | /// types. |
789 | bool isFusedQuantizedType() const { |
790 | return isFusedQuantizedElemKind(elementType_); |
791 | } |
792 | |
793 | /// \returns true if the type of this Tensor is one of the floating point |
794 | /// types. |
795 | bool isFPType() const { return isFloatElemKind(getElementType()); } |
796 | |
797 | /// \return the size of the type element. |
798 | unsigned getElementSize() const { return getElementSize(elementType_); } |
799 | |
800 | /// \returns the size in bytes for this Tensor. |
801 | size_t getSizeInBytes() const { |
802 | size_t s = getElementSize(); |
803 | for (unsigned char i = 0; i < numSizes_; i++) { |
804 | // If any dimensions are 0 then the entire size is 0, so early return. |
805 | if (sizes_[i] == 0) { |
806 | return 0; |
807 | } |
808 | s = std::max<dim_t>(s, |
809 | size_t(sizes_[i]) * getElementSize() * strides_[i]); |
810 | } |
811 | return s; |
812 | } |
813 | |
814 | /// \returns the actual number of elements in the tensor taking striding into |
815 | /// account. Since size() does not take striding into account, size() is |
816 | /// always <= actualSize(). |
817 | size_t actualSize() const { return getSizeInBytes() / getElementSize(); } |
818 | |
819 | /// \return the size of the element \p Ty. |
820 | static unsigned getElementSize(ElemKind Ty) { |
821 | switch (Ty) { |
822 | case ElemKind::FloatTy: |
823 | return sizeof(float); |
824 | case ElemKind::Float16Ty: |
825 | return sizeof(float16_t); |
826 | case ElemKind::BFloat16Ty: |
827 | return sizeof(bfloat16_t); |
828 | case ElemKind::Float64Ty: |
829 | return sizeof(double); |
830 | case ElemKind::Int8QTy: |
831 | return sizeof(int8_t); |
832 | case ElemKind::UInt8QTy: |
833 | return sizeof(uint8_t); |
834 | case ElemKind::Int16QTy: |
835 | return sizeof(int16_t); |
836 | case ElemKind::Int32QTy: |
837 | return sizeof(int32_t); |
838 | case ElemKind::Int64QTy: |
839 | return sizeof(int64_t); |
840 | case ElemKind::UInt8ITy: |
841 | return sizeof(uint8_t); |
842 | case ElemKind::Int32ITy: |
843 | return sizeof(int32_t); |
844 | case ElemKind::Int64ITy: |
845 | return sizeof(int64_t); |
846 | case ElemKind::UInt8FusedQTy: |
847 | return sizeof(uint8_t); |
848 | case ElemKind::UInt8FusedFP16QTy: |
849 | return sizeof(uint8_t); |
850 | case ElemKind::UInt4FusedFP16QTy: |
851 | return sizeof(uint8_t); |
852 | case ElemKind::UInt4FusedQTy: |
853 | return sizeof(uint8_t); |
854 | case ElemKind::BoolTy: |
855 | return sizeof(bool); |
856 | } |
857 | LOG(FATAL) << "Invalid type: " << getElementName(Ty).str(); |
858 | } |
859 | |
860 | /// \return the textual name of the element. |
861 | llvm::StringRef getElementName() const { |
862 | return getElementName(elementType_); |
863 | } |
864 | |
865 | /// \return the textual name of the element \p Ty. |
866 | static llvm::StringRef getElementName(ElemKind Ty) { |
867 | static const char *names[] = { |
868 | "float" , "float16" , "bfloat16" , "float64" , "i8" , |
869 | "ui8" , "i16" , "i32" , "uindex8" , "index32" , |
870 | "index64" , "ui8fused" , "ui8fusedfp16" , "ui4fusedfp16" , "ui4fused" , |
871 | "bool" , "i64" , |
872 | }; |
873 | return names[(int)Ty]; |
874 | } |
875 | |
876 | /// Given a string \p str containing the name of an ElemKind from |
877 | /// Type::getElementName, returns the corresponding ElemKind or Error if a |
878 | /// mapping couldn't be found. |
879 | static ElemKind getElementKindFromName(llvm::StringRef str) { |
880 | if (str == Type::getElementName(ElemKind::FloatTy)) { |
881 | return ElemKind::FloatTy; |
882 | } else if (str == Type::getElementName(ElemKind::Float16Ty)) { |
883 | return ElemKind::Float16Ty; |
884 | } else if (str == Type::getElementName(ElemKind::Float64Ty)) { |
885 | return ElemKind::Float64Ty; |
886 | } else if (str == Type::getElementName(ElemKind::BFloat16Ty)) { |
887 | return ElemKind::BFloat16Ty; |
888 | } else if (str == Type::getElementName(ElemKind::Int8QTy)) { |
889 | return ElemKind::Int8QTy; |
890 | } else if (str == Type::getElementName(ElemKind::UInt8QTy)) { |
891 | return ElemKind::UInt8QTy; |
892 | } else if (str == Type::getElementName(ElemKind::Int16QTy)) { |
893 | return ElemKind::Int16QTy; |
894 | } else if (str == Type::getElementName(ElemKind::Int32QTy)) { |
895 | return ElemKind::Int32QTy; |
896 | } else if (str == Type::getElementName(ElemKind::UInt8ITy)) { |
897 | return ElemKind::UInt8ITy; |
898 | } else if (str == Type::getElementName(ElemKind::Int32ITy)) { |
899 | return ElemKind::Int32ITy; |
900 | } else if (str == Type::getElementName(ElemKind::Int64ITy)) { |
901 | return ElemKind::Int64ITy; |
902 | } else if (str == Type::getElementName(ElemKind::UInt8FusedQTy)) { |
903 | return ElemKind::UInt8FusedQTy; |
904 | } else if (str == Type::getElementName(ElemKind::UInt8FusedFP16QTy)) { |
905 | return ElemKind::UInt8FusedFP16QTy; |
906 | } else if (str == Type::getElementName(ElemKind::UInt4FusedFP16QTy)) { |
907 | return ElemKind::UInt4FusedFP16QTy; |
908 | } else if (str == Type::getElementName(ElemKind::UInt4FusedQTy)) { |
909 | return ElemKind::UInt4FusedQTy; |
910 | } else if (str == Type::getElementName(ElemKind::BoolTy)) { |
911 | return ElemKind::BoolTy; |
912 | } else if (str == Type::getElementName(ElemKind::Int64QTy)) { |
913 | return ElemKind::Int64QTy; |
914 | } else { |
915 | LOG(DFATAL) << "Invalid ElemKind string: " << str.str(); |
916 | return ElemKind::FloatTy; |
917 | } |
918 | } |
919 | |
920 | /// Dump a textual representation of the Type into provided output stream. |
921 | void dump(llvm::raw_ostream &out) const; |
922 | |
923 | /// Dump a textual representation of the Type into default output stream. |
924 | void dump() const; |
925 | |
926 | /// Dump a textual representation of the Type to std::string. |
927 | std::string toString() const; |
928 | |
929 | /// Load a Type object from a textual representation \p str. This method is |
930 | /// paired and should be used together with \ref toString. |
931 | static Type fromString(llvm::StringRef str); |
932 | |
933 | private: |
934 | /// Setup the internals of type that store the dimensions. This method is |
935 | /// used by the constructor. |
936 | /// \param dims of the tensor (in elements). |
937 | /// \param alignments of the tensor (in bytes). |
938 | void initDims(llvm::ArrayRef<dim_t> dims, llvm::ArrayRef<dim_t> alignments) { |
939 | assert(dims.size() <= max_tensor_dimensions && "Too many dimensions." ); |
940 | assert(dims.size() == alignments.size() && |
941 | "The number of dimensions and alignments should be the same" ); |
942 | // Update the tensor strides and sizes based on given dims and alignments. |
943 | // Sizes are simply assigned to dims. And strides are computed as partial |
944 | // product of dims, making sure that each dimension is aligned as required. |
945 | numSizes_ = dims.size(); |
946 | if (numSizes_ > 0) { |
947 | // Stride of the last dimension is always 1. |
948 | assert(alignments[numSizes_ - 1] == 1 && |
949 | "Last dimension must always be aligned." ); |
950 | strides_[numSizes_ - 1] = 1; |
951 | sizes_[numSizes_ - 1] = dims[numSizes_ - 1]; |
952 | } |
953 | for (int i = numSizes_ - 2; i >= 0; i--) { |
954 | dim_t alignment = alignments[i]; |
955 | if (alignment != 1) { |
956 | assert(alignment % getElementSize() == 0 && |
957 | "Alignment should be a multiple of element size" ); |
958 | alignment /= getElementSize(); |
959 | } |
960 | // All the strides (except for last one) depend on the previous dimension. |
961 | strides_[i] = alignedSize(dims[i + 1] * strides_[i + 1], alignment); |
962 | sizes_[i] = dims[i]; |
963 | } |
964 | } |
965 | |
966 | void initDims(llvm::ArrayRef<dim_t> dims) { |
967 | assert(dims.size() <= max_tensor_dimensions && "Too many dimensions." ); |
968 | // Update the tensor sizes. |
969 | for (size_t i = 0, e = dims.size(); i < e; i++) { |
970 | sizes_[i] = dims[i]; |
971 | } |
972 | numSizes_ = dims.size(); |
973 | } |
974 | }; |
975 | |
976 | inline bool operator==(const Type &LHS, const Type &RHS) { |
977 | return LHS.isEqual(RHS); |
978 | } |
979 | |
980 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Type &type); |
981 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const TypeRef &type); |
982 | |
983 | } // namespace glow |
984 | |
985 | #endif // GLOW_BASE_TYPE_H |
986 | |