1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Base/Tensor.h"
18
19#include "glow/Base/Type.h"
20
21#include "llvm/Support/NativeFormatting.h"
22#include "llvm/Support/raw_ostream.h"
23#include <glog/logging.h>
24
25using namespace glow;
26
27namespace {
28
29/// This is a helper method that's used in the visualization of tensors.
30template <class ElemTy> static char valueToChar(ElemTy input) {
31 char ch = ' ';
32 const double val = input;
33 if (val > 0.2) {
34 ch = '.';
35 }
36 if (val > 0.4) {
37 ch = ',';
38 }
39 if (val > 0.6) {
40 ch = ':';
41 }
42 if (val > 0.8) {
43 ch = 'o';
44 }
45 if (val > 1.0) {
46 ch = 'O';
47 }
48 if (val > 1.5) {
49 ch = '0';
50 }
51 if (val > 2.0) {
52 ch = '@';
53 }
54 if (val < -0.1) {
55 ch = '-';
56 }
57 if (val < -0.2) {
58 ch = '~';
59 }
60 if (val < -0.4) {
61 ch = '=';
62 }
63 if (val < -1.0) {
64 ch = '#';
65 }
66 return ch;
67}
68
69static void dumpShape(llvm::ArrayRef<dim_t> shape, llvm::raw_ostream &os) {
70 os << "shape: ( ";
71 for (auto &d : shape) {
72 os << d << " ";
73 }
74 os << ")";
75}
76
77template <class ElemTy>
78static void dumpGenericImpl(Handle<ElemTy> handle, llvm::raw_ostream &os,
79 unsigned maxNumElem) {
80 auto shape = handle.dims();
81 size_t numDims = shape.size();
82 auto &Ty = handle.getType();
83
84 constexpr unsigned numDigsFP = 5;
85 const unsigned numDigs = std::is_integral<ElemTy>::value ? 0 : numDigsFP;
86
87 // Check for 0-dimensional tensor.
88 if (!numDims) {
89 os << "[ Scalar containing: ";
90 llvm::write_double(os, handle.raw(0), llvm::FloatStyle::Fixed, numDigs);
91 os << " ]\n";
92 return;
93 }
94
95 const size_t numRealElems = handle.getRealNumElements();
96
97 // Output shape.
98 dumpShape(shape, os);
99 if (numRealElems < handle.size()) {
100 os << " ; partial num elements: " << numRealElems;
101 }
102 os << "\n";
103
104 // Output ElemKind.
105 os << "elemkind: " << Ty.getElementName() << "\n";
106
107 // Check for tensor of size 0.
108 if (handle.getUnpaddedSizeInBytes() == 0) {
109 os << "[ tensor has no elements ]\n";
110 return;
111 }
112
113 ElemTy mx = handle.raw(0);
114 ElemTy mn = handle.raw(0);
115 double avg = 0.0f;
116
117 for (auto elem : handle) {
118 mx = std::max(mx, elem);
119 mn = std::min(mn, elem);
120 avg += (double)elem;
121 }
122 avg /= numRealElems;
123
124 // Check for zero tensor.
125 if (mn == ElemTy(.0) && mx == ElemTy(.0)) {
126 os << "[ Zero tensor ]\n";
127 return;
128 }
129
130 // Output max and min.
131 os << "max: ";
132 llvm::write_double(os, mx, llvm::FloatStyle::Fixed, numDigs);
133 os << " min: ";
134 llvm::write_double(os, mn, llvm::FloatStyle::Fixed, numDigs);
135 os << " avg: ";
136 llvm::write_double(os, avg, llvm::FloatStyle::Fixed, numDigsFP);
137 os << "\n";
138
139 os << "[";
140
141 for (size_t i = 0, e = std::min<size_t>(maxNumElem, numRealElems); i < e;
142 i++) {
143
144 // Print one open brace at the beginning of every row, slice, and tensor.
145 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
146 if (i % Ty.getSliceSize(j + 1) == 0) {
147 // This iteration of outer loop is a new row, slice or tensor.
148 os << "[";
149 }
150 }
151
152 // Print the value at the current index.
153 llvm::write_double(os, handle.raw(i), llvm::FloatStyle::Fixed, numDigs);
154
155 // Print one closed brace at the end of every row, slice, or tensor.
156 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
157 size_t next_index = i + 1;
158 if (next_index % Ty.getSliceSize(j + 1) == 0u) {
159 os << "]";
160 }
161 }
162
163 os << ", ";
164
165 // Print one newline at the end of every row, slice, or tensor.
166 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
167 size_t next_index = i + 1;
168 if (next_index % Ty.getSliceSize(j + 1) == 0u) {
169 // Next iteration of outer loop will be a new row, slice or tensor.
170 os << "\n";
171 }
172 }
173 }
174
175 if (numRealElems > maxNumElem) {
176 os << "...";
177 }
178
179 os << "]\n";
180
181 os.flush();
182}
183
184template <class ElemTy>
185static void dumpAsciiGenericImpl(Handle<ElemTy> handle, llvm::raw_ostream &os) {
186 auto d = handle.dims();
187
188 if (d.size() == 2) {
189 for (dim_t x = 0; x < d[0]; x++) {
190 for (dim_t y = 0; y < d[1]; y++) {
191 auto val = handle.at({x, y});
192 os << valueToChar(val);
193 }
194 os << "\n";
195 }
196 } else if (d.size() == 3) {
197 // Print monochrome (one-color channel) tensors:
198 if (d[2] == 1) {
199 for (dim_t x = 0; x < d[0]; x++) {
200 for (dim_t y = 0; y < d[1]; y++) {
201 auto val = handle.at({x, y, 0});
202 os << valueToChar(val);
203 }
204 os << "\n";
205 }
206 } else {
207 for (dim_t z = 0; z < d[2]; z++) {
208 os << "\n";
209 for (dim_t x = 0; x < d[0]; x++) {
210 for (dim_t y = 0; y < d[1]; y++) {
211 auto val = handle.at({x, y, z});
212 os << valueToChar(val);
213 }
214 os << "\n";
215 }
216 }
217 }
218
219 } else {
220 llvm_unreachable("Invalid tensor size");
221 }
222
223 os.flush();
224}
225
226/// This is a slow generic transpose. This method performs a single for loop
227/// over a single dimension, or if we've reached the last dimension perform a
228/// single copy of a single element.
229template <class ElemTy>
230static void
231transposeGenericImpl(const Handle<ElemTy> &src, Handle<ElemTy> &dest,
232 dim_t *srcCoor, dim_t *destCoor,
233 llvm::ArrayRef<unsigned_t> shuffle, unsigned depth = 0) {
234 if (depth == shuffle.size()) {
235 auto srcIdx = llvm::ArrayRef<dim_t>(srcCoor, depth);
236 auto destIdx = llvm::ArrayRef<dim_t>(destCoor, depth);
237 dest.at(destIdx) = src.at(srcIdx);
238 return;
239 }
240
241 // Iterate over one dimension and continue recursively to the next dim.
242 for (dim_t x = 0, e = dest.dims()[depth]; x < e; x++) {
243 unsigned_t swizzledDepth = shuffle[depth];
244 srcCoor[swizzledDepth] = x;
245 destCoor[depth] = x;
246 transposeGenericImpl(src, dest, srcCoor, destCoor, shuffle, depth + 1);
247 }
248}
249
250/// Faster function for transposing a tensor for important/common tensor
251/// shapes. If a transpose successfully occurs, the function \returns true;
252/// otherwise it \returns false, representing no transpose occurred and some
253/// other transpose function (e.g. transposeGenericImpl) must be called. \p
254/// dest is the tensor to transpose, and \p shuffle defines how to transpose.
255template <class ElemTy>
256static bool tryTransposeFastImpl(const Handle<ElemTy> &src,
257 Handle<ElemTy> &dest,
258 llvm::ArrayRef<unsigned_t> shuffle) {
259 const dim_t numDims = dest.dims().size();
260 dim_t srcCoorArr[max_tensor_dimensions];
261 dim_t destCoorArr[max_tensor_dimensions] = {0};
262 auto srcCoor = llvm::ArrayRef<dim_t>(srcCoorArr, numDims);
263 auto destCoor = llvm::ArrayRef<dim_t>(destCoorArr, numDims);
264
265 /// This defines a single depth of the for loop used to iterate over the
266 /// source and destination tensors for transposing.
267#define TRANSPOSE_LOOP_LEVEL(DEPTH_) \
268 for (srcCoorArr[shuffle[DEPTH_]] = 0, destCoorArr[DEPTH_] = 0; \
269 destCoorArr[DEPTH_] < dest.dims()[DEPTH_]; \
270 srcCoorArr[shuffle[DEPTH_]]++, destCoorArr[DEPTH_]++)
271
272 switch (numDims) {
273 case 2:
274 TRANSPOSE_LOOP_LEVEL(1) {
275 TRANSPOSE_LOOP_LEVEL(0) { dest.at(destCoor) = src.at(srcCoor); }
276 }
277 return true;
278 case 4:
279 TRANSPOSE_LOOP_LEVEL(1) {
280 TRANSPOSE_LOOP_LEVEL(2) {
281 TRANSPOSE_LOOP_LEVEL(0) {
282 TRANSPOSE_LOOP_LEVEL(3) { dest.at(destCoor) = src.at(srcCoor); }
283 }
284 }
285 }
286 return true;
287 }
288 return false;
289}
290
291template <class ElemTy>
292static void transposeSelectImpl(const Handle<ElemTy> &src, Handle<ElemTy> &dest,
293 llvm::ArrayRef<unsigned_t> shuffle) {
294 bool transposeOccurred = tryTransposeFastImpl(src, dest, shuffle);
295 if (!transposeOccurred) {
296 dim_t srcCoor[max_tensor_dimensions];
297 dim_t destCoor[max_tensor_dimensions];
298 transposeGenericImpl(src, dest, srcCoor, destCoor, shuffle);
299 }
300}
301
302template <class ElemTy>
303static bool isTiledImpl(const Tensor *tensor, unsigned_t axis, dim_t size,
304 bool fractional) {
305 assert(axis < tensor->dims().size() && "Axis parameter invalid!");
306 assert(size <= tensor->dims()[axis] && "Size parameter invalid!");
307 assert(size >= 1 && "Size parameter invalid!");
308
309 // When the tile size matches the dimension size then we return true.
310 // This is because a tensor can be considered a tiled version of itself.
311 if (size == tensor->dims()[axis]) {
312 return true;
313 }
314
315 // If fractional tiling verification is disabled and the dimension size
316 // is NOT divisible by the tile size then we return false.
317 if (!fractional && ((tensor->dims()[axis] % size) != 0)) {
318 return false;
319 }
320
321 static_assert(max_tensor_dimensions == 6,
322 "Implementation assumes max_tensor_dimensions = 6.");
323
324 // Get tensor view with maximum number of dimensions.
325 auto dimsMax = expandDimsToMax(tensor->dims());
326 Tensor tensorMax = tensor->getUnowned(dimsMax);
327 auto tensorH = tensorMax.getHandle<ElemTy>();
328 for (dim_t idx0 = 0; idx0 < dimsMax[0]; ++idx0) {
329 for (dim_t idx1 = 0; idx1 < dimsMax[1]; ++idx1) {
330 for (dim_t idx2 = 0; idx2 < dimsMax[2]; ++idx2) {
331 for (dim_t idx3 = 0; idx3 < dimsMax[3]; ++idx3) {
332 for (dim_t idx4 = 0; idx4 < dimsMax[4]; ++idx4) {
333 for (dim_t idx5 = 0; idx5 < dimsMax[5]; ++idx5) {
334 std::vector<dim_t> idx = {idx0, idx1, idx2, idx3, idx4, idx5};
335 std::vector<dim_t> idxWrapped = idx;
336 idxWrapped[axis] = (idx[axis] % size);
337 double delta = tensorH.at(idx) - tensorH.at(idxWrapped);
338 // Since any comparison with NAN returns false, we use a negated
339 // condition so that this function correctly returns false when
340 // delta is NAN.
341 if (!(delta == 0.0)) {
342 return false;
343 }
344 }
345 }
346 }
347 }
348 }
349 }
350 return true;
351}
352
353/// \returns a tensor with UInt8FusedQTy from \p T, whose type should be
354/// UInt4FusedFP16QTy, UInt4FusedQTy, or UInt8FusedFP16QTy.
355template <class scaleOffsetTy = float16_t>
356static Tensor convertToUInt8FusedQTy(const Tensor *T) {
357 const ElemKind origKind = T->getElementType();
358 // Supports UInt4FusedFP16QTy/UInt8FusedFP16QTy/UInt4FusedQTy -> UInt8FusedQTy
359 DCHECK((origKind == ElemKind::UInt4FusedFP16QTy ||
360 origKind == ElemKind::UInt4FusedQTy ||
361 origKind == ElemKind::UInt8FusedFP16QTy) &&
362 T->dims().size() == 2)
363 << "UInt4FusedFP16QTy, UInt4FusedQTy or UInt8FusedFP16QTy must be 2 "
364 "dimensional.";
365 bool is4Bit = (origKind == ElemKind::UInt4FusedFP16QTy ||
366 origKind == ElemKind::UInt4FusedQTy);
367 const dim_t dataCol = T->dims()[1] - 2 * sizeof(scaleOffsetTy);
368 const dim_t numTotalRows = T->dims()[0];
369 const dim_t numTotalColumns = dataCol * (is4Bit ? 2 : 1) + 2 * sizeof(float);
370 Tensor tmp(ElemKind::UInt8FusedQTy, {numTotalRows, numTotalColumns}, 1.0, 0);
371 auto srcH = T->getHandle<uint8_t>();
372 auto dstH = tmp.getHandle<uint8_t>();
373 for (dim_t row = 0; row < T->dims()[0]; row++) {
374 // Copy scale and offset from src to dst.
375 scaleOffsetTy scale, offset;
376 std::tie(scale, offset) =
377 srcH.getFusedScaleOffsetFromRow<scaleOffsetTy>(row);
378 dstH.setFusedScaleOffsetInRow<float>(row, static_cast<float>(scale),
379 static_cast<float>(offset));
380 for (dim_t column = 0; column < dataCol; column++) {
381 if (is4Bit) {
382 auto src = srcH.at({row, column});
383 // Even column in new data uses value from LSB 4-bit from src data.
384 dstH.at({row, column * 2}) = src & 0x0F;
385 // Odd column in new data uses value from MSB 4-bit from dst data.
386 dstH.at({row, column * 2 + 1}) = (src >> 4) & 0x0F;
387 } else {
388 dstH.at({row, column}) = srcH.at({row, column});
389 }
390 }
391 }
392 return tmp;
393}
394
395/// \returns a tensor with UInt4FusedQTy from \p T, whose type should be
396/// UInt4FusedFP16QTy.
397static Tensor convertToUInt4FusedQTy(const Tensor *T) {
398 const ElemKind origKind = T->getElementType();
399 // Supports UInt4FusedFP16QTy -> UInt4FusedQTy.
400 DCHECK(origKind == ElemKind::UInt4FusedFP16QTy && T->dims().size() == 2)
401 << "UInt4FusedFP16QTy must be 2 dimensional.";
402 const dim_t dataCol = T->dims()[1] - 2 * sizeof(float16_t);
403 const dim_t numTotalRows = T->dims()[0];
404 const dim_t numTotalColumns = dataCol + 2 * sizeof(float);
405 Tensor tmp(ElemKind::UInt4FusedQTy, {numTotalRows, numTotalColumns}, 1.0, 0);
406 auto srcH = T->getHandle<uint8_t>();
407 auto dstH = tmp.getHandle<uint8_t>();
408 for (dim_t row = 0; row < T->dims()[0]; row++) {
409 // Copy scale and offset from src to dst.
410 float16_t scale, offset;
411 std::tie(scale, offset) = srcH.getFusedScaleOffsetFromRow<float16_t>(row);
412 dstH.setFusedScaleOffsetInRow<float>(row, static_cast<float>(scale),
413 static_cast<float>(offset));
414 for (dim_t column = 0; column < dataCol; column++) {
415 dstH.at({row, column}) = srcH.at({row, column});
416 }
417 }
418 return tmp;
419}
420} // namespace
421
422void glow::dumpAsciiImpl(const Tensor *T, llvm::raw_ostream &os) {
423 switch (T->getElementType()) {
424 case ElemKind::FloatTy:
425 return dumpAsciiGenericImpl(T->getHandle<float>(), os);
426 case ElemKind::Float16Ty:
427 return dumpAsciiGenericImpl(T->getHandle<float16_t>(), os);
428 case ElemKind::BFloat16Ty:
429 return dumpAsciiGenericImpl(T->getHandle<bfloat16_t>(), os);
430 case ElemKind::Float64Ty:
431 return dumpAsciiGenericImpl(T->getHandle<double>(), os);
432 case ElemKind::Int8QTy:
433 return dumpAsciiGenericImpl(T->getHandle<int8_t>(), os);
434 case ElemKind::UInt8QTy:
435 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
436 case ElemKind::Int16QTy:
437 return dumpAsciiGenericImpl(T->getHandle<int16_t>(), os);
438 case ElemKind::Int32QTy:
439 return dumpAsciiGenericImpl(T->getHandle<int32_t>(), os);
440 case ElemKind::Int64QTy:
441 return dumpAsciiGenericImpl(T->getHandle<int64_t>(), os);
442 case ElemKind::UInt8ITy:
443 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
444 case ElemKind::Int32ITy:
445 return dumpAsciiGenericImpl(T->getHandle<int32_t>(), os);
446 case ElemKind::Int64ITy:
447 return dumpAsciiGenericImpl(T->getHandle<int64_t>(), os);
448 case ElemKind::UInt8FusedQTy:
449 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
450 case ElemKind::UInt8FusedFP16QTy:
451 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
452 case ElemKind::UInt4FusedFP16QTy:
453 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
454 case ElemKind::UInt4FusedQTy:
455 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
456 case ElemKind::BoolTy:
457 return dumpAsciiGenericImpl(T->getHandle<bool>(), os);
458 }
459}
460
461void glow::dumpAsciiImpl(const Tensor *T) { dumpAsciiImpl(T, llvm::outs()); }
462
463void glow::dumpImpl(const Tensor *T, llvm::raw_ostream &os,
464 unsigned maxNumElem) {
465 switch (T->getElementType()) {
466 case ElemKind::FloatTy:
467 return dumpGenericImpl(T->getHandle<float>(), os, maxNumElem);
468 case ElemKind::Float16Ty:
469 return dumpGenericImpl(T->getHandle<float16_t>(), os, maxNumElem);
470 case ElemKind::BFloat16Ty:
471 return dumpGenericImpl(T->getHandle<bfloat16_t>(), os, maxNumElem);
472 case ElemKind::Float64Ty:
473 return dumpGenericImpl(T->getHandle<double>(), os, maxNumElem);
474 case ElemKind::Int8QTy:
475 return dumpGenericImpl(T->getHandle<int8_t>(), os, maxNumElem);
476 case ElemKind::UInt8QTy:
477 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
478 case ElemKind::Int16QTy:
479 return dumpGenericImpl(T->getHandle<int16_t>(), os, maxNumElem);
480 case ElemKind::Int32QTy:
481 return dumpGenericImpl(T->getHandle<int32_t>(), os, maxNumElem);
482 case ElemKind::Int64QTy:
483 return dumpGenericImpl(T->getHandle<int64_t>(), os, maxNumElem);
484 case ElemKind::UInt8ITy:
485 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
486 case ElemKind::Int32ITy:
487 return dumpGenericImpl(T->getHandle<int32_t>(), os, maxNumElem);
488 case ElemKind::Int64ITy:
489 return dumpGenericImpl(T->getHandle<int64_t>(), os, maxNumElem);
490 case ElemKind::UInt8FusedQTy:
491 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
492 case ElemKind::UInt8FusedFP16QTy:
493 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
494 case ElemKind::UInt4FusedFP16QTy:
495 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
496 case ElemKind::UInt4FusedQTy:
497 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
498 case ElemKind::BoolTy:
499 return dumpGenericImpl(T->getHandle<bool>(), os, maxNumElem);
500 }
501}
502
503void glow::dumpImpl(const Tensor *T, unsigned maxNumElem) {
504 dumpImpl(T, llvm::outs(), maxNumElem);
505}
506
507void glow::dumpImpl(const Tensor *T) { dumpImpl(T, llvm::outs()); }
508
509// Dump functions.
510void Tensor::dump(llvm::raw_ostream &os) const { dumpImpl(this, os); }
511
512void Tensor::dump() const { dumpImpl(this, llvm::outs()); }
513
514std::string Tensor::toString() const {
515 std::string storage;
516 llvm::raw_string_ostream os(storage);
517 dumpImpl(this, os);
518 return os.str();
519}
520
521std::string Tensor::getShapeToString() const {
522 std::string storage;
523 llvm::raw_string_ostream os(storage);
524 dumpShape(dims(), os);
525 return os.str();
526}
527
528void Tensor::dump(llvm::raw_ostream &os, unsigned maxNumElem) const {
529 dumpImpl(this, os, maxNumElem);
530}
531
532void Tensor::dump(unsigned maxNumElem) const {
533 dumpImpl(this, llvm::outs(), maxNumElem);
534}
535
536std::string Tensor::toString(unsigned maxNumElem) const {
537 std::string storage;
538 llvm::raw_string_ostream os(storage);
539 dumpImpl(this, os, maxNumElem);
540 return os.str();
541}
542
543/// Dump a textual representation of a specific number of elements in the Tensor
544/// to std::string.
545
546void glow::genericTranspose(const Tensor *src, Tensor *dest,
547 llvm::ArrayRef<unsigned_t> shuffle) {
548 DCHECK(src->dims().size() == shuffle.size())
549 << "Invalid dimensions " << src->dims().size()
550 << " != " << src->dims().size();
551
552 dim_t newSizes[max_tensor_dimensions];
553
554 // Generate the swizzled dimensions.
555 auto origDims = src->dims();
556 for (unsigned i = 0; i < origDims.size(); i++) {
557 newSizes[i] = origDims[shuffle[i]];
558 }
559
560 // Resize the tensor to the transposed shape.
561 auto destType = Type::newShape(src->getType(), {newSizes, origDims.size()});
562 // genericTranspose function doesn't know how to set non-trivial strides and
563 // alignments and it cannot figure out the correct ones as it can be
564 // backend-specific. Therefore set the type to destType only if it is not set
565 // properly by the caller yet.
566 // Reset should be called anyways to allocate memory for the tensor.
567 if (dest->dims() != destType.dims()) {
568 dest->reset(destType);
569 } else {
570 dest->reset(dest->getType());
571 }
572
573 // fill with 0 for padding bytes.
574 if (src->actualSize() != dest->actualSize()) {
575 dest->zero();
576 }
577
578 switch (src->getElementType()) {
579 case ElemKind::FloatTy: {
580 auto srcH = src->getHandle<float>();
581 auto destH = dest->getHandle<float>();
582 transposeSelectImpl(srcH, destH, shuffle);
583 return;
584 }
585 case ElemKind::Float16Ty: {
586 auto srcH = src->getHandle<float16_t>();
587 auto destH = dest->getHandle<float16_t>();
588 transposeSelectImpl(srcH, destH, shuffle);
589 return;
590 }
591 case ElemKind::BFloat16Ty: {
592 auto srcH = src->getHandle<bfloat16_t>();
593 auto destH = dest->getHandle<bfloat16_t>();
594 transposeSelectImpl(srcH, destH, shuffle);
595 return;
596 }
597 case ElemKind::Float64Ty: {
598 auto srcH = src->getHandle<double>();
599 auto destH = dest->getHandle<double>();
600 transposeSelectImpl(srcH, destH, shuffle);
601 return;
602 }
603 case ElemKind::Int8QTy: {
604 auto srcH = src->getHandle<int8_t>();
605 auto destH = dest->getHandle<int8_t>();
606 transposeSelectImpl(srcH, destH, shuffle);
607 return;
608 }
609 case ElemKind::UInt8QTy: {
610 auto srcH = src->getHandle<uint8_t>();
611 auto destH = dest->getHandle<uint8_t>();
612 transposeSelectImpl(srcH, destH, shuffle);
613 return;
614 }
615 case ElemKind::Int16QTy: {
616 auto srcH = src->getHandle<int16_t>();
617 auto destH = dest->getHandle<int16_t>();
618 transposeSelectImpl(srcH, destH, shuffle);
619 return;
620 }
621 case ElemKind::Int32QTy: {
622 auto srcH = src->getHandle<int32_t>();
623 auto destH = dest->getHandle<int32_t>();
624 transposeSelectImpl(srcH, destH, shuffle);
625 return;
626 }
627 case ElemKind::Int64QTy: {
628 auto srcH = src->getHandle<int64_t>();
629 auto destH = dest->getHandle<int64_t>();
630 transposeSelectImpl(srcH, destH, shuffle);
631 return;
632 }
633 case ElemKind::UInt8ITy: {
634 auto srcH = src->getHandle<uint8_t>();
635 auto destH = dest->getHandle<uint8_t>();
636 transposeSelectImpl(srcH, destH, shuffle);
637 return;
638 }
639 case ElemKind::Int32ITy: {
640 auto srcH = src->getHandle<int32_t>();
641 auto destH = dest->getHandle<int32_t>();
642 transposeSelectImpl(srcH, destH, shuffle);
643 return;
644 }
645 case ElemKind::Int64ITy: {
646 auto srcH = src->getHandle<int64_t>();
647 auto destH = dest->getHandle<int64_t>();
648 transposeSelectImpl(srcH, destH, shuffle);
649 return;
650 }
651 case ElemKind::UInt8FusedQTy: {
652 llvm_unreachable("Transposing UInt8FusedQTy is unsupported.");
653 }
654 case ElemKind::UInt8FusedFP16QTy: {
655 llvm_unreachable("Transposing UInt8FusedFP16QTy is unsupported.");
656 }
657 case ElemKind::UInt4FusedFP16QTy: {
658 llvm_unreachable("Transposing UInt4FusedFP16QTy is unsupported.");
659 }
660 case ElemKind::UInt4FusedQTy: {
661 llvm_unreachable("Transposing UInt4FusedQTy is unsupported.");
662 }
663 case ElemKind::BoolTy: {
664 auto srcH = src->getHandle<bool>();
665 auto destH = dest->getHandle<bool>();
666 transposeSelectImpl(srcH, destH, shuffle);
667 return;
668 }
669 }
670}
671
672ShapeVector glow::expandDimsToMax(llvm::ArrayRef<dim_t> currDims) {
673 ShapeVector newDims(currDims.begin(), currDims.end());
674 for (size_t i = newDims.size(); i < max_tensor_dimensions; i++) {
675 newDims.push_back(1);
676 }
677 return newDims;
678}
679
680ShapeVector glow::reduceDims(llvm::ArrayRef<dim_t> dims,
681 llvm::ArrayRef<unsigned_t> axes, bool keepDims) {
682 ShapeVector newDims;
683 for (unsigned_t dim = 0, end = dims.size(); dim < end; ++dim) {
684 auto it = std::find(axes.begin(), axes.end(), dim);
685 bool dimReduced = (it != axes.end());
686 if (dimReduced) {
687 if (keepDims) {
688 newDims.push_back(1);
689 } else {
690 continue;
691 }
692 } else {
693 newDims.push_back(dims[dim]);
694 }
695 }
696 return newDims;
697}
698
699std::vector<unsigned_t>
700glow::getInverseTranspose(llvm::ArrayRef<unsigned_t> shuffle) {
701 std::vector<unsigned_t> unshuffle;
702 // For each index, go find where it ended up in the shuffle
703 for (auto i = 0; i < shuffle.size(); ++i) {
704 for (auto j = 0; j < shuffle.size(); ++j) {
705 if (shuffle[j] == i) {
706 unshuffle.push_back(j);
707 break;
708 }
709 }
710 }
711 return unshuffle;
712}
713
714void Tensor::init(InitKind init, float val, PseudoRNG &PRNG) {
715 assert(!isDeviceResident() && "Tensor must reside on host to access data.");
716 switch (init) {
717 case InitKind::Zero:
718 zero();
719 break;
720
721 case InitKind::Broadcast: {
722 switch (getElementType()) {
723 case ElemKind::FloatTy: {
724 getHandle<float>().clear(val);
725 break;
726 }
727 case ElemKind::Float16Ty: {
728 getHandle<float16_t>().clear(float16_t(val));
729 break;
730 }
731 case ElemKind::BFloat16Ty: {
732 getHandle<bfloat16_t>().clear(bfloat16_t(val));
733 break;
734 }
735 case ElemKind::Float64Ty: {
736 getHandle<double>().clear(val);
737 break;
738 }
739 case ElemKind::Int8QTy: {
740 getHandle<int8_t>().clear(val);
741 break;
742 }
743 case ElemKind::UInt8QTy: {
744 getHandle<uint8_t>().clear(val);
745 break;
746 }
747 case ElemKind::Int16QTy: {
748 getHandle<int16_t>().clear(val);
749 break;
750 }
751 case ElemKind::Int32QTy: {
752 getHandle<int32_t>().clear(val);
753 break;
754 }
755 case ElemKind::Int64QTy: {
756 getHandle<int64_t>().clear(val);
757 break;
758 }
759 case ElemKind::UInt8ITy: {
760 getHandle<uint8_t>().clear(val);
761 break;
762 }
763 case ElemKind::Int32ITy: {
764 getHandle<int32_t>().clear(val);
765 break;
766 }
767 case ElemKind::Int64ITy: {
768 getHandle<int64_t>().clear(val);
769 break;
770 }
771
772#define FUSED_CASE(ELEM_KIND, DATA_TYPE) \
773 case ElemKind::ELEM_KIND: { \
774 DCHECK(dims().size() == 2) \
775 << "Fused tensor must be 2-dimensional but instead has " \
776 << dims().size() << " dimensions."; \
777 DCHECK(dims()[1] > 2 * sizeof(DATA_TYPE)) \
778 << "Fused tensor must have space for scale/offset, but only has " \
779 << dims()[1] << " columns."; \
780 auto H = getHandle<uint8_t>(); \
781 for (dim_t i = 0; i < dims()[0]; i++) { \
782 for (dim_t j = 0, f = dims()[1] - 2 * sizeof(DATA_TYPE); j < f; j++) { \
783 H.at({i, j}) = val; \
784 } \
785 } \
786 break; \
787 }
788 FUSED_CASE(UInt8FusedQTy, float);
789 FUSED_CASE(UInt4FusedQTy, float);
790 FUSED_CASE(UInt8FusedFP16QTy, float16_t);
791 FUSED_CASE(UInt4FusedFP16QTy, float16_t);
792#undef FUSED_CASE
793
794 case ElemKind::BoolTy: {
795 getHandle<bool>().clear(val);
796 break;
797 }
798 }
799 break;
800 }
801
802 case InitKind::Xavier: {
803 switch (getElementType()) {
804 case ElemKind::FloatTy: {
805 getHandle<float>().initXavier(val, PRNG);
806 break;
807 }
808 case ElemKind::Float16Ty: {
809 getHandle<float16_t>().initXavier(val, PRNG);
810 break;
811 }
812 case ElemKind::BFloat16Ty: {
813 getHandle<bfloat16_t>().initXavier(val, PRNG);
814 break;
815 }
816 default: {
817 llvm_unreachable("Undefined to Xavier-initialize non-Float Tensors.");
818 }
819 }
820 break;
821 }
822 }
823}
824
825void Tensor::convertToType(ElemKind newTy) {
826 assert(!isDeviceResident() && "Tensor must reside on host to access data.");
827 *this = this->getCopyConvertedToType(newTy);
828}
829
830Tensor Tensor::getCopyConvertedToType(ElemKind newKind) const {
831 assert(!isDeviceResident() && "Tensor must reside on host to access data.");
832 const ElemKind origKind = getElementType();
833 DCHECK((origKind == ElemKind::FloatTy && newKind == ElemKind::Float16Ty) ||
834 (origKind == ElemKind::FloatTy && newKind == ElemKind::BFloat16Ty) ||
835 (origKind == ElemKind::FloatTy && newKind == ElemKind::Int32ITy) ||
836 (origKind == ElemKind::FloatTy && newKind == ElemKind::Int64ITy) ||
837 (origKind == ElemKind::Float16Ty && newKind == ElemKind::FloatTy) ||
838 (origKind == ElemKind::BFloat16Ty && newKind == ElemKind::FloatTy) ||
839 (origKind == ElemKind::Int64ITy && newKind == ElemKind::Int32ITy) ||
840 (origKind == ElemKind::Int64ITy && newKind == ElemKind::FloatTy) ||
841 (origKind == ElemKind::Int32ITy && newKind == ElemKind::Int64ITy) ||
842 (origKind == ElemKind::Int32ITy && newKind == ElemKind::FloatTy) ||
843 (origKind == ElemKind::UInt8FusedQTy &&
844 newKind == ElemKind::UInt8FusedFP16QTy) ||
845 (origKind == ElemKind::UInt8FusedFP16QTy &&
846 newKind == ElemKind::UInt8FusedQTy) ||
847 (origKind == ElemKind::UInt4FusedFP16QTy &&
848 newKind == ElemKind::UInt8FusedQTy) ||
849 (origKind == ElemKind::UInt4FusedFP16QTy &&
850 newKind == ElemKind::UInt4FusedQTy) ||
851 (origKind == ElemKind::UInt4FusedQTy &&
852 newKind == ElemKind::UInt8FusedQTy))
853 << "Conversion from " << Type::getElementName(origKind).str() << " to "
854 << Type::getElementName(newKind).str() << " is not yet implemented";
855
856 if (!isQuantizedElemKind(newKind)) {
857 Tensor tmp(newKind, dims());
858 switch (newKind) {
859 case ElemKind::Float16Ty:
860 tmp.copyWithCast<float16_t, float>(this);
861 break;
862 case ElemKind::BFloat16Ty:
863 tmp.copyWithCast<bfloat16_t, float>(this);
864 break;
865
866 case ElemKind::FloatTy:
867 if (getElementType() == ElemKind::Int32ITy) {
868 tmp.copyWithCast<float, int32_t>(this);
869 } else if (getElementType() == ElemKind::Int64ITy) {
870 tmp.copyWithCast<float, int64_t>(this);
871 } else if (getElementType() == ElemKind::Float16Ty) {
872 tmp.copyWithCast<float, float16_t>(this);
873 } else if (getElementType() == ElemKind::BFloat16Ty) {
874 tmp.copyWithCast<float, bfloat16_t>(this);
875 } else if (getElementType() == ElemKind::FloatTy) {
876 tmp.copyRawFrom(this);
877 } else {
878 llvm_unreachable("Invalid conversion to FLOAT.");
879 }
880 break;
881
882 case ElemKind::Int32ITy:
883 if (getElementType() == ElemKind::Int64ITy) {
884 tmp.copyWithCast<int32_t, int64_t>(this);
885 } else if (getElementType() == ElemKind::FloatTy) {
886 tmp.copyWithCast<int32_t, float>(this);
887 } else {
888 llvm_unreachable("Invalid conversion from FLOAT.");
889 }
890 break;
891 case ElemKind::Int64ITy:
892 if (getElementType() == ElemKind::Int32ITy) {
893 tmp.copyWithCast<int64_t, int32_t>(this);
894 } else {
895 llvm_unreachable("Invalid conversion from FLOAT.");
896 }
897 break;
898
899 default:
900 llvm_unreachable("Type not supported");
901 }
902 return tmp;
903 }
904
905 // Handle Fused conversion.
906 if ((origKind == ElemKind::UInt8FusedFP16QTy ||
907 origKind == ElemKind::UInt4FusedFP16QTy) &&
908 newKind == ElemKind::UInt8FusedQTy) {
909 return convertToUInt8FusedQTy<float16_t>(this);
910 }
911 if (origKind == ElemKind::UInt4FusedQTy &&
912 newKind == ElemKind::UInt8FusedQTy) {
913 return convertToUInt8FusedQTy<float>(this);
914 }
915 if (origKind == ElemKind::UInt4FusedFP16QTy &&
916 newKind == ElemKind::UInt4FusedQTy) {
917 return convertToUInt4FusedQTy(this);
918 }
919
920 // Supports UInt8FusedQTy -> UInt8FusedFP16QTy.
921 DCHECK(origKind == ElemKind::UInt8FusedQTy && dims().size() == 2)
922 << "UInt8FusedQTy must be 2 dimensional.";
923 Tensor tmp(newKind,
924 {dims()[0], dims()[1] - 2 * ((dim_t)sizeof(float) -
925 (dim_t)sizeof(float16_t))},
926 1.0, 0);
927
928 const size_t dstWidth = tmp.dims()[1];
929 auto srcH = getHandle<uint8_t>();
930 auto dstH = tmp.getHandle<uint8_t>();
931 for (dim_t i = 0, e = dims()[0]; i < e; i++) {
932 // Copy the scale/offset from src to dst.
933 float scale, offset;
934 std::tie(scale, offset) = srcH.getFusedScaleOffsetFromRow<float>(i);
935 dstH.setFusedScaleOffsetInRow<float16_t>(i, static_cast<float16_t>(scale),
936 static_cast<float16_t>(offset));
937
938 // Copy over the row's uint8 data from src to dst; scales and offsets were
939 // already copied over above.
940 for (dim_t j = 0, f = dstWidth - 2 * sizeof(float16_t); j < f; j++) {
941 dstH.at({i, j}) = srcH.at({i, j});
942 }
943 }
944 return tmp;
945}
946
947namespace glow {
948llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Tensor &t) {
949 t.dump(os);
950 return os;
951}
952
953llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Tensor *t) {
954 assert(t != nullptr && "Null Pointer.");
955 t->dump(os);
956 return os;
957}
958
959void Tensor::moveToDevice(DeviceTensorTransferManager *deviceManager,
960 void *locationContext) {
961 if (deviceResidency_ == nullptr) {
962 deviceResidency_ = new DeviceResidencyInfo();
963 }
964 deviceResidency_->deviceManager_ = deviceManager;
965 deviceResidency_->locationContext_ = locationContext;
966 deviceResidency_->tensorResidency_ =
967 DeviceResidencyInfo::TensorResidency::Device;
968}
969
970void Tensor::ensureOnHost() {
971 if (deviceResidency_ == nullptr) {
972 // already on host.
973 return;
974 }
975 if (deviceResidency_->isDeviceResident()) {
976 deviceResidency_->deviceManager_->transferFromDevice(*this);
977 }
978 assert(!isDeviceResident());
979}
980
981void Tensor::copyRawToDevice(const Tensor *t) {
982 assert(isDeviceResident());
983 void *locationContext = deviceResidency_->locationContext_;
984 DeviceTensorTransferManager *DM = deviceResidency_->deviceManager_;
985 clearDeviceResidency();
986 copyRawFrom(t);
987 DM->transferToDevice(*this, locationContext);
988}
989
990bool Tensor::isTiled(unsigned_t axis, dim_t size, bool fractional) const {
991 switch (getElementType()) {
992 case ElemKind::FloatTy: {
993 return isTiledImpl<float>(this, axis, size, fractional);
994 }
995 case ElemKind::Float16Ty: {
996 return isTiledImpl<float16_t>(this, axis, size, fractional);
997 }
998 case ElemKind::Int8QTy: {
999 return isTiledImpl<int8_t>(this, axis, size, fractional);
1000 }
1001 case ElemKind::UInt8QTy: {
1002 return isTiledImpl<uint8_t>(this, axis, size, fractional);
1003 }
1004 case ElemKind::Int16QTy: {
1005 return isTiledImpl<int16_t>(this, axis, size, fractional);
1006 }
1007 case ElemKind::Int32QTy: {
1008 return isTiledImpl<int32_t>(this, axis, size, fractional);
1009 }
1010 case ElemKind::Int32ITy: {
1011 return isTiledImpl<int32_t>(this, axis, size, fractional);
1012 }
1013 case ElemKind::Int64ITy: {
1014 return isTiledImpl<int64_t>(this, axis, size, fractional);
1015 }
1016 case ElemKind::BoolTy: {
1017 return isTiledImpl<bool>(this, axis, size, fractional);
1018 }
1019 default:
1020 llvm_unreachable("isTiled: Precision not supported!");
1021 }
1022}
1023
1024bool Tensor::isTiled(llvm::ArrayRef<unsigned_t> axes,
1025 llvm::ArrayRef<dim_t> sizes, bool fractional) const {
1026 assert(axes.size() == sizes.size() &&
1027 "Mismatch between axes and sizes length!");
1028 for (size_t idx = 0, end = axes.size(); idx < end; ++idx) {
1029 if (!isTiled(axes[idx], sizes[idx], fractional)) {
1030 return false;
1031 }
1032 }
1033 return true;
1034}
1035
1036bool isSliceContiguous(llvm::ArrayRef<dim_t> sliceShape,
1037 llvm::ArrayRef<dim_t> tensorShape) {
1038 assert(sliceShape.size() == tensorShape.size() &&
1039 "Array length mismatch for slice/tensor sizes!");
1040 // Search first non-singleton slice dimension. If all the dimensions are
1041 // singleton then by convention the first non-singleton dimension is the
1042 // slice size.
1043 size_t firstNonSingleDim = sliceShape.size();
1044 for (size_t dim = 0, dimEnd = sliceShape.size(); dim < dimEnd; ++dim) {
1045 if (sliceShape[dim] != 1) {
1046 firstNonSingleDim = dim;
1047 break;
1048 }
1049 }
1050 // First non-singleton slice dimension can be partially or fully extracted.
1051 // The following dimensions must be fully extracted.
1052 for (size_t dim = firstNonSingleDim + 1, dimEnd = sliceShape.size();
1053 dim < dimEnd; ++dim) {
1054 if (sliceShape[dim] != tensorShape[dim]) {
1055 return false;
1056 }
1057 }
1058 return true;
1059}
1060
1061} // namespace glow
1062