1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Backends/Interpreter/Interpreter.h"
18
19#include "glow/Base/TensorSerialization.h"
20#include "glow/Base/Type.h"
21#include "glow/IR/Instrs.h"
22#include "glow/Quantization/Base/Base.h"
23#include "glow/Quantization/Base/Profile.h"
24
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28
29#include <chrono>
30#include <cmath>
31#include <numeric>
32#include <random>
33
34#ifdef WIN32
35#include <corecrt_math_defines.h>
36#endif
37
38using namespace glow;
39
40namespace IntNBitSplitEmbeddingBagsHelper {
41inline int32_t unpaddedRowSizeInBytes(int32_t dim,
42 SplitEmbeddingSparseType weight_ty) {
43 if (weight_ty == SplitEmbeddingSparseType::EST_FLOAT) {
44 return dim * sizeof(float);
45 }
46 if (weight_ty == SplitEmbeddingSparseType::EST_FLOAT16) {
47 return dim * sizeof(float16_t);
48 }
49 if (weight_ty == SplitEmbeddingSparseType::EST_INT8) {
50 return dim + 2 * sizeof(float16_t);
51 }
52 if (weight_ty == SplitEmbeddingSparseType::EST_INT4) {
53 return dim / 2 + 2 * sizeof(float16_t);
54 }
55 if (weight_ty == SplitEmbeddingSparseType::EST_INT2) {
56 return dim / 4 + 2 * sizeof(float16_t);
57 }
58 llvm_unreachable("Unsupported SparseType");
59}
60
61uint32_t roundUp(uint32_t a, uint32_t b) { return ((a + b - 1) / b) * b; }
62
63inline int32_t paddedRowSizeInBytes(int32_t dim,
64 SplitEmbeddingSparseType weight_ty) {
65 auto r = unpaddedRowSizeInBytes(dim, weight_ty);
66 return roundUp(r, 16);
67}
68
69template <typename SumTy>
70SumTy add(SumTy a, const uint8_t *row, const uint8_t *data,
71 SplitEmbeddingSparseType dataTy, bool isMSB, float weight) {
72 float sum = a;
73
74 if (dataTy == SplitEmbeddingSparseType::EST_FLOAT) {
75 return sum + weight * static_cast<float>(
76 *(reinterpret_cast<const float *>(data)));
77 }
78
79 if (dataTy == SplitEmbeddingSparseType::EST_FLOAT16) {
80 return sum + weight * static_cast<float>(
81 *(reinterpret_cast<const float16_t *>(data)));
82 }
83
84 float scale = *(reinterpret_cast<const float16_t *>(row));
85 float offset = *(reinterpret_cast<const float16_t *>(row) + 1);
86
87 if (dataTy == SplitEmbeddingSparseType::EST_INT8) {
88 return sum + weight * (static_cast<float>(*data) * scale + offset);
89 }
90
91 if (dataTy == SplitEmbeddingSparseType::EST_INT4) {
92 if (isMSB) {
93 return sum + weight * (static_cast<float>(*data >> 4) * scale + offset);
94 }
95 return sum + weight * (static_cast<float>(*data & 0xF) * scale + offset);
96 }
97
98 llvm_unreachable("Unsuppored SplitEmbeddingSparseType");
99}
100
101template <typename DataTy>
102void save(DataTy a, uint8_t *data, SplitEmbeddingSparseType dataTy) {
103 if (dataTy == SplitEmbeddingSparseType::EST_FLOAT) {
104 *(reinterpret_cast<float *>(data)) = a;
105 } else if (dataTy == SplitEmbeddingSparseType::EST_FLOAT16) {
106 *(reinterpret_cast<float16_t *>(data)) = a;
107 }
108}
109
110} // namespace IntNBitSplitEmbeddingBagsHelper
111
112#define dispatchImpl(functionName, elemTy, ...) \
113 switch (elemTy) { \
114 case ElemKind::FloatTy: \
115 functionName<float>(__VA_ARGS__); \
116 break; \
117 case ElemKind::Float16Ty: \
118 functionName<float16_t>(__VA_ARGS__); \
119 break; \
120 case ElemKind::BFloat16Ty: \
121 functionName<bfloat16_t>(__VA_ARGS__); \
122 break; \
123 case ElemKind::Int8QTy: \
124 functionName<int8_t>(__VA_ARGS__); \
125 break; \
126 case ElemKind::Int16QTy: \
127 functionName<int16_t>(__VA_ARGS__); \
128 break; \
129 case ElemKind::Int32QTy: \
130 functionName<int32_t>(__VA_ARGS__); \
131 break; \
132 case ElemKind::Int32ITy: \
133 functionName<int32_t>(__VA_ARGS__); \
134 break; \
135 case ElemKind::Int64ITy: \
136 functionName<int64_t>(__VA_ARGS__); \
137 break; \
138 case ElemKind::BoolTy: \
139 functionName<bool>(__VA_ARGS__); \
140 break; \
141 default: \
142 llvm_unreachable("Type is not supported"); \
143 }
144
145#define dispatchFloatingPointImpl(functionName, elemTy, ...) \
146 switch (elemTy) { \
147 case ElemKind::FloatTy: \
148 functionName<float>(__VA_ARGS__); \
149 break; \
150 case ElemKind::Float16Ty: \
151 functionName<float16_t>(__VA_ARGS__); \
152 break; \
153 case ElemKind::BFloat16Ty: \
154 functionName<bfloat16_t>(__VA_ARGS__); \
155 break; \
156 default: \
157 llvm_unreachable("Type is not supported"); \
158 }
159
160#define dispatchFloatingPointAndInt32Impl(functionName, elemTy, ...) \
161 switch (elemTy) { \
162 case ElemKind::FloatTy: \
163 functionName<float>(__VA_ARGS__); \
164 break; \
165 case ElemKind::Float16Ty: \
166 functionName<float16_t>(__VA_ARGS__); \
167 break; \
168 case ElemKind::BFloat16Ty: \
169 functionName<bfloat16_t>(__VA_ARGS__); \
170 break; \
171 case ElemKind::Int32ITy: \
172 functionName<int>(__VA_ARGS__); \
173 break; \
174 default: \
175 llvm_unreachable("Type is not supported"); \
176 }
177
178#define dispatchFloatingPointAndIndexImpl(functionName, elemTy, elemTyIndex, \
179 ...) \
180 switch (elemTy) { \
181 case ElemKind::FloatTy: \
182 if (elemTyIndex == ElemKind::Int64ITy) { \
183 functionName<float, int64_t>(__VA_ARGS__); \
184 } else if (elemTyIndex == ElemKind::Int32ITy) { \
185 functionName<float, int32_t>(__VA_ARGS__); \
186 } \
187 break; \
188 case ElemKind::Float16Ty: \
189 if (elemTyIndex == ElemKind::Int64ITy) { \
190 functionName<float16, int64_t>(__VA_ARGS__); \
191 } else if (elemTyIndex == ElemKind::Int32ITy) { \
192 functionName<float16, int32_t>(__VA_ARGS__); \
193 } \
194 break; \
195 case ElemKind::BFloat16Ty: \
196 if (elemTyIndex == ElemKind::Int64ITy) { \
197 functionName<bfloat16, int64_t>(__VA_ARGS__); \
198 } else if (elemTyIndex == ElemKind::Int32ITy) { \
199 functionName<bfloat16, int32_t>(__VA_ARGS__); \
200 } \
201 break; \
202 default: \
203 llvm_unreachable("Type is not supported"); \
204 }
205
206#define dispatchIndexTypeImpl(functionName, elemTy, ...) \
207 switch (elemTy) { \
208 case ElemKind::Int32ITy: \
209 functionName<int32_t>(__VA_ARGS__); \
210 break; \
211 case ElemKind::Int64ITy: \
212 functionName<int64_t>(__VA_ARGS__); \
213 break; \
214 default: \
215 llvm_unreachable("Type is not supported"); \
216 }
217
218#define dispatchIndexAndOutputTypeImpl(functionName, indexTy, outputTy, ...) \
219 switch (indexTy) { \
220 case ElemKind::Int32ITy: \
221 if (outputTy == SplitEmbeddingSparseType::EST_FLOAT) { \
222 functionName<int32_t, float>(__VA_ARGS__); \
223 } else if (outputTy == SplitEmbeddingSparseType::EST_FLOAT16) { \
224 functionName<int32_t, float16>(__VA_ARGS__); \
225 } \
226 break; \
227 case ElemKind::Int64ITy: \
228 if (outputTy == SplitEmbeddingSparseType::EST_FLOAT) { \
229 functionName<int64_t, float>(__VA_ARGS__); \
230 } else if (outputTy == SplitEmbeddingSparseType::EST_FLOAT16) { \
231 functionName<int64_t, float16>(__VA_ARGS__); \
232 } \
233 break; \
234 default: \
235 llvm_unreachable("Type is not supported"); \
236 }
237
238#define dispatchArithmeticImpl(functionName, elemTy, ...) \
239 switch (elemTy) { \
240 case ElemKind::FloatTy: \
241 functionName<float>(__VA_ARGS__); \
242 break; \
243 case ElemKind::Float16Ty: \
244 functionName<float16_t>(__VA_ARGS__); \
245 break; \
246 case ElemKind::BFloat16Ty: \
247 functionName<bfloat16_t>(__VA_ARGS__); \
248 break; \
249 case ElemKind::Int32ITy: \
250 functionName<int32_t>(__VA_ARGS__); \
251 break; \
252 case ElemKind::Int64ITy: \
253 functionName<int64_t>(__VA_ARGS__); \
254 break; \
255 default: \
256 llvm_unreachable("Type is not supported"); \
257 }
258
259#define dispatchBitwiseImpl(functionName, elemTy, ...) \
260 switch (elemTy) { \
261 case ElemKind::Int32ITy: \
262 functionName<int32_t>(__VA_ARGS__); \
263 break; \
264 case ElemKind::Int64ITy: \
265 functionName<int64_t>(__VA_ARGS__); \
266 break; \
267 case ElemKind::BoolTy: \
268 functionName<bool>(__VA_ARGS__); \
269 break; \
270 default: \
271 llvm_unreachable("Type is not supported"); \
272 }
273
274#define dispatchQuantizedImpl(functionName, elemTy, ...) \
275 switch (elemTy) { \
276 case ElemKind::Int8QTy: \
277 functionName<int8_t>(__VA_ARGS__); \
278 break; \
279 case ElemKind::Int16QTy: \
280 functionName<int16_t>(__VA_ARGS__); \
281 break; \
282 case ElemKind::Int32QTy: \
283 functionName<int32_t>(__VA_ARGS__); \
284 break; \
285 default: \
286 llvm_unreachable("Type is not supported"); \
287 }
288
289#define dispatchQuantizedWithAccumulationImpl(functionName, elemTy, ...) \
290 switch (elemTy) { \
291 case ElemKind::Int8QTy: \
292 functionName<int8_t, int32_t>(__VA_ARGS__); \
293 break; \
294 case ElemKind::Int16QTy: \
295 functionName<int16_t, int64_t>(__VA_ARGS__); \
296 break; \
297 default: \
298 llvm_unreachable("Type is not supported"); \
299 }
300
301#define dispatchQuantizedWithAccumulationAndBiasImpl(functionName, elemTy, \
302 biasElemType, ...) \
303 if (elemTy == ElemKind::Int8QTy && biasElemType == ElemKind::Int8QTy) { \
304 functionName<int8_t, int32_t, int8_t>(__VA_ARGS__); \
305 } else if (elemTy == ElemKind::Int8QTy && \
306 biasElemType == ElemKind::Int32QTy) { \
307 functionName<int8_t, int32_t, int32_t>(__VA_ARGS__); \
308 } else if (elemTy == ElemKind::Int16QTy && \
309 biasElemType == ElemKind::Int16QTy) { \
310 functionName<int16_t, int64_t, int16_t>(__VA_ARGS__); \
311 } else if (elemTy == ElemKind::Int16QTy && \
312 biasElemType == ElemKind::Int32QTy) { \
313 functionName<int16_t, int64_t, int32_t>(__VA_ARGS__); \
314 } else { \
315 llvm_unreachable("Type is not supported"); \
316 }
317
318#define staticAssertFloatingPointType(ElemTy) \
319 static_assert( \
320 std::is_floating_point<ElemTy>::value || \
321 std::is_same<float16_t, \
322 typename std::remove_cv<ElemTy>::type>::value || \
323 std::is_same<bfloat16_t, \
324 typename std::remove_cv<ElemTy>::type>::value, \
325 "This implementation is for floating-point values only")
326
327#define staticAssertArithmeticType(ElemTy) \
328 static_assert( \
329 std::is_arithmetic<ElemTy>::value || \
330 std::is_same<float16_t, \
331 typename std::remove_cv<ElemTy>::type>::value || \
332 std::is_same<bfloat16_t, \
333 typename std::remove_cv<ElemTy>::type>::value, \
334 "This implementation is for arithmetic values only")
335
336#ifndef MIN
337#define MIN(a, b) (((a) < (b)) ? (a) : (b))
338#endif
339
340#ifndef MAX
341#define MAX(a, b) (((a) > (b)) ? (a) : (b))
342#endif
343
344//===----------------------------------------------------------------------===//
345// Convolution
346//===----------------------------------------------------------------------===//
347
348/// This is the floating point implementation of Convolution.
349template <typename ElemTy>
350void BoundInterpreterFunction::fwdConvolutionInstFloatImpl(
351 Value *inV, Value *outV, Value *filterV, Value *biasV,
352 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
353 llvm::ArrayRef<unsigned_t> pads, size_t group,
354 llvm::ArrayRef<unsigned_t> dilation) {
355 staticAssertFloatingPointType(ElemTy);
356
357 auto inW = getWeightHandle<ElemTy>(inV);
358 auto outW = getWeightHandle<ElemTy>(outV);
359 auto filterW = getWeightHandle<ElemTy>(filterV);
360 auto biasW = getWeightHandle<ElemTy>(biasV);
361
362 ShapeNHWC odim(outW.dims());
363 ShapeNHWC idim(inW.dims());
364 ShapeHW kdim(kernelSizes);
365 ShapeHW sdim(strides);
366
367 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
368 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
369 dim_t inCperG = idim.c / group;
370 dim_t outCperG = odim.c / group;
371
372 PaddingTLBR pdim(pads);
373
374 // For each input in the batch:
375 for (dim_t n = 0; n < idim.n; n++) {
376
377 // For each group of input channels:
378 for (dim_t g = 0; g < group; g++) {
379
380 // For each output channel in the group:
381 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
382
383 // For each convolution 'jump' in the input tensor:
384 ssize_t x = -ssize_t(pdim.top);
385 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
386 ssize_t y = -ssize_t(pdim.left);
387 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
388
389 // For each element in the convolution-filter:
390 float sum = 0;
391 for (dim_t fx = 0; fx < kdim.height; fx++) {
392 for (dim_t fy = 0; fy < kdim.width; fy++) {
393 sdim_t ox = x + fx * dilation[0];
394 sdim_t oy = y + fy * dilation[1];
395
396 // Ignore index access below zero (this is due to padding).
397 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
398 oy >= ssize_t(idim.w)) {
399 continue;
400 }
401 for (dim_t fd = 0; fd < inCperG; fd++) {
402 sum += float(
403 filterW.at({d, fx, fy, fd}) *
404 inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}));
405 }
406 }
407 }
408
409 sum += float(biasW.at({d}));
410 outW.at({n, ax, ay, d}) = ElemTy(sum);
411 } // W
412 } // H
413 } // C
414 } // G
415 } // N
416}
417
418/// This is the quantized implementation of Convolution.
419template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
420void BoundInterpreterFunction::fwdConvolutionInstQuantizedImpl(
421 Value *inV, Value *outV, Value *filterV, Value *biasV,
422 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
423 llvm::ArrayRef<unsigned_t> pads, size_t group,
424 llvm::ArrayRef<unsigned_t> dilation) {
425 auto inW = getWeightHandle<ElemTy>(inV);
426 auto outW = getWeightHandle<ElemTy>(outV);
427 auto filterW = getWeightHandle<ElemTy>(filterV);
428 auto biasW = getWeightHandle<BiasElemTy>(biasV);
429
430 ShapeNHWC odim(outW.dims());
431 ShapeNHWC idim(inW.dims());
432 ShapeHW kdim(kernelSizes);
433 ShapeHW sdim(strides);
434
435 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
436 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
437 dim_t inCperG = idim.c / group;
438 dim_t outCperG = odim.c / group;
439
440 PaddingTLBR pdim(pads);
441 auto outTy = outV->getType();
442 auto inTy = inV->getType();
443 auto filterTy = filterV->getType();
444 auto biasTy = biasV->getType();
445
446 int32_t outOffset = outTy->getOffset();
447 int32_t inOffset = inTy->getOffset();
448 int32_t filterOffset = filterTy->getOffset();
449 int32_t biasOffset = biasTy->getOffset();
450
451 float outScale = outTy->getScale();
452 float inScale = inTy->getScale();
453 float filterScale = filterTy->getScale();
454 float biasScale = biasTy->getScale();
455
456 // Calculate the scale of the values that come out of the matrix
457 // multiplication part of the calculation.
458 float matMulScale = inScale * filterScale;
459
460 // For each input in the batch:
461 for (dim_t n = 0; n < idim.n; n++) {
462 // For each group of input channels:
463 for (dim_t g = 0; g < group; g++) {
464
465 // For each output channel in the group:
466 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
467
468 // For each convolution 'jump' in the input tensor:
469 ssize_t x = -ssize_t(pdim.top);
470 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
471 ssize_t y = -ssize_t(pdim.left);
472 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
473
474 // For each element in the convolution-filter:
475 AccumulatorTy sum = 0;
476 for (dim_t fx = 0; fx < kdim.height; fx++) {
477 for (dim_t fy = 0; fy < kdim.width; fy++) {
478 sdim_t ox = x + fx * dilation[0];
479 sdim_t oy = y + fy * dilation[1];
480
481 // Ignore index access below zero (this is due to padding).
482 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
483 oy >= sdim_t(idim.w)) {
484 continue;
485 }
486 for (dim_t fd = 0; fd < inCperG; fd++) {
487
488 AccumulatorTy F = filterW.at({d, fx, fy, fd});
489 AccumulatorTy I =
490 inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd});
491 // We represent the element multiplication with offset as
492 // (value - offset).
493 sum += (F - filterOffset) * (I - inOffset);
494 }
495 }
496 }
497
498 // Scale the bias to match the scale of the matrix multiplication.
499 AccumulatorTy B = std::round(float(biasW.at({d}) - biasOffset) *
500 (biasScale / matMulScale));
501
502 // Add the bias.
503 sum += B;
504
505 // Scale the result back to the expected destination scale.
506 outW.at({n, ax, ay, d}) = quantization::clip<AccumulatorTy, ElemTy>(
507 std::round(float(sum) * (matMulScale / outScale) + outOffset));
508 } // W
509 } // H
510 } // C
511 } // G
512 } // N
513}
514
515/// This is the floating point implementation of ConvTranspose.
516template <typename ElemTy>
517void BoundInterpreterFunction::fwdConvTransposeInstFloatImpl(
518 Value *inV, Value *outV, Value *filterV, Value *biasV,
519 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
520 llvm::ArrayRef<unsigned_t> pads, size_t group,
521 llvm::ArrayRef<unsigned_t> dilation) {
522 staticAssertFloatingPointType(ElemTy);
523
524 auto inW = getWeightHandle<ElemTy>(inV);
525 auto outW = getWeightHandle<ElemTy>(outV);
526 auto filterW = getWeightHandle<ElemTy>(filterV);
527 auto biasW = getWeightHandle<ElemTy>(biasV);
528
529 ShapeNHWC odim(outW.dims());
530 ShapeNHWC idim(inW.dims());
531 ShapeHW kdim(kernelSizes);
532 ShapeHW sdim(strides);
533
534 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
535 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
536
537 dim_t inCperG = idim.c / group;
538 dim_t outCperG = odim.c / group;
539
540 PaddingTLBR pdim(pads);
541
542 // For each input in the batch:
543 for (dim_t n = 0; n < idim.n; n++) {
544
545 // Initialize bias (TODO take out to a separate function when quant is in).
546 for (dim_t ax = 0; ax < odim.h; ax++) {
547 for (dim_t ay = 0; ay < odim.w; ay++) {
548 for (dim_t d = 0; d < odim.c; d++) {
549 outW.at({n, ax, ay, d}) = static_cast<ElemTy>(biasW.at({d}));
550 }
551 }
552 }
553
554 // For each group of input channels:
555 for (dim_t g = 0; g < group; g++) {
556
557 // For each input channel in the group:
558 for (dim_t d = g * inCperG; d < (g + 1) * inCperG; d++) {
559
560 // For each transposed convolution 'jump' in the input tensor:
561 ssize_t x = -ssize_t(pdim.top);
562 for (dim_t bx = 0; bx < idim.h; bx++, x += sdim.height) {
563 ssize_t y = -ssize_t(pdim.left);
564 for (dim_t by = 0; by < idim.w; by++, y += sdim.width) {
565
566 // For each element in the each transposed convolution filter:
567 ElemTy input = inW.at({n, bx, by, d});
568
569 for (dim_t kx = 0; kx < kdim.height; kx++) {
570 for (dim_t ky = 0; ky < kdim.width; ky++) {
571 ssize_t ax = x + kx * dilation[0];
572 ssize_t ay = y + ky * dilation[1];
573
574 // Ignore index access below zero (this is due to padding).
575 if (ax < 0 || ay < 0 || ax >= ssize_t(odim.h) ||
576 ay >= ssize_t(odim.w)) {
577 continue;
578 }
579 for (dim_t c = 0; c < outCperG; c++) {
580 outW.at({n, (dim_t)ax, (dim_t)ay, g * outCperG + c}) +=
581 filterW.at({c, kx, ky, d}) * input;
582 }
583 }
584 }
585 } // W
586 } // H
587 } // C
588 } // G
589 } // N
590}
591
592void BoundInterpreterFunction::fwdConvTransposeInst(
593 const ConvTransposeInst *I) {
594 auto kernelSizes = I->getKernels();
595 auto pads = I->getPads();
596 auto strides = I->getStrides();
597 size_t group = I->getGroup();
598
599 if (I->getSrc()->getType()->isQuantizedType()) {
600 llvm_unreachable("Quantized ConvTranspose not supported");
601 return;
602 }
603
604 dispatchFloatingPointImpl(
605 fwdConvTransposeInstFloatImpl, I->getSrc()->getElementType(), I->getSrc(),
606 I->getDest(), I->getFilter(), I->getBias(), kernelSizes, strides, pads,
607 group, I->getDilation());
608}
609
610void BoundInterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) {
611 auto kernelSizes = I->getKernels();
612 auto pads = I->getPads();
613 auto strides = I->getStrides();
614 size_t group = I->getGroup();
615
616 if (I->getSrc()->getType()->isQuantizedType()) {
617 dispatchQuantizedWithAccumulationAndBiasImpl(
618 fwdConvolutionInstQuantizedImpl, I->getSrc()->getElementType(),
619 I->getBias()->getElementType(), I->getSrc(), I->getDest(),
620 I->getFilter(), I->getBias(), kernelSizes, strides, pads, group,
621 I->getDilation());
622 return;
623 }
624
625 dispatchFloatingPointImpl(
626 fwdConvolutionInstFloatImpl, I->getSrc()->getElementType(), I->getSrc(),
627 I->getDest(), I->getFilter(), I->getBias(), kernelSizes, strides, pads,
628 group, I->getDilation());
629}
630
631void BoundInterpreterFunction::fwdConcatInst(const ConcatInst *I) {
632 (void)I;
633 // TODO
634 llvm_unreachable("not yet implemented");
635}
636
637void BoundInterpreterFunction::fwdConvolutionGradInst(
638 const ConvolutionGradInst *I) {
639 auto inW = getWeightHandle(I->getSrc());
640 auto inG = getWeightHandle(I->getSrcGrad());
641 auto outG = getWeightHandle(I->getDestGrad());
642
643 auto filterW = getWeightHandle(I->getFilter());
644 auto filterG = getWeightHandle(I->getFilterGrad());
645 auto biasG = getWeightHandle(I->getBiasGrad());
646
647 size_t group = I->getGroup();
648 auto dilation = I->getDilation();
649
650 inG.clear();
651 filterG.clear();
652 biasG.clear();
653
654 ShapeNHWC odim(outG.dims());
655 ShapeNHWC idim(inW.dims());
656 ShapeHW kdim(I->getKernels());
657 ShapeHW sdim(I->getStrides());
658 PaddingTLBR pdim(I->getPads());
659
660 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
661 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
662 dim_t inCperG = idim.c / group;
663 dim_t outCperG = odim.c / group;
664
665 // For each input in the batch:
666 for (dim_t n = 0; n < odim.n; n++) {
667
668 // For each group of input channels:
669 for (dim_t g = 0; g < group; g++) {
670
671 // Compute the gradient. For each layer in the output tensor:
672 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
673
674 // For each convolution 'jump' in the input tensor:
675 sdim_t x = -sdim_t(pdim.top);
676 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
677 sdim_t y = -sdim_t(pdim.left);
678 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
679
680 float chainGrad = outG.at({n, ax, ay, d});
681
682 // For each element in the convolution-filter:
683 for (dim_t fx = 0; fx < kdim.height; fx++) {
684 for (dim_t fy = 0; fy < kdim.width; fy++) {
685 sdim_t ox = x + fx * dilation[0];
686 sdim_t oy = y + fy * dilation[1];
687
688 // Ignore index access below zero (this is due to padding).
689 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
690 oy >= sdim_t(idim.w)) {
691 continue;
692 }
693
694 for (dim_t fd = 0; fd < inCperG; fd++) {
695 filterG.at({d, fx, fy, fd}) +=
696 inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}) *
697 chainGrad;
698 inG.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}) +=
699 filterW.at({d, fx, fy, fd}) * chainGrad;
700 }
701 }
702 }
703
704 biasG.at({d}) += chainGrad;
705 } // W
706 } // H
707 } // C
708 } // G
709 } // N
710}
711
712/// This is the floating point implementation of Convolution3D.
713template <typename ElemTy>
714void BoundInterpreterFunction::fwdConvolution3DInstFloatImpl(
715 Value *inV, Value *outV, Value *filterV, Value *biasV,
716 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
717 llvm::ArrayRef<unsigned_t> pads, size_t group) {
718 staticAssertFloatingPointType(ElemTy);
719
720 auto inW = getWeightHandle<ElemTy>(inV);
721 auto outW = getWeightHandle<ElemTy>(outV);
722 auto filterW = getWeightHandle<ElemTy>(filterV);
723 auto biasW = getWeightHandle<ElemTy>(biasV);
724
725 ShapeNTHWC odim(outW.dims());
726 ShapeNTHWC idim(inW.dims());
727 ShapeTHW kdim(kernelSizes);
728 ShapeTHW sdim(strides);
729
730 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
731 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
732 dim_t inCperG = idim.c / group;
733 dim_t outCperG = odim.c / group;
734
735 PaddingNFTBLR pdim(pads);
736
737 // For each input in the batch:
738 for (dim_t n = 0; n < idim.n; n++) {
739
740 // For each group of input channels:
741 for (dim_t ig = 0; ig < group; ig++) {
742
743 // For each output channel in the group:
744 for (dim_t og = ig * outCperG; og < (ig + 1) * outCperG; og++) {
745
746 ssize_t t = -ssize_t(pdim.near);
747 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
748 // For each convolution 'jump' in the input tensor:
749 ssize_t x = -ssize_t(pdim.top);
750 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
751 ssize_t y = -ssize_t(pdim.left);
752 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
753 // For each element in the 3D convolution-filter:
754 float sum = 0;
755 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
756 for (dim_t fx = 0; fx < kdim.height; fx++) {
757 for (dim_t fy = 0; fy < kdim.width; fy++) {
758 sdim_t ot = t + ft;
759 sdim_t ox = x + fx;
760 sdim_t oy = y + fy;
761
762 // Ignore index access below zero (this is due to padding).
763 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
764 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
765 continue;
766 }
767 for (dim_t fg = 0; fg < inCperG; fg++) {
768 sum += float(filterW.at({og, ft, fx, fy, fg}) *
769 inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy,
770 ig * inCperG + fg}));
771 }
772 }
773 }
774 }
775
776 sum += float(biasW.at({og}));
777 outW.at({n, at, ax, ay, og}) = ElemTy(sum);
778 } // D
779 } // W
780 } // H
781 } // C
782 } // G
783 } // N
784}
785
786/// This is the quantized implementation of Convolution3D.
787template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
788void BoundInterpreterFunction::fwdConvolution3DInstQuantizedImpl(
789 Value *inV, Value *outV, Value *filterV, Value *biasV,
790 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
791 llvm::ArrayRef<unsigned_t> pads, size_t group) {
792 auto inW = getWeightHandle<ElemTy>(inV);
793 auto outW = getWeightHandle<ElemTy>(outV);
794 auto filterW = getWeightHandle<ElemTy>(filterV);
795 auto biasW = getWeightHandle<BiasElemTy>(biasV);
796
797 ShapeNTHWC odim(outW.dims());
798 ShapeNTHWC idim(inW.dims());
799 ShapeTHW kdim(kernelSizes);
800 ShapeTHW sdim(strides);
801
802 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
803 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
804 dim_t inCperG = idim.c / group;
805 dim_t outCperG = odim.c / group;
806
807 PaddingNFTBLR pdim(pads);
808
809 auto outTy = outV->getType();
810 auto inTy = inV->getType();
811 auto filterTy = filterV->getType();
812 auto biasTy = biasV->getType();
813
814 int32_t outOffset = outTy->getOffset();
815 int32_t inOffset = inTy->getOffset();
816 int32_t filterOffset = filterTy->getOffset();
817 int32_t biasOffset = biasTy->getOffset();
818
819 float outScale = outTy->getScale();
820 float inScale = inTy->getScale();
821 float filterScale = filterTy->getScale();
822 float biasScale = biasTy->getScale();
823
824 // Calculate the scale of the values that come out of the matrix
825 // multiplication part of the calculation.
826 float matMulScale = inScale * filterScale;
827
828 // For each input in the batch:
829 for (dim_t n = 0; n < idim.n; n++) {
830
831 // For each group of input channels:
832 for (dim_t ig = 0; ig < group; ig++) {
833
834 // For each output channel in the group:
835 for (dim_t og = ig * outCperG; og < (ig + 1) * outCperG; og++) {
836
837 // For each convolution 'jump' in the input tensor:
838 ssize_t t = -ssize_t(pdim.near);
839 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
840 ssize_t x = -ssize_t(pdim.top);
841 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
842 ssize_t y = -ssize_t(pdim.left);
843 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
844
845 // For each element in the convolution-filter:
846 AccumulatorTy sum = 0;
847 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
848 for (dim_t fx = 0; fx < kdim.height; fx++) {
849 for (dim_t fy = 0; fy < kdim.width; fy++) {
850 ssize_t ot = t + ft;
851 ssize_t ox = x + fx;
852 ssize_t oy = y + fy;
853
854 // Ignore index access below zero (this is due to padding).
855 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
856 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
857 continue;
858 }
859 for (dim_t fg = 0; fg < inCperG; fg++) {
860
861 AccumulatorTy F = filterW.at({og, ft, fx, fy, fg});
862 AccumulatorTy I = inW.at({n, (dim_t)ot, (dim_t)ox,
863 (dim_t)oy, ig * inCperG + fg});
864 // We represent the element multiplication with offset as
865 // (value - offset).
866 sum += (F - filterOffset) * (I - inOffset);
867 }
868 }
869 }
870 }
871
872 // Scale the bias to match the scale of the matrix multiplication.
873 AccumulatorTy B = std::round(float(biasW.at({og}) - biasOffset) *
874 (biasScale / matMulScale));
875
876 // Add the bias:
877 sum += B;
878
879 // Scale the result back to the expected destination scale.
880 outW.at({n, at, ax, ay, og}) =
881 quantization::clip<AccumulatorTy, ElemTy>(std::round(
882 float(sum) * (matMulScale / outScale) + outOffset));
883 } // D
884 } // W
885 } // H
886 } // C
887 } // G
888 } // N
889}
890
891void BoundInterpreterFunction::fwdConvolution3DInst(
892 const Convolution3DInst *I) {
893 auto kernelSizes = I->getKernels();
894 auto pads = I->getPads();
895 auto strides = I->getStrides();
896 size_t group = I->getGroup();
897
898 if (I->getSrc()->getType()->isQuantizedType()) {
899 dispatchQuantizedWithAccumulationAndBiasImpl(
900 fwdConvolution3DInstQuantizedImpl, I->getSrc()->getElementType(),
901 I->getBias()->getElementType(), I->getSrc(), I->getDest(),
902 I->getFilter(), I->getBias(), kernelSizes, strides, pads, group);
903 return;
904 }
905
906 dispatchFloatingPointImpl(fwdConvolution3DInstFloatImpl,
907 I->getSrc()->getElementType(), I->getSrc(),
908 I->getDest(), I->getFilter(), I->getBias(),
909 kernelSizes, strides, pads, group);
910}
911
912void BoundInterpreterFunction::fwdConvolution3DGradInst(
913 const Convolution3DGradInst *I) {
914 (void)I;
915 // TODO
916 llvm_unreachable("not yet implemented");
917}
918
919//===----------------------------------------------------------------------===//
920// Channelwise quantized Convolution
921//===----------------------------------------------------------------------===//
922template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
923void BoundInterpreterFunction::fwdChannelwiseQuantizedConv2DInstImpl(
924 const ChannelwiseQuantizedConvolutionInst *I) {
925 auto inW = getWeightHandle<ElemTy>(I->getSrc());
926 auto outW = getWeightHandle<ElemTy>(I->getDest());
927 auto filterW = getWeightHandle<ElemTy>(I->getFilter());
928 auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
929 auto filterScales = getWeightHandle<float>(I->getFilterScales());
930 auto filterOffsets = getWeightHandle<int32_t>(I->getFilterOffsets());
931 auto biasScales = getWeightHandle<float>(I->getBiasScales());
932 auto biasOffsets = getWeightHandle<int32_t>(I->getBiasOffsets());
933
934 llvm::ArrayRef<unsigned_t> kernelSizes = I->getKernels();
935 llvm::ArrayRef<unsigned_t> pads = I->getPads();
936 llvm::ArrayRef<unsigned_t> strides = I->getStrides();
937 dim_t group = I->getGroup();
938 llvm::ArrayRef<unsigned_t> dilation = I->getDilation();
939
940 ShapeNHWC odim(outW.dims());
941 ShapeNHWC idim(inW.dims());
942 ShapeHW kdim(kernelSizes);
943 ShapeHW sdim(strides);
944
945 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
946 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
947 dim_t inCperG = idim.c / group;
948 dim_t outCperG = odim.c / group;
949
950 PaddingTLBR pdim(pads);
951
952 auto &inTy = inW.getType();
953 auto &outTy = outW.getType();
954
955 float inScale = inTy.getScale();
956 float outScale = outTy.getScale();
957
958 int32_t inOffset = inTy.getOffset();
959 int32_t outOffset = outTy.getOffset();
960
961 // For each input in the batch:
962 for (dim_t n = 0; n < idim.n; n++) {
963 // For each group of input channels:
964 for (dim_t g = 0; g < group; g++) {
965 // For each output channel in the group:
966 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
967
968 // Get channel wise quantization params.
969 int32_t filterOffset = filterOffsets.at(d);
970 float filterScale = filterScales.at(d);
971 int32_t biasOffset = biasOffsets.at(d);
972 float biasScale = biasScales.at(d);
973 float matMulScale = inScale * filterScale;
974
975 // For each convolution 'jump' in the input tensor:
976 sdim_t x = -sdim_t(pdim.top);
977 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
978 sdim_t y = -sdim_t(pdim.left);
979 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
980
981 // For each element in the convolution-filter:
982 AccumulatorTy sum = 0;
983 for (dim_t fx = 0; fx < kdim.height; fx++) {
984 for (dim_t fy = 0; fy < kdim.width; fy++) {
985 sdim_t ox = x + fx * dilation[0];
986 sdim_t oy = y + fy * dilation[1];
987
988 // Ignore index access below zero (this is due to padding).
989 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
990 oy >= sdim_t(idim.w)) {
991 continue;
992 }
993
994 // Accumulate along the filter depth.
995 for (dim_t fd = 0; fd < inCperG; fd++) {
996 AccumulatorTy F = filterW.at({d, fx, fy, fd});
997 AccumulatorTy I =
998 inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd});
999 // We represent the element multiplication with offset as
1000 // (value - offset).
1001 sum += (F - filterOffset) * (I - inOffset);
1002 }
1003 }
1004 }
1005
1006 // Scale the bias to match the scale of the matrix multiplication.
1007 sum += std::round(float(biasW.at({d}) - biasOffset) *
1008 (biasScale / matMulScale));
1009
1010 // Scale the result back to the expected destination scale.
1011 outW.at({n, ax, ay, d}) = quantization::clip<AccumulatorTy, ElemTy>(
1012 std::round(float(sum) * (matMulScale / outScale) + outOffset));
1013 } // W
1014 } // H
1015 } // C
1016 } // G
1017 } // N
1018}
1019
1020template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
1021void BoundInterpreterFunction::fwdChannelwiseQuantizedConv3DInstImpl(
1022 const ChannelwiseQuantizedConvolutionInst *I) {
1023 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1024 auto outW = getWeightHandle<ElemTy>(I->getDest());
1025 auto filterW = getWeightHandle<ElemTy>(I->getFilter());
1026 auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
1027 auto filterScales = getWeightHandle<float>(I->getFilterScales());
1028 auto filterOffsets = getWeightHandle<int32_t>(I->getFilterOffsets());
1029 auto biasScales = getWeightHandle<float>(I->getBiasScales());
1030 auto biasOffsets = getWeightHandle<int32_t>(I->getBiasOffsets());
1031
1032 llvm::ArrayRef<unsigned_t> kernelSizes = I->getKernels();
1033 llvm::ArrayRef<unsigned_t> pads = I->getPads();
1034 llvm::ArrayRef<unsigned_t> strides = I->getStrides();
1035 dim_t group = I->getGroup();
1036
1037 ShapeNTHWC odim(outW.dims());
1038 ShapeNTHWC idim(inW.dims());
1039 ShapeTHW kdim(kernelSizes);
1040 ShapeTHW sdim(strides);
1041
1042 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
1043 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
1044 dim_t inCperG = idim.c / group;
1045 dim_t outCperG = odim.c / group;
1046
1047 PaddingNFTBLR pdim(pads);
1048
1049 auto &inTy = inW.getType();
1050 auto &outTy = outW.getType();
1051
1052 float inScale = inTy.getScale();
1053 float outScale = outTy.getScale();
1054
1055 int32_t inOffset = inTy.getOffset();
1056 int32_t outOffset = outTy.getOffset();
1057
1058 // For each input in the batch:
1059 for (dim_t n = 0; n < idim.n; n++) {
1060 // For each group of input channels:
1061 for (dim_t g = 0; g < group; g++) {
1062 // For each output channel in the group:
1063 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
1064
1065 // Get channel wise quantization params.
1066 int32_t filterOffset = filterOffsets.at(d);
1067 float filterScale = filterScales.at(d);
1068 int32_t biasOffset = biasOffsets.at(d);
1069 float biasScale = biasScales.at(d);
1070 float matMulScale = inScale * filterScale;
1071
1072 // For each convolution 'jump' in the input tensor:
1073 sdim_t t = -sdim_t(pdim.near);
1074 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1075 sdim_t x = -sdim_t(pdim.top);
1076 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1077 sdim_t y = -sdim_t(pdim.left);
1078 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1079
1080 // For each element in the convolution-filter:
1081 AccumulatorTy sum = 0;
1082 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1083 for (dim_t fx = 0; fx < kdim.height; fx++) {
1084 for (dim_t fy = 0; fy < kdim.width; fy++) {
1085 sdim_t ot = t + ft;
1086 sdim_t ox = x + fx;
1087 sdim_t oy = y + fy;
1088
1089 // Ignore index access below zero (this is due to
1090 // padding).
1091 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1092 ox >= ssize_t(idim.h) || oy >= sdim_t(idim.w)) {
1093 continue;
1094 }
1095
1096 // Accumulate along the filter depth.
1097 for (dim_t fd = 0; fd < inCperG; fd++) {
1098
1099 AccumulatorTy F = filterW.at({d, ft, fx, fy, fd});
1100 AccumulatorTy I = inW.at({n, (dim_t)ot, (dim_t)ox,
1101 (dim_t)oy, g * inCperG + fd});
1102 // We represent the element multiplication with offset
1103 // as (value - offset).
1104 sum += (F - filterOffset) * (I - inOffset);
1105 }
1106 }
1107 }
1108 }
1109
1110 // Scale the bias to match the scale of the matrix multiplication.
1111 sum += std::round(float(biasW.at({d}) - biasOffset) *
1112 (biasScale / matMulScale));
1113
1114 // Scale the result back to the expected destination scale.
1115 outW.at({n, at, ax, ay, d}) =
1116 quantization::clip<AccumulatorTy, ElemTy>(std::round(
1117 float(sum) * (matMulScale / outScale) + outOffset));
1118 } // W
1119 } // H
1120 } // T
1121 } // C
1122 } // G
1123 } // N
1124}
1125
1126void BoundInterpreterFunction::fwdChannelwiseQuantizedConvolutionInst(
1127 const ChannelwiseQuantizedConvolutionInst *I) {
1128 bool isConv3D = (I->getSrc()->dims().size() == 5);
1129 if (isConv3D) {
1130 dispatchQuantizedWithAccumulationAndBiasImpl(
1131 fwdChannelwiseQuantizedConv3DInstImpl, I->getSrc()->getElementType(),
1132 I->getBias()->getElementType(), I);
1133 } else {
1134 dispatchQuantizedWithAccumulationAndBiasImpl(
1135 fwdChannelwiseQuantizedConv2DInstImpl, I->getSrc()->getElementType(),
1136 I->getBias()->getElementType(), I);
1137 }
1138}
1139
1140template <typename ElemTy>
1141void BoundInterpreterFunction::fwdBatchNormalizationFloatImpl(
1142 const BatchNormalizationInst *I, int numDims) {
1143 staticAssertFloatingPointType(ElemTy);
1144
1145 // input
1146 auto inH = getWeightHandle<ElemTy>(I->getSrc());
1147 auto scaleH = getWeightHandle<ElemTy>(I->getScale());
1148 auto biasH = getWeightHandle<ElemTy>(I->getBias());
1149 auto meanH = getWeightHandle<ElemTy>(I->getMean());
1150 auto varH = getWeightHandle<ElemTy>(I->getVar());
1151 unsigned_t channelIdx = I->getChannelIdx();
1152 float epsilon = I->getEpsilon();
1153
1154 // output
1155 auto outW = getWeightHandle<ElemTy>(I->getDest());
1156
1157 dim_t N, C, sizeN, sizeImg;
1158 bool isCMinor;
1159 if (numDims == 3) {
1160 if (channelIdx == 4) {
1161 ShapeNTHWC idim(I->getSrc()->dims());
1162 N = idim.n;
1163 C = idim.c;
1164 sizeImg = idim.t * idim.h * idim.w;
1165 sizeN = idim.c * sizeImg;
1166 isCMinor = true;
1167 } else {
1168 ShapeNCTHW idim(I->getSrc()->dims());
1169 N = idim.n;
1170 C = idim.c;
1171 sizeImg = idim.t * idim.h * idim.w;
1172 sizeN = idim.c * sizeImg;
1173 isCMinor = false;
1174 }
1175 } else if (numDims == 2) {
1176 if (channelIdx == 3) {
1177 ShapeNHWC idim(I->getSrc()->dims());
1178 N = idim.n;
1179 C = idim.c;
1180 sizeImg = idim.h * idim.w;
1181 sizeN = idim.c * sizeImg;
1182 isCMinor = true;
1183 } else {
1184 ShapeNCHW idim(I->getSrc()->dims());
1185 N = idim.n;
1186 C = idim.c;
1187 sizeImg = idim.h * idim.w;
1188 sizeN = idim.c * sizeImg;
1189 isCMinor = false;
1190 }
1191 } else if (numDims == 1) {
1192 N = I->getSrc()->dims()[0];
1193 C = I->getSrc()->dims()[channelIdx];
1194 sizeImg = I->getSrc()->dims()[channelIdx == 2 ? 1 : 2];
1195 sizeN = C * sizeImg;
1196 isCMinor = (channelIdx == 2);
1197 } else {
1198 N = I->getSrc()->dims()[0];
1199 C = I->getSrc()->dims()[channelIdx];
1200 sizeImg = 1;
1201 sizeN = C;
1202 isCMinor = false;
1203 }
1204
1205 std::vector<float> scale(C), mean(C), bias(C);
1206 for (dim_t c = 0; c < C; c++) {
1207 scale[c] = float(scaleH.at({c})) / std::sqrt(float(varH.at({c})) + epsilon);
1208 bias[c] = biasH.at({c});
1209 mean[c] = meanH.at({c});
1210 }
1211
1212 // For each input in the batch:
1213 for (dim_t n = 0; n < N; n++) {
1214 if (isCMinor) {
1215 // For each H*W{*T} of the image
1216 for (dim_t i = 0; i < sizeImg; i++) {
1217 // For each channel
1218 for (dim_t c = 0; c < C; c++) {
1219 int index = n * sizeN + i * C + c;
1220 outW.raw(index) =
1221 ElemTy(scale[c] * (float(inH.raw(index)) - mean[c]) + bias[c]);
1222 } // C
1223 } // image
1224 } else {
1225 // For each channel
1226 for (dim_t c = 0; c < C; c++) {
1227 // For each H*W{*T} of the image
1228 for (dim_t i = 0; i < sizeImg; i++) {
1229 int index = n * sizeN + c * sizeImg + i;
1230 outW.raw(index) =
1231 ElemTy(scale[c] * (float(inH.raw(index)) - mean[c]) + bias[c]);
1232 } // image
1233 } // C
1234 }
1235 } // N
1236}
1237
1238template <typename ParamTy>
1239void BoundInterpreterFunction::fwdBatchNormalizationI8Impl(
1240 const BatchNormalizationInst *I, int numDims) {
1241
1242 // input
1243 auto inH = getWeightHandle<int8_t>(I->getSrc());
1244 auto scaleH = getWeightHandle<ParamTy>(I->getScale());
1245 auto biasH = getWeightHandle<ParamTy>(I->getBias());
1246 auto meanH = getWeightHandle<ParamTy>(I->getMean());
1247 auto varH = getWeightHandle<ParamTy>(I->getVar());
1248 unsigned_t channelIdx =
1249 I->getChannelIdx(); // NOTE: We only support NTHWC, NHWC, NWC and NCW
1250 float epsilon = I->getEpsilon();
1251 auto inScale = float(I->getSrc()->getType()->getScale());
1252 auto inZero = I->getSrc()->getType()->getOffset();
1253
1254 // output
1255 auto outH = getWeightHandle<int8_t>(I->getDest());
1256 auto outScale = float(I->getDest()->getType()->getScale());
1257 auto outZero = I->getDest()->getType()->getOffset();
1258
1259 dim_t N, C, sizeN, sizeImg;
1260 bool isCMinor;
1261 if (numDims == 3) {
1262 if (channelIdx == 4) {
1263 ShapeNTHWC idim(I->getSrc()->dims());
1264 N = idim.n;
1265 C = idim.c;
1266 sizeImg = idim.t * idim.h * idim.w;
1267 sizeN = idim.c * sizeImg;
1268 isCMinor = true;
1269
1270 } else {
1271 ShapeNCTHW idim(I->getSrc()->dims());
1272 N = idim.n;
1273 C = idim.c;
1274 sizeImg = idim.t * idim.h * idim.w;
1275 sizeN = idim.c * sizeImg;
1276 isCMinor = false;
1277 }
1278 } else if (numDims == 2) {
1279 if (channelIdx == 3) {
1280 ShapeNHWC idim(I->getSrc()->dims());
1281 N = idim.n;
1282 C = idim.c;
1283 sizeImg = idim.h * idim.w;
1284 sizeN = idim.c * sizeImg;
1285 isCMinor = true;
1286
1287 } else {
1288 ShapeNCHW idim(I->getSrc()->dims());
1289 N = idim.n;
1290 C = idim.c;
1291 sizeImg = idim.h * idim.w;
1292 sizeN = idim.c * sizeImg;
1293 isCMinor = false;
1294 }
1295
1296 } else {
1297 // numDims == 1. This can happen due to optimization pass that sinks
1298 // reshape below batchnorm.
1299 N = I->getSrc()->dims()[0];
1300 C = I->getSrc()->dims()[channelIdx];
1301 sizeImg = I->getSrc()->dims()[channelIdx == 2 ? 1 : 2];
1302 sizeN = C * sizeImg;
1303 isCMinor = (channelIdx == 2);
1304 }
1305
1306 // See qbatch_norm.cpp/compute_fused_params() for FBGEMM implementation
1307 std::vector<ParamTy> alpha(C), beta(C);
1308 for (dim_t c = 0; c < C; c++) {
1309 float invSigma = 1 / std::sqrt(float(varH.at({c})) + epsilon);
1310 alpha[c] = ParamTy(invSigma * float(scaleH.at({c})) * (inScale / outScale));
1311 beta[c] = ParamTy((float(biasH.at({c})) - float(meanH.at({c})) * invSigma *
1312 float(scaleH.at({c}))) /
1313 outScale);
1314 }
1315
1316 // See QuantizedOpKernels.cpp/q_batch_norm_kernel() for FBGEMM implementation
1317 TensorQuantizationParams outputQ{1.0f, outZero};
1318 // For each input in the batch:
1319 for (dim_t n = 0; n < N; n++) {
1320 if (isCMinor) {
1321 // For each H*W{*T} of the image
1322 for (dim_t i = 0; i < sizeImg; i++) {
1323 // For each channel
1324 for (dim_t c = 0; c < C; c++) {
1325 int index = n * sizeN + i * C + c;
1326 ParamTy x = inH.raw(index) - inZero;
1327 ParamTy y = alpha[c] * x + beta[c];
1328 outH.raw(index) = quantization::quantize(y, outputQ);
1329 } // image
1330 } // C
1331 } else {
1332 // For each channel
1333 for (dim_t c = 0; c < C; c++) {
1334 // For each H*W{*T} of the image
1335 for (dim_t i = 0; i < sizeImg; i++) {
1336 int index = n * sizeN + c * sizeImg + i;
1337 auto x = ParamTy(inH.raw(index) - inZero);
1338 ParamTy y = alpha[c] * x + beta[c];
1339 outH.raw(index) = quantization::quantize(y, outputQ);
1340 } // image
1341 } // C
1342 }
1343 } // N
1344}
1345
1346void BoundInterpreterFunction::fwdBatchNormalizationInst(
1347 const BatchNormalizationInst *I) {
1348 int numDims = I->getSrc()->dims().size() - 2;
1349 bool isQuantized = I->getSrc()->getType()->isQuantizedType();
1350
1351 if (isQuantized) {
1352 if (I->getScale()->getType()->getElementType() == ElemKind::FloatTy) {
1353 fwdBatchNormalizationI8Impl<float>(I, numDims);
1354 } else {
1355 fwdBatchNormalizationI8Impl<float16_t>(I, numDims);
1356 }
1357 } else {
1358 dispatchFloatingPointImpl(fwdBatchNormalizationFloatImpl,
1359 I->getSrc()->getElementType(), I, numDims);
1360 }
1361}
1362
1363//===----------------------------------------------------------------------===//
1364// LayerNormalization
1365//===----------------------------------------------------------------------===//
1366
1367template <typename ElemTy>
1368void BoundInterpreterFunction::fwdLayerNormalizationInstFloatImpl(
1369 const glow::LayerNormalizationInst *I) {
1370 staticAssertFloatingPointType(ElemTy);
1371
1372 // input
1373 auto inH = getWeightHandle<ElemTy>(I->getSrc());
1374 auto scaleH = getWeightHandle<ElemTy>(I->getScale());
1375 auto biasH = getWeightHandle<ElemTy>(I->getBias());
1376 float epsilon = I->getEpsilon();
1377
1378 // output
1379 auto outW = getWeightHandle<ElemTy>(I->getDest());
1380
1381 auto N = I->getSrc()->dims()[0];
1382 auto K = I->getSrc()->dims()[1];
1383
1384 std::vector<float> val(K);
1385 for (dim_t n = 0; n < N; n++) {
1386 // 1. mean = x.mean(dim=-1, keepdim=True)
1387 float sum = 0.0f;
1388 for (dim_t k = 0; k < K; k++) {
1389 val[k] = inH.at({n, k});
1390 sum += val[k];
1391 }
1392 float mean = sum / K;
1393
1394 // 2. var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
1395 float diff_sqr_sum = 0.0f;
1396 for (dim_t k = 0; k < K; k++) {
1397 float diff = val[k] - mean;
1398 diff_sqr_sum += diff * diff;
1399 }
1400 float var = diff_sqr_sum / K;
1401
1402 // 3. std = (var + epsilon).sqrt()
1403 float std = std::sqrt(var + epsilon);
1404
1405 for (dim_t k = 0; k < K; k++) {
1406 // 4. y = ((x - mean) / std) * scale + bias
1407 float scale = scaleH.at({k});
1408 float bias = biasH.at({k});
1409 outW.at({n, k}) = ElemTy((((val[k] - mean) / std) * scale) + bias);
1410 }
1411 }
1412}
1413
1414void BoundInterpreterFunction::fwdLayerNormalizationInst(
1415 const LayerNormalizationInst *I) {
1416 dispatchFloatingPointImpl(fwdLayerNormalizationInstFloatImpl,
1417 I->getSrc()->getElementType(), I);
1418}
1419
1420//===----------------------------------------------------------------------===//
1421// Pooling
1422//===----------------------------------------------------------------------===//
1423template <class T>
1424static void fwdMaxPool(Tensor *inW, Tensor *outW, Tensor *argmaxW,
1425 llvm::ArrayRef<unsigned_t> kernelSizes,
1426 llvm::ArrayRef<unsigned_t> strides,
1427 llvm::ArrayRef<unsigned_t> pads) {
1428 ShapeNHWC odim(outW->dims());
1429 ShapeNHWC idim(inW->dims());
1430 Handle<T> inHandle = inW->getHandle<T>();
1431 Handle<T> outHandle = outW->getHandle<T>();
1432 PaddingTLBR pdim(pads);
1433 ShapeHW kdim(kernelSizes);
1434 ShapeHW sdim(strides);
1435
1436 llvm::Optional<Handle<int64_t>> argmaxH;
1437 if (argmaxW) {
1438 argmaxH = argmaxW->getHandle<int64_t>();
1439 }
1440 // For each input in the batch:
1441 for (dim_t n = 0; n < odim.n; n++) {
1442
1443 // For each layer in the output tensor:
1444 for (dim_t z = 0; z < idim.c; z++) {
1445 // For each convolution 'jump' in the input tensor:
1446 sdim_t x = -sdim_t(pdim.top);
1447 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1448 sdim_t y = -sdim_t(pdim.left);
1449 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1450
1451 // When the MaxPool window includes only padding pixels then for
1452 // that window by convention we return 0.
1453 bool first = true;
1454 T max_value = outW->getType().isQuantizedType()
1455 ? static_cast<T>(outW->getType().getOffset())
1456 : static_cast<T>(0);
1457 dim_t argmaxNHWC = 0;
1458
1459 for (dim_t fx = 0; fx < kdim.height; fx++) {
1460 for (dim_t fy = 0; fy < kdim.width; fy++) {
1461 sdim_t ox = x + fx;
1462 sdim_t oy = y + fy;
1463
1464 // Ignore index access below zero (this is due to padding).
1465 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1466 oy >= ssize_t(idim.w)) {
1467 continue;
1468 }
1469
1470 T val = inHandle.at({n, (dim_t)ox, (dim_t)oy, z});
1471 if (first || (val >= max_value)) {
1472 first = false;
1473 max_value = val;
1474 if (argmaxW) {
1475 argmaxNHWC = &inHandle.at({n, (dim_t)ox, (dim_t)oy, z}) -
1476 &inHandle.raw(0);
1477 }
1478 }
1479 }
1480 }
1481
1482 outHandle.at({n, ax, ay, z}) = max_value;
1483
1484 if (argmaxW) {
1485 (*argmaxH).at({n, ax, ay, z}) = argmaxNHWC;
1486 }
1487 } // W
1488 } // H
1489 } // C
1490 } // N
1491}
1492
1493void BoundInterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) {
1494 auto inW = getTensor(I->getSrc());
1495 auto outW = getTensor(I->getDest());
1496
1497 if (inW->getType().isQuantizedType()) {
1498 dispatchQuantizedImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1499 outW, nullptr, I->getKernels(), I->getStrides(),
1500 I->getPads());
1501 return;
1502 }
1503
1504 dispatchFloatingPointImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1505 outW, nullptr, I->getKernels(), I->getStrides(),
1506 I->getPads());
1507}
1508
1509void BoundInterpreterFunction::fwdMaxPoolWithArgmaxInst(
1510 const MaxPoolWithArgmaxInst *I) {
1511 auto inW = getTensor(I->getSrc());
1512 auto outW = getTensor(I->getDest());
1513 auto argmaxW = getTensor(I->getArgmax());
1514
1515 if (inW->getType().isQuantizedType()) {
1516 dispatchQuantizedImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1517 outW, argmaxW, I->getKernels(), I->getStrides(),
1518 I->getPads());
1519 return;
1520 }
1521 dispatchFloatingPointImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1522 outW, argmaxW, I->getKernels(), I->getStrides(),
1523 I->getPads());
1524}
1525
1526template <typename ElemTy>
1527void BoundInterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) {
1528 staticAssertFloatingPointType(ElemTy);
1529
1530 ShapeNHWC odim(I->getDest()->dims());
1531 ShapeNHWC idim(I->getSrc()->dims());
1532
1533 PaddingTLBR pdim(I->getPads());
1534 ShapeHW kdim(I->getKernels());
1535 ShapeHW sdim(I->getStrides());
1536 // Implement the avg pooling operation as defined here:
1537 // https://arxiv.org/abs/1312.4400
1538 float rawFilterArea = kdim.height * kdim.width;
1539
1540 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1541 auto outW = getWeightHandle<ElemTy>(I->getDest());
1542
1543 // For each input in the batch:
1544 for (dim_t n = 0; n < odim.n; n++) {
1545 // For each layer in the output tensor:
1546 for (dim_t z = 0; z < idim.c; z++) {
1547 // For each convolution 'jump' in the input tensor:
1548 ssize_t x = -ssize_t(pdim.top);
1549 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1550 ssize_t y = -ssize_t(pdim.left);
1551 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1552 float sum = 0;
1553 float filterArea = rawFilterArea;
1554
1555 for (dim_t fx = 0; fx < kdim.height; fx++) {
1556 for (dim_t fy = 0; fy < kdim.width; fy++) {
1557 sdim_t ox = x + fx;
1558 sdim_t oy = y + fy;
1559
1560 // Ignore index access below zero (this is due to padding).
1561 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1562 oy >= ssize_t(idim.w)) {
1563 if (!I->getCountIncludePads()) {
1564 filterArea--;
1565 }
1566
1567 continue;
1568 }
1569
1570 sum += float(inW.at({n, (dim_t)ox, (dim_t)oy, z}));
1571 }
1572 }
1573 if (filterArea == 0) {
1574 outW.at({n, ax, ay, z}) = ElemTy(0);
1575 } else {
1576 outW.at({n, ax, ay, z}) = ElemTy(sum / filterArea);
1577 }
1578 } // W
1579 } // H
1580 } // C
1581 } // N
1582}
1583
1584void BoundInterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) {
1585 ShapeNHWC odim(I->getDest()->dims());
1586 ShapeNHWC idim(I->getSrc()->dims());
1587
1588 PaddingTLBR pdim(I->getPads());
1589 ShapeHW kdim(I->getKernels());
1590 ShapeHW sdim(I->getStrides());
1591 // Implement the avg pooling operation as defined here:
1592 // https://arxiv.org/abs/1312.4400
1593 float rawFilterArea = kdim.height * kdim.width;
1594
1595 auto inW = getWeightHandle<int8_t>(I->getSrc());
1596 auto outW = getWeightHandle<int8_t>(I->getDest());
1597 TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1598 I->getSrc()->getType()->getOffset()};
1599 TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1600 I->getDest()->getType()->getOffset()};
1601
1602 // For each input in the batch:
1603 for (dim_t n = 0; n < odim.n; n++) {
1604 // For each layer in the output tensor:
1605 for (dim_t z = 0; z < idim.c; z++) {
1606 // For each convolution 'jump' in the input tensor:
1607 ssize_t x = -ssize_t(pdim.top);
1608 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1609 ssize_t y = -ssize_t(pdim.left);
1610 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1611 int32_t sum = 0;
1612 float filterArea = rawFilterArea;
1613
1614 for (dim_t fx = 0; fx < kdim.height; fx++) {
1615 for (dim_t fy = 0; fy < kdim.width; fy++) {
1616 sdim_t ox = x + fx;
1617 sdim_t oy = y + fy;
1618
1619 // Ignore index access below zero (this is due to padding).
1620 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1621 oy >= ssize_t(idim.w)) {
1622 if (!I->getCountIncludePads()) {
1623 filterArea--;
1624 }
1625
1626 continue;
1627 }
1628
1629 sum += inW.at({n, (dim_t)ox, (dim_t)oy, z}) - inQP.offset;
1630 }
1631 }
1632 if (filterArea == 0) {
1633 outW.at({n, ax, ay, z}) =
1634 quantization::clip<int32_t, int8_t>(outQP.offset);
1635 } else {
1636 // Instead of dividing by filterArea, just change scale.
1637 outW.at({n, ax, ay, z}) =
1638 quantization::clip<int32_t, int8_t>(std::round(
1639 float(sum) * (inQP.scale / outQP.scale / filterArea) +
1640 outQP.offset));
1641 }
1642 } // W
1643 } // H
1644 } // C
1645 } // N
1646}
1647
1648template <typename ElemTy>
1649void BoundInterpreterFunction::fwdAvgPool3DInstFloatImpl(const AvgPoolInst *I) {
1650 staticAssertFloatingPointType(ElemTy);
1651
1652 ShapeNTHWC odim(I->getDest()->dims());
1653 ShapeNTHWC idim(I->getSrc()->dims());
1654
1655 PaddingNFTBLR pdim(I->getPads());
1656 ShapeTHW kdim(I->getKernels());
1657 ShapeTHW sdim(I->getStrides());
1658 // Implement the avg pooling operation as defined here:
1659 // https://arxiv.org/abs/1312.4400
1660 float rawFilterArea = kdim.temporal_frames * kdim.height * kdim.width;
1661
1662 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1663 auto outW = getWeightHandle<ElemTy>(I->getDest());
1664
1665 // For each input in the batch:
1666 for (dim_t n = 0; n < odim.n; n++) {
1667 // For each layer in the output tensor:
1668 for (dim_t z = 0; z < idim.c; z++) {
1669 // For each convolution 'jump' in the input tensor:
1670 ssize_t t = -ssize_t(pdim.near);
1671 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1672 ssize_t x = -ssize_t(pdim.top);
1673 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1674 ssize_t y = -ssize_t(pdim.left);
1675 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1676 float sum = 0;
1677 float filterArea = rawFilterArea;
1678
1679 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1680 for (dim_t fx = 0; fx < kdim.height; fx++) {
1681 for (dim_t fy = 0; fy < kdim.width; fy++) {
1682 sdim_t ot = t + ft;
1683 sdim_t ox = x + fx;
1684 sdim_t oy = y + fy;
1685
1686 // Ignore index access below zero (this is due to padding).
1687 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1688 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
1689 if (!I->getCountIncludePads()) {
1690 filterArea--;
1691 }
1692
1693 continue;
1694 }
1695
1696 sum += float(inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}));
1697 }
1698 }
1699 }
1700 assert(filterArea != 0 && "filterArea can't be 0");
1701 outW.at({n, at, ax, ay, z}) = ElemTy(sum / filterArea);
1702 } // W
1703 } // H
1704 } // T
1705 } // C
1706 } // N
1707}
1708
1709void BoundInterpreterFunction::fwdAvgPool3DInstI8Impl(const AvgPoolInst *I) {
1710 ShapeNTHWC odim(I->getDest()->dims());
1711 ShapeNTHWC idim(I->getSrc()->dims());
1712
1713 PaddingNFTBLR pdim(I->getPads());
1714 ShapeTHW kdim(I->getKernels());
1715 ShapeTHW sdim(I->getStrides());
1716 // Implement the avg pooling operation as defined here:
1717 // https://arxiv.org/abs/1312.4400
1718 float rawFilterArea = kdim.temporal_frames * kdim.height * kdim.width;
1719
1720 auto inW = getWeightHandle<int8_t>(I->getSrc());
1721 auto outW = getWeightHandle<int8_t>(I->getDest());
1722 TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1723 I->getSrc()->getType()->getOffset()};
1724 TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1725 I->getDest()->getType()->getOffset()};
1726
1727 // For each input in the batch:
1728 for (dim_t n = 0; n < odim.n; n++) {
1729 // For each layer in the output tensor:
1730 for (dim_t z = 0; z < idim.c; z++) {
1731 // For each convolution 'jump' in the input tensor:
1732 ssize_t t = -ssize_t(pdim.near);
1733 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1734 ssize_t x = -ssize_t(pdim.top);
1735 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1736 ssize_t y = -ssize_t(pdim.left);
1737 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1738 int32_t sum = 0;
1739 float filterArea = rawFilterArea;
1740
1741 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1742 for (dim_t fx = 0; fx < kdim.height; fx++) {
1743 for (dim_t fy = 0; fy < kdim.width; fy++) {
1744 sdim_t ot = t + ft;
1745 sdim_t ox = x + fx;
1746 sdim_t oy = y + fy;
1747
1748 // Ignore index access below zero (this is due to padding).
1749 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1750 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
1751 if (!I->getCountIncludePads()) {
1752 filterArea--;
1753 }
1754
1755 continue;
1756 }
1757
1758 sum += inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}) -
1759 inQP.offset;
1760 }
1761 }
1762 }
1763 // Instead of dividing by filterArea, just change scale.
1764 assert(filterArea != 0 && "filterArea can't be 0");
1765
1766 float multiplier = inQP.scale / outQP.scale / filterArea;
1767 TensorQuantizationParams outputQ{1.0f / multiplier, outQP.offset};
1768 outW.at({n, at, ax, ay, z}) =
1769 quantization::quantize(float(sum), outputQ);
1770 } // W
1771 } // H
1772 } // T
1773 } // C
1774 } // N
1775}
1776
1777void BoundInterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) {
1778 bool isConv3D = is3DData(ConvolutionLayout(I->getLayout()));
1779 bool isQuantized = I->getSrc()->getType()->isQuantizedType();
1780
1781 if (isConv3D) {
1782 if (isQuantized) {
1783 fwdAvgPool3DInstI8Impl(I);
1784 } else {
1785 dispatchFloatingPointImpl(fwdAvgPool3DInstFloatImpl,
1786 I->getSrc()->getElementType(), I);
1787 }
1788 } else {
1789 if (isQuantized) {
1790 fwdAvgPoolInstI8Impl(I);
1791 } else {
1792 dispatchFloatingPointImpl(fwdAvgPoolInstFloatImpl,
1793 I->getSrc()->getElementType(), I);
1794 }
1795 }
1796}
1797
1798template <typename ElemTy>
1799void BoundInterpreterFunction::fwdAdaptiveAvgPoolInstFloatImpl(
1800 const AdaptiveAvgPoolInst *I) {
1801 staticAssertFloatingPointType(ElemTy);
1802
1803 ShapeNHWC odim(I->getDest()->dims());
1804 ShapeNHWC idim(I->getSrc()->dims());
1805
1806 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1807 auto outW = getWeightHandle<ElemTy>(I->getDest());
1808
1809// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
1810#define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1811#define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1812
1813 // For each input in the batch:
1814 for (dim_t n = 0; n < odim.n; n++) {
1815 // For each layer in the output tensor:
1816 for (dim_t z = 0; z < idim.c; z++) {
1817 // For each value in the output tensor:
1818 for (dim_t ax = 0; ax < odim.h; ax++) {
1819
1820 dim_t x = START_IND(ax, odim.h, idim.h);
1821 dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1822
1823 for (dim_t ay = 0; ay < odim.w; ay++) {
1824
1825 dim_t y = START_IND(ay, odim.w, idim.w);
1826 dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1827
1828 float sum = 0;
1829 for (dim_t fx = 0; fx < kH; fx++) {
1830 for (dim_t fy = 0; fy < kW; fy++) {
1831 dim_t ox = x + fx;
1832 dim_t oy = y + fy;
1833
1834 sum += float(inW.at({n, ox, oy, z}));
1835 }
1836 }
1837 outW.at({n, ax, ay, z}) = ElemTy(sum / kW / kH);
1838 } // W
1839 } // H
1840 } // C
1841 } // N
1842#undef START_IND
1843#undef END_IND
1844}
1845
1846void BoundInterpreterFunction::fwdAdaptiveAvgPoolInstI8Impl(
1847 const AdaptiveAvgPoolInst *I) {
1848 ShapeNHWC odim(I->getDest()->dims());
1849 ShapeNHWC idim(I->getSrc()->dims());
1850
1851 auto inW = getWeightHandle<int8_t>(I->getSrc());
1852 auto outW = getWeightHandle<int8_t>(I->getDest());
1853
1854 TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1855 I->getSrc()->getType()->getOffset()};
1856 TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1857 I->getDest()->getType()->getOffset()};
1858
1859// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
1860#define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1861#define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1862
1863 // For each input in the batch:
1864 for (dim_t n = 0; n < odim.n; n++) {
1865 // For each layer in the output tensor:
1866 for (dim_t z = 0; z < idim.c; z++) {
1867 // For each value in the output tensor:
1868 for (dim_t ax = 0; ax < odim.h; ax++) {
1869
1870 dim_t x = START_IND(ax, odim.h, idim.h);
1871 dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1872
1873 for (dim_t ay = 0; ay < odim.w; ay++) {
1874
1875 dim_t y = START_IND(ay, odim.w, idim.w);
1876 dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1877
1878 int32_t sum = 0;
1879 for (dim_t fx = 0; fx < kH; fx++) {
1880 for (dim_t fy = 0; fy < kW; fy++) {
1881 dim_t ox = x + fx;
1882 dim_t oy = y + fy;
1883
1884 sum += inW.at({n, ox, oy, z}) - inQP.offset;
1885 }
1886 }
1887
1888 outW.at({n, ax, ay, z}) = quantization::clip<int32_t, int8_t>(
1889 std::round(float(sum) * (inQP.scale / outQP.scale / kW / kH) +
1890 outQP.offset));
1891 } // W
1892 } // H
1893 } // C
1894 } // N
1895#undef START_IND
1896#undef END_IND
1897}
1898
1899void BoundInterpreterFunction::fwdAdaptiveAvgPoolInst(
1900 const AdaptiveAvgPoolInst *I) {
1901 if (I->getSrc()->getType()->isQuantizedType()) {
1902 fwdAdaptiveAvgPoolInstI8Impl(I);
1903 return;
1904 }
1905
1906 dispatchFloatingPointImpl(fwdAdaptiveAvgPoolInstFloatImpl,
1907 I->getSrc()->getElementType(), I);
1908}
1909
1910void BoundInterpreterFunction::fwdAdaptiveAvgPoolGradInst(
1911 const AdaptiveAvgPoolGradInst *I) {
1912 auto inG = getWeightHandle(I->getSrcGrad());
1913 auto outW = getWeightHandle(I->getDest());
1914 auto outG = getWeightHandle(I->getDestGrad());
1915
1916 inG.clear();
1917
1918 ShapeNHWC odim(outW.dims());
1919 ShapeNHWC idim(inG.dims());
1920
1921 const float gradCoefficient = 1. / (odim.h * odim.w);
1922
1923#define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1924#define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1925
1926 // https://software.intel.com/en-us/daal-programming-guide-2d-average-pooling-backward-layer
1927 // For each input in the batch:
1928 for (dim_t n = 0; n < odim.n; n++) {
1929 // For each layer in the output tensor:
1930 for (dim_t z = 0; z < idim.c; z++) {
1931 // For each value in the output tensor:
1932 for (dim_t ax = 0; ax < odim.h; ax++) {
1933
1934 dim_t x = START_IND(ax, odim.h, idim.h);
1935 dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1936
1937 for (dim_t ay = 0; ay < odim.w; ay++) {
1938
1939 dim_t y = START_IND(ay, odim.w, idim.w);
1940 dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1941
1942 const float chainGrad = outG.at({n, ax, ay, z}) * gradCoefficient;
1943
1944 for (dim_t fx = 0; fx < kH; fx++) {
1945 for (dim_t fy = 0; fy < kW; fy++) {
1946 dim_t ox = x + fx;
1947 dim_t oy = y + fy;
1948
1949 inG.at({n, ox, oy, z}) += chainGrad;
1950 }
1951 }
1952 } // W
1953 } // H
1954 } // C
1955 } // N
1956#undef START_IND
1957#undef END_IND
1958}
1959
1960void BoundInterpreterFunction::fwdMaxPoolWithArgmaxGradInst(
1961 const MaxPoolWithArgmaxGradInst *I) {
1962 auto inG = getWeightHandle(I->getSrcGrad());
1963 auto outW = getWeightHandle(I->getDest());
1964 auto outG = getWeightHandle(I->getDestGrad());
1965
1966 inG.clear();
1967
1968 ShapeNHWC idim(inG.dims());
1969 ShapeNHWC odim(outW.dims());
1970
1971 auto argmax = getWeightHandle<int64_t>(I->getArgmax());
1972
1973 // For each input in the batch:
1974 for (dim_t n = 0; n < odim.n; n++) {
1975
1976 // Compute the gradient. For each layer in the output tensor:
1977 for (dim_t z = 0; z < odim.c; z++) {
1978
1979 // For each convolution 'jump' in the input tensor:
1980 for (dim_t ax = 0; ax < odim.h; ax++) {
1981 for (dim_t ay = 0; ay < odim.w; ay++) {
1982 // Reuse precomputed linear index of max element from argmax.
1983 float chainGrad = outG.at({n, ax, ay, z});
1984 inG.raw(argmax.at({n, ax, ay, z})) += chainGrad;
1985 } // W
1986 } // H
1987 } // C
1988 } // N
1989}
1990
1991void BoundInterpreterFunction::fwdAvgPool2DGradInst(const AvgPoolGradInst *I) {
1992 auto inG = getWeightHandle(I->getSrcGrad());
1993 auto outW = getWeightHandle(I->getDest());
1994 auto outG = getWeightHandle(I->getDestGrad());
1995
1996 ShapeNHWC odim(outW.dims());
1997 ShapeNHWC idim(inG.dims());
1998
1999 PaddingTLBR pdim(I->getPads());
2000 ShapeHW kdim(I->getKernels());
2001 ShapeHW sdim(I->getStrides());
2002
2003 inG.clear();
2004
2005 float rawFilterArea = kdim.height * kdim.width;
2006
2007 // For each input in the batch:
2008 for (dim_t n = 0; n < odim.n; n++) {
2009
2010 // For each layer in the output tensor:
2011 for (dim_t z = 0; z < odim.c; z++) {
2012 // For each convolution 'jump' in the input tensor:
2013 ssize_t x = -ssize_t(pdim.top);
2014 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
2015 ssize_t y = -ssize_t(pdim.left);
2016 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
2017 float filterArea = rawFilterArea;
2018
2019 // Excludes the padding area in filterArea if the flag is false
2020 if (!I->getCountIncludePads()) {
2021 ssize_t pad_x = (-x > 0 ? -x : 0) +
2022 ((x + ssize_t(kdim.height) - ssize_t(idim.h)) > 0
2023 ? (x + ssize_t(kdim.height) - ssize_t(idim.h))
2024 : 0);
2025 ssize_t pad_y = (-y > 0 ? -y : 0) +
2026 ((y + ssize_t(kdim.width) - ssize_t(idim.w)) > 0
2027 ? (y + ssize_t(kdim.width) - ssize_t(idim.w))
2028 : 0);
2029 filterArea = rawFilterArea - pad_x * kdim.width -
2030 pad_y * kdim.height + pad_x * pad_y;
2031 }
2032 assert(filterArea != 0 && "filterArea can't be 0");
2033 float dy = outG.at({n, ax, ay, z}) / filterArea;
2034
2035 for (dim_t fx = 0; fx < kdim.height; fx++) {
2036 for (dim_t fy = 0; fy < kdim.width; fy++) {
2037 ssize_t ox = x + fx;
2038 ssize_t oy = y + fy;
2039
2040 // Ignore index access below zero (this is due to padding).
2041 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
2042 oy >= ssize_t(idim.w)) {
2043 continue;
2044 }
2045 inG.at({n, (dim_t)ox, (dim_t)oy, z}) += dy;
2046 }
2047 }
2048 } // W
2049 } // H
2050 } // C
2051 } // N
2052}
2053
2054void BoundInterpreterFunction::fwdAvgPool3DGradInst(const AvgPoolGradInst *I) {
2055 auto inG = getWeightHandle(I->getSrcGrad());
2056 auto outW = getWeightHandle(I->getDest());
2057 auto outG = getWeightHandle(I->getDestGrad());
2058
2059 ShapeNTHWC odim(outW.dims());
2060 ShapeNTHWC idim(inG.dims());
2061
2062 PaddingNFTBLR pdim(I->getPads());
2063 ShapeTHW kdim(I->getKernels());
2064 ShapeTHW sdim(I->getStrides());
2065
2066 inG.clear();
2067
2068 float rawFilterArea = kdim.temporal_frames * kdim.height * kdim.width;
2069
2070 // For each input in the batch:
2071 for (dim_t n = 0; n < odim.n; n++) {
2072
2073 // For each layer in the output tensor:
2074 for (dim_t z = 0; z < odim.c; z++) {
2075 // For each convolution 'jump' in the input tensor:
2076 ssize_t t = -ssize_t(pdim.near);
2077 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
2078 ssize_t x = -ssize_t(pdim.top);
2079 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
2080 ssize_t y = -ssize_t(pdim.left);
2081 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
2082 float filterArea = rawFilterArea;
2083
2084 // Excludes the padding area in filterArea if the flag is false
2085 if (!I->getCountIncludePads()) {
2086 ssize_t pad_x =
2087 (-x > 0 ? -x : 0) +
2088 ((x + ssize_t(kdim.height) - ssize_t(idim.h)) > 0
2089 ? (x + ssize_t(kdim.height) - ssize_t(idim.h))
2090 : 0);
2091 ssize_t pad_y = (-y > 0 ? -y : 0) +
2092 ((y + ssize_t(kdim.width) - ssize_t(idim.w)) > 0
2093 ? (y + ssize_t(kdim.width) - ssize_t(idim.w))
2094 : 0);
2095 ssize_t pad_z =
2096 (-t > 0 ? -t : 0) +
2097 ((t + ssize_t(kdim.temporal_frames) - ssize_t(idim.t)) > 0
2098 ? (t + ssize_t(kdim.temporal_frames) - ssize_t(idim.t))
2099 : 0);
2100 filterArea = rawFilterArea -
2101 pad_x * kdim.width * kdim.temporal_frames -
2102 pad_y * kdim.height * kdim.temporal_frames -
2103 pad_z * kdim.height * kdim.width +
2104 pad_x * pad_y * kdim.temporal_frames +
2105 pad_x * pad_z * kdim.width +
2106 pad_y * pad_z * kdim.height - pad_x * pad_y * pad_z;
2107 }
2108 assert(filterArea != 0 && "filterArea can't be 0");
2109 float dy = outG.at({n, at, ax, ay, z}) / filterArea;
2110
2111 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
2112 for (dim_t fx = 0; fx < kdim.height; fx++) {
2113 for (dim_t fy = 0; fy < kdim.width; fy++) {
2114 ssize_t ot = t + ft;
2115 ssize_t ox = x + fx;
2116 ssize_t oy = y + fy;
2117
2118 // Ignore index access below zero (this is due to padding).
2119 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
2120 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
2121 continue;
2122 }
2123 inG.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}) += dy;
2124 }
2125 }
2126 }
2127 } // W
2128 } // H
2129 } // T
2130 } // C
2131 } // N
2132}
2133
2134void BoundInterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) {
2135 bool isConv3D = is3DData(ConvolutionLayout(I->getLayout()));
2136
2137 if (isConv3D) {
2138 fwdAvgPool3DGradInst(I);
2139 } else {
2140 fwdAvgPool2DGradInst(I);
2141 }
2142}
2143
2144//===----------------------------------------------------------------------===//
2145// Activation functions
2146//===----------------------------------------------------------------------===//
2147
2148void BoundInterpreterFunction::fwdReluInst(const ReluInst *) {
2149 DCHECK(!"Found ReluInst but Relu is lowered on Interpreter");
2150}
2151
2152void BoundInterpreterFunction::fwdClipInst(const ClipInst *) {
2153 DCHECK(!"Found ClipInst but Clip is lowered on Interpreter");
2154}
2155
2156void BoundInterpreterFunction::fwdLeakyReluInst(const LeakyReluInst *) {
2157 DCHECK(!"Found LeakyReluInst but LeakyRelu is lowered on Interpreter");
2158}
2159
2160template <typename ElemTy>
2161void BoundInterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) {
2162 staticAssertFloatingPointType(ElemTy);
2163
2164 auto inW = getWeightHandle<ElemTy>(I->getSrc());
2165 auto outW = getWeightHandle<ElemTy>(I->getDest());
2166
2167 for (dim_t i = 0, e = outW.size(); i < e; i++) {
2168 float val = inW.raw(i);
2169 outW.raw(i) = ElemTy(1 / (1 + std::exp(-val)));
2170 }
2171}
2172
2173void BoundInterpreterFunction::fwdSigmoidInst(const SigmoidInst *I) {
2174 dispatchFloatingPointImpl(fwdSigmoidInstFloatImpl,
2175 I->getSrc()->getElementType(), I);
2176}
2177
2178template <typename ElemTy>
2179void BoundInterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) {
2180 staticAssertFloatingPointType(ElemTy);
2181
2182 auto inW = getWeightHandle<ElemTy>(I->getSrc());
2183 auto outW = getWeightHandle<ElemTy>(I->getDest());
2184
2185 for (dim_t i = 0, e = inW.size(); i < e; i++) {
2186 float val = inW.raw(i);
2187 outW.raw(i) = ElemTy(std::tanh(val));
2188 }
2189}
2190
2191void BoundInterpreterFunction::fwdTanhInst(const TanhInst *I) {
2192 dispatchFloatingPointImpl(fwdTanhInstFloatImpl, I->getSrc()->getElementType(),
2193 I);
2194}
2195
2196template <typename ElemTy>
2197void BoundInterpreterFunction::fwdSoftPlusInstFloatImpl(const SoftPlusInst *I) {
2198 staticAssertFloatingPointType(ElemTy);
2199
2200 auto inW = getWeightHandle<ElemTy>(I->getSrc());
2201 auto outW = getWeightHandle<ElemTy>(I->getDest());
2202
2203 for (dim_t i = 0, e = outW.size(); i < e; i++) {
2204 float val = inW.raw(i);
2205 outW.raw(i) = ElemTy(std::log(1 + std::exp(val)));
2206 }
2207}
2208
2209void BoundInterpreterFunction::fwdSoftPlusInst(const SoftPlusInst *I) {
2210 dispatchFloatingPointImpl(fwdSoftPlusInstFloatImpl,
2211 I->getSrc()->getElementType(), I);
2212}
2213
2214//===----------------------------------------------------------------------===//
2215// Loss Functions (Softmax/regression/...)
2216//===----------------------------------------------------------------------===//
2217
2218template <typename ElemTy>
2219void BoundInterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) {
2220 staticAssertFloatingPointType(ElemTy);
2221
2222 auto inW = getWeightHandle<ElemTy>(I->getSrc());
2223 auto outW = getWeightHandle<ElemTy>(I->getDest());
2224 auto idim = inW.dims();
2225
2226 for (dim_t n = 0; n < idim[0]; n++) {
2227 // Find Max.
2228 float max = float(inW.at({n, 0}));
2229 for (dim_t i = 1; i < idim[1]; i++) {
2230 max = std::max(max, float(inW.at({n, i})));
2231 }
2232
2233 // Compute exp.
2234 float sum = 0;
2235 for (dim_t i = 0; i < idim[1]; i++) {
2236 float e = std::exp(float(inW.at({n, i})) - max);
2237 sum += e;
2238 outW.at({n, i}) = ElemTy(e);
2239 }
2240
2241 // Normalize the output.
2242 for (dim_t i = 0; i < idim[1]; i++) {
2243 outW.at({n, i}) = ElemTy(float(outW.at({n, i})) / sum);
2244 }
2245 } // N
2246}
2247
2248void BoundInterpreterFunction::fwdSoftMaxInst(const SoftMaxInst *I) {
2249 dispatchFloatingPointImpl(fwdSoftMaxInstImpl, I->getSrc()->getElementType(),
2250 I);
2251}
2252
2253void BoundInterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) {
2254 auto inG = getWeightHandle(I->getSrcGrad());
2255 auto idim = inG.dims();
2256 auto outW = getWeightHandle(I->getOrigDest());
2257 auto selectedH = getWeightHandle<int64_t>(I->getSelected());
2258
2259 inG.clear();
2260
2261 // http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/
2262 // https://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network
2263 for (dim_t n = 0; n < idim[0]; n++) {
2264 for (dim_t i = 0; i < idim[1]; i++) {
2265 float delta = (selectedH.at({n, 0}) == (int64_t)i);
2266 inG.at({n, i}) = outW.at({n, i}) - delta;
2267 }
2268 }
2269}
2270
2271template <typename ElemTy>
2272void BoundInterpreterFunction::fwdLogSoftMaxInstImpl(const LogSoftMaxInst *I) {
2273 staticAssertFloatingPointType(ElemTy);
2274
2275 auto inW = getWeightHandle<ElemTy>(I->getSrc());
2276 auto outW = getWeightHandle<ElemTy>(I->getDest());
2277 auto idim = inW.dims();
2278 // using log(softmax(x)) = x - max_x - log(sum)
2279 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L14-L64
2280 for (dim_t n = 0; n < idim[0]; n++) {
2281 // Find Max.
2282 float max = float(inW.at({n, 0}));
2283 for (dim_t i = 1; i < idim[1]; i++) {
2284 max = std::max(max, float(inW.at({n, i})));
2285 }
2286
2287 // Compute sum of exp(x-max).
2288 float sum = 0;
2289 for (dim_t i = 0; i < idim[1]; i++) {
2290 float e = std::exp(float(inW.at({n, i})) - max);
2291 sum += e;
2292 }
2293
2294 // Output = x - max - log(sum)
2295 for (dim_t i = 0; i < idim[1]; i++) {
2296 outW.at({n, i}) = ElemTy(float(inW.at({n, i})) - max - std::log(sum));
2297 }
2298 }
2299}
2300
2301void BoundInterpreterFunction::fwdLogSoftMaxInst(const LogSoftMaxInst *I) {
2302 dispatchFloatingPointImpl(fwdLogSoftMaxInstImpl,
2303 I->getSrc()->getElementType(), I);
2304}
2305
2306template <typename ElemTy>
2307void BoundInterpreterFunction::fwdCrossEntropyLossInstFloatImpl(
2308 const CrossEntropyLossInst *I) {
2309 staticAssertFloatingPointType(ElemTy);
2310
2311 auto P = getWeightHandle<ElemTy>(I->getP());
2312 auto labels = getWeightHandle<int64_t>(I->getLabels());
2313 auto CE = getWeightHandle<ElemTy>(I->getCE());
2314 auto dims = P.dims();
2315 CE.clear();
2316 for (dim_t n = 0; n < dims[0]; ++n) {
2317 assert(labels.raw(n) >= 0 && "Cannot use negative index.");
2318 dim_t y = labels.raw(n);
2319 float p_n = P.at({n, y});
2320 CE.at({0}) -= log(p_n);
2321 }
2322}
2323
2324void BoundInterpreterFunction::fwdCrossEntropyLossInst(
2325 const CrossEntropyLossInst *I) {
2326 dispatchFloatingPointImpl(fwdCrossEntropyLossInstFloatImpl,
2327 I->getP()->getElementType(), I);
2328}
2329
2330void BoundInterpreterFunction::fwdCrossEntropyLossGradInst(
2331 const CrossEntropyLossGradInst *I) {
2332 auto P = getWeightHandle(I->getP());
2333 auto Labels = getWeightHandle<int64_t>(I->getLabels());
2334 auto PGrad = getWeightHandle(I->getPgrad());
2335 auto dims = PGrad.dims();
2336 PGrad.clear();
2337 for (dim_t n = 0; n < dims[0]; ++n) {
2338 assert(Labels.raw(n) >= 0 && "Cannot use negative index.");
2339 dim_t y = Labels.raw(n);
2340 PGrad.at({n, y}) = -1 / P.at({n, y}); // * CEGrad.at({0})
2341 }
2342}
2343
2344//===----------------------------------------------------------------------===//
2345// Tensor shape (copy/transpose/concat/...)
2346//===----------------------------------------------------------------------===//
2347
2348void BoundInterpreterFunction::fwdCopyInst(const CopyInst *I) {
2349 auto inT = getTensor(I->getSrc());
2350 auto outT = getTensor(I->getDest());
2351 outT->copyRawFrom(inT);
2352}
2353
2354void BoundInterpreterFunction::fwdTransposeInst(const TransposeInst *I) {
2355 auto inT = getTensor(I->getSrc());
2356 (void)inT;
2357 auto outT = getTensor(I->getDest());
2358
2359 assert(outT->size() == inT->size() && "Invalid tensor dimensions");
2360
2361 if (I->getSrc()->getType()->isQuantizedType()) {
2362 inT->transpose(outT, I->getShuffle());
2363 } else {
2364 inT->transpose(outT, I->getShuffle());
2365 }
2366}
2367
2368void BoundInterpreterFunction::fwdTensorViewInst(const TensorViewInst *I) {
2369 getOrCreateUnownedTensor(I, I->getSrc(), I->getOffsets());
2370}
2371
2372void BoundInterpreterFunction::fwdSplatInst(const glow::SplatInst *I) {
2373 auto *T = getTensor(I->getDest());
2374 ElemKind k = T->getElementType();
2375
2376 if (k == ElemKind::Int32ITy) {
2377 return T->getHandle<int32_t>().clear(I->getValue());
2378 }
2379
2380 if (k == ElemKind::Int64ITy) {
2381 return T->getHandle<int64_t>().clear(I->getValue());
2382 }
2383
2384 if (k == ElemKind::Int32ITy) {
2385 return T->getHandle<int32_t>().clear(I->getValue());
2386 }
2387
2388 if (k == ElemKind::FloatTy) {
2389 return T->getHandle<float>().clear(I->getValue());
2390 }
2391
2392 if (k == ElemKind::Float16Ty) {
2393 return T->getHandle<float16_t>().clear(I->getValue());
2394 }
2395
2396 if (k == ElemKind::BFloat16Ty) {
2397 return T->getHandle<bfloat16_t>().clear(I->getValue());
2398 }
2399
2400 if (k == ElemKind::BoolTy) {
2401 return T->getHandle<bool>().clear(static_cast<bool>(I->getValue()));
2402 }
2403
2404 if (k == ElemKind::Int8QTy) {
2405 // Quantize the requested floating point splat value into the correct
2406 // integer representation.
2407 auto destTy = I->getDest()->getType();
2408 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
2409 float val = I->getValue();
2410 return T->getHandle<int8_t>().clear(quantization::quantize(val, destQ));
2411 }
2412
2413 if (k == ElemKind::Int16QTy) {
2414 // Quantize the requested floating point splat value into the correct
2415 // integer representation.
2416 auto destTy = I->getDest()->getType();
2417 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
2418 float val = I->getValue();
2419 return T->getHandle<int16_t>().clear(quantization::quantize(val, destQ));
2420 }
2421
2422 if (k == ElemKind::BoolTy) {
2423 return T->getHandle<bool>().clear(static_cast<bool>(I->getValue()));
2424 }
2425
2426 llvm_unreachable("Unsupported tensor type");
2427}
2428
2429void BoundInterpreterFunction::fwdTouchInst(const glow::TouchInst *) {
2430 // Do nothing for a TouchInst
2431}
2432
2433void BoundInterpreterFunction::fwdInsertTensorInst(
2434 const glow::InsertTensorInst *I) {
2435 Tensor *outT = getTensor(I->getDest());
2436 Tensor *inT = getTensor(I->getSrc());
2437 ElemKind k = outT->getElementType();
2438#define TYPED_INSERT(TY, TYPEKIND) \
2439 if (k == TYPEKIND) { \
2440 auto OH = outT->getHandle<TY>(); \
2441 auto IH = inT->getHandle<TY>(); \
2442 return OH.insertTensors(IH, I->getOffsets(), I->getCount(), I->getAxis()); \
2443 }
2444
2445 TYPED_INSERT(int64_t, ElemKind::Int64ITy);
2446 TYPED_INSERT(int32_t, ElemKind::Int32ITy);
2447 TYPED_INSERT(float, ElemKind::FloatTy);
2448 TYPED_INSERT(float16_t, ElemKind::Float16Ty);
2449 TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty);
2450 TYPED_INSERT(int8_t, ElemKind::Int8QTy);
2451 TYPED_INSERT(int16_t, ElemKind::Int16QTy);
2452 TYPED_INSERT(bool, ElemKind::BoolTy);
2453#undef TYPED_INSERT
2454
2455 llvm_unreachable("Unsupported tensor type");
2456}
2457
2458void BoundInterpreterFunction::fwdExtractTensorInst(
2459 const glow::ExtractTensorInst *I) {
2460 Tensor *outT = getTensor(I->getDest());
2461 Tensor *inT = getTensor(I->getSrc());
2462 ElemKind k = outT->getElementType();
2463#define TYPED_INSERT(TY, TYPEKIND) \
2464 if (k == TYPEKIND) { \
2465 auto OH = outT->getHandle<TY>(); \
2466 auto IH = inT->getHandle<TY>(); \
2467 return IH.extractTensors(OH, I->getOffsets()); \
2468 }
2469
2470 TYPED_INSERT(int64_t, ElemKind::Int64ITy);
2471 TYPED_INSERT(float, ElemKind::FloatTy);
2472 TYPED_INSERT(float16_t, ElemKind::Float16Ty);
2473 TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty);
2474 TYPED_INSERT(int8_t, ElemKind::Int8QTy);
2475 TYPED_INSERT(int32_t, ElemKind::Int32QTy);
2476 TYPED_INSERT(int32_t, ElemKind::Int32ITy);
2477 TYPED_INSERT(bool, ElemKind::BoolTy);
2478#undef TYPED_INSERT
2479
2480 llvm_unreachable("Unsupported tensor type");
2481}
2482
2483template <typename ElemTy>
2484void BoundInterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) {
2485 Tensor *dataT = getTensor(I->getData());
2486 auto &dataTy = dataT->getType();
2487 Tensor *indicesT = getTensor(I->getIndices());
2488 Tensor *outT = getTensor(I->getDest());
2489 unsigned_t axis = I->getBatchDims();
2490
2491 size_t out_p = 0;
2492 dim_t elementSize = dataTy.getElementSize();
2493 // The size of the sample in the batch.
2494 dim_t dataSampleSize = dataTy.getSliceSize(axis) * elementSize;
2495 // The size of the slices that we gather.
2496 dim_t dataSliceSize = dataTy.getSliceSize(axis + 1) * elementSize;
2497
2498 // Calculate the size of each sample in the batch.
2499 dim_t numSamples = (dataT->size() * elementSize) / dataSampleSize;
2500
2501 // Calculate number of samples in the batch.
2502 dim_t batchSize = dataTy.dims()[axis];
2503 (void)batchSize;
2504
2505 // For each sample in the batch:
2506 for (dim_t sample = 0; sample < numSamples; sample++) {
2507 dim_t sampleStart = sample * dataSampleSize;
2508
2509 // For each slice (small fragment) that we copy from the source memory:
2510 for (dim_t i = 0, end = indicesT->size(); i < end; i++) {
2511 dim_t slice = indicesT->getHandle<ElemTy>().raw(i);
2512 assert(slice < batchSize && "Invalid index seen during Gather operation");
2513 std::copy(
2514 &dataT->getUnsafePtr()[sampleStart + dataSliceSize * slice],
2515 &dataT->getUnsafePtr()[sampleStart + dataSliceSize * (slice + 1)],
2516 &outT->getUnsafePtr()[out_p]);
2517 out_p += dataSliceSize;
2518 }
2519 }
2520}
2521
2522void BoundInterpreterFunction::fwdGatherInst(const glow::GatherInst *I) {
2523 switch (I->getIndices()->getElementType()) {
2524 case ElemKind::Int64ITy:
2525 fwdGatherInstImpl<int64_t>(I);
2526 break;
2527 case ElemKind::Int32ITy:
2528 fwdGatherInstImpl<int32_t>(I);
2529 break;
2530 default:
2531 llvm_unreachable("Unsupported type for indices input of Gather.");
2532 }
2533}
2534
2535//===----------------------------------------------------------------------===//
2536// Gather Elements
2537//===----------------------------------------------------------------------===//
2538
2539template <typename IndexTy>
2540void BoundInterpreterFunction::fwdGatherElementsInstImpl(
2541 const glow::GatherElementsInst *I) {
2542 Tensor *outT = getTensor(I->getDest());
2543 Tensor *dataT = getTensor(I->getData());
2544 auto dataDims = dataT->getType().dims();
2545 Tensor *indicesT = getTensor(I->getIndices());
2546 auto indicesDims = indicesT->getType().dims();
2547 const auto dim = I->getDim();
2548 const auto numElems = outT->getRealNumElements();
2549 auto ndims = indicesDims.size();
2550
2551 std::vector<dim_t> ind_dim_off(ndims);
2552 std::vector<dim_t> data_dim_off(ndims);
2553
2554 ind_dim_off[0] = 1;
2555 data_dim_off[0] = 1;
2556 for (dim_t i = 1; i < ndims; i++) {
2557 ind_dim_off[i] =
2558 std::accumulate(indicesDims.begin() + ndims - i, indicesDims.end(), 1,
2559 std::multiplies<dim_t>());
2560 data_dim_off[i] =
2561 std::accumulate(dataDims.begin() + ndims - i, dataDims.end(), 1,
2562 std::multiplies<dim_t>());
2563 }
2564
2565 const auto dimPos = ndims - 1 - dim;
2566 const auto elemSize = dataT->getType().getElementSize();
2567 // Loop over number of elements in indices
2568 for (size_t idx = 0; idx < numElems; idx++) {
2569 unsigned_t offset = 0;
2570 // Loop over number of dimensions to calculate offset
2571 for (dim_t i = 0; i < ndims; i++) {
2572 // Calculate axis index i.e. (i, j, k)
2573 const dim_t dim_idx =
2574 static_cast<dim_t>(std::floor(idx / ind_dim_off[i])) %
2575 indicesDims[ndims - (i + 1)];
2576 offset += (dimPos != i) * dim_idx * data_dim_off[i];
2577 }
2578 auto indIdx = indicesT->getHandle<IndexTy>().raw(idx);
2579 assert(indIdx < 0 ? -indIdx <= dataDims[dim]
2580 : indIdx < dataDims[dim] &&
2581 "[GatherElements] Got out of bounds index");
2582 // In case of negative indices
2583 if (indIdx < 0) {
2584 indIdx += dataDims[dim];
2585 }
2586 const auto dataRawIdx = indIdx * data_dim_off[dimPos] + offset;
2587 std::copy(&dataT->getUnsafePtr()[dataRawIdx * elemSize],
2588 &dataT->getUnsafePtr()[(dataRawIdx + 1) * elemSize],
2589 &outT->getUnsafePtr()[idx * elemSize]);
2590 }
2591}
2592
2593void BoundInterpreterFunction::fwdGatherElementsInst(
2594 const glow::GatherElementsInst *I) {
2595 switch (I->getIndices()->getElementType()) {
2596 case ElemKind::Int64ITy:
2597 fwdGatherElementsInstImpl<int64_t>(I);
2598 break;
2599 case ElemKind::Int32ITy:
2600 fwdGatherElementsInstImpl<int32_t>(I);
2601 break;
2602 default:
2603 llvm_unreachable("[GatherElements] Unsupported type for indices input of "
2604 "GatherElements.");
2605 }
2606}
2607
2608template <typename ElemTy>
2609void BoundInterpreterFunction::fwdGatherNDInstImpl(
2610 const glow::GatherNDInst *I) {
2611
2612 Tensor *dataT = getTensor(I->getData());
2613 Tensor *indicesT = getTensor(I->getIndices());
2614 Tensor *outT = getTensor(I->getDest());
2615 auto batchDims = I->getBatchDims();
2616
2617 auto dataDims = I->getData()->dims();
2618 auto indicesDims = I->getIndices()->dims();
2619 dim_t indicesDimLast = indicesDims.back();
2620
2621 // Compute batch count.
2622 dim_t batchCount = 1;
2623 for (size_t idx = 0; idx < batchDims; ++idx) {
2624 batchCount *= dataDims[idx];
2625 }
2626
2627 // Compute input slice count.
2628 dim_t inpSliceCount = 1;
2629 for (size_t idx = batchDims; idx < batchDims + indicesDimLast; ++idx) {
2630 inpSliceCount *= dataDims[idx];
2631 }
2632
2633 // Compute output slice count.
2634 dim_t outSliceCount = 1;
2635 for (size_t idx = batchDims; idx < indicesDims.size() - 1; ++idx) {
2636 outSliceCount *= indicesDims[idx];
2637 }
2638
2639 // Compute slice size (in bytes).
2640 dim_t sliceSize = dataT->getType().getElementSize();
2641 for (size_t idx = batchDims + indicesDimLast; idx < dataDims.size(); idx++) {
2642 sliceSize *= dataDims[idx];
2643 }
2644
2645 // Get indices dimension products.
2646 std::vector<dim_t> indicesDimProd(indicesDimLast);
2647 indicesDimProd[indicesDimLast - 1] = 1;
2648 for (ssize_t idx = static_cast<ssize_t>(indicesDimLast) - 2; idx >= 0;
2649 idx--) {
2650 indicesDimProd[idx] =
2651 indicesDimProd[idx + 1] * dataDims[batchDims + idx + 1];
2652 }
2653
2654 // We will view the tensors as equivalent 3D tensors with the dimensions:
2655 // data - batchCount x inpSliceCount x sliceSize
2656 // indices - batchCount x outSliceCount x indicesDimLast
2657 // output - batchCount x outSliceCount x sliceSize
2658
2659 char *dataPtr = dataT->getUnsafePtr();
2660 ElemTy *indicesPtr = (ElemTy *)indicesT->getUnsafePtr();
2661 char *outPtr = outT->getUnsafePtr();
2662
2663 for (size_t batchIdx = 0; batchIdx < batchCount; ++batchIdx) {
2664 for (size_t outSliceIdx = 0; outSliceIdx < outSliceCount; ++outSliceIdx) {
2665
2666 // Compute input slice index.
2667 dim_t inpSliceIdx = 0;
2668 for (size_t idx = 0; idx < indicesDimLast; ++idx) {
2669 inpSliceIdx += (*indicesPtr++) * indicesDimProd[idx];
2670 }
2671
2672 // Copy data.
2673 std::copy(dataPtr + (inpSliceIdx + 0) * sliceSize,
2674 dataPtr + (inpSliceIdx + 1) * sliceSize, outPtr);
2675 outPtr += sliceSize;
2676 }
2677
2678 // Increment input pointer for next batch.
2679 dataPtr += inpSliceCount * sliceSize;
2680 }
2681}
2682
2683void BoundInterpreterFunction::fwdGatherNDInst(const glow::GatherNDInst *I) {
2684 switch (I->getIndices()->getElementType()) {
2685 case ElemKind::Int64ITy:
2686 fwdGatherNDInstImpl<int64_t>(I);
2687 break;
2688 case ElemKind::Int32ITy:
2689 fwdGatherNDInstImpl<int32_t>(I);
2690 break;
2691 default:
2692 llvm_unreachable("Unsupported type for indices input of Gather.");
2693 }
2694}
2695
2696template <typename ElemTy>
2697void BoundInterpreterFunction::fwdGatherRangesInstImpl(
2698 const glow::GatherRangesInst *I) {
2699 Tensor *dataT = getTensor(I->getData());
2700 auto &dataTy = dataT->getType();
2701 Tensor *rangesT = getTensor(I->getRanges());
2702 auto &rangesTy = rangesT->getType();
2703 Tensor *outT = getTensor(I->getOutput());
2704 Tensor *lengthsT = getTensor(I->getLengths());
2705
2706 // Offset into the output tensor that keeps track of where to start
2707 // copying data.
2708 size_t outP = 0;
2709
2710 unsigned dataElementSize = dataTy.getElementSize();
2711 dim_t numExamples = rangesTy.dims()[0];
2712 dim_t exampleSize = rangesTy.dims()[1];
2713
2714 // Keep track of the total number of elements gathered across all
2715 // examples for a sanity check later.
2716 dim_t grandTotalLen = 0;
2717
2718 // For each example in ranges:
2719 for (dim_t example = 0; example < numExamples; ++example) {
2720 // Keep a running total of the lengths of all ranges in this example
2721 // to record into lengthsT once the entire example is processed.
2722 ElemTy totalLen = 0;
2723
2724 // For each range in the example:
2725 for (dim_t range = 0; range < exampleSize; ++range) {
2726 // Get the start index and range length.
2727 ElemTy startIdx = rangesT->getHandle<ElemTy>().at({example, range, 0});
2728 ElemTy len = rangesT->getHandle<ElemTy>().at({example, range, 1});
2729
2730 // Add the length of this current range to the example length counter.
2731 totalLen += len;
2732
2733 // Compute the start and end offsets.
2734 dim_t startOffset = startIdx * dataElementSize;
2735 dim_t endOffset = startOffset + (len * dataElementSize);
2736
2737 // Sanity checks on the offsets.
2738 assert(startOffset < dataT->getSizeInBytes());
2739 assert(endOffset <= dataT->getSizeInBytes());
2740 assert(endOffset >= startOffset);
2741 assert(outP < outT->getSizeInBytes());
2742 assert((outP + (len * dataElementSize)) <= outT->getSizeInBytes());
2743
2744 // Copy the specified data to outT.
2745 std::copy(&dataT->getUnsafePtr()[startOffset],
2746 &dataT->getUnsafePtr()[endOffset], &outT->getUnsafePtr()[outP]);
2747
2748 // Advance the offset into outT.
2749 outP += len * dataElementSize;
2750 }
2751
2752 // Record the total number of elements gathered for the example in
2753 // lengthsT.
2754 lengthsT->getHandle<ElemTy>().at({example}) = totalLen;
2755
2756 // Add the total length of the entire example to the grand total.
2757 grandTotalLen += static_cast<size_t>(totalLen);
2758 }
2759
2760 // Make sure that number of elements written to outT is equal to the
2761 // total of all elements in lengthsT.
2762 assert(grandTotalLen == (outP / dataElementSize));
2763}
2764
2765void BoundInterpreterFunction::fwdGatherRangesInst(
2766 const glow::GatherRangesInst *I) {
2767 switch (I->getRanges()->getElementType()) {
2768 case ElemKind::Int64ITy:
2769 fwdGatherRangesInstImpl<int64_t>(I);
2770 break;
2771 case ElemKind::Int32ITy:
2772 fwdGatherRangesInstImpl<int32_t>(I);
2773 break;
2774 default:
2775 llvm_unreachable("Unsupported type for ranges input of GatherRanges.");
2776 }
2777}
2778
2779template <typename ElemTy, typename IndicesElemTy>
2780void BoundInterpreterFunction::fwdScatterDataInstCopyImpl(
2781 const glow::ScatterDataInst *I) {
2782 Tensor *dataT = getTensor(I->getData());
2783 Tensor *indicesT = getTensor(I->getIndices());
2784 Tensor *slicesT = getTensor(I->getSlices());
2785
2786 assert(indicesT->dims().size() == 2 &&
2787 "Index should be stored in 2D tensor!");
2788 const dim_t dataSliceSize = slicesT->size() / slicesT->dims()[0] *
2789 slicesT->getType().getElementSize();
2790
2791 auto IH = indicesT->getHandle<IndicesElemTy>();
2792 // For each index, copy from the slice at that index into the location in
2793 // data given the offset from the indices tensor.
2794 for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2795 dim_t destDataIdx = 0;
2796 for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2797 destDataIdx *= dataT->dims()[j];
2798 destDataIdx += IH.at({i, j});
2799 }
2800 std::copy(&slicesT->getUnsafePtr()[i * dataSliceSize],
2801 &slicesT->getUnsafePtr()[(i + 1) * dataSliceSize],
2802 &dataT->getUnsafePtr()[dataSliceSize * destDataIdx]);
2803 }
2804}
2805
2806template <typename ElemTy, typename IndicesElemTy>
2807void BoundInterpreterFunction::fwdScatterDataInstAddFloatImpl(
2808 const glow::ScatterDataInst *I) {
2809 Tensor *dataT = getTensor(I->getData());
2810 Tensor *indicesT = getTensor(I->getIndices());
2811 Tensor *slicesT = getTensor(I->getSlices());
2812
2813 assert(!dataT->getType().isQuantizedType() && "Should be float type!");
2814 assert(!slicesT->getType().isQuantizedType() && "Should be float type!");
2815
2816 const size_t numSlices = slicesT->size() / slicesT->dims()[0];
2817
2818 auto IH = indicesT->getHandle<IndicesElemTy>();
2819 // For each index, copy from the slice at that index into the location in
2820 // data given the offset from the indices tensor.
2821 assert(indicesT->dims().size() == 2 &&
2822 "Multi-dimensional index should be stored in 2D tensor!");
2823 auto D = dataT->getHandle<ElemTy>(), S = slicesT->getHandle<ElemTy>();
2824 for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2825 size_t destDataIdx = 0;
2826 for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2827 destDataIdx *= dataT->dims()[j];
2828 destDataIdx += IH.at({i, j});
2829 }
2830 for (dim_t j = 0; j < numSlices; j++) {
2831 D.raw(destDataIdx * numSlices + j) += S.raw(i * numSlices + j);
2832 }
2833 }
2834}
2835
2836template <typename ElemTy, typename IndicesElemTy>
2837void BoundInterpreterFunction::fwdScatterDataInstAddQuantizedImpl(
2838 const glow::ScatterDataInst *I) {
2839 Tensor *dataT = getTensor(I->getData());
2840 Tensor *indicesT = getTensor(I->getIndices());
2841 Tensor *slicesT = getTensor(I->getSlices());
2842
2843 assert(dataT->getType().isQuantizedType() && "Should be quantized type!");
2844 assert(slicesT->getType().isQuantizedType() && "Should be quantized type!");
2845
2846 const dim_t numSlices = slicesT->size() / slicesT->dims()[0];
2847
2848 TensorQuantizationParams dataQ{dataT->getType().getScale(),
2849 dataT->getType().getOffset()};
2850 TensorQuantizationParams sliceQ{slicesT->getType().getScale(),
2851 slicesT->getType().getOffset()};
2852
2853 auto IH = indicesT->getHandle<IndicesElemTy>();
2854 // For each index, copy from the slice at that index into the location in
2855 // data given the offset from the indices tensor.
2856 assert(indicesT->dims().size() == 2 &&
2857 "Multi-dimensional index should be stored in 2D tensor!");
2858 auto D = dataT->getHandle<ElemTy>(), S = slicesT->getHandle<ElemTy>();
2859 for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2860 dim_t destDataIdx = 0;
2861 for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2862 destDataIdx *= dataT->dims()[j];
2863 destDataIdx += IH.at({i, j});
2864 }
2865 for (dim_t j = 0; j < numSlices; j++) {
2866 float lhs =
2867 quantization::dequantize(D.raw(destDataIdx * numSlices + j), dataQ);
2868 float rhs = quantization::dequantize(S.raw(i * numSlices + j), sliceQ);
2869 ElemTy result = quantization::quantize(lhs + rhs, dataQ);
2870 D.raw(destDataIdx * numSlices + j) = result;
2871 }
2872 }
2873}
2874
2875void BoundInterpreterFunction::fwdScatterDataInst(
2876 const glow::ScatterDataInst *I) {
2877 const auto indicesAreInt64 =
2878 I->getIndices()->getElementType() == ElemKind::Int64ITy;
2879
2880 if (I->getCumulative()) {
2881 switch (I->getData()->getElementType()) {
2882 case ElemKind::FloatTy:
2883 if (indicesAreInt64) {
2884 fwdScatterDataInstAddFloatImpl<float, int64_t>(I);
2885 } else {
2886 fwdScatterDataInstAddFloatImpl<float, int32_t>(I);
2887 }
2888
2889 break;
2890 case ElemKind::Int8QTy:
2891 if (indicesAreInt64) {
2892 fwdScatterDataInstAddQuantizedImpl<int8_t, int64_t>(I);
2893 } else {
2894 fwdScatterDataInstAddQuantizedImpl<int8_t, int32_t>(I);
2895 }
2896 break;
2897 default:
2898 llvm_unreachable("Unsupported type for ScatterData.");
2899 }
2900 } else {
2901 switch (I->getData()->getElementType()) {
2902 case ElemKind::FloatTy:
2903 if (indicesAreInt64) {
2904 fwdScatterDataInstCopyImpl<float, int64_t>(I);
2905 } else {
2906 fwdScatterDataInstCopyImpl<float, int32_t>(I);
2907 }
2908 break;
2909 case ElemKind::Int8QTy:
2910 if (indicesAreInt64) {
2911 fwdScatterDataInstCopyImpl<int8_t, int64_t>(I);
2912 } else {
2913 fwdScatterDataInstCopyImpl<int8_t, int32_t>(I);
2914 }
2915 break;
2916 default:
2917 llvm_unreachable("Unsupported type for ScatterData.");
2918 }
2919 }
2920}
2921
2922template <typename ElemTy>
2923void BoundInterpreterFunction::fwdBatchOneHotImpl(
2924 const glow::BatchOneHotInst *I) {
2925 auto dataH = getWeightHandle<ElemTy>(I->getData());
2926 auto lengthsH = getWeightHandle<int32_t>(I->getLengths());
2927 auto valuesH = getWeightHandle<ElemTy>(I->getValues());
2928 auto destH = getWeightHandle<ElemTy>(I->getDest());
2929
2930 auto batchSize = dataH.dims()[0];
2931 auto featureCnt = dataH.dims()[1];
2932
2933 for (dim_t batchId = 0; batchId < batchSize; batchId++) {
2934 size_t offset = 0;
2935 for (dim_t featureId = 0; featureId < featureCnt; featureId++) {
2936 auto curValue = dataH.at({batchId, featureId});
2937 auto curLength = lengthsH.at({featureId});
2938 for (dim_t i = offset, e = offset + curLength; i != e; i++) {
2939 destH.at({batchId, i}) = curValue == valuesH.at({i});
2940 }
2941 offset += curLength;
2942 }
2943 assert(offset == destH.dims()[1] &&
2944 "Sum of Lengths must be equal to size of Values");
2945 }
2946}
2947
2948void BoundInterpreterFunction::fwdBatchOneHotInst(
2949 const glow::BatchOneHotInst *I) {
2950 switch (I->getData()->getElementType()) {
2951 case ElemKind::Int64ITy:
2952 fwdBatchOneHotImpl<int64_t>(I);
2953 break;
2954 case ElemKind::Int32ITy:
2955 fwdBatchOneHotImpl<int32_t>(I);
2956 break;
2957 case ElemKind::Int8QTy:
2958 fwdBatchOneHotImpl<int8_t>(I);
2959 break;
2960 default:
2961 dispatchFloatingPointImpl(fwdBatchOneHotImpl,
2962 I->getData()->getElementType(), I);
2963 }
2964}
2965
2966template <typename ElemTy>
2967void BoundInterpreterFunction::fwdSpaceToDepthInstImpl(
2968 const glow::SpaceToDepthInst *I) {
2969 auto *inT = getTensor(I->getSrc());
2970 auto *outT = getTensor(I->getDest());
2971
2972 auto inH = inT->getHandle<ElemTy>();
2973 auto outH = outT->getHandle<ElemTy>();
2974
2975 unsigned blockSize = I->getBlockSize();
2976
2977 dim_t inDepth = inT->dims()[3];
2978
2979 dim_t outBatch = outT->dims()[0];
2980 dim_t outHeight = outT->dims()[1];
2981 dim_t outWidth = outT->dims()[2];
2982 dim_t outDepth = outT->dims()[3];
2983
2984 for (dim_t ob = 0; ob < outBatch; ++ob) {
2985 for (dim_t oh = 0; oh < outHeight; ++oh) {
2986 for (dim_t ow = 0; ow < outWidth; ++ow) {
2987 for (dim_t oc = 0; oc < outDepth; ++oc) {
2988 // Gets the block layer we are on
2989 dim_t blockDepthLayer = oc / inDepth;
2990 // every multiple of block size we reset to 0 offset
2991 dim_t iw = ow * blockSize + blockDepthLayer % blockSize;
2992 // every multiple of blockSize we start height traversal + 1
2993 dim_t ih = oh * blockSize + blockDepthLayer / blockSize;
2994 // at every multiple of inDepth index in to input depths resets to 0
2995 dim_t ic = oc % inDepth;
2996
2997 outH.at({ob, oh, ow, oc}) = inH.at({ob, ih, iw, ic});
2998 }
2999 }
3000 }
3001 }
3002}
3003
3004void BoundInterpreterFunction::fwdSpaceToDepthInst(
3005 const glow::SpaceToDepthInst *I) {
3006 switch (I->getSrc()->getElementType()) {
3007 case ElemKind::FloatTy:
3008 fwdSpaceToDepthInstImpl<float>(I);
3009 break;
3010 case ElemKind::Int8QTy:
3011 fwdSpaceToDepthInstImpl<int8_t>(I);
3012 break;
3013 default:
3014 llvm_unreachable("Type is not supported");
3015 break;
3016 }
3017}
3018
3019template <typename ElemTy>
3020void BoundInterpreterFunction::fwdResizeNearestInstImpl(
3021 const ResizeNearestInst *I) {
3022 auto inW = getWeightHandle<ElemTy>(I->getSrc());
3023 auto scale = I->getScale();
3024 auto outW = getWeightHandle<ElemTy>(I->getDest());
3025
3026 auto outputDims = outW.dims();
3027 auto inputDims = inW.dims();
3028
3029 for (dim_t oa = 0; oa < outputDims[0]; ++oa) {
3030 auto ia = std::min(dim_t(oa / scale[0]), inputDims[0] - 1);
3031 for (dim_t ob = 0; ob < outputDims[1]; ++ob) {
3032 auto ib = std::min(dim_t(ob / scale[1]), inputDims[1] - 1);
3033 for (dim_t oc = 0; oc < outputDims[2]; ++oc) {
3034 auto ic = std::min(dim_t(oc / scale[2]), inputDims[2] - 1);
3035 if (outputDims.size() > 3) {
3036 for (dim_t od = 0; od < outputDims[3]; ++od) {
3037 auto id = std::min(dim_t(od / scale[3]), inputDims[3] - 1);
3038 if (outputDims.size() > 4) {
3039 for (dim_t oe = 0; oe < outputDims[4]; ++oe) {
3040 auto ie = std::min(dim_t(oe / scale[4]), inputDims[4] - 1);
3041 if (outputDims.size() > 5) {
3042 for (dim_t of = 0; of < outputDims[4]; ++of) {
3043 auto f = std::min(dim_t(of / scale[5]), inputDims[5] - 1);
3044 outW.at({oa, ob, oc, od, oe, of}) =
3045 inW.at({ia, ib, ic, id, ie, f});
3046 }
3047 } else {
3048 outW.at({oa, ob, oc, od, oe}) = inW.at({ia, ib, ic, id, ie});
3049 }
3050 }
3051 } else {
3052 outW.at({oa, ob, oc, od}) = inW.at({ia, ib, ic, id});
3053 }
3054 }
3055 } else {
3056 outW.at({oa, ob, oc}) = inW.at({ia, ib, ic});
3057 }
3058 }
3059 }
3060 }
3061}
3062
3063void BoundInterpreterFunction::fwdResizeNearestInst(
3064 const ResizeNearestInst *I) {
3065 if (getTensor(I->getSrc())->getType().isQuantizedType()) {
3066 dispatchQuantizedImpl(fwdResizeNearestInstImpl,
3067 I->getSrc()->getElementType(), I);
3068 return;
3069 }
3070
3071 dispatchImpl(fwdResizeNearestInstImpl, I->getSrc()->getElementType(), I);
3072}
3073
3074template <typename ElemTy>
3075void BoundInterpreterFunction::fwdResizeBilinearInstImpl(
3076 const ResizeBilinearInst *I) {
3077 auto inW = getWeightHandle<ElemTy>(I->getSrc());
3078 auto scale = I->getScale();
3079 auto outW = getWeightHandle<ElemTy>(I->getDest());
3080
3081 ShapeNHWC odim(outW.dims());
3082 ShapeNHWC idim(inW.dims());
3083
3084 CHECK_EQ(scale[0], 1.0) << "Scaling batch not supported.";
3085 CHECK_EQ(scale[3], 1.0) << "Scaling channel not supported.";
3086
3087 for (dim_t ob = 0; ob < odim.n; ++ob) {
3088 for (dim_t oh = 0; oh < odim.h; ++oh) {
3089 for (dim_t ow = 0; ow < odim.w; ++ow) {
3090
3091 float ihf = oh / scale[1];
3092 float iwf = ow / scale[2];
3093 dim_t ih = dim_t(ihf);
3094 dim_t iw = dim_t(iwf);
3095
3096 auto ih0 = std::min(ih, idim.h - 1);
3097 auto ih1 = std::min(ih + 1, idim.h - 1);
3098 auto iw0 = std::min(iw, idim.w - 1);
3099 auto iw1 = std::min(iw + 1, idim.w - 1);
3100
3101 for (dim_t oc = 0; oc < odim.c; ++oc) {
3102 auto v00 = inW.at({ob, ih0, iw0, oc});
3103 auto v01 = inW.at({ob, ih0, iw1, oc});
3104 auto v10 = inW.at({ob, ih1, iw0, oc});
3105 auto v11 = inW.at({ob, ih1, iw1, oc});
3106
3107 auto hd = (float)v00 + (float)(v10 - v00) * (ihf - ih);
3108 auto hw = (float)v01 + (float)(v11 - v01) * (ihf - ih);
3109 float result = hd + (hw - hd) * (iwf - iw);
3110 outW.at({ob, oh, ow, oc}) = result;
3111 }
3112 }
3113 }
3114 }
3115}
3116
3117void BoundInterpreterFunction::fwdResizeBilinearInst(
3118 const ResizeBilinearInst *I) {
3119 if (getTensor(I->getSrc())->getType().isQuantizedType()) {
3120 dispatchQuantizedImpl(fwdResizeBilinearInstImpl,
3121 I->getSrc()->getElementType(), I);
3122 return;
3123 }
3124
3125 dispatchImpl(fwdResizeBilinearInstImpl, I->getSrc()->getElementType(), I);
3126}
3127
3128//===----------------------------------------------------------------------===//
3129// Local Response Normalization
3130//===----------------------------------------------------------------------===//
3131
3132template <typename ElemTy>
3133void BoundInterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl(
3134 const glow::LocalResponseNormalizationInst *I) {
3135 staticAssertFloatingPointType(ElemTy);
3136
3137 auto inW = getWeightHandle<ElemTy>(I->getSrc());
3138 auto outW = getWeightHandle<ElemTy>(I->getDest());
3139 auto scaleCache = getWeightHandle<ElemTy>(I->getScale());
3140
3141 ShapeNHWC odim(outW.dims());
3142 ShapeNHWC idim(inW.dims());
3143
3144 (void)odim;
3145
3146 // LRN node does not change the shape of the input.
3147 assert(odim == idim && "Output of LRN node must be same shape as input");
3148
3149 // LRN node normalizes across channels, so the input must have a minimum
3150 // depth of 1.
3151 assert(idim.c > 0 && "Input of LRN node must have a minimum depth of 1");
3152
3153 auto halfWindowSize = (size_t)I->getHalfWindowSize();
3154 auto k = I->getK();
3155 auto beta = I->getBeta();
3156 auto windowSize = 2 * halfWindowSize + 1;
3157 auto normedAlpha = I->getAlpha() / windowSize;
3158
3159 // For every input in the batch:
3160 for (dim_t n = 0; n < idim.n; n++) {
3161
3162 // For every row:
3163 for (dim_t h = 0; h < idim.h; h++) {
3164
3165 // For every column:
3166 for (dim_t w = 0; w < idim.w; w++) {
3167
3168 // For every channel:
3169 for (dim_t c = 0; c < idim.c; c++) {
3170 float squareSum = 0.0;
3171 for (dim_t i = (c >= halfWindowSize ? c - halfWindowSize : 0);
3172 i <= std::min<dim_t>(c + halfWindowSize, (size_t)idim.c - 1);
3173 i++) {
3174 float val = inW.at({n, h, w, i});
3175 squareSum += val * val;
3176 }
3177
3178 auto scale = k + normedAlpha * squareSum;
3179
3180 // This will be used to accelerate the backward pass.
3181 scaleCache.at({n, h, w, c}) = ElemTy(scale);
3182
3183 auto normFactor = std::pow(scale, -beta);
3184 outW.at({n, h, w, c}) =
3185 ElemTy(float(inW.at({n, h, w, c})) * normFactor);
3186 }
3187 }
3188 }
3189 }
3190}
3191
3192void BoundInterpreterFunction::fwdLocalResponseNormalizationInst(
3193 const LocalResponseNormalizationInst *I) {
3194 dispatchFloatingPointImpl(fwdLocalResponseNormalizationInstFloatImpl,
3195 I->getSrc()->getElementType(), I);
3196}
3197
3198void BoundInterpreterFunction::fwdLocalResponseNormalizationGradInst(
3199 const glow::LocalResponseNormalizationGradInst *I) {
3200 auto inW = getWeightHandle(I->getSrc());
3201 auto inG = getWeightHandle(I->getSrcGrad());
3202 auto outW = getWeightHandle(I->getDest());
3203 auto outG = getWeightHandle(I->getDestGrad());
3204 auto scaleCache = getWeightHandle(I->getScale());
3205
3206 ShapeNHWC odim(outW.dims());
3207
3208 auto halfWindowSize = I->getHalfWindowSize();
3209 auto beta = I->getBeta();
3210 auto windowSize = 2 * halfWindowSize + 1;
3211 auto normedAlpha = I->getAlpha() / windowSize;
3212
3213 // For every input in the batch:
3214 for (dim_t n = 0; n < odim.n; n++) {
3215
3216 // For every row:
3217 for (dim_t h = 0; h < odim.h; h++) {
3218
3219 // For every column:
3220 for (dim_t w = 0; w < odim.w; w++) {
3221
3222 float sum = 0.0;
3223
3224 // Compute sum for first channel.
3225 for (dim_t c = 0; c <= halfWindowSize && c < odim.c; c++) {
3226 auto outw = outW.at({n, h, w, c});
3227 auto scale = scaleCache.at({n, h, w, c});
3228 auto outg = outG.at({n, h, w, c});
3229 sum += (outg * (outw / scale));
3230 }
3231
3232 // For every channel:
3233 for (dim_t c = 0; c < odim.c; c++) {
3234 auto outg = outG.at({n, h, w, c});
3235 auto scale = scaleCache.at({n, h, w, c});
3236 auto inw = inW.at({n, h, w, c});
3237
3238 inG.at({n, h, w, c}) = outg * std::pow(scale, -beta) -
3239 2 * normedAlpha * beta * inw * sum;
3240
3241 // Modify sum for next channel.
3242 auto subIndex = c - halfWindowSize;
3243 auto addIndex = c + halfWindowSize + 1;
3244
3245 if (c >= halfWindowSize) {
3246 auto outw = outW.at({n, h, w, subIndex});
3247 auto scale = scaleCache.at({n, h, w, subIndex});
3248 auto outg = outG.at({n, h, w, subIndex});
3249
3250 // Subtract "rear" end of this window.
3251 sum -= (outg * (outw / scale));
3252 }
3253
3254 if (addIndex < odim.c) {
3255 auto outw = outW.at({n, h, w, addIndex});
3256 auto scale = scaleCache.at({n, h, w, addIndex});
3257 auto outg = outG.at({n, h, w, addIndex});
3258
3259 // Add "front" end of next window.
3260 sum += (outg * (outw / scale));
3261 }
3262 }
3263 }
3264 }
3265 }
3266}
3267
3268//===--------------------------------------------------------------------===//
3269// Bucketing
3270//===--------------------------------------------------------------------===//
3271
3272void BoundInterpreterFunction::fwdBucketizeInst(const BucketizeInst *I) {
3273 auto inputH = getTensor(I->getSrc())->getHandle<float>();
3274 auto outputH = getTensor(I->getDest())->getHandle<int32_t>();
3275 const auto boundaries = I->getBoundaries();
3276
3277 const auto numItems = inputH.size();
3278
3279 for (size_t i = 0; i < numItems; ++i) {
3280 outputH.raw(i) =
3281 std::lower_bound(boundaries.begin(), boundaries.end(), inputH.raw(i)) -
3282 boundaries.begin();
3283 }
3284}
3285
3286//===----------------------------------------------------------------------===//
3287// Arithmetic operations
3288//===----------------------------------------------------------------------===//
3289void BoundInterpreterFunction::fwdElementAddInstI8Impl(
3290 const ElementAddInst *I) {
3291 assert(getTensor(I->getLHS())->getType().isQuantizedType() &&
3292 "Wrong function");
3293 auto lhsTy = I->getLHS()->getType();
3294 auto rhsTy = I->getRHS()->getType();
3295 auto destTy = I->getDest()->getType();
3296
3297 float lhsScale = lhsTy->getScale();
3298 float rhsScale = rhsTy->getScale();
3299 float destScale = destTy->getScale();
3300
3301 int32_t lhsOffset = lhsTy->getOffset();
3302 int32_t rhsOffset = rhsTy->getOffset();
3303 int32_t destOffset = destTy->getOffset();
3304
3305 auto outW = getWeightHandle<int8_t>(I->getDest());
3306 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3307 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3308 for (dim_t i = 0, e = outW.size(); i < e; i++) {
3309 int32_t L = lhsW.raw(i);
3310 int32_t R = rhsW.raw(i);
3311
3312 // We increase the size of the integer up to 16 bits to prevent overflow.
3313 const float largeScale = float(1) / (1 << 15);
3314 // Scale both sides from 8-bit to 16-bits.
3315 int32_t L32 = std::round(float(L - lhsOffset) * (lhsScale / largeScale));
3316 int32_t R32 = std::round(float(R - rhsOffset) * (rhsScale / largeScale));
3317 int32_t sum32 = L32 + R32;
3318 sum32 = std::round(float(sum32) * (largeScale / destScale) + destOffset);
3319 outW.raw(i) = quantization::clip<int32_t, int8_t>(sum32);
3320 }
3321}
3322
3323template <typename ElemTy>
3324void BoundInterpreterFunction::fwdElementAddInstArithmeticImpl(
3325 const ElementAddInst *I) {
3326 staticAssertArithmeticType(ElemTy);
3327
3328 auto outW = getWeightHandle<ElemTy>(I->getDest());
3329 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3330 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3331 for (size_t i = 0, e = outW.size(); i < e; i++) {
3332 outW.raw(i) = lhsW.raw(i) + rhsW.raw(i);
3333 }
3334}
3335
3336void BoundInterpreterFunction::fwdElementAddInst(const ElementAddInst *I) {
3337 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3338 fwdElementAddInstI8Impl(I);
3339 return;
3340 }
3341
3342 dispatchArithmeticImpl(fwdElementAddInstArithmeticImpl,
3343 I->getLHS()->getType()->getElementType(), I);
3344}
3345
3346template <typename ElemTy>
3347void BoundInterpreterFunction::fwdElementSubInstArithmeticImpl(
3348 const ElementSubInst *I) {
3349 staticAssertArithmeticType(ElemTy);
3350
3351 auto outW = getWeightHandle<ElemTy>(I->getDest());
3352 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3353 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3354 for (size_t i = 0, e = outW.size(); i < e; i++) {
3355 outW.raw(i) = lhsW.raw(i) - rhsW.raw(i);
3356 }
3357}
3358
3359void BoundInterpreterFunction::fwdElementSubInst(const ElementSubInst *I) {
3360 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3361 auto destTy = I->getDest()->getType();
3362 auto lhsTy = I->getLHS()->getType();
3363 auto rhsTy = I->getRHS()->getType();
3364
3365 float destScale = destTy->getScale();
3366 float lhsScale = lhsTy->getScale();
3367 float rhsScale = rhsTy->getScale();
3368
3369 int32_t destOffset = destTy->getOffset();
3370 int32_t lhsOffset = lhsTy->getOffset();
3371 int32_t rhsOffset = rhsTy->getOffset();
3372
3373 auto outW = getWeightHandle<int8_t>(I->getDest());
3374 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3375 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3376 for (size_t i = 0, e = outW.size(); i < e; i++) {
3377 // s_d * (i_d - o_d) = s_l * (i_l - o_l) - s_r * (i_r - o_r)
3378 // => i_d = (s_l / s_d) * (i_l - o_l) - (s_r / s_d) * (i_r - o_r) + o_d
3379 float l = (lhsScale / destScale) * float(lhsW.raw(i) - lhsOffset);
3380 float r = (rhsScale / destScale) * float(rhsW.raw(i) - rhsOffset);
3381 int32_t q = std::round(l - r + destOffset);
3382 outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
3383 }
3384 return;
3385 }
3386
3387 dispatchArithmeticImpl(fwdElementSubInstArithmeticImpl,
3388 I->getDest()->getElementType(), I);
3389}
3390
3391template <typename ElemTy>
3392void BoundInterpreterFunction::fwdElementMulInstArithmeticImpl(
3393 const ElementMulInst *I) {
3394 staticAssertArithmeticType(ElemTy);
3395
3396 auto outW = getWeightHandle<ElemTy>(I->getDest());
3397 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3398 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3399 for (size_t i = 0, e = outW.size(); i < e; i++) {
3400 outW.raw(i) = lhsW.raw(i) * rhsW.raw(i);
3401 }
3402}
3403
3404void BoundInterpreterFunction::fwdElementMulInst(const ElementMulInst *I) {
3405 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3406 auto lhsTy = I->getLHS()->getType();
3407 auto rhsTy = I->getRHS()->getType();
3408 auto destTy = I->getDest()->getType();
3409
3410 TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
3411 TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
3412 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
3413
3414 auto outW = getWeightHandle<int8_t>(I->getDest());
3415 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3416 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3417 float scale = lhsQ.scale * rhsQ.scale / destQ.scale;
3418 for (size_t i = 0, e = outW.size(); i < e; i++) {
3419 int32_t mul = (lhsW.raw(i) - lhsQ.offset) * (rhsW.raw(i) - rhsQ.offset);
3420 outW.raw(i) = quantization::clip<int32_t, int8_t>(
3421 std::round(mul * scale) + destQ.offset);
3422 }
3423 return;
3424 }
3425
3426 dispatchArithmeticImpl(fwdElementMulInstArithmeticImpl,
3427 I->getDest()->getElementType(), I);
3428}
3429
3430void BoundInterpreterFunction::fwdElementFmodInst(const ElementFmodInst *I) {
3431 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3432 auto destTy = I->getDest()->getType();
3433 auto lhsTy = I->getLHS()->getType();
3434 auto rhsTy = I->getRHS()->getType();
3435
3436 float destScale = destTy->getScale();
3437 float lhsScale = lhsTy->getScale();
3438 float rhsScale = rhsTy->getScale();
3439
3440 int32_t destOffset = destTy->getOffset();
3441 int32_t lhsOffset = lhsTy->getOffset();
3442 int32_t rhsOffset = rhsTy->getOffset();
3443
3444 auto outW = getWeightHandle<int8_t>(I->getDest());
3445 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3446 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3447 for (size_t i = 0, e = outW.size(); i < e; i++) {
3448 float l = lhsScale * float(lhsW.raw(i) - lhsOffset);
3449 float r = rhsScale * destScale * float(rhsW.raw(i) - rhsOffset);
3450 int32_t q = std::round(std::fmod(l, r) + destOffset);
3451 outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
3452 }
3453 return;
3454 }
3455
3456#define FMOD_LOOP(TYPE_) \
3457 staticAssertArithmeticType(TYPE_); \
3458 auto outW = getWeightHandle<TYPE_>(I->getDest()); \
3459 auto lhsW = getWeightHandle<TYPE_>(I->getLHS()); \
3460 auto rhsW = getWeightHandle<TYPE_>(I->getRHS()); \
3461 for (size_t i = 0, e = outW.size(); i < e; i++) { \
3462 outW.raw(i) = std::fmod((float)lhsW.raw(i), (float)rhsW.raw(i)); \
3463 }
3464
3465 auto *T = getTensor(I->getDest());
3466 switch (T->getElementType()) {
3467 case ElemKind::Int64ITy: {
3468 FMOD_LOOP(int64_t);
3469 return;
3470 }
3471 case ElemKind::Int32ITy: {
3472 FMOD_LOOP(int32_t);
3473 return;
3474 }
3475 case ElemKind::FloatTy: {
3476 FMOD_LOOP(float);
3477 return;
3478 }
3479 case ElemKind::Float16Ty: {
3480 FMOD_LOOP(float16_t);
3481 return;
3482 }
3483 case ElemKind::BFloat16Ty: {
3484 FMOD_LOOP(bfloat16_t);
3485 return;
3486 }
3487
3488 default:
3489 llvm_unreachable("Unsupported type for Fmod.");
3490 }
3491}
3492
3493void BoundInterpreterFunction::fwdElementDivInst(const ElementDivInst *I) {
3494 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3495 auto destTy = I->getDest()->getType();
3496 auto lhsTy = I->getLHS()->getType();
3497 auto rhsTy = I->getRHS()->getType();
3498
3499 float destScale = destTy->getScale();
3500 float lhsScale = lhsTy->getScale();
3501 float rhsScale = rhsTy->getScale();
3502
3503 int32_t destOffset = destTy->getOffset();
3504 int32_t lhsOffset = lhsTy->getOffset();
3505 int32_t rhsOffset = rhsTy->getOffset();
3506
3507 auto outW = getWeightHandle<int8_t>(I->getDest());
3508 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3509 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3510 for (size_t i = 0, e = outW.size(); i < e; i++) {
3511 // s_d * (i_d - o_d) = (s_l * (i_l - o_l)) / (s_r * (i_r - o_r))
3512 // => i_d = (s_l * (i_l - o_l)) / (s_d * s_r * (i_r - o_r)) + o_d
3513 float l = lhsScale * float(lhsW.raw(i) - lhsOffset);
3514 float r = rhsScale * destScale * float(rhsW.raw(i) - rhsOffset);
3515 int32_t q = std::round(l / r + destOffset);
3516 outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
3517 }
3518 return;
3519 }
3520
3521#define DIV_LOOP(TYPE_) \
3522 auto outW = getWeightHandle<TYPE_>(I->getDest()); \
3523 auto lhsW = getWeightHandle<TYPE_>(I->getLHS()); \
3524 auto rhsW = getWeightHandle<TYPE_>(I->getRHS()); \
3525 for (size_t i = 0, e = outW.size(); i < e; i++) { \
3526 outW.raw(i) = lhsW.raw(i) / rhsW.raw(i); \
3527 }
3528
3529 auto *T = getTensor(I->getDest());
3530 switch (T->getElementType()) {
3531 case ElemKind::Int64ITy: {
3532 DIV_LOOP(int64_t);
3533 return;
3534 }
3535 case ElemKind::Int32ITy: {
3536 DIV_LOOP(int32_t);
3537 return;
3538 }
3539 case ElemKind::FloatTy: {
3540 DIV_LOOP(float);
3541 return;
3542 }
3543 case ElemKind::Float16Ty: {
3544 DIV_LOOP(float16_t);
3545 return;
3546 }
3547 case ElemKind::BFloat16Ty: {
3548 DIV_LOOP(bfloat16_t);
3549 return;
3550 }
3551 default:
3552 llvm_unreachable("Unsupported type for Div.");
3553 }
3554}
3555
3556void BoundInterpreterFunction::fwdElementMaxInstI8Impl(
3557 const ElementMaxInst *I) {
3558 assert(getTensor(I->getLHS())->getType().isQuantizedType() &&
3559 "Wrong function");
3560 auto lhsTy = I->getLHS()->getType();
3561 auto rhsTy = I->getRHS()->getType();
3562 auto destTy = I->getDest()->getType();
3563
3564 TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
3565 TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
3566 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
3567
3568 auto outW = getWeightHandle<int8_t>(I->getDest());
3569 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3570 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3571 for (size_t i = 0, e = outW.size(); i < e; i++) {
3572 // Convert both sides to the destination scale and perform a regular
3573 // comparison.
3574 int8_t L = quantization::quantize(
3575 quantization::dequantize(lhsW.raw(i), lhsQ), destQ);
3576 int8_t R = quantization::quantize(
3577 quantization::dequantize(rhsW.raw(i), rhsQ), destQ);
3578 outW.raw(i) = std::max(L, R);
3579 }
3580}
3581
3582template <typename ElemTy>
3583void BoundInterpreterFunction::fwdElementMaxInstArithmeticImpl(
3584 const ElementMaxInst *I) {
3585 staticAssertArithmeticType(ElemTy);
3586
3587 auto outW = getWeightHandle<ElemTy>(I->getDest());
3588 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3589 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3590 for (size_t i = 0, e = outW.size(); i < e; i++) {
3591 outW.raw(i) = std::max(lhsW.raw(i), rhsW.raw(i));
3592 }
3593}
3594
3595void BoundInterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) {
3596 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3597 fwdElementMaxInstI8Impl(I);
3598 return;
3599 }
3600
3601 dispatchArithmeticImpl(fwdElementMaxInstArithmeticImpl,
3602 I->getLHS()->getType()->getElementType(), I);
3603}
3604
3605template <typename ElemTy>
3606void BoundInterpreterFunction::fwdElementMinInstArithmeticImpl(
3607 const ElementMinInst *I) {
3608 staticAssertArithmeticType(ElemTy);
3609
3610 auto outW = getWeightHandle<ElemTy>(I->getDest());
3611 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3612 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3613 for (size_t i = 0, e = outW.size(); i < e; i++) {
3614 outW.raw(i) = std::min(lhsW.raw(i), rhsW.raw(i));
3615 }
3616}
3617
3618void BoundInterpreterFunction::fwdElementMinInst(const ElementMinInst *I) {
3619 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3620 auto lhsTy = I->getLHS()->getType();
3621 auto rhsTy = I->getRHS()->getType();
3622 auto destTy = I->getDest()->getType();
3623
3624 TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
3625 TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
3626 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
3627
3628 auto outW = getWeightHandle<int8_t>(I->getDest());
3629 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3630 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3631 for (size_t i = 0, e = outW.size(); i < e; i++) {
3632 // Convert both sides to the destination scale and perform a regular
3633 // comparison.
3634 int8_t L = quantization::quantize(
3635 quantization::dequantize(lhsW.raw(i), lhsQ), destQ);
3636 int8_t R = quantization::quantize(
3637 quantization::dequantize(rhsW.raw(i), rhsQ), destQ);
3638 outW.raw(i) = std::min(L, R);
3639 }
3640 return;
3641 }
3642
3643 dispatchArithmeticImpl(fwdElementMinInstArithmeticImpl,
3644 I->getDest()->getElementType(), I);
3645}
3646
3647template <typename ElemTy>
3648void BoundInterpreterFunction::fwdElementBitwiseOrInstImpl(
3649 const ElementBitwiseOrInst *I) {
3650
3651 auto outW = getWeightHandle<ElemTy>(I->getDest());
3652 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3653 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3654 for (size_t i = 0, e = outW.size(); i < e; i++) {
3655 outW.raw(i) = lhsW.raw(i) | rhsW.raw(i);
3656 }
3657}
3658
3659void BoundInterpreterFunction::fwdElementBitwiseOrInst(
3660 const ElementBitwiseOrInst *I) {
3661
3662 dispatchBitwiseImpl(fwdElementBitwiseOrInstImpl,
3663 I->getDest()->getElementType(), I);
3664}
3665
3666template <typename ElemTy>
3667void BoundInterpreterFunction::fwdElementBitwiseAndInstImpl(
3668 const ElementBitwiseAndInst *I) {
3669
3670 auto outW = getWeightHandle<ElemTy>(I->getDest());
3671 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3672 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3673 for (size_t i = 0, e = outW.size(); i < e; i++) {
3674 outW.raw(i) = lhsW.raw(i) & rhsW.raw(i);
3675 }
3676}
3677
3678void BoundInterpreterFunction::fwdElementBitwiseAndInst(
3679 const ElementBitwiseAndInst *I) {
3680
3681 dispatchBitwiseImpl(fwdElementBitwiseAndInstImpl,
3682 I->getDest()->getElementType(), I);
3683}
3684
3685template <typename ElemTy>
3686void BoundInterpreterFunction::fwdElementBitwiseXorInstImpl(
3687 const ElementBitwiseXorInst *I) {
3688 staticAssertArithmeticType(ElemTy);
3689
3690 auto outW = getWeightHandle<ElemTy>(I->getDest());
3691 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3692 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3693 for (size_t i = 0, e = outW.size(); i < e; i++) {
3694 outW.raw(i) = lhsW.raw(i) ^ rhsW.raw(i);
3695 }
3696}
3697
3698void BoundInterpreterFunction::fwdElementBitwiseXorInst(
3699 const ElementBitwiseXorInst *I) {
3700
3701 dispatchBitwiseImpl(fwdElementBitwiseXorInstImpl,
3702 I->getDest()->getElementType(), I);
3703}
3704
3705//===----------------------------------------------------------------------===//
3706// Logical operations
3707//===----------------------------------------------------------------------===//
3708void BoundInterpreterFunction::fwdElementNotInst(const ElementNotInst *I) {
3709 auto inpW = getWeightHandle<bool>(I->getSrc());
3710 auto outW = getWeightHandle<bool>(I->getDest());
3711 for (size_t i = 0, e = outW.size(); i < e; ++i) {
3712 outW.raw(i) = (!inpW.raw(i));
3713 }
3714}
3715
3716void BoundInterpreterFunction::fwdElementAndInst(const ElementAndInst *I) {
3717 auto lhsW = getWeightHandle<bool>(I->getLHS());
3718 auto rhsW = getWeightHandle<bool>(I->getRHS());
3719 auto outW = getWeightHandle<bool>(I->getDest());
3720 for (size_t i = 0, e = outW.size(); i < e; ++i) {
3721 outW.raw(i) = (lhsW.raw(i) && rhsW.raw(i));
3722 }
3723}
3724
3725void BoundInterpreterFunction::fwdElementOrInst(const ElementOrInst *I) {
3726 auto lhsW = getWeightHandle<bool>(I->getLHS());
3727 auto rhsW = getWeightHandle<bool>(I->getRHS());
3728 auto outW = getWeightHandle<bool>(I->getDest());
3729 for (size_t i = 0, e = outW.size(); i < e; ++i) {
3730 outW.raw(i) = (lhsW.raw(i) || rhsW.raw(i));
3731 }
3732}
3733
3734void BoundInterpreterFunction::fwdElementXorInst(const ElementXorInst *I) {
3735 auto lhsW = getWeightHandle<bool>(I->getLHS());
3736 auto rhsW = getWeightHandle<bool>(I->getRHS());
3737 auto outW = getWeightHandle<bool>(I->getDest());
3738 for (size_t i = 0, e = outW.size(); i < e; ++i) {
3739 outW.raw(i) = (lhsW.raw(i) ^ rhsW.raw(i));
3740 }
3741}
3742
3743//===----------------------------------------------------------------------===//
3744// Unary arithmetic operations
3745//===----------------------------------------------------------------------===//
3746template <typename ElemTy, typename InstKind>
3747void BoundInterpreterFunction::fwdUnaryArithmeticImpl(
3748 const InstKind *I, std::function<float(float)> func) {
3749 Value *inpV = I->getSrc();
3750 Value *outV = I->getDest();
3751 auto inpTy = inpV->getType();
3752 auto outTy = outV->getType();
3753 auto inpH = getWeightHandle<ElemTy>(inpV);
3754 auto outH = getWeightHandle<ElemTy>(outV);
3755
3756 if (inpTy->isQuantizedType()) {
3757 float inpScale = inpTy->getScale();
3758 int32_t inpOffset = inpTy->getOffset();
3759 float outScale = outTy->getScale();
3760 int32_t outOffset = outTy->getOffset();
3761 for (size_t i = 0, e = outH.size(); i < e; ++i) {
3762 float inpVal =
3763 quantization::dequantize<ElemTy>(inpH.raw(i), {inpScale, inpOffset});
3764 float outVal = func(inpVal);
3765 outH.raw(i) =
3766 quantization::quantize<ElemTy>(outVal, {outScale, outOffset});
3767 }
3768 } else {
3769 for (size_t i = 0, e = outH.size(); i < e; ++i) {
3770 float inpVal = static_cast<float>(inpH.raw(i));
3771 float outVal = func(inpVal);
3772 outH.raw(i) = static_cast<ElemTy>(outVal);
3773 }
3774 }
3775}
3776
3777void BoundInterpreterFunction::fwdElementBitwiseNotInst(
3778 const ElementBitwiseNotInst *I) {
3779 auto func = [](int64_t i) -> int64_t { return ~i; };
3780 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3781}
3782
3783void BoundInterpreterFunction::fwdElementAbsInst(const ElementAbsInst *I) {
3784 auto func = [](float x) -> float { return std::abs(x); };
3785 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3786}
3787
3788void BoundInterpreterFunction::fwdElementNegInst(const ElementNegInst *I) {
3789 auto func = [](float x) -> float { return -x; };
3790 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3791}
3792
3793void BoundInterpreterFunction::fwdElementFloorInst(const ElementFloorInst *I) {
3794 auto func = [](float x) -> float { return std::floor(x); };
3795 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3796}
3797
3798void BoundInterpreterFunction::fwdElementSignInst(const ElementSignInst *I) {
3799 auto func = [](float x) -> float { return ((x > 0) - (x < 0)); };
3800 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3801}
3802
3803void BoundInterpreterFunction::fwdElementCeilInst(const ElementCeilInst *I) {
3804 auto func = [](float x) -> float { return std::ceil(x); };
3805 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3806}
3807
3808void BoundInterpreterFunction::fwdElementTruncateInst(
3809 const ElementTruncateInst *I) {
3810 auto func = [](float x) -> float { return std::trunc(x); };
3811 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3812}
3813
3814void BoundInterpreterFunction::fwdElementRoundInst(const ElementRoundInst *I) {
3815 // Rounding mode required by ONNX, Numpy, TensorFlow is round to even which
3816 // rounds to nearest even integer those values with fractional part 0.5.
3817 auto func = [](float x) -> float { return std::nearbyintf(x); };
3818 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3819}
3820
3821void BoundInterpreterFunction::fwdElementSqrtInst(const ElementSqrtInst *I) {
3822 auto func = [](float x) -> float { return std::sqrt(x); };
3823 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3824}
3825
3826void BoundInterpreterFunction::fwdElementRsqrtInst(const ElementRsqrtInst *I) {
3827 auto func = [](float x) -> float { return 1 / std::sqrt(x); };
3828 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3829}
3830
3831void BoundInterpreterFunction::fwdElementReciprocalInst(
3832 const ElementReciprocalInst *I) {
3833 auto func = [](float x) -> float { return 1 / x; };
3834 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3835}
3836
3837void BoundInterpreterFunction::fwdElementSinInst(const ElementSinInst *I) {
3838 auto func = [](float x) -> float { return std::sin(x); };
3839 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3840}
3841
3842void BoundInterpreterFunction::fwdElementCosInst(const ElementCosInst *I) {
3843 auto func = [](float x) -> float { return std::cos(x); };
3844 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3845}
3846
3847void BoundInterpreterFunction::fwdElementErfInst(const ElementErfInst *I) {
3848 auto func = [](float x) -> float { return std::erf(x); };
3849 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
3850}
3851
3852//===----------------------------------------------------------------------===//
3853// Compare operations
3854//===----------------------------------------------------------------------===//
3855template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
3856 typename CmpTy, typename InstCmpKind>
3857void BoundInterpreterFunction::fwdElementCmpHelperImpl(
3858 const InstCmpKind *I, std::function<bool(CmpTy LHS, CmpTy RHS)> cmpHelper) {
3859 Value *lhsV = I->getLHS();
3860 Value *rhsV = I->getRHS();
3861 Value *outV = I->getDest();
3862
3863 auto lhsH = getWeightHandle<ElemTy>(lhsV);
3864 auto rhsH = getWeightHandle<ElemTy>(rhsV);
3865 auto oH = getWeightHandle<bool>(outV);
3866
3867 ElemScaleTy lhsScale = 1.0f;
3868 ElemScaleTy rhsScale = 1.0f;
3869 ElemOffsetTy lhsOffset = 0;
3870 ElemOffsetTy rhsOffset = 0;
3871
3872 auto lhsTy = lhsV->getType();
3873 auto rhsTy = rhsV->getType();
3874
3875 if (lhsV->getType()->isQuantizedType()) {
3876 lhsScale = lhsTy->getScale();
3877 rhsScale = rhsTy->getScale();
3878
3879 lhsOffset = lhsTy->getOffset();
3880 rhsOffset = rhsTy->getOffset();
3881 }
3882
3883 // For each layer in the batch:
3884 for (size_t i = 0, e = oH.size(); i < e; i++) {
3885 oH.raw(i) = cmpHelper(lhsScale * (lhsH.raw(i) - lhsOffset),
3886 rhsScale * (rhsH.raw(i) - rhsOffset));
3887 }
3888}
3889
3890template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
3891 typename CmpTy>
3892void BoundInterpreterFunction::fwdElementCmpLTEInstImpl(
3893 const ElementCmpLTEInst *I) {
3894 auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS <= RHS; };
3895 fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3896 ElementCmpLTEInst>(I, cmpHelper);
3897}
3898
3899void BoundInterpreterFunction::fwdElementCmpLTEInst(
3900 const ElementCmpLTEInst *I) {
3901 auto *T = getTensor(I->getLHS());
3902
3903 if (T->getType().isQuantizedType()) {
3904 switch (T->getElementType()) {
3905 case ElemKind::Int8QTy:
3906 fwdElementCmpLTEInstImpl<int8_t, int32_t, float, int32_t>(I);
3907 break;
3908 case ElemKind::Int16QTy:
3909 fwdElementCmpLTEInstImpl<int16_t, int32_t, float, int32_t>(I);
3910 break;
3911 default:
3912 llvm_unreachable("Type is not supported");
3913 }
3914 return;
3915 }
3916
3917 switch (T->getElementType()) {
3918 case ElemKind::FloatTy:
3919 fwdElementCmpLTEInstImpl<float, float, float>(I);
3920 break;
3921 case ElemKind::Float16Ty:
3922 fwdElementCmpLTEInstImpl<float16_t, float16_t, float16_t>(I);
3923 break;
3924 case ElemKind::BFloat16Ty:
3925 fwdElementCmpLTEInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
3926 break;
3927 case ElemKind::Int32ITy:
3928 fwdElementCmpLTEInstImpl<int32_t, int32_t, float>(I);
3929 break;
3930 case ElemKind::Int64ITy:
3931 fwdElementCmpLTEInstImpl<int64_t, int64_t, float>(I);
3932 break;
3933 default:
3934 llvm_unreachable("Type is not supported");
3935 }
3936}
3937
3938template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
3939 typename CmpTy>
3940void BoundInterpreterFunction::fwdElementCmpEQInstImpl(
3941 const ElementCmpEQInst *I) {
3942 auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS == RHS; };
3943 fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3944 ElementCmpEQInst>(I, cmpHelper);
3945}
3946
3947void BoundInterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) {
3948 auto *T = getTensor(I->getLHS());
3949
3950 if (T->getType().isQuantizedType()) {
3951 switch (T->getElementType()) {
3952 case ElemKind::Int8QTy:
3953 fwdElementCmpEQInstImpl<int8_t, int32_t, float, int32_t>(I);
3954 break;
3955 case ElemKind::Int16QTy:
3956 fwdElementCmpEQInstImpl<int16_t, int32_t, float, int32_t>(I);
3957 break;
3958 default:
3959 llvm_unreachable("Type is not supported");
3960 }
3961 return;
3962 }
3963
3964 switch (T->getElementType()) {
3965 case ElemKind::FloatTy:
3966 fwdElementCmpEQInstImpl<float, float, float>(I);
3967 break;
3968 case ElemKind::Float16Ty:
3969 fwdElementCmpEQInstImpl<float16_t, float16_t, float16_t>(I);
3970 break;
3971 case ElemKind::BFloat16Ty:
3972 fwdElementCmpEQInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
3973 break;
3974 case ElemKind::Int32ITy:
3975 fwdElementCmpEQInstImpl<int32_t, int32_t, float>(I);
3976 break;
3977 case ElemKind::Int64ITy:
3978 fwdElementCmpEQInstImpl<int64_t, int64_t, float>(I);
3979 break;
3980 default:
3981 llvm_unreachable("Type is not supported");
3982 }
3983}
3984
3985template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
3986 typename CmpTy>
3987void BoundInterpreterFunction::fwdElementCmpNEQInstImpl(
3988 const ElementCmpNEQInst *I) {
3989 auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return !(LHS == RHS); };
3990 fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3991 ElementCmpNEQInst>(I, cmpHelper);
3992}
3993
3994void BoundInterpreterFunction::fwdElementCmpNEQInst(
3995 const ElementCmpNEQInst *I) {
3996 auto *T = getTensor(I->getLHS());
3997
3998 if (T->getType().isQuantizedType()) {
3999 switch (T->getElementType()) {
4000 case ElemKind::Int8QTy:
4001 fwdElementCmpNEQInstImpl<int8_t, int32_t, float, int32_t>(I);
4002 break;
4003 case ElemKind::Int16QTy:
4004 fwdElementCmpNEQInstImpl<int16_t, int32_t, float, int32_t>(I);
4005 break;
4006 default:
4007 llvm_unreachable("Type is not supported");
4008 }
4009 return;
4010 }
4011
4012 switch (T->getElementType()) {
4013 case ElemKind::FloatTy:
4014 fwdElementCmpNEQInstImpl<float, float, float>(I);
4015 break;
4016 case ElemKind::Float16Ty:
4017 fwdElementCmpNEQInstImpl<float16_t, float16_t, float16_t>(I);
4018 break;
4019 case ElemKind::BFloat16Ty:
4020 fwdElementCmpNEQInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
4021 break;
4022 case ElemKind::Int32ITy:
4023 fwdElementCmpNEQInstImpl<int32_t, int32_t, float>(I);
4024 break;
4025 case ElemKind::Int64ITy:
4026 fwdElementCmpNEQInstImpl<int64_t, int64_t, float>(I);
4027 break;
4028 default:
4029 llvm_unreachable("Type is not supported");
4030 }
4031}
4032
4033template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
4034 typename CmpTy>
4035void BoundInterpreterFunction::fwdElementCmpLTInstImpl(
4036 const ElementCmpLTInst *I) {
4037 auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS < RHS; };
4038 fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
4039 ElementCmpLTInst>(I, cmpHelper);
4040}
4041
4042void BoundInterpreterFunction::fwdElementCmpLTInst(ElementCmpLTInst const *I) {
4043 auto *T = getTensor(I->getLHS());
4044 if (T->getType().isQuantizedType()) {
4045 switch (T->getElementType()) {
4046 case ElemKind::Int8QTy:
4047 fwdElementCmpLTInstImpl<int8_t, int32_t, float, int32_t>(I);
4048 break;
4049 case ElemKind::Int16QTy:
4050 fwdElementCmpLTInstImpl<int16_t, int32_t, float, int32_t>(I);
4051 break;
4052 default:
4053 llvm_unreachable("Type is not supported");
4054 }
4055 return;
4056 }
4057
4058 switch (T->getElementType()) {
4059 case ElemKind::FloatTy:
4060 fwdElementCmpLTInstImpl<float, float, float>(I);
4061 break;
4062 case ElemKind::Float16Ty:
4063 fwdElementCmpLTInstImpl<float16_t, float16_t, float16_t>(I);
4064 break;
4065 case ElemKind::BFloat16Ty:
4066 fwdElementCmpLTInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
4067 break;
4068 case ElemKind::Int32ITy:
4069 fwdElementCmpLTInstImpl<int32_t, int32_t, float>(I);
4070 break;
4071 case ElemKind::Int64ITy:
4072 fwdElementCmpLTInstImpl<int64_t, int64_t, float>(I);
4073 break;
4074 default:
4075 llvm_unreachable("Type is not supported");
4076 }
4077}
4078
4079template <typename ElemTy>
4080void BoundInterpreterFunction::fwdElementPowInstFloatImpl(
4081 const ElementPowInst *I) {
4082 staticAssertFloatingPointType(ElemTy);
4083
4084 auto baseW = getWeightHandle<ElemTy>(I->getLHS());
4085 auto expW = getWeightHandle<ElemTy>(I->getRHS());
4086 auto outW = getWeightHandle<ElemTy>(I->getDest());
4087 for (size_t i = 0, e = outW.size(); i < e; i++) {
4088 outW.raw(i) = ElemTy(pow(float(baseW.raw(i)), float(expW.raw(i))));
4089 }
4090}
4091
4092void BoundInterpreterFunction::fwdElementPowInstI8Impl(
4093 const ElementPowInst *I) {
4094 assert(getTensor(I->getLHS())->getType().isQuantizedType() &&
4095 "Expect quantized type");
4096 auto baseTy = I->getLHS()->getType();
4097 auto expTy = I->getRHS()->getType();
4098 auto destTy = I->getDest()->getType();
4099
4100 float baseScale = baseTy->getScale();
4101 int32_t baseOffset = baseTy->getOffset();
4102 TensorQuantizationParams baseTQP{baseScale, baseOffset};
4103
4104 float expScale = expTy->getScale();
4105 int32_t expOffset = expTy->getOffset();
4106 TensorQuantizationParams expTQP{expScale, expOffset};
4107
4108 float destScale = destTy->getScale();
4109 int32_t destOffset = destTy->getOffset();
4110 TensorQuantizationParams destTQP{destScale, destOffset};
4111
4112 auto outW = getWeightHandle<int8_t>(I->getDest());
4113 auto baseW = getWeightHandle<int8_t>(I->getLHS());
4114 auto expW = getWeightHandle<int8_t>(I->getRHS());
4115
4116 for (dim_t i = 0, e = outW.size(); i < e; i++) {
4117 float base = quantization::dequantize(baseW.raw(i), baseTQP);
4118 float exp = quantization::dequantize(expW.raw(i), expTQP);
4119 outW.raw(i) = quantization::quantize(std::pow(base, exp), destTQP);
4120 }
4121}
4122
4123void BoundInterpreterFunction::fwdElementPowInst(
4124 const glow::ElementPowInst *I) {
4125 auto *T = getTensor(I->getLHS());
4126 if (T->getType().isQuantizedType()) {
4127 fwdElementPowInstI8Impl(I);
4128 return;
4129 }
4130
4131 dispatchFloatingPointImpl(fwdElementPowInstFloatImpl,
4132 I->getLHS()->getElementType(), I);
4133}
4134
4135template <typename ElemTy>
4136void BoundInterpreterFunction::fwdElementIsNaNInstFloatImpl(
4137 const ElementIsNaNInst *I) {
4138 staticAssertFloatingPointType(ElemTy);
4139
4140 auto inW = getWeightHandle<ElemTy>(I->getSrc());
4141 auto outW = getWeightHandle<bool>(I->getDest());
4142 for (size_t i = 0, e = inW.size(); i < e; i++) {
4143 float val = inW.raw(i);
4144 outW.raw(i) = std::isnan(val);
4145 }
4146}
4147
4148void BoundInterpreterFunction::fwdElementIsNaNInst(
4149 const glow::ElementIsNaNInst *I) {
4150 dispatchFloatingPointImpl(fwdElementIsNaNInstFloatImpl,
4151 I->getSrc()->getElementType(), I);
4152}
4153
4154template <typename ElemTy>
4155void BoundInterpreterFunction::fwdElementLogInstFloatImpl(
4156 const ElementLogInst *I) {
4157 staticAssertFloatingPointType(ElemTy);
4158
4159 auto inW = getWeightHandle<ElemTy>(I->getSrc());
4160 auto outW = getWeightHandle<ElemTy>(I->getDest());
4161 for (size_t i = 0, e = inW.size(); i < e; i++) {
4162 float val = inW.raw(i);
4163 outW.raw(i) = ElemTy(log(val));
4164 }
4165}
4166
4167void BoundInterpreterFunction::fwdElementLogInst(const ElementLogInst *I) {
4168 dispatchFloatingPointImpl(fwdElementLogInstFloatImpl,
4169 I->getSrc()->getElementType(), I);
4170}
4171
4172template <typename ElemTy>
4173void BoundInterpreterFunction::fwdElementExpInstFloatImpl(
4174 const ElementExpInst *I) {
4175 staticAssertFloatingPointType(ElemTy);
4176
4177 auto inW = getWeightHandle<ElemTy>(I->getSrc());
4178 auto outW = getWeightHandle<ElemTy>(I->getDest());
4179 for (size_t i = 0, e = inW.size(); i < e; i++) {
4180 float val = inW.raw(i);
4181 outW.raw(i) = ElemTy(exp(val));
4182 }
4183}
4184
4185void BoundInterpreterFunction::fwdElementExpInst(const ElementExpInst *I) {
4186 dispatchFloatingPointImpl(fwdElementExpInstFloatImpl,
4187 I->getSrc()->getElementType(), I);
4188}
4189
4190void BoundInterpreterFunction::fwdNonZeroInst(const NonZeroInst *I) {
4191 auto *T = getTensor(I->getDest());
4192 T->zero();
4193 auto outW = T->getHandle<int32_t>();
4194 auto condW = getWeightHandle<bool>(I->getCond());
4195 for (size_t condIdx = 0, outIdx = 0, n = condW.size(); condIdx < n;
4196 condIdx++) {
4197 if (condW.raw(condIdx)) {
4198 outW.raw(outIdx) = condIdx;
4199 outIdx++;
4200 }
4201 }
4202}
4203
4204template <typename ElemTy>
4205void BoundInterpreterFunction::fwdElementSelectInstFloatImpl(
4206 const glow::ElementSelectInst *I) {
4207 staticAssertFloatingPointType(ElemTy);
4208 auto outW = getWeightHandle<ElemTy>(I->getDest());
4209 auto condW = getWeightHandle<bool>(I->getCond());
4210 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
4211 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
4212 for (size_t i = 0, e = outW.size(); i < e; i++) {
4213 outW.raw(i) = condW.raw(i) ? lhsW.raw(i) : rhsW.raw(i);
4214 }
4215}
4216
4217void BoundInterpreterFunction::fwdElementSelectInst(
4218 const glow::ElementSelectInst *I) {
4219 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
4220 auto destTy = I->getDest()->getType();
4221 auto lhsTy = I->getLHS()->getType();
4222 auto rhsTy = I->getRHS()->getType();
4223
4224 float destScale = destTy->getScale();
4225 float lhsScale = lhsTy->getScale();
4226 float rhsScale = rhsTy->getScale();
4227
4228 int32_t destOffset = destTy->getOffset();
4229 int32_t lhsOffset = lhsTy->getOffset();
4230 int32_t rhsOffset = rhsTy->getOffset();
4231
4232 auto outW = getWeightHandle<int8_t>(I->getDest());
4233 auto condW = getWeightHandle<bool>(I->getCond());
4234 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
4235 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
4236 for (size_t i = 0, e = outW.size(); i < e; i++) {
4237 float val = condW.raw(i) ? lhsScale * (lhsW.raw(i) - lhsOffset)
4238 : rhsScale * (rhsW.raw(i) - rhsOffset);
4239 int32_t q = std::round(val / destScale + destOffset);
4240 outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
4241 }
4242 return;
4243 }
4244
4245 dispatchFloatingPointImpl(fwdElementSelectInstFloatImpl,
4246 I->getLHS()->getElementType(), I);
4247}
4248
4249template <typename ElemTy>
4250void BoundInterpreterFunction::fwdModuloInstImpl(glow::ModuloInst const *I) {
4251 auto srcH = getTensor(I->getSrc())->getHandle<ElemTy>();
4252 auto destH = getTensor(I->getDest())->getHandle<ElemTy>();
4253
4254 auto divisor = I->getDivisor();
4255 auto signFollowDivisor = I->getSignFollowDivisor();
4256
4257 for (size_t i = 0, e = srcH.size(); i < e; i++) {
4258 auto res = srcH.raw(i) % divisor;
4259 if (signFollowDivisor && res < 0) {
4260 res += divisor;
4261 }
4262 destH.raw(i) = res;
4263 }
4264}
4265
4266void BoundInterpreterFunction::fwdModuloInst(glow::ModuloInst const *I) {
4267 dispatchIndexTypeImpl(fwdModuloInstImpl, I->getSrc()->getElementType(), I);
4268}
4269
4270///=============== Trigonometric Operators===============
4271template <typename ElemTy, typename InstKind>
4272void BoundInterpreterFunction::fwdUnaryTrigonometricImpl(
4273 const InstKind *I, std::function<float(float)> func) {
4274 Value *inpV = I->getSrc();
4275 Value *outV = I->getDest();
4276 auto inpTy = inpV->getType();
4277 auto outTy = outV->getType();
4278 auto inpH = getWeightHandle<ElemTy>(inpV);
4279 auto outH = getWeightHandle<ElemTy>(outV);
4280
4281 if (inpTy->isQuantizedType()) {
4282 float inpScale = inpTy->getScale();
4283 int32_t inpOffset = inpTy->getOffset();
4284 float outScale = outTy->getScale();
4285 int32_t outOffset = outTy->getOffset();
4286 for (size_t i = 0, e = outH.size(); i < e; ++i) {
4287 float inpVal =
4288 quantization::dequantize<ElemTy>(inpH.raw(i), {inpScale, inpOffset});
4289 float outVal = func(inpVal);
4290 outH.raw(i) =
4291 quantization::quantize<ElemTy>(outVal, {outScale, outOffset});
4292 }
4293 } else {
4294 for (size_t i = 0, e = outH.size(); i < e; ++i) {
4295 float inpVal = static_cast<float>(inpH.raw(i));
4296 float outVal = func(inpVal);
4297 outH.raw(i) = static_cast<ElemTy>(outVal);
4298 }
4299 }
4300}
4301
4302void BoundInterpreterFunction::fwdElementAcosInst(const ElementAcosInst *I) {
4303 auto func = [](float x) -> float { return std::acos(x); };
4304 dispatchImpl(fwdUnaryTrigonometricImpl, I->getSrc()->getElementType(), I,
4305 func);
4306}
4307
4308void BoundInterpreterFunction::fwdElementAsinInst(const ElementAsinInst *I) {
4309 auto func = [](float x) -> float { return std::asin(x); };
4310 dispatchImpl(fwdUnaryTrigonometricImpl, I->getSrc()->getElementType(), I,
4311 func);
4312}
4313
4314void BoundInterpreterFunction::fwdElementAtanInst(const ElementAtanInst *I) {
4315 auto func = [](float x) -> float { return std::atan(x); };
4316 dispatchImpl(fwdUnaryTrigonometricImpl, I->getSrc()->getElementType(), I,
4317 func);
4318}
4319//===----------------------------------------------------------------------===//
4320// Mat Mul
4321//===----------------------------------------------------------------------===//
4322template <typename ElemTy, typename AccumulatorTy>
4323void BoundInterpreterFunction::fwdMatMulInstQuantizedImpl(
4324 const glow::MatMulInst *I) {
4325 assert(getTensor(I->getLHS())->getType().isQuantizedType());
4326 auto lhs = getWeightHandle<ElemTy>(I->getLHS());
4327 auto rhs = getWeightHandle<ElemTy>(I->getRHS());
4328
4329 auto dest = getWeightHandle<ElemTy>(I->getDest());
4330
4331 auto destDim = dest.dims();
4332 auto lhsDim = lhs.dims();
4333
4334 auto destTy = I->getDest()->getType();
4335 auto lhsTy = I->getLHS()->getType();
4336 auto rhsTy = I->getRHS()->getType();
4337
4338 dest.clear(0);
4339
4340 // For matrix multiplication, if the offset is equal to zero the scale
4341 // is defined as the formula (L.scale * R.scale / D.scale).
4342 // In here we assume that the offset for all buffers is zero.
4343 float scale = lhsTy->getScale() * rhsTy->getScale() / destTy->getScale();
4344 int32_t lhsOffset = lhsTy->getOffset();
4345 int32_t rhsOffset = rhsTy->getOffset();
4346 int32_t destOffset = destTy->getOffset();
4347
4348 // For each (x,y) in the destination matrix:
4349 for (dim_t x = 0; x < destDim[0]; x++) {
4350 for (dim_t y = 0; y < destDim[1]; y++) {
4351
4352 // Perform DOT on the row an column.
4353 AccumulatorTy sum = 0;
4354 for (dim_t i = 0; i < lhsDim[1]; i++) {
4355 AccumulatorTy L = lhs.at({x, i});
4356 AccumulatorTy R = rhs.at({i, y});
4357 // We represent the element multiplication with offset as
4358 // (value - offset).
4359 sum += (L - lhsOffset) * (R - rhsOffset);
4360 }
4361
4362 dest.at({x, y}) = quantization::clip<AccumulatorTy, ElemTy>(
4363 std::round(scale * sum + destOffset));
4364 }
4365 }
4366}
4367
4368template <typename ElemTy>
4369void BoundInterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) {
4370 staticAssertFloatingPointType(ElemTy);
4371
4372 auto lhs = getWeightHandle<ElemTy>(I->getLHS());
4373 auto rhs = getWeightHandle<ElemTy>(I->getRHS());
4374 auto dest = getWeightHandle<ElemTy>(I->getDest());
4375
4376 auto destDim = dest.dims();
4377 auto lhsDim = lhs.dims();
4378
4379 dest.clear(0);
4380
4381 // For each (x,y) in the destination matrix:
4382 for (dim_t x = 0; x < destDim[0]; x++) {
4383 for (dim_t y = 0; y < destDim[1]; y++) {
4384
4385 // Perform DOT on the row an column.
4386 float sum = 0;
4387 for (dim_t i = 0; i < lhsDim[1]; i++) {
4388 sum += float(lhs.at({x, i}) * rhs.at({i, y}));
4389 }
4390 dest.at({x, y}) = ElemTy(sum);
4391 }
4392 }
4393}
4394
4395template <typename ElemTy>
4396void BoundInterpreterFunction::fwdBatchMatMulInstFloatImpl(
4397 const BatchMatMulInst *I) {
4398 staticAssertFloatingPointType(ElemTy);
4399
4400 auto lhs = getWeightHandle<ElemTy>(I->getLHS());
4401 auto rhs = getWeightHandle<ElemTy>(I->getRHS());
4402 auto dest = getWeightHandle<ElemTy>(I->getDest());
4403
4404 auto destDim = dest.dims();
4405 auto lhsDim = lhs.dims();
4406
4407 dest.clear(0);
4408
4409 for (dim_t batch = 0; batch < destDim[0]; batch++) {
4410 // For each (x,y) in the destination matrix:
4411 for (dim_t x = 0; x < destDim[1]; x++) {
4412 for (dim_t y = 0; y < destDim[2]; y++) {
4413 // Perform DOT on the row an column.
4414 float sum = 0;
4415 for (dim_t i = 0; i < lhsDim[2]; i++) {
4416 sum += float(lhs.at({batch, x, i})) * float(rhs.at({batch, i, y}));
4417 }
4418 dest.at({batch, x, y}) = ElemTy(sum);
4419 }
4420 }
4421 }
4422}
4423
4424void BoundInterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) {
4425 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
4426 dispatchQuantizedWithAccumulationImpl(fwdMatMulInstQuantizedImpl,
4427 I->getLHS()->getElementType(), I);
4428 return;
4429 }
4430
4431 dispatchFloatingPointImpl(fwdMatMulInstFloatImpl,
4432 I->getLHS()->getElementType(), I);
4433}
4434
4435void BoundInterpreterFunction::fwdBatchMatMulInst(
4436 const glow::BatchMatMulInst *I) {
4437 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
4438 DCHECK(!"Quantized implementation for BatchMatmul not supported yet.");
4439 return;
4440 }
4441
4442 dispatchFloatingPointImpl(fwdBatchMatMulInstFloatImpl,
4443 I->getLHS()->getElementType(), I);
4444}
4445
4446void BoundInterpreterFunction::fwdReluGradInst(const glow::ReluGradInst *I) {
4447 DCHECK(!"Found ReluGradInst but ReluGrad is lowered on Interpreter");
4448}
4449
4450//===----------------------------------------------------------------------===//
4451// FC
4452//===----------------------------------------------------------------------===//
4453template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
4454void BoundInterpreterFunction::fwdFullyConnectedInstQuantizedImpl(
4455 const glow::FullyConnectedInst *I) {
4456 assert(getTensor(I->getSrc())->getType().isQuantizedType());
4457
4458 auto inW = getWeightHandle<ElemTy>(I->getSrc());
4459 auto weightsW = getWeightHandle<ElemTy>(I->getWeights());
4460 auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
4461 auto outW = getWeightHandle<ElemTy>(I->getDest());
4462
4463 auto inTy = inW.getType();
4464 auto weightsTy = weightsW.getType();
4465 auto biasTy = biasW.getType();
4466 auto outTy = outW.getType();
4467
4468 int32_t inOffset = inTy.getOffset();
4469 int32_t weightsOffset = weightsTy.getOffset();
4470 int32_t biasOffset = biasTy.getOffset();
4471 int32_t outOffset = outTy.getOffset();
4472
4473 float outScale = outTy.getScale();
4474 float weightsScale = weightsTy.getScale();
4475 float biasScale = biasTy.getScale();
4476 float inScale = inTy.getScale();
4477
4478 ShapeHW idim(inW.dims());
4479 ShapeHW odim(outW.dims());
4480
4481 // Calculate the scale of the values that come out of the matrix
4482 // multiplication part of the calculation.
4483 float matMulScale = weightsScale * inScale;
4484
4485 outW.clear(0);
4486
4487 for (dim_t i = 0; i < idim.height; i++) {
4488 for (dim_t j = 0; j < odim.width; j++) {
4489 AccumulatorTy sum = 0;
4490 for (dim_t k = 0; k < idim.width; k++) {
4491 AccumulatorTy W = weightsW.at({k, j});
4492 AccumulatorTy A = inW.at({i, k});
4493 sum += (W - weightsOffset) * (A - inOffset);
4494 }
4495
4496 // Scale the bias to match the scale of the matrix multiplication.
4497 AccumulatorTy B = std::round(float(biasW.at({j}) - biasOffset) *
4498 (biasScale / matMulScale));
4499
4500 // Add the bias.
4501 sum += B;
4502
4503 // Scale the result back to the expected destination scale.
4504 outW.at({i, j}) = quantization::clip<AccumulatorTy, ElemTy>(
4505 std::round(float(sum) * (matMulScale / outScale)) + outOffset);
4506 }
4507 }
4508}
4509
4510template <typename ElemTy>
4511void BoundInterpreterFunction::fwdFullyConnectedInstFloatImpl(
4512 const FullyConnectedInst *I) {
4513 staticAssertFloatingPointType(ElemTy);
4514
4515 auto inW = getWeightHandle<ElemTy>(I->getSrc());
4516 auto weightsW = getWeightHandle<ElemTy>(I->getWeights());
4517 auto biasW = getWeightHandle<ElemTy>(I->getBias());
4518 auto outW = getWeightHandle<ElemTy>(I->getDest());
4519
4520 ShapeHW idim(inW.dims());
4521 ShapeHW odim(outW.dims());
4522
4523 outW.clear(0);
4524
4525 for (dim_t i = 0; i < idim.height; i++) {
4526 for (dim_t j = 0; j < odim.width; j++) {
4527 float sum = 0;
4528 for (dim_t k = 0; k < idim.width; k++) {
4529 sum += float(inW.at({i, k})) * float(weightsW.at({k, j}));
4530 }
4531
4532 outW.at({i, j}) = sum + float(biasW.at({j}));
4533 }
4534 }
4535}
4536
4537void BoundInterpreterFunction::fwdFullyConnectedInst(
4538 const glow::FullyConnectedInst *I) {
4539
4540 if (getTensor(I->getSrc())->getType().isQuantizedType()) {
4541 dispatchQuantizedWithAccumulationAndBiasImpl(
4542 fwdFullyConnectedInstQuantizedImpl, I->getSrc()->getElementType(),
4543 I->getBias()->getElementType(), I);
4544 return;
4545 } else {
4546 dispatchFloatingPointImpl(fwdFullyConnectedInstFloatImpl,
4547 I->getSrc()->getElementType(), I);
4548 }
4549}
4550
4551//===----------------------------------------------------------------------===//
4552// Dynamic quantized FC
4553//===----------------------------------------------------------------------===//
4554
4555template <typename ElemTy, typename OutputTy, typename AccumulatorTy>
4556void BoundInterpreterFunction::fwdDynRowwiseQuantizedFullyConnectedInstImpl(
4557 Handle<ElemTy> inW, Handle<OutputTy> &outW, dim_t baseRow,
4558 Handle<ElemTy> weightsW, Handle<float> biasW, Handle<float> scalesW,
4559 Handle<int32_t> offsetsW) {
4560 ShapeHW idim(inW.dims());
4561 ShapeHW odim(outW.dims());
4562 auto inTy = inW.getType();
4563 int32_t inOffset = inTy.getOffset();
4564 float inScale = inTy.getScale();
4565
4566 for (dim_t i = 0; i < idim.height; i++) {
4567 for (dim_t j = 0; j < odim.width; j++) {
4568 float matMulScale = inScale * static_cast<float>(scalesW.raw(j));
4569 AccumulatorTy sum = 0;
4570 for (dim_t k = 0; k < idim.width; k++) {
4571 AccumulatorTy W = weightsW.at({k, j});
4572 AccumulatorTy A = inW.at({i, k});
4573 sum += (W - offsetsW.raw(j)) * (A - inOffset);
4574 }
4575
4576 float B = float(biasW.at({j}));
4577
4578 // Scale the result back to the expected destination scale and add the
4579 // bias.
4580 outW.at({baseRow + i, j}) = float(sum) * matMulScale + B;
4581 }
4582 }
4583}
4584
4585void BoundInterpreterFunction::fwdDynRowwiseQuantizedFullyConnectedInstPreimpl(
4586 Tensor *inputTensor, Tensor *weightsTensor, Tensor *biasTensor,
4587 Tensor *resultTensor, Tensor *wScaleTensor, Tensor *wOffsetTensor,
4588 bool isSymmetric, bool isPerBatchElement) {
4589
4590 /* Check the options */
4591 assert(isSymmetric && "Only symmetric quantization is supported.");
4592 assert(isPerBatchElement && "Only quantized per batch element is supported.");
4593
4594 auto weightsW = weightsTensor->getHandle<int8_t>();
4595
4596 /* Dynamic Quantization */
4597 auto resultHandle = resultTensor->getHandle<float16_t>();
4598 auto offsetsW = wOffsetTensor->getHandle<int32_t>();
4599 dim_t N = inputTensor->dims()[0];
4600 dim_t L = inputTensor->dims()[1];
4601 if (isPerBatchElement && isSymmetric) {
4602 // We slice N * L input tensor to N tensors with 1 * L shape,
4603 // For each batch we calculate qparams, quantize, FC and dequantize
4604 // independently, and finally splice them together.
4605 for (dim_t i = 0; i < N; i++) {
4606 Tensor slicedInputTensor = inputTensor->getOwnedSlice({1, L}, {i, 0});
4607 auto slicedInputHandle = slicedInputTensor.getHandle<float16_t>();
4608 auto minMax = slicedInputHandle.minMaxArg();
4609 auto qMin = slicedInputHandle.raw(minMax.first);
4610 auto qMax = slicedInputHandle.raw(minMax.second);
4611
4612 // TODO Currently we only support symmetric quantization.
4613 // We should support both symmetric/asymmetric based of isSymmetric.
4614 auto qParams = quantization::chooseQuantizationParams(
4615 {qMin, qMax}, quantization::Schema::Symmetric, ElemKind::Int8QTy);
4616 Tensor qInputTensor = quantization::quantizeTensor(
4617 slicedInputTensor, {qParams.scale, qParams.offset},
4618 ElemKind::Int8QTy);
4619 auto inW = qInputTensor.getHandle<int8_t>();
4620
4621 auto biasW = biasTensor->getHandle<float>();
4622 auto scalesW = wScaleTensor->getHandle<float>();
4623 fwdDynRowwiseQuantizedFullyConnectedInstImpl<int8_t, float16_t, int32_t>(
4624 inW, resultHandle, i, weightsW, biasW, scalesW, offsetsW);
4625 }
4626 }
4627}
4628
4629void BoundInterpreterFunction::fwdDynamicRowwiseQuantizedFullyConnectedInst(
4630 const glow::DynamicRowwiseQuantizedFullyConnectedInst *I) {
4631 auto *inputTensor = getTensor(I->getSrc());
4632 auto *weightsTensor = getTensor(I->getWeights());
4633 auto *biasTensor = getTensor(I->getBias());
4634 auto *resultTensor = getTensor(I->getDest());
4635 auto *scaleTensor = getTensor(I->getScales());
4636 auto *offsetTensor = getTensor(I->getOffsets());
4637 auto isSymmetric = I->getIsSymmetric();
4638 auto isPerBatchElement = I->getIsPerBatchElement();
4639 fwdDynRowwiseQuantizedFullyConnectedInstPreimpl(
4640 inputTensor, weightsTensor, biasTensor, resultTensor, scaleTensor,
4641 offsetTensor, isSymmetric, isPerBatchElement);
4642}
4643
4644void BoundInterpreterFunction::fwdDynamicQuantizedFullyConnectedInst(
4645 const glow::DynamicQuantizedFullyConnectedInst *I) {
4646
4647 auto *inputTensor = getTensor(I->getSrc());
4648 auto *weightsTensor = getTensor(I->getWeights());
4649 auto *biasTensor = getTensor(I->getBias());
4650 auto *resultTensor = getTensor(I->getDest());
4651 auto isSymmetric = I->getIsSymmetric();
4652 auto isPerBatchElement = I->getIsPerBatchElement();
4653 dim_t M = resultTensor->dims()[1];
4654
4655 // Calc channelwise QParam
4656 Tensor scaleTensor = Tensor(ElemKind::FloatTy, {M});
4657 Tensor offsetTensor = Tensor(ElemKind::Int32ITy, {M});
4658 auto scalesW = scaleTensor.getHandle<float>();
4659 auto offsetsW = offsetTensor.getHandle<int32_t>();
4660 auto weightsW = weightsTensor->getHandle<int8_t>();
4661 for (int i = 0; i < M; i++) {
4662 scalesW.raw(i) = weightsW.getType().getScale();
4663 offsetsW.raw(i) = weightsW.getType().getOffset();
4664 }
4665 fwdDynRowwiseQuantizedFullyConnectedInstPreimpl(
4666 inputTensor, weightsTensor, biasTensor, resultTensor, &scaleTensor,
4667 &offsetTensor, isSymmetric, isPerBatchElement);
4668}
4669
4670//===----------------------------------------------------------------------===//
4671// Row-wise quantized FC
4672//===----------------------------------------------------------------------===//
4673template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
4674void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInstImpl(
4675 Value *inV, Value *outV, Value *weightsV, Value *biasV, Value *scalesV,
4676 Value *offsetsV) {
4677 auto inW = getWeightHandle<ElemTy>(inV);
4678 auto outW = getWeightHandle<ElemTy>(outV);
4679 auto weightsW = getWeightHandle<ElemTy>(weightsV);
4680 auto biasW = getWeightHandle<BiasElemTy>(biasV);
4681 auto scalesW = getWeightHandle<float>(scalesV);
4682 auto offsetsW = getWeightHandle<int32_t>(offsetsV);
4683 ShapeHW idim(inW.dims());
4684 ShapeHW odim(outW.dims());
4685 auto inTy = inW.getType();
4686 auto biasTy = biasW.getType();
4687 auto outTy = outW.getType();
4688 int32_t outOffset = outTy.getOffset();
4689 int32_t inOffset = inTy.getOffset();
4690 int32_t biasOffset = biasTy.getOffset();
4691 float outScale = outTy.getScale();
4692 float inScale = inTy.getScale();
4693 float biasScale = biasTy.getScale();
4694
4695 for (dim_t i = 0; i < idim.height; i++) {
4696 for (dim_t j = 0; j < odim.width; j++) {
4697 float matMulScale = scalesW.raw(j) * inScale;
4698 AccumulatorTy sum = 0;
4699 for (dim_t k = 0; k < idim.width; k++) {
4700 AccumulatorTy W = weightsW.at({j, k});
4701 AccumulatorTy A = inW.at({i, k});
4702 sum += (W - offsetsW.raw(j)) * (A - inOffset);
4703 }
4704
4705 // Scale the bias to match the scale of the matrix multiplication.
4706 AccumulatorTy B = std::round(float(biasW.at({j}) - biasOffset) *
4707 (biasScale / matMulScale));
4708
4709 // Add the bias.
4710 sum += B;
4711
4712 // Scale the result back to the expected destination scale.
4713 outW.at({i, j}) = quantization::clip<AccumulatorTy, ElemTy>(
4714 std::round(float(sum) * (matMulScale / outScale) + outOffset));
4715 }
4716 }
4717}
4718
4719void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInst(
4720 const RowwiseQuantizedFullyConnectedInst *I) {
4721 dispatchQuantizedWithAccumulationAndBiasImpl(
4722 fwdRowwiseQuantizedFullyConnectedInstImpl, I->getSrc()->getElementType(),
4723 I->getBias()->getElementType(), I->getSrc(), I->getDest(),
4724 I->getWeights(), I->getBias(), I->getScales(), I->getOffsets());
4725}
4726
4727//===----------------------------------------------------------------------===//
4728// Batched operations
4729//===----------------------------------------------------------------------===//
4730template <typename ElemTy, typename AccumulatorTy, typename SliceElemTy>
4731static void fwdBatchedAdd(Tensor *batch, Tensor *slice, Tensor *dest) {
4732 auto batchH = batch->getHandle<ElemTy>();
4733 auto sliceH = slice->getHandle<SliceElemTy>();
4734 auto destH = dest->getHandle<ElemTy>();
4735
4736 auto batchTy = batch->getType();
4737 auto sliceTy = slice->getType();
4738 auto destTy = dest->getType();
4739
4740 float sliceScale = sliceTy.getScale();
4741 float batchScale = batchTy.getScale();
4742 float destScale = destTy.getScale();
4743
4744 int32_t sliceOffset = sliceTy.getOffset();
4745 int32_t batchOffset = batchTy.getOffset();
4746 int32_t destOffset = destTy.getOffset();
4747
4748 auto bdim = flattenCdr(batchH.dims());
4749 assert(sliceH.size() == bdim.second && "Invalid slice size");
4750 assert(batchH.dims().drop_front() == sliceH.dims() && "Invalid batch size");
4751
4752 // For each layer in the batch:
4753 for (dim_t n = 0; n < bdim.first; n++) {
4754 size_t base = batchH.getElementPtr({n});
4755
4756 // For each element in the slice.
4757 for (dim_t i = 0; i < bdim.second; i++) {
4758 AccumulatorTy batchVal = batchH.raw(base + i);
4759 AccumulatorTy sliceVal = sliceH.raw(i);
4760 // We increase the size of the integer up to 16 bits for more accurate
4761 // arithmetic.
4762 const float largeScale = float(1) / (1 << 15);
4763 // Scale both sides from 8-bit to 16-bits.
4764 AccumulatorTy B =
4765 std::round(float(batchVal - batchOffset) * (batchScale / largeScale));
4766 AccumulatorTy S =
4767 std::round(float(sliceVal - sliceOffset) * (sliceScale / largeScale));
4768 AccumulatorTy R = B + S;
4769 destH.raw(base + i) = quantization::clip<AccumulatorTy, ElemTy>(
4770 std::round(float(R) * (largeScale / destScale) + destOffset));
4771 }
4772 }
4773}
4774
4775template <typename ElemTy>
4776void BoundInterpreterFunction::fwdBatchedAddInstFloatImpl(
4777 const glow::BatchedAddInst *I) {
4778 staticAssertFloatingPointType(ElemTy);
4779
4780 auto batch = getWeightHandle<ElemTy>(I->getBatch());
4781 auto slice = getWeightHandle<ElemTy>(I->getSlice());
4782 auto dest = getWeightHandle<ElemTy>(I->getDest());
4783
4784 auto bdim = flattenCdr(batch.dims());
4785 assert(slice.size() == bdim.second && "Invalid slice size");
4786 assert(batch.dims().drop_front() == slice.dims() && "Invalid batch size");
4787
4788 // For each layer in the batch:
4789 for (dim_t n = 0; n < bdim.first; n++) {
4790 size_t base = batch.getElementPtr({n});
4791
4792 // For each element in the slice.
4793 for (dim_t i = 0; i < bdim.second; i++) {
4794 dest.raw(base + i) = batch.raw(base + i) + slice.raw(i);
4795 }
4796 }
4797}
4798
4799void BoundInterpreterFunction::fwdBatchedAddInst(
4800 const glow::BatchedAddInst *I) {
4801 if (getTensor(I->getBatch())->getType().isQuantizedType()) {
4802 dispatchQuantizedWithAccumulationAndBiasImpl(
4803 fwdBatchedAdd, I->getBatch()->getElementType(),
4804 I->getSlice()->getElementType(), getTensor(I->getBatch()),
4805 getTensor(I->getSlice()), getTensor(I->getDest()));
4806 return;
4807 }
4808 dispatchFloatingPointImpl(fwdBatchedAddInstFloatImpl,
4809 I->getBatch()->getElementType(), I);
4810}
4811
4812// Macro to define the ReduceAdd/Prod kernel implementation.
4813#define DEFINE_REDUCEADDPROD_INST_IMPL(func, init, op, inst) \
4814 template <typename ElemTy> \
4815 void BoundInterpreterFunction::fwdBatched##func##inst( \
4816 Value *batch, Value *dest, unsigned_t axis, \
4817 const ShapeVector &eBatchDims, const ShapeVector &eDestDims) { \
4818 /*Get unowned handles of the batch and dest with these new expanded \
4819 * dims.*/ \
4820 auto eBatch = getTensor(batch)->getUnowned(eBatchDims); \
4821 auto eDest = getTensor(dest)->getUnowned(eDestDims); \
4822 auto eBatchH = eBatch.getHandle<ElemTy>(); \
4823 auto eDestH = eDest.getHandle<ElemTy>(); \
4824 eDestH.clear(init); \
4825 \
4826 /* We can use this loop for all shapes. Use the same indices for both the \
4827 * batch and dest, except for setting the axis index in the dest to 0.*/ \
4828 for (dim_t x = 0; x < eBatchDims[0]; x++) { \
4829 for (dim_t y = 0; y < eBatchDims[1]; y++) { \
4830 for (dim_t z = 0; z < eBatchDims[2]; z++) { \
4831 for (dim_t w = 0; w < eBatchDims[3]; w++) { \
4832 for (dim_t q = 0; q < eBatchDims[4]; q++) { \
4833 for (dim_t r = 0; r < eBatchDims[5]; r++) { \
4834 dim_t destIndices[] = {x, y, z, w, q, r}; \
4835 destIndices[axis] = 0; \
4836 eDestH.at(destIndices) = \
4837 eDestH.at(destIndices) op eBatchH.at({x, y, z, w, q, r}); \
4838 } \
4839 } \
4840 } \
4841 } \
4842 } \
4843 } \
4844 }
4845
4846/// Define fwdBatchedReduceAddInstImpl
4847DEFINE_REDUCEADDPROD_INST_IMPL(ReduceAdd, 0, +, InstImpl)
4848
4849/// Define fwdBatchedReduceAddInstImpl
4850DEFINE_REDUCEADDPROD_INST_IMPL(ReduceProd, 1, *, InstFloatImpl)
4851
4852#undef DEFINE_REDUCEADDPROD_INST_IMPL
4853
4854void BoundInterpreterFunction::fwdBatchedReduceAddInst(
4855 const glow::BatchedReduceAddInst *I) {
4856 static_assert(max_tensor_dimensions == 6,
4857 "Loops below assume max_tensor_dimensions = 6.");
4858
4859 auto *batch = I->getBatch();
4860 auto *dest = I->getDest();
4861 const auto axis = I->getAxis();
4862
4863 // Initialize both expanded batch and dest dims to the expanded batch
4864 // dims. This allows us below to iterate over the tensor regardless of its
4865 // shape using max_tensor_dimensions loops below.
4866 ShapeVector eBatchDims = expandDimsToMax(batch->dims());
4867 ShapeVector eDestDims = eBatchDims;
4868
4869 // Set the destination axis dimension (the one we are reducing) to 1.
4870 eDestDims[axis] = 1;
4871
4872 if (getTensor(batch)->getType().isQuantizedType()) {
4873 auto destTy = dest->getType();
4874 auto batchTy = batch->getType();
4875
4876 float destScale = destTy->getScale();
4877 float batchScale = batchTy->getScale();
4878
4879 int32_t destOffset = destTy->getOffset();
4880 int32_t batchOffset = batchTy->getOffset();
4881
4882 // Get unowned handles of the batch and dest with these new expanded dims.
4883 auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
4884 auto eDest = getTensor(dest)->getUnowned(eDestDims);
4885 auto eBatchH = eBatch.getHandle<int8_t>();
4886 auto eDestH = eDest.getHandle<int8_t>();
4887 eDestH.clear();
4888
4889 // For quantization, we must accumulate in the inner-most loop into a
4890 // local float and then clip the result back into the dest tensor. Here
4891 // are the max_tensor_dimensions cases for this, to ensure the axis is
4892 // used as the inner-most loop.
4893 switch (axis) {
4894#define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS) \
4895 case _D5_AXIS: \
4896 for (dim_t i##_D0 = 0; i##_D0 < eBatchDims[_D0]; i##_D0++) \
4897 for (dim_t i##_D1 = 0; i##_D1 < eBatchDims[_D1]; i##_D1++) \
4898 for (dim_t i##_D2 = 0; i##_D2 < eBatchDims[_D2]; i##_D2++) \
4899 for (dim_t i##_D3 = 0; i##_D3 < eBatchDims[_D3]; i##_D3++) \
4900 for (dim_t i##_D4 = 0; i##_D4 < eBatchDims[_D4]; i##_D4++) { \
4901 float sum = 0.0; \
4902 for (dim_t i##_D5_AXIS = 0; i##_D5_AXIS < eBatchDims[_D5_AXIS]; \
4903 i##_D5_AXIS++) { \
4904 sum += eBatchH.at({i0, i1, i2, i3, i4, i5}) - batchOffset; \
4905 } \
4906 dim_t i##_D5_AXIS = 0; \
4907 int32_t res = \
4908 std::round(sum * batchScale / destScale) + destOffset; \
4909 eDestH.at({i0, i1, i2, i3, i4, i5}) = \
4910 quantization::clip<int32_t, int8_t>(res); \
4911 } \
4912 return;
4913
4914 // Each loop order, with the inner-most dimension/index equal to the
4915 // axis.
4916 LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0);
4917 LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1);
4918 LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2);
4919 LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3);
4920 LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4);
4921 LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5);
4922#undef LOOP_AXIS_CASE
4923 default:
4924 llvm_unreachable("Axis should be less than max_tensor_dimensions.");
4925 }
4926 }
4927 dispatchFloatingPointAndInt32Impl(fwdBatchedReduceAddInstImpl,
4928 batch->getElementType(), batch, dest, axis,
4929 eBatchDims, eDestDims);
4930}
4931
4932void BoundInterpreterFunction::fwdBatchedReduceProdInst(
4933 const glow::BatchedReduceProdInst *I) {
4934 static_assert(max_tensor_dimensions == 6,
4935 "Loops below assume max_tensor_dimensions = 6.");
4936
4937 auto *batch = I->getBatch();
4938 auto *dest = I->getDest();
4939 const auto axis = I->getAxis();
4940
4941 // Initialize both expanded batch and dest dims to the expanded batch
4942 // dims. This allows us below to iterate over the tensor regardless of its
4943 // shape using max_tensor_dimensions loops below.
4944 ShapeVector eBatchDims = expandDimsToMax(batch->dims());
4945 ShapeVector eDestDims = eBatchDims;
4946
4947 // Set the destination axis dimension (the one we are reducing) to 1.
4948 eDestDims[axis] = 1;
4949
4950 assert(!batch->getType()->isQuantizedType() &&
4951 "Quantized implementation for ReduceProd not supported yet.");
4952
4953 dispatchArithmeticImpl(fwdBatchedReduceProdInstFloatImpl,
4954 batch->getElementType(), batch, dest, axis, eBatchDims,
4955 eDestDims);
4956}
4957
4958/// Macro to define ReduceMin/Max kernel implementation.
4959#define DEFINE_REDUCEMINMAX_INST_IMPL(func, compare) \
4960 template <typename ElemTy> \
4961 void BoundInterpreterFunction::fwdBatched##func##InstImpl( \
4962 Value *batch, Value *dest, const ShapeVector &eBatchDims, \
4963 const ShapeVector &eDestDims, ElemTy init) { \
4964 static_assert(max_tensor_dimensions == 6, \
4965 "Loops below assume max_tensor_dimensions = 6."); \
4966 /* Get unowned handles of the batch and dest with these new expanded \
4967 * dims.*/ \
4968 auto eBatch = getTensor(batch)->getUnowned(eBatchDims); \
4969 auto eDest = getTensor(dest)->getUnowned(eDestDims); \
4970 auto eBatchH = eBatch.getHandle<ElemTy>(); \
4971 auto eDestH = eDest.getHandle<ElemTy>(); \
4972 eDestH.clear(init); \
4973 \
4974 unsigned int axes[max_tensor_dimensions]; \
4975 for (dim_t i = 0; i < max_tensor_dimensions; i++) { \
4976 axes[i] = (eDestDims[i] > 1); \
4977 } \
4978 \
4979 /* We can use this loop for all shapes. Use the same indices for both the \
4980 * batch and dest, except for setting the axis index in the dest to 0.*/ \
4981 for (dim_t x = 0, dx = 0; x < eBatchDims[0]; x++, dx += axes[0]) { \
4982 for (dim_t y = 0, dy = 0; y < eBatchDims[1]; y++, dy += axes[1]) { \
4983 for (dim_t z = 0, dz = 0; z < eBatchDims[2]; z++, dz += axes[2]) { \
4984 for (dim_t w = 0, dw = 0; w < eBatchDims[3]; w++, dw += axes[3]) { \
4985 for (dim_t q = 0, dq = 0; q < eBatchDims[4]; q++, dq += axes[4]) { \
4986 for (dim_t r = 0, dr = 0; r < eBatchDims[5]; \
4987 r++, dr += axes[5]) { \
4988 dim_t destIndices[] = {dx, dy, dz, dw, dq, dr}; \
4989 dim_t srcIndices[] = {x, y, z, w, q, r}; \
4990 eDestH.at(destIndices) = \
4991 compare(eDestH.at(destIndices), eBatchH.at(srcIndices)); \
4992 } \
4993 } \
4994 } \
4995 } \
4996 } \
4997 } \
4998 }
4999
5000/// Define fwdBatchedReduceMaxInstImpl.
5001DEFINE_REDUCEMINMAX_INST_IMPL(ReduceMax, std::max)
5002
5003/// Define fwdBatchedReduceMinInstImpl.
5004DEFINE_REDUCEMINMAX_INST_IMPL(ReduceMin, std::min)
5005
5006#undef DEFINE_REDUCEMINMAX_INST_IMPL
5007
5008/// Macro to define ReduceMin/Max instruction.
5009#define DEFINE_REDUCEMINMAX_INST(func, init_func) \
5010 void BoundInterpreterFunction::fwdBatched##func##Inst( \
5011 const glow::Batched##func##Inst *I) { \
5012 \
5013 auto *batch = I->getBatch(); \
5014 auto *dest = I->getDest(); \
5015 const auto axes = I->getAxes(); \
5016 \
5017 /* Initialize both expanded batch and dest dims to the expanded batch \
5018 dims. This allows us below to iterate over the tensor regardless of its \
5019 shape using max_tensor_dimensions loops below.*/ \
5020 ShapeVector eBatchDims = expandDimsToMax(batch->dims()); \
5021 ShapeVector eDestDims = eBatchDims; \
5022 /* Set the destination axes dimensions (the one we are reducing) to 1.*/ \
5023 for (dim_t i = 0; i < axes.size(); i++) { \
5024 eDestDims[axes[i]] = 1; \
5025 } \
5026 \
5027 if (batch->getElementType() == ElemKind::Int8QTy) { \
5028 dispatchQuantizedImpl( \
5029 fwdBatched##func##InstImpl, batch->getElementType(), batch, dest, \
5030 eBatchDims, eDestDims, std::numeric_limits<int8_t>::init_func()); \
5031 } else { \
5032 dispatchArithmeticImpl( \
5033 fwdBatched##func##InstImpl, batch->getElementType(), batch, dest, \
5034 eBatchDims, eDestDims, std::numeric_limits<int32_t>::init_func()); \
5035 } \
5036 }
5037
5038// Define fwdBatchedMinInst
5039DEFINE_REDUCEMINMAX_INST(ReduceMin, max)
5040
5041// Define fwdBatchedMaxInst
5042DEFINE_REDUCEMINMAX_INST(ReduceMax, min)
5043
5044#undef DEFINE_REDUCEMINMAX_INST
5045
5046template <typename ElemTy>
5047void BoundInterpreterFunction::fwdCumSumInstImpl(Value *input, Value *dest,
5048 int64_t dim, bool exclusive,
5049 bool reverse) {
5050 auto eInputH = getTensor(input)->getHandle<ElemTy>();
5051 auto eDestH = getTensor(dest)->getHandle<ElemTy>();
5052 eDestH.clear();
5053
5054 // deal with dim < 0
5055 if (dim < 0) {
5056 dim += eInputH.dims().size();
5057 }
5058 assert(dim < eInputH.dims().size() &&
5059 "Dim must be less than the number of dimensions of input tensor");
5060
5061 std::vector<dim_t> accumDims(eInputH.dims());
5062 accumDims[dim] = 1;
5063
5064 Tensor accum(eInputH.getElementType(), accumDims);
5065 auto accumH = accum.getHandle<ElemTy>();
5066 accumH.clear();
5067
5068 sdim_t s = 0;
5069 sdim_t n = eInputH.dims()[dim];
5070 sdim_t dir = 1;
5071
5072 if (reverse) {
5073 s = n - 1;
5074 n = -1;
5075 dir = -1;
5076 }
5077
5078 for (sdim_t i = s; i != n; i += dir) {
5079 std::vector<dim_t> offset(eInputH.dims().size());
5080 offset[dim] = i;
5081
5082 Tensor temp(eInputH.getElementType(), accumDims);
5083 auto tempH = temp.getHandle<ElemTy>();
5084 eInputH.extractTensors(tempH, offset);
5085
5086 if (!exclusive) {
5087 for (auto accumIt = accumH.begin(), tempIt = tempH.begin();
5088 accumIt != accumH.end(); ++accumIt, ++tempIt) {
5089 *accumIt += *tempIt;
5090 }
5091 }
5092
5093 eDestH.insertTensors(accumH, offset, 1, dim);
5094
5095 if (exclusive) {
5096 for (auto accumIt = accumH.begin(), tempIt = tempH.begin();
5097 accumIt != accumH.end(); ++accumIt, ++tempIt) {
5098 *accumIt += *tempIt;
5099 }
5100 }
5101 }
5102}
5103
5104void BoundInterpreterFunction::fwdCumSumInst(glow::CumSumInst const *I) {
5105 dispatchArithmeticImpl(fwdCumSumInstImpl, I->getInput()->getElementType(),
5106 I->getInput(), I->getDest(), I->getDim(),
5107 I->getExclusive(), I->getReverse());
5108}
5109
5110template <typename ElemTy>
5111void BoundInterpreterFunction::fwdLengthsSumInstFloatImpl(
5112 const LengthsSumInst *I) {
5113 staticAssertFloatingPointType(ElemTy);
5114
5115 auto out = getTensor(I->getDest());
5116 auto data = getTensor(I->getData());
5117 auto lengths = getTensor(I->getLengths());
5118
5119 out->zero();
5120
5121 auto LH = lengths->getHandle<int32_t>();
5122
5123 size_t segments = lengths->dims()[0];
5124 size_t sliceSize = data->size() / data->dims()[0];
5125
5126 auto DH = data->getHandle<ElemTy>();
5127 auto OH = out->getHandle<ElemTy>();
5128
5129 size_t offsetIn = 0;
5130 size_t offsetOut = 0;
5131 for (dim_t i = 0; i < segments; i++) {
5132 for (int32_t j = 0, e = LH.raw(i); j < e; j++) {
5133 for (dim_t k = 0; k < sliceSize; k++) {
5134 OH.raw(offsetOut + k) += DH.raw(offsetIn + k);
5135 }
5136 offsetIn += sliceSize;
5137 }
5138 offsetOut += sliceSize;
5139 }
5140
5141 assert(offsetIn == data->size() && "All values in Data should be consumed");
5142 assert(offsetOut == out->size() && "All values in Dest should be written to");
5143}
5144
5145void BoundInterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) {
5146 dispatchFloatingPointImpl(fwdLengthsSumInstFloatImpl,
5147 I->getData()->getElementType(), I)
5148}
5149
5150template <typename TI>
5151void BoundInterpreterFunction::fwdSparseLengthsSumInstI8Impl(
5152 const SparseLengthsSumInst *I) {
5153
5154 auto out = getTensor(I->getDest());
5155 auto data = getTensor(I->getData());
5156 auto indices = getTensor(I->getIndices());
5157 auto lengths = getTensor(I->getLengths());
5158
5159 out->zero();
5160
5161 auto IH = indices->getHandle<TI>();
5162 auto LH = lengths->getHandle<int32_t>();
5163
5164 size_t segments = lengths->dims()[0];
5165 size_t totalLength = 0;
5166 for (size_t i = 0; i < segments; i++) {
5167 totalLength += LH.raw(i);
5168 }
5169 assert(totalLength <= indices->dims()[0] &&
5170 "sum(Lengths) must be equal to len(Indices)");
5171
5172 size_t lineSize = data->size() / data->dims()[0];
5173
5174 auto DH = data->getHandle<int8_t>();
5175 auto OH = out->getHandle<int8_t>();
5176
5177 auto TQP = [](Tensor *T) {
5178 return TensorQuantizationParams{T->getType().getScale(),
5179 T->getType().getOffset()};
5180 };
5181
5182 size_t curIdx = 0;
5183 for (size_t i = 0; i < segments; i++) {
5184 std::vector<float> accum(lineSize, 0.0f);
5185 for (int32_t j = 0; j < LH.raw(i); j++) {
5186 size_t offsetIn = IH.raw(curIdx) * lineSize;
5187 for (size_t k = 0; k < lineSize; k++) {
5188 accum[k] += quantization::dequantize(DH.raw(offsetIn++), TQP(data));
5189 }
5190 curIdx++;
5191 }
5192 size_t offsetOut = i * lineSize;
5193 for (size_t k = 0; k < lineSize; k++) {
5194 OH.raw(offsetOut++) = quantization::quantize(accum[k], TQP(out));
5195 }
5196 }
5197}
5198
5199template <typename ElemTy, typename TI>
5200void BoundInterpreterFunction::fwdSparseLengthsSumInstFloatImpl(
5201 const SparseLengthsSumInst *I) {
5202 staticAssertFloatingPointType(ElemTy);
5203
5204 auto out = getTensor(I->getDest());
5205 auto data = getTensor(I->getData());
5206 auto indices = getTensor(I->getIndices());
5207 auto lengths = getTensor(I->getLengths());
5208
5209 out->zero();
5210
5211 auto IH = indices->getHandle<TI>();
5212 auto LH = lengths->getHandle<int32_t>();
5213
5214 size_t segments = lengths->dims()[0];
5215 size_t totalLength = 0;
5216 for (size_t i = 0; i < segments; i++) {
5217 totalLength += LH.raw(i);
5218 }
5219 assert(totalLength <= indices->dims()[0] &&
5220 "sum(Lengths) must be equal to len(Indices)");
5221
5222 size_t lineSize = data->size() / data->dims()[0];
5223
5224 auto DH = data->getHandle<ElemTy>();
5225 auto OH = out->getHandle<ElemTy>();
5226
5227 size_t curIdx = 0;
5228 for (size_t i = 0; i < segments; i++) {
5229 for (size_t j = 0, e = LH.raw(i); j < e; j++) {
5230 size_t offsetIn = IH.raw(curIdx++) * lineSize;
5231 size_t offsetOut = i * lineSize;
5232 for (size_t k = 0; k < lineSize; k++)
5233 OH.raw(offsetOut++) += DH.raw(offsetIn++);
5234 }
5235 }
5236}
5237
5238void BoundInterpreterFunction::fwdSparseLengthsSumInst(
5239 const SparseLengthsSumInst *I) {
5240 if (I->getDest()->getType()->isQuantizedType()) {
5241 dispatchIndexTypeImpl(fwdSparseLengthsSumInstI8Impl,
5242 I->getIndices()->getElementType(), I);
5243 return;
5244 }
5245 dispatchFloatingPointAndIndexImpl(fwdSparseLengthsSumInstFloatImpl,
5246 I->getData()->getElementType(),
5247 I->getIndices()->getElementType(), I);
5248}
5249
5250template <typename ElemTy, typename TI>
5251void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl(
5252 const SparseLengthsWeightedSumInst *I) {
5253 staticAssertFloatingPointType(ElemTy);
5254
5255 auto out = getTensor(I->getDest());
5256 auto data = getTensor(I->getData());
5257 auto weights = getTensor(I->getWeights());
5258 auto indices = getTensor(I->getIndices());
5259 auto lengths = getTensor(I->getLengths());
5260
5261 out->zero();
5262
5263 auto IH = indices->getHandle<TI>();
5264 auto LH = lengths->getHandle<int32_t>();
5265
5266 size_t segments = lengths->dims()[0];
5267 size_t totalLength = 0;
5268 for (dim_t i = 0; i < segments; i++) {
5269 totalLength += LH.raw(i);
5270 }
5271 assert(totalLength <= indices->dims()[0] &&
5272 "sum(Lengths) must be equal to len(Indices)");
5273
5274 dim_t lineSize = data->size() / data->dims()[0];
5275
5276 auto DH = data->getHandle<ElemTy>();
5277 auto WH = weights->getHandle<ElemTy>();
5278 auto OH = out->getHandle<ElemTy>();
5279
5280 dim_t curIdx = 0;
5281 for (dim_t i = 0; i < segments; i++) {
5282 for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
5283 ElemTy weight = WH.raw(curIdx);
5284 size_t offsetIn = IH.raw(curIdx++) * lineSize;
5285 size_t offsetOut = i * lineSize;
5286 for (dim_t k = 0; k < lineSize; k++)
5287 OH.raw(offsetOut++) += DH.raw(offsetIn++) * weight;
5288 }
5289 }
5290}
5291
5292template <typename TI>
5293void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl(
5294 const SparseLengthsWeightedSumInst *I) {
5295
5296 auto out = getTensor(I->getDest());
5297 auto data = getTensor(I->getData());
5298 auto weights = getTensor(I->getWeights());
5299 auto indices = getTensor(I->getIndices());
5300 auto lengths = getTensor(I->getLengths());
5301
5302 out->zero();
5303
5304 auto IH = indices->getHandle<TI>();
5305 auto LH = lengths->getHandle<int32_t>();
5306
5307 dim_t segments = lengths->dims()[0];
5308 dim_t totalLength = 0;
5309 for (dim_t i = 0; i < segments; i++) {
5310 totalLength += LH.raw(i);
5311 }
5312 assert(totalLength <= indices->dims()[0] &&
5313 "sum(Lengths) must be equal to len(Indices)");
5314
5315 dim_t lineSize = data->size() / data->dims()[0];
5316
5317 auto DH = data->getHandle<int8_t>();
5318 auto WH = weights->getHandle<int8_t>();
5319 auto OH = out->getHandle<int8_t>();
5320
5321 auto TQP = [](Tensor *T) {
5322 return TensorQuantizationParams{T->getType().getScale(),
5323 T->getType().getOffset()};
5324 };
5325 using namespace quantization;
5326
5327 dim_t curIdx = 0;
5328 for (dim_t i = 0; i < segments; i++) {
5329 std::vector<float> accum(lineSize, 0.0f);
5330 for (int32_t j = 0; j < LH.raw(i); j++) {
5331 float weight = dequantize(WH.raw(curIdx), TQP(weights));
5332 size_t offsetIn = IH.raw(curIdx) * lineSize;
5333 for (dim_t k = 0; k < lineSize; k++) {
5334 accum[k] += weight * dequantize(DH.raw(offsetIn++), TQP(data));
5335 }
5336 curIdx++;
5337 }
5338 dim_t offsetOut = i * lineSize;
5339 for (dim_t k = 0; k < lineSize; k++) {
5340 OH.raw(offsetOut++) = quantize(accum[k], TQP(out));
5341 }
5342 }
5343}
5344
5345void BoundInterpreterFunction::fwdSparseLengthsSumGradInst(
5346 const SparseLengthsSumGradInst * /*I*/) {
5347 DCHECK(!"Found SparseLengthsSumGradInst but SparseLengthsSum is lowered on "
5348 "Interpreter");
5349}
5350
5351void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInst(
5352 const SparseLengthsWeightedSumInst *I) {
5353 if (I->getDest()->getType()->isQuantizedType()) {
5354 dispatchIndexTypeImpl(fwdSparseLengthsWeightedSumInstI8Impl,
5355 I->getIndices()->getElementType(), I);
5356 return;
5357 }
5358 dispatchFloatingPointAndIndexImpl(fwdSparseLengthsWeightedSumInstFloatImpl,
5359 I->getData()->getElementType(),
5360 I->getIndices()->getElementType(), I);
5361}
5362
5363void BoundInterpreterFunction::fwdSparseLengthsWeightedSumGradInst(
5364 const SparseLengthsWeightedSumGradInst *I) {
5365 assert(I->getDataGrad()->getType()->getElementType() == ElemKind::FloatTy &&
5366 "Input type must be float");
5367
5368 auto destGrad = getTensor(I->getDestGrad());
5369 auto data = getTensor(I->getData());
5370 auto dataGrad = getTensor(I->getDataGrad());
5371 auto weightsGrad = getTensor(I->getWeightsGrad());
5372 auto weights = getTensor(I->getWeights());
5373 auto indices = getTensor(I->getIndices());
5374 auto lengths = getTensor(I->getLengths());
5375
5376 // The data gradients not touched by this operation should
5377 // be 0, so set the entire buffer to 0 to start with.
5378 dataGrad->zero();
5379
5380 auto LH = lengths->getHandle<int32_t>();
5381 auto IH = indices->getHandle<int64_t>();
5382
5383 size_t segments = lengths->dims()[0];
5384 size_t totalLength = 0;
5385 for (size_t i = 0; i < segments; ++i) {
5386 totalLength += LH.raw(i);
5387 }
5388 assert(totalLength == indices->dims()[0] &&
5389 "sum(Lengths) must be equal to len(Indices)");
5390
5391 size_t lineSize = dataGrad->size() / dataGrad->dims()[0];
5392
5393 auto IGH = destGrad->getHandle();
5394 auto WH = weights->getHandle();
5395 auto WGH = weightsGrad->getHandle();
5396 auto DH = data->getHandle();
5397 auto OGH = dataGrad->getHandle();
5398
5399 // For each index in each segment:
5400 // 1) accumulate into the corresponding data gradient the product of the
5401 // gradient of the result it was added to and the weight that it was
5402 // multiplied by during the SparseLengthsWeightedSum operation.
5403 //
5404 // 2) accumulate into each weight gradient the reduced sum of the
5405 // elementwise product of the result slice that the corresponding weight
5406 // produced and the input slice that the weight was multiplied with.
5407 for (size_t i = 0, curIdx = 0; i < segments; ++i) {
5408 size_t destOffset = i * lineSize;
5409 for (size_t j = 0, e = LH.raw(i); j < e; ++j, ++curIdx) {
5410 float weightGrad = 0.0f;
5411 float weight = WH.raw(curIdx);
5412 size_t dataOffset = IH.raw(curIdx) * lineSize;
5413
5414 for (size_t k = 0; k < lineSize; ++k) {
5415 OGH.raw(dataOffset + k) += IGH.raw(destOffset + k) * weight;
5416 weightGrad += IGH.raw(destOffset + k) * DH.raw(dataOffset + k);
5417 }
5418
5419 WGH.raw(curIdx) = weightGrad;
5420 }
5421 }
5422}
5423
5424template <typename ElemTy, typename IndexType>
5425void BoundInterpreterFunction::fwdEmbeddingBagInstFloatImpl(
5426 const EmbeddingBagInst *I) {
5427 staticAssertFloatingPointType(ElemTy);
5428
5429 auto out = getTensor(I->getDest());
5430 auto data = getTensor(I->getData());
5431 auto weights = getTensor(I->getWeights());
5432 auto indices = getTensor(I->getIndices());
5433 auto offsets = getTensor(I->getOffsets());
5434 bool hasEndOffset = I->getHasEndOffset();
5435
5436 out->zero();
5437
5438 auto IH = indices->getHandle<IndexType>();
5439 auto OFFH = offsets->getHandle<IndexType>();
5440
5441 // If an end offset is present to mark the end of the last segment then this
5442 // must be subtracted to get the correct number of segments
5443 size_t segments = hasEndOffset ? offsets->dims()[0] - 1 : offsets->dims()[0];
5444 size_t numIndices = indices->dims()[0];
5445
5446 size_t lineSize = data->size() / data->dims()[0];
5447
5448 auto DH = data->getHandle<ElemTy>();
5449 auto WH = weights->getHandle<ElemTy>();
5450 auto OH = out->getHandle<ElemTy>();
5451
5452 dim_t curIdx = 0;
5453 for (dim_t i = 0; i < segments; i++) {
5454 dim_t start = OFFH.raw(i);
5455 dim_t end;
5456 if (!hasEndOffset) {
5457 // Note that in this case we have to use numIndices to find the end of
5458 // the last segment. This is an issue though because it relies on
5459 // knowing the total length of the indices tensor which may not be
5460 // possible. Future implementations of this operator should always give
5461 // an end offset so eventually this case should be removed.
5462 end = i == segments - 1 ? numIndices : OFFH.raw(i + 1);
5463 } else {
5464 end = OFFH.raw(i + 1);
5465 }
5466 if (start == end) {
5467 continue;
5468 } else if (start > end) {
5469 break;
5470 }
5471 for (dim_t j = start; j < end; j++) {
5472 ElemTy weight = WH.raw(curIdx);
5473 dim_t offsetIn = IH.raw(curIdx++) * lineSize;
5474 dim_t offsetOut = i * lineSize;
5475 for (dim_t k = 0; k < lineSize; k++) {
5476 OH.raw(offsetOut++) += DH.raw(offsetIn++) * weight;
5477 }
5478 }
5479 }
5480}
5481
5482void BoundInterpreterFunction::fwdEmbeddingBagInst(const EmbeddingBagInst *I) {
5483 dispatchFloatingPointAndIndexImpl(fwdEmbeddingBagInstFloatImpl,
5484 I->getData()->getElementType(),
5485 I->getIndices()->getElementType(), I);
5486}
5487
5488template <typename ElemTy>
5489void BoundInterpreterFunction::fwdEmbeddingInstImpl(Tensor *wtT, Tensor *indT,
5490 Tensor *outT,
5491 int64_t padIdx, bool sparse,
5492 bool scale,
5493 dim_t embedding_dim) {
5494
5495 staticAssertFloatingPointType(ElemTy);
5496
5497 assert(!scale && "Currently only support scale_grad_by_freq == 'false'");
5498 assert(!sparse && "Currently only support sparse == 'false'");
5499
5500 // Indices Tensor can be an arbitrary shape.
5501 // Get it flattened to 1D vector of size indLen
5502 // Output Tensor can be an arbitray shape.
5503 // Get it reshaped to a 2D tensor of size (indLen, embedding_dim)
5504 dim_t indLen = 1;
5505 for (dim_t idx = 0; idx < indT->dims().size(); ++idx) {
5506 indLen *= indT->dims()[idx];
5507 }
5508 auto fIndT = indT->getUnowned({indLen});
5509 auto fOutT = outT->getUnowned({indLen, embedding_dim});
5510
5511 fOutT.zero();
5512
5513 auto WH = wtT->getHandle<ElemTy>();
5514 auto OH = fOutT.getHandle<ElemTy>();
5515 auto IH = fIndT.getHandle<int32_t>();
5516
5517 for (dim_t i = 0; i < indLen; i++) {
5518 dim_t index = IH.at(i);
5519 if (index != padIdx) {
5520 for (dim_t j = 0; j < embedding_dim; j++) {
5521 OH.at({i, j}) = WH.at({index, j});
5522 }
5523 }
5524 }
5525}
5526
5527void BoundInterpreterFunction::fwdEmbeddingInst(const EmbeddingInst *I) {
5528 auto wtT = getTensor(I->getWeights());
5529 auto indT = getTensor(I->getIndices());
5530 auto outT = getTensor(I->getDest());
5531 auto padIdx = I->getPadIdx();
5532 bool sparse = I->getSparse();
5533 bool scale = I->getScale();
5534 dim_t embedding_dim = wtT->dims()[1];
5535 auto elemTy = wtT->getElementType();
5536
5537 if (padIdx > -1) {
5538 assert(static_cast<dim_t>(padIdx) < wtT->dims()[0] &&
5539 "padIdx should be within num_embeddings");
5540 }
5541
5542 dispatchFloatingPointImpl(fwdEmbeddingInstImpl, elemTy, wtT, indT, outT,
5543 padIdx, sparse, scale, embedding_dim);
5544}
5545
5546template <typename T, typename AccumT, typename TI>
5547void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumImpl(
5548 const RowwiseQuantizedSparseLengthsWeightedSumInst *I) {
5549 auto *out = getTensor(I->getDest());
5550 auto *data = getTensor(I->getData());
5551 auto *dataScales = getTensor(I->getScales());
5552 auto *dataOffsets = getTensor(I->getOffsets());
5553 auto *weights = getTensor(I->getWeights());
5554 auto *indices = getTensor(I->getIndices());
5555 auto *lengths = getTensor(I->getLengths());
5556
5557 out->zero();
5558
5559 auto IH = indices->getHandle<TI>();
5560 auto LH = lengths->getHandle<int32_t>();
5561
5562 dim_t segments = lengths->dims()[0];
5563 dim_t totalLength = 0;
5564 for (dim_t i = 0; i < segments; i++) {
5565 totalLength += LH.raw(i);
5566 }
5567 assert(totalLength <= indices->dims()[0] &&
5568 "sum(Lengths) must be equal to len(Indices)");
5569
5570 dim_t lineSize = data->size() / data->dims()[0];
5571
5572 auto DH = data->getHandle<uint8_t>();
5573 auto DSH = dataScales->getHandle<T>();
5574 auto DOH = dataOffsets->getHandle<T>();
5575 auto WH = weights->getHandle<T>();
5576 auto OH = out->getHandle<T>();
5577
5578 dim_t curIdx = 0;
5579 for (dim_t i = 0; i < segments; i++) {
5580 std::vector<AccumT> accum(lineSize, 0.0f);
5581 for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
5582 const float weight = static_cast<float>(WH.raw(curIdx));
5583 const dim_t rowIdx = IH.raw(curIdx++);
5584 const float scale = static_cast<float>(DSH.at({rowIdx}));
5585 const float offset = static_cast<float>(DOH.at({rowIdx}));
5586 size_t offsetIn = rowIdx * lineSize;
5587 for (dim_t k = 0; k < lineSize; k++) {
5588 float d = quantization::dequantizeWithFloatOffset(DH.raw(offsetIn++),
5589 scale, offset);
5590 accum[k] += d * weight;
5591 }
5592 }
5593 // Accumulation in FP32 complete, now copy back to output with cast to T.
5594 size_t offsetOut = i * lineSize;
5595 for (size_t k = 0; k < lineSize; k++) {
5596 OH.raw(offsetOut++) = static_cast<T>(accum[k]);
5597 }
5598 }
5599}
5600
5601void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumInst(
5602 const RowwiseQuantizedSparseLengthsWeightedSumInst *I) {
5603 const auto ity = I->getIndices()->getElementType();
5604 switch (I->getDest()->getElementType()) {
5605 case ElemKind::FloatTy:
5606 if (ity == ElemKind::Int32ITy) {
5607 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, int32_t>(I);
5608 } else if (ity == ElemKind::Int64ITy) {
5609 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, int64_t>(I);
5610 } else {
5611 llvm_unreachable("Index type is not supported");
5612 }
5613 break;
5614 case ElemKind::Float16Ty:
5615 if (I->getUseFP16Accumulation()) {
5616 if (ity == ElemKind::Int32ITy) {
5617 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float16_t,
5618 int32_t>(I);
5619 } else if (ity == ElemKind::Int64ITy) {
5620 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float16_t,
5621 int64_t>(I);
5622 } else {
5623 llvm_unreachable("Index type is not supported");
5624 }
5625 } else {
5626 if (ity == ElemKind::Int32ITy) {
5627 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
5628 int32_t>(I);
5629 } else if (ity == ElemKind::Int64ITy) {
5630 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
5631 int64_t>(I);
5632 } else {
5633 llvm_unreachable("Index type is not supported");
5634 }
5635 }
5636 break;
5637 default:
5638 llvm_unreachable("Type is not supported");
5639 }
5640}
5641
5642template <typename T, typename AccumT, typename TI>
5643void BoundInterpreterFunction::
5644 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl(
5645 const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I) {
5646 auto *out = getTensor(I->getDest());
5647 auto *data = getTensor(I->getData());
5648 auto *weights = getTensor(I->getWeights());
5649 auto *indices = getTensor(I->getIndices());
5650 auto *lengths = getTensor(I->getLengths());
5651
5652 out->zero();
5653
5654 auto IH = indices->getHandle<TI>();
5655 auto LH = lengths->getHandle<int32_t>();
5656
5657 size_t segments = lengths->dims()[0];
5658 size_t totalLength = 0;
5659 for (size_t i = 0; i < segments; i++) {
5660 totalLength += LH.raw(i);
5661 }
5662 assert(totalLength <= indices->dims()[0] &&
5663 "sum(Lengths) must be equal to len(Indices)");
5664
5665 const bool using4BitQuantization =
5666 data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy ||
5667 data->getType().getElementType() == ElemKind::UInt4FusedQTy;
5668
5669 const size_t outLineSize = out->size() / out->dims()[0];
5670
5671 auto DH = data->getHandle<uint8_t>();
5672 auto WH = weights->getHandle<T>();
5673 auto OH = out->getHandle<T>();
5674
5675 dim_t curIdx = 0;
5676 for (dim_t i = 0; i < segments; i++) {
5677 std::vector<AccumT> accum(outLineSize, 0.0f);
5678 for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
5679 const float weight = static_cast<float>(WH.raw(curIdx));
5680 const dim_t rowIdx = IH.raw(curIdx++);
5681 // Data type for the Scale and Offset for fused types need not follow
5682 // the type for the output Tensor passed in T.
5683 float scale, offset;
5684 switch (
5685 getScaleOffsetElemKindFromFused(data->getType().getElementType())) {
5686 case ElemKind::FloatTy:
5687 std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<float>(rowIdx);
5688 break;
5689 case ElemKind::Float16Ty:
5690 std::tie(scale, offset) =
5691 DH.getFusedScaleOffsetFromRow<float16_t>(rowIdx);
5692 break;
5693 default:
5694 llvm_unreachable("Type is not supported");
5695 break;
5696 }
5697
5698 for (dim_t k = 0; k < outLineSize; k++) {
5699 float d = 0.0f;
5700 if (!using4BitQuantization) {
5701 d = quantization::dequantizeWithFloatOffset(
5702 DH.at({rowIdx, k}), static_cast<float>(scale),
5703 static_cast<float>(offset));
5704 } else {
5705 const bool isMSB = (k % 2 == 1);
5706 d = quantization::dequantize4BitWithFloatOffset(
5707 DH.at({rowIdx, k / 2}), static_cast<float>(scale),
5708 static_cast<float>(offset), isMSB);
5709 }
5710 accum[k] += d * weight;
5711 }
5712 }
5713 // Accumulation in FP32 complete, now copy back to output with cast to T.
5714 dim_t offsetOut = i * outLineSize;
5715 for (dim_t k = 0; k < outLineSize; k++) {
5716 OH.raw(offsetOut++) = static_cast<T>(accum[k]);
5717 }
5718 }
5719}
5720
5721void BoundInterpreterFunction::
5722 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumInst(
5723 const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I) {
5724 const auto ity = I->getIndices()->getElementType();
5725 const bool fp32FusedScaleOffset =
5726 (I->getData()->getElementType() == ElemKind::UInt4FusedQTy) ||
5727 (I->getData()->getElementType() == ElemKind::UInt8FusedQTy);
5728
5729 switch (I->getDest()->getElementType()) {
5730 case ElemKind::FloatTy:
5731 if (ity == ElemKind::Int32ITy) {
5732 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float,
5733 int32_t>(I);
5734 } else if (ity == ElemKind::Int64ITy) {
5735 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float,
5736 int64_t>(I);
5737 } else {
5738 llvm_unreachable("Index type is not supported");
5739 }
5740 break;
5741 case ElemKind::Float16Ty:
5742 if (I->getUseFP16Accumulation() && !fp32FusedScaleOffset) {
5743 if (ity == ElemKind::Int32ITy) {
5744 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<
5745 float16_t, float16_t, int32_t>(I);
5746 } else if (ity == ElemKind::Int64ITy) {
5747 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<
5748 float16_t, float16_t, int64_t>(I);
5749 } else {
5750 llvm_unreachable("Index type is not supported");
5751 }
5752 } else {
5753 if (ity == ElemKind::Int32ITy) {
5754 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
5755 int32_t>(I);
5756 } else if (ity == ElemKind::Int64ITy) {
5757 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
5758 int64_t>(I);
5759 } else {
5760 llvm_unreachable("Index type is not supported");
5761 }
5762 }
5763 break;
5764 default:
5765 llvm_unreachable("Type is not supported");
5766 }
5767}
5768
5769void BoundInterpreterFunction::fwdFusedRowwiseQuantizedSparseLengthsSumInst(
5770 const FusedRowwiseQuantizedSparseLengthsSumInst *I) {
5771 llvm_unreachable("Not supported");
5772}
5773
5774template <typename T, typename AccumT, typename IndexT>
5775void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsImpl(
5776 const EmbeddingBagByteRowwiseOffsetsInst *I) {
5777 auto *out = getTensor(I->getDest());
5778 auto *data = getTensor(I->getData());
5779 auto *weights = getTensor(I->getWeights());
5780 auto *indices = getTensor(I->getIndices());
5781 auto *offsets = getTensor(I->getOffsets());
5782 bool hasEndOffset = I->getHasEndOffset();
5783
5784 out->zero();
5785
5786 auto IH = indices->getHandle<IndexT>();
5787 auto OFFH = offsets->getHandle<IndexT>();
5788
5789 // If an end offset is present to mark the end of the last segment then this
5790 // must be subtracted to get the correct number of segments
5791 size_t segments = hasEndOffset ? offsets->dims()[0] - 1 : offsets->dims()[0];
5792 dim_t numIndices = indices->dims()[0];
5793
5794 const bool using4BitQuantization =
5795 data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy;
5796
5797 const size_t outLineSize = out->size() / out->dims()[0];
5798
5799 auto DH = data->getHandle<uint8_t>();
5800 auto WH = weights->getHandle<T>();
5801 auto OH = out->getHandle<T>();
5802
5803 for (dim_t i = 0; i < segments; i++) {
5804 std::vector<AccumT> accum(outLineSize, 0.0f);
5805 size_t start = OFFH.raw(i);
5806 dim_t end;
5807 if (!hasEndOffset) {
5808 // Note that in this case we have to use numIndices to find the end of
5809 // the last segment. This is an issue though because it relies on
5810 // knowing the total length of the indices tensor which may not be
5811 // possible. Future implementations of this operator should always give
5812 // an end offset so eventually this case should be removed.
5813 end = i == segments - 1 ? numIndices : OFFH.raw(i + 1);
5814 } else {
5815 end = OFFH.raw(i + 1);
5816 }
5817 if (start == end) {
5818 continue;
5819 } else if (start > end) {
5820 break;
5821 }
5822
5823 for (dim_t j = start; j < end; j++) {
5824 const float weight = static_cast<float>(WH.raw(j));
5825 const dim_t rowIdx = IH.raw(j);
5826 T scale, offset;
5827 std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<T>(rowIdx);
5828 for (dim_t k = 0; k < outLineSize; k++) {
5829 float d = 0.0f;
5830 if (!using4BitQuantization) {
5831 d = quantization::dequantizeWithFloatOffset(
5832 DH.at({rowIdx, k}), static_cast<float>(scale),
5833 static_cast<float>(offset));
5834 } else {
5835 const bool isMSB = (k % 2 == 1);
5836 d = quantization::dequantize4BitWithFloatOffset(
5837 DH.at({rowIdx, k / 2}), static_cast<float>(scale),
5838 static_cast<float>(offset), isMSB);
5839 }
5840 accum[k] += d * weight;
5841 }
5842 }
5843 // Accumulation in FP32 complete, now copy back to output with cast to T.
5844 dim_t offsetOut = i * outLineSize;
5845 for (dim_t k = 0; k < outLineSize; k++) {
5846 OH.raw(offsetOut++) = static_cast<T>(accum[k]);
5847 }
5848 }
5849}
5850
5851void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsInst(
5852 const EmbeddingBagByteRowwiseOffsetsInst *I) {
5853 const auto ity = I->getIndices()->getElementType();
5854 const bool fp32FusedScaleOffset =
5855 (I->getData()->getElementType() == ElemKind::UInt4FusedQTy) ||
5856 (I->getData()->getElementType() == ElemKind::UInt8FusedQTy);
5857
5858 switch (I->getDest()->getElementType()) {
5859 case ElemKind::FloatTy:
5860 if (ity == ElemKind::Int32ITy) {
5861 fwdEmbeddingBagByteRowwiseOffsetsImpl<float, float, int32_t>(I);
5862 } else if (ity == ElemKind::Int64ITy) {
5863 fwdEmbeddingBagByteRowwiseOffsetsImpl<float, float, int64_t>(I);
5864 } else {
5865 llvm_unreachable("Index type is not supported");
5866 }
5867 break;
5868 case ElemKind::Float16Ty:
5869 if (I->getUseFP16Accumulation() && !fp32FusedScaleOffset) {
5870 if (ity == ElemKind::Int32ITy) {
5871 fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float16_t, int32_t>(I);
5872 } else if (ity == ElemKind::Int64ITy) {
5873 fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float16_t, int64_t>(I);
5874 } else {
5875 llvm_unreachable("Index type is not supported");
5876 }
5877 } else {
5878 if (ity == ElemKind::Int32ITy) {
5879 fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float, int32_t>(I);
5880 } else if (ity == ElemKind::Int64ITy) {
5881 fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float, int64_t>(I);
5882 } else {
5883 llvm_unreachable("Index type is not supported");
5884 }
5885 }
5886 break;
5887 default:
5888 llvm_unreachable("Type is not supported");
5889 }
5890}
5891
5892void BoundInterpreterFunction::fwdLengthsToRangesInst(
5893 const LengthsToRangesInst *I) {
5894 auto ranges = getTensor(I->getDest())->getHandle<int32_t>();
5895 auto lengths = getTensor(I->getLengths())->getHandle<int32_t>();
5896 int32_t offset = 0;
5897 for (dim_t i = 0; i < lengths.dims()[0]; i++) {
5898 auto length = lengths.at({i});
5899 ranges.at({i, 0}) = offset;
5900 ranges.at({i, 1}) = length;
5901 offset += length;
5902 }
5903}
5904
5905void BoundInterpreterFunction::fwdLengthsRangeFillInst(
5906 const LengthsRangeFillInst *I) {
5907 auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>();
5908 auto resultH = getTensor(I->getDest())->getHandle<int32_t>();
5909 dim_t curIdx = 0;
5910 for (dim_t i = 0, e = lengthsH.dims()[0]; i < e; i++) {
5911 for (int32_t j = 0, f = lengthsH.at({i}); j < f; j++) {
5912 resultH.at({curIdx++}) = j;
5913 }
5914 }
5915}
5916
5917void BoundInterpreterFunction::fwdGaussianFillInst(const GaussianFillInst *I) {
5918 std::mt19937 rnd(I->getSeed());
5919 std::normal_distribution<float> dist(I->getMean(), I->getScale());
5920 auto outT = getTensor(I->getDest());
5921 auto outH = outT->getHandle<float16_t>();
5922 for (auto &elem : outH) {
5923 elem = dist(rnd);
5924 }
5925}
5926
5927template <typename ElemTy, typename LengthsTy, typename IndicesTy>
5928void BoundInterpreterFunction::fwdBatchSparseToDenseInstImpl2(
5929 const BatchSparseToDenseInst *I) {
5930 auto outH = getWeightHandle<ElemTy>(I->getDest());
5931 auto lengthsH = getWeightHandle<LengthsTy>(I->getLengths());
5932 auto valuesH = getWeightHandle<ElemTy>(I->getValues());
5933 auto indicesH = getWeightHandle<IndicesTy>(I->getIndices());
5934 auto denseLastDim = I->getDenseLastDim();
5935 auto defaultValue = I->getDefaultValue();
5936 outH.clear(defaultValue);
5937
5938 // Verifying input sizes.
5939 size_t lengthsSum = 0;
5940 auto batchSize = lengthsH.size();
5941 for (dim_t i = 0; i < batchSize; ++i) {
5942 lengthsSum += lengthsH.at(i);
5943 }
5944 CHECK_EQ(lengthsSum, indicesH.size());
5945
5946 dim_t k = 0;
5947 for (dim_t i = 0; i < batchSize; ++i) {
5948 for (dim_t j = 0; j < lengthsH.at(i); ++j) {
5949 CHECK_LT(indicesH.at(i), denseLastDim);
5950 outH.at({static_cast<dim_t>(i), static_cast<dim_t>(indicesH.at(k))}) =
5951 valuesH.at(k);
5952 k++;
5953 }
5954 }
5955}
5956
5957template <typename ElemTy, typename LengthsTy>
5958void BoundInterpreterFunction::fwdBatchSparseToDenseInstImpl1(
5959 const BatchSparseToDenseInst *I) {
5960 switch (I->getLengths()->getElementType()) {
5961 case ElemKind::Int32ITy:
5962 fwdBatchSparseToDenseInstImpl2<ElemTy, LengthsTy, int32_t>(I);
5963 break;
5964 case ElemKind::Int64ITy:
5965 fwdBatchSparseToDenseInstImpl2<ElemTy, LengthsTy, int64_t>(I);
5966 break;
5967 default:
5968 llvm_unreachable("Index type is not supported");
5969 }
5970}
5971
5972void BoundInterpreterFunction::fwdBatchSparseToDenseInst(
5973 const BatchSparseToDenseInst *I) {
5974 dispatchFloatingPointAndIndexImpl(fwdBatchSparseToDenseInstImpl1,
5975 I->getDest()->getElementType(),
5976 I->getLengths()->getElementType(), I);
5977}
5978
5979template <typename ElemTy, typename IndicatorTy>
5980void BoundInterpreterFunction::fwdFillExamplesWithIndicatorInstImpl2(
5981 const FillExamplesWithIndicatorInst *I) {
5982 auto outT = getTensor(I->getDest());
5983 auto dataT = getTensor(I->getData());
5984 auto elemSize = dataT->getType().getElementSize();
5985 auto indicatorH = getWeightHandle<IndicatorTy>(I->getIndicator());
5986
5987 size_t numBatches = indicatorH.dims()[0];
5988 outT->zero();
5989
5990 size_t nonzero = 0;
5991 for (size_t i = 0; i < numBatches; ++i) {
5992 if (static_cast<bool>(indicatorH.at(i))) {
5993 nonzero++;
5994 }
5995 }
5996 CHECK_EQ(dataT->dims()[0], nonzero);
5997
5998 // Calculate size of last n-1 data dims
5999 size_t blockSize = 1;
6000 for (size_t i = 1; i < dataT->dims().size(); i++) {
6001 blockSize *= dataT->dims()[i];
6002 }
6003 size_t blockByteSize = blockSize * elemSize;
6004 size_t dataP = 0;
6005 for (size_t i = 0; i < numBatches; i++) {
6006 if (static_cast<bool>(indicatorH.at(i))) {
6007 std::copy(&dataT->getUnsafePtr()[dataP],
6008 &dataT->getUnsafePtr()[dataP + blockByteSize],
6009 &outT->getUnsafePtr()[i * blockByteSize]);
6010 dataP += blockByteSize;
6011 }
6012 }
6013}
6014
6015template <typename ElemTy>
6016void BoundInterpreterFunction::fwdFillExamplesWithIndicatorInstImpl1(
6017 const FillExamplesWithIndicatorInst *I) {
6018 switch (I->getIndicator()->getElementType()) {
6019 case ElemKind::Int32ITy:
6020 fwdFillExamplesWithIndicatorInstImpl2<ElemTy, int32_t>(I);
6021 break;
6022 case ElemKind::Int64ITy:
6023 fwdFillExamplesWithIndicatorInstImpl2<ElemTy, int64_t>(I);
6024 break;
6025 case ElemKind::BoolTy:
6026 fwdFillExamplesWithIndicatorInstImpl2<ElemTy, bool>(I);
6027 break;
6028 default:
6029 llvm_unreachable("Indicator type is not supported");
6030 }
6031}
6032
6033void BoundInterpreterFunction::fwdFillExamplesWithIndicatorInst(
6034 const FillExamplesWithIndicatorInst *I) {
6035 dispatchArithmeticImpl(fwdFillExamplesWithIndicatorInstImpl1,
6036 I->getDest()->getElementType(), I);
6037}
6038void BoundInterpreterFunction::fwdSparseToDenseMaskInst(
6039 const SparseToDenseMaskInst *I) {
6040 auto out = getTensor(I->getDest());
6041 auto values = getTensor(I->getValues());
6042 auto defaultValue = getTensor(I->getDefaultValue());
6043
6044 auto indicesH = getTensor(I->getIndices())->getHandle<int64_t>();
6045 auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>();
6046
6047 const std::vector<dim_t> &mask = I->getMask();
6048 size_t maskSize = mask.size();
6049 // Create a reverse map from ID to its position in the mask.
6050 std::unordered_map<int64_t, size_t> reverseMap;
6051 for (size_t i = 0; i < maskSize; i++) {
6052 assert(reverseMap.find(mask[i]) == reverseMap.end() &&
6053 "duplicate IDs in the mask");
6054 reverseMap[mask[i]] = i;
6055 }
6056
6057 auto valueSize = defaultValue->getSizeInBytes();
6058
6059 // First un-processed index-value pair.
6060 size_t posIn = 0;
6061 // Beginning of output block for first unprocessed batch.
6062 size_t byteOffsetOut = 0;
6063 // Lengths can be scalar, which means that all pairs belong to one batch.
6064 size_t numBatches = lengthsH.dims().empty() ? 1 : lengthsH.dims()[0];
6065 for (size_t batch = 0; batch < numBatches; batch++) {
6066 // Fill everything with maskSize copies of defaultValue.
6067 for (size_t i = 0; i < maskSize; i++) {
6068 std::copy(defaultValue->getUnsafePtr(),
6069 &defaultValue->getUnsafePtr()[valueSize],
6070 &out->getUnsafePtr()[byteOffsetOut + valueSize * i]);
6071 }
6072 // Go through input pairs and find matches.
6073 for (size_t i = 0, batchLen = lengthsH.raw(batch); i < batchLen;
6074 i++, posIn++) {
6075 int64_t idx = indicesH.raw(posIn);
6076 auto it = reverseMap.find(idx);
6077 // Skip if ID is not present in the mask.
6078 if (it == reverseMap.end())
6079 continue;
6080 size_t to = it->second;
6081
6082 std::copy(&values->getUnsafePtr()[posIn * valueSize],
6083 &values->getUnsafePtr()[(posIn + 1) * valueSize],
6084 &out->getUnsafePtr()[byteOffsetOut + valueSize * to]);
6085 }
6086
6087 byteOffsetOut += maskSize * valueSize;
6088 }
6089
6090 assert(posIn == indicesH.dims()[0] &&
6091 "Sum of Lengths must be equal to size of indices.");
6092}
6093
6094void BoundInterpreterFunction::fwdSparseLabelSplitInst(
6095 const SparseLabelSplitInst *I) {
6096 auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>();
6097 auto indicesH = getTensor(I->getIndices())->getHandle<int64_t>();
6098 auto valuesH = getTensor(I->getValues())->getHandle();
6099
6100 const auto numLabels = I->getNumLabels();
6101
6102 auto labelValuesH = getTensor(I->getLabelValues())->getHandle();
6103 auto exampleIdsH = getTensor(I->getExampleIds())->getHandle<int32_t>();
6104 auto gradientOffsetMapH =
6105 getTensor(I->getGradientOffsetMap())->getHandle<int32_t>();
6106
6107 // Verifying input sizes.
6108 size_t lengthsSum = 0;
6109 for (size_t i = 0; i < lengthsH.size(); ++i) {
6110 lengthsSum += lengthsH.at(i);
6111 }
6112 CHECK_EQ(lengthsSum, indicesH.size());
6113 CHECK_EQ(indicesH.size(), valuesH.size());
6114
6115 // Verifying that outputs have same sizes.
6116 const auto numValuesPerRow = indicesH.size() / numLabels;
6117 std::vector<size_t> numExamplesPerTask(numLabels, 0);
6118 for (size_t i = 0; i < indicesH.size(); ++i) {
6119 numExamplesPerTask[indicesH.at(i)] += 1;
6120 }
6121 for (size_t i = 0; i < numLabels; ++i) {
6122 CHECK_EQ(numValuesPerRow, numExamplesPerTask[i])
6123 << "Unexpected number of values at " << i;
6124 }
6125
6126 // Populating outputs
6127 size_t pos = 0;
6128 std::fill(numExamplesPerTask.begin(), numExamplesPerTask.end(), 0);
6129 for (size_t i = 0; i < lengthsH.size(); ++i) {
6130 for (size_t l = 0; l < lengthsH.at(i); ++l) {
6131 auto ind = indicesH.at(pos);
6132 auto val = valuesH.at(pos);
6133
6134 auto posOutput = numExamplesPerTask[ind]++;
6135 gradientOffsetMapH.at(pos) = posOutput;
6136
6137 labelValuesH.at(
6138 {static_cast<dim_t>(ind), static_cast<dim_t>(posOutput)}) = val;
6139 exampleIdsH.at({static_cast<dim_t>(ind), static_cast<dim_t>(posOutput)}) =
6140 i;
6141 pos++;
6142 }
6143 }
6144}
6145
6146//===----------------------------------------------------------------------===//
6147// Instructions used by RNN
6148//===----------------------------------------------------------------------===//
6149template <typename T, typename TI>
6150static void fwdTopK(Tensor *outW, Tensor *indW, Tensor *inW, size_t k) {
6151 auto values = outW->getHandle<T>();
6152 auto indices = indW->getHandle<TI>();
6153 auto in = inW->getHandle<T>();
6154 size_t n = in.dims().back();
6155
6156 size_t in_p = 0, out_p = 0;
6157 size_t tensor_end = in.size();
6158 using pairType = std::pair<float, size_t>;
6159 std::vector<pairType> buf(n);
6160
6161 while (in_p < tensor_end) {
6162 for (size_t i = 0; i < n; i++) {
6163 buf[i].first = in.raw(in_p++);
6164 buf[i].second = i;
6165 }
6166 // NOTE: it's possible to do N + KlogK, while this version is NlogN
6167 std::sort(buf.begin(), buf.end(), [](const pairType &a, const pairType &b) {
6168 if (a.first != b.first)
6169 return a.first > b.first;
6170 return a.second < b.second;
6171 });
6172 for (size_t i = 0; i < k; i++) {
6173 values.raw(out_p) = buf[i].first;
6174 indices.raw(out_p) = buf[i].second;
6175 out_p++;
6176 }
6177 }
6178}
6179
6180template <typename inpType, typename outType>
6181static void fwdArgMax(Tensor *inpT, Tensor *outT, size_t axis) {
6182
6183 // Get input/output handles with dimensions expanded to maximum.
6184 ShapeVector inpDims = expandDimsToMax(inpT->dims());
6185 ShapeVector outDims = inpDims;
6186 outDims[axis] = 1;
6187 auto eInpT = inpT->getUnowned(inpDims);
6188 auto eOutT = outT->getUnowned(outDims);
6189 auto inpH = eInpT.getHandle<inpType>();
6190 auto outH = eOutT.getHandle<outType>();
6191
6192 static_assert(max_tensor_dimensions == 6,
6193 "Loops below assume max_tensor_dimensions = 6.");
6194
6195 for (dim_t idx0 = 0; idx0 < outDims[0]; idx0++) {
6196 for (dim_t idx1 = 0; idx1 < outDims[1]; idx1++) {
6197 for (dim_t idx2 = 0; idx2 < outDims[2]; idx2++) {
6198 for (dim_t idx3 = 0; idx3 < outDims[3]; idx3++) {
6199 for (dim_t idx4 = 0; idx4 < outDims[4]; idx4++) {
6200 for (dim_t idx5 = 0; idx5 < outDims[5]; idx5++) {
6201
6202 // Initialize maximum value/index.
6203 inpType maxVal = std::numeric_limits<inpType>::lowest();
6204 outType maxIdx = 0;
6205
6206 // Iterate input axis dimension.
6207 for (dim_t axisIdx = 0; axisIdx < inpDims[axis]; axisIdx++) {
6208 std::vector<dim_t> inpIdx = {idx0, idx1, idx2,
6209 idx3, idx4, idx5};
6210 inpIdx[axis] = axisIdx;
6211 inpType inpVal = inpH.at(inpIdx);
6212 if (inpVal > maxVal) {
6213 maxVal = inpVal;
6214 maxIdx = axisIdx;
6215 }
6216 }
6217
6218 // Store maximum index.
6219 outH.at({idx0, idx1, idx2, idx3, idx4, idx5}) = maxIdx;
6220 }
6221 }
6222 }
6223 }
6224 }
6225 }
6226}
6227
6228template <typename inpType, typename outType>
6229static void fwdArgMin(Tensor *inpT, Tensor *outT, size_t axis) {
6230
6231 // Get input/output handles with dimensions expanded to maximum.
6232 ShapeVector inpDims = expandDimsToMax(inpT->dims());
6233 ShapeVector outDims = inpDims;
6234 outDims[axis] = 1;
6235 auto eInpT = inpT->getUnowned(inpDims);
6236 auto eOutT = outT->getUnowned(outDims);
6237 auto inpH = eInpT.getHandle<inpType>();
6238 auto outH = eOutT.getHandle<outType>();
6239
6240 static_assert(max_tensor_dimensions == 6,
6241 "Loops below assume max_tensor_dimensions = 6.");
6242
6243 for (dim_t idx0 = 0; idx0 < outDims[0]; idx0++) {
6244 for (dim_t idx1 = 0; idx1 < outDims[1]; idx1++) {
6245 for (dim_t idx2 = 0; idx2 < outDims[2]; idx2++) {
6246 for (dim_t idx3 = 0; idx3 < outDims[3]; idx3++) {
6247 for (dim_t idx4 = 0; idx4 < outDims[4]; idx4++) {
6248 for (dim_t idx5 = 0; idx5 < outDims[5]; idx5++) {
6249
6250 // Initialize minimum value/index.
6251 inpType minVal = std::numeric_limits<inpType>::max();
6252 outType minIdx = 0;
6253
6254 // Iterate input axis dimension.
6255 for (dim_t axisIdx = 0; axisIdx < inpDims[axis]; axisIdx++) {
6256 std::vector<dim_t> inpIdx = {idx0, idx1, idx2,
6257 idx3, idx4, idx5};
6258 inpIdx[axis] = axisIdx;
6259 inpType inpVal = inpH.at(inpIdx);
6260 if (inpVal < minVal) {
6261 minVal = inpVal;
6262 minIdx = axisIdx;
6263 }
6264 }
6265
6266 // Store minimum index.
6267 outH.at({idx0, idx1, idx2, idx3, idx4, idx5}) = minIdx;
6268 }
6269 }
6270 }
6271 }
6272 }
6273 }
6274}
6275
6276//===----------------------------------------------------------------------===//
6277// Sorting operators
6278//===----------------------------------------------------------------------===//
6279template <typename ElemTy>
6280static void
6281CollectRpnProposalsHelper(const std::vector<std::vector<ElemTy>> &roisIn,
6282 const std::vector<ElemTy> &scores,
6283 std::vector<std::vector<ElemTy>> &rois,
6284 dim_t rpnPostNmsTopN) {
6285 // N + KlogK implementation, K is rpnPostNmsTopN
6286 dim_t roisSecondDim = roisIn[0].size();
6287
6288 std::vector<dim_t> idx(scores.size());
6289 std::iota(idx.begin(), idx.end(), 0);
6290
6291 auto comp = [&](dim_t leftIdx, dim_t rightIdx) -> bool {
6292 if (scores[leftIdx] > scores[rightIdx]) {
6293 return true;
6294 }
6295 if (scores[leftIdx] < scores[rightIdx]) {
6296 return false;
6297 }
6298 // Sort on indices if two scores are equal
6299 return leftIdx < rightIdx;
6300 };
6301
6302 dim_t k = rpnPostNmsTopN;
6303 // Getting the indices in order according to scores
6304 // Arrange Kth element at its proper location of sorted array, O(N)
6305 if (k < roisIn.size()) {
6306 std::nth_element(idx.begin(), idx.begin() + k, idx.end(), comp);
6307 } else {
6308 k = roisIn.size();
6309 }
6310
6311 rois.resize(k, std::vector<ElemTy>(roisSecondDim));
6312
6313 // KlogK, K is rpnPostNmsTopN
6314 std::sort(idx.begin(), idx.begin() + k, comp);
6315
6316 // Output rois according to new order of indices
6317 for (dim_t i = 0; i < k; i++) {
6318 rois[i] = roisIn[idx[i]];
6319 }
6320}
6321
6322template <typename ElemTy>
6323void BoundInterpreterFunction::fwdCollectRpnProposalsInstImpl(
6324 const glow::CollectRpnProposalsInst *I) {
6325
6326 // Get params
6327 dim_t rpnPostNmsTopN = I->getRpnPostNmsTopN();
6328 int64_t rpnMaxLevel = I->getRpnMaxLevel();
6329 int64_t rpnMinLevel = I->getRpnMinLevel();
6330 int64_t rpnLevels = rpnMaxLevel - rpnMinLevel + 1;
6331
6332 std::vector<std::vector<ElemTy>> roisIn;
6333 std::vector<ElemTy> scores;
6334 std::vector<std::vector<ElemTy>> roisOut;
6335
6336 auto getRoisAndScores = [&](Tensor *input, bool isroi) {
6337 auto inputHndl = input->getHandle<ElemTy>();
6338
6339 if (isroi) {
6340 for (dim_t i = 0; i < input->dims()[0]; i++) {
6341 std::vector<ElemTy> roi;
6342
6343 for (dim_t j = 0; j < input->dims()[1]; j++) {
6344 roi.push_back(inputHndl.at({i, j}));
6345 }
6346
6347 roisIn.push_back(roi);
6348 }
6349 } else {
6350 for (dim_t i = 0; i < input->dims()[0]; i++) {
6351 scores.push_back(inputHndl.at({i}));
6352 }
6353 }
6354 };
6355
6356 // Input starts from index 1
6357 for (dim_t idx = 1; idx <= rpnLevels; idx++) {
6358 getRoisAndScores(getTensor(I->getOperand(idx).first), true);
6359 getRoisAndScores(getTensor(I->getOperand(idx + rpnLevels).first), false);
6360 }
6361
6362 // Sorting the roisIn according to scores limited by rpnPostNmsTopN
6363 CollectRpnProposalsHelper<ElemTy>(roisIn, scores, roisOut, rpnPostNmsTopN);
6364
6365 // Storing roisOut in result tensor
6366 dim_t roisSecondDim = roisIn[0].size();
6367 Tensor *result = getTensor(I->getResult());
6368
6369 for (dim_t i = 0; i < rpnPostNmsTopN; i++) {
6370 for (dim_t j = 0; j < roisSecondDim; j++) {
6371 result->getHandle<ElemTy>().at({i, j}) = roisOut[i][j];
6372 // Invalid rois are set to 0
6373 if (i > roisOut.size()) {
6374 result->getHandle<ElemTy>().at({i, j}) = ElemTy(0);
6375 }
6376 }
6377 }
6378}
6379
6380void BoundInterpreterFunction::fwdCollectRpnProposalsInst(
6381 const glow::CollectRpnProposalsInst *I) {
6382 dispatchFloatingPointImpl(fwdCollectRpnProposalsInstImpl,
6383 I->getOperand(1).first->getElementType(), I);
6384}
6385
6386void BoundInterpreterFunction::fwdTopKInst(const TopKInst *I) {
6387 auto outW = getTensor(I->getValues());
6388 auto indW = getTensor(I->getIndices());
6389 auto inW = getTensor(I->getInput());
6390 size_t k = I->getK();
6391
6392 if (inW->getType().isQuantizedType()) {
6393 if (indW->getElementType() == ElemKind::Int64ITy) {
6394
6395 fwdTopK<int8_t, int64_t>(outW, indW, inW, k);
6396 } else if (indW->getElementType() == ElemKind::Int32ITy) {
6397 fwdTopK<int8_t, int32_t>(outW, indW, inW, k);
6398 }
6399 return;
6400 }
6401
6402 dispatchFloatingPointAndIndexImpl(fwdTopK, inW->getElementType(),
6403 indW->getElementType(), outW, indW, inW, k);
6404}
6405
6406void BoundInterpreterFunction::fwdBatchedUnaryEmbeddingsBagsInst(
6407 const BatchedUnaryEmbeddingsBagsInst *I) {
6408 dispatchFloatingPointAndIndexImpl(fwdBatchedUnaryEmbeddingsBagsInstImpl,
6409 I->getWeights()->getElementType(),
6410 I->getIndices()->getElementType(), I);
6411}
6412
6413template <typename ElemTy, typename IndexType>
6414void BoundInterpreterFunction::fwdBatchedUnaryEmbeddingsBagsInstImpl(
6415 const BatchedUnaryEmbeddingsBagsInst *I) {
6416 staticAssertFloatingPointType(ElemTy);
6417
6418 auto out = getTensor(I->getDest());
6419 auto weights = getTensor(I->getWeights());
6420 auto tableOffsets = getTensor(I->getTableOffsets());
6421 auto indices = getTensor(I->getIndices());
6422 auto offsets = getTensor(I->getOffsets());
6423
6424 out->zero();
6425
6426 auto indicesH = indices->getHandle<IndexType>();
6427 auto offsetsH = offsets->getHandle<IndexType>();
6428 auto weightsH = weights->getHandle<ElemTy>();
6429 auto tableOffsetsH = tableOffsets->getHandle<IndexType>();
6430 auto outH = out->getHandle<ElemTy>();
6431
6432 size_t numTasks = weightsH.dims()[0];
6433 size_t numTables = tableOffsets->size() - 1;
6434 size_t numBatches = (offsets->size() - 1) / numTables;
6435
6436 IndexType sumTable = tableOffsetsH.raw(numTables);
6437
6438 for (size_t n = 0; n < numTasks; n++) {
6439 for (size_t b = 0; b < numBatches; b++) {
6440 for (size_t t = 0; t < numTables; t++) {
6441 IndexType indicesStart = offsetsH.raw(t * numBatches + b);
6442 IndexType indicesEnd = offsetsH.raw(t * numBatches + b + 1);
6443 ElemTy sum = 0;
6444 for (IndexType i = indicesStart; i < indicesEnd; i++) {
6445 IndexType idx = n * sumTable + tableOffsetsH.raw(t) + indicesH.raw(i);
6446 assert(idx < weightsH.size() &&
6447 "Index shall be within weights boundary.");
6448 sum += weightsH.raw(idx);
6449 }
6450 outH.raw((n * numBatches + b) * numTables + t) = sum;
6451 }
6452 }
6453 }
6454}
6455
6456void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingBagsInst(
6457 const IntNBitSplitEmbeddingBagsInst *I) {
6458 dispatchIndexAndOutputTypeImpl(fwdIntNBitSplitEmbeddingBagsInstImpl,
6459 I->getIndices()->getElementType(),
6460 I->getOutputDType(), I);
6461}
6462
6463void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingWeightedBagsInst(
6464 const IntNBitSplitEmbeddingWeightedBagsInst *I) {
6465 dispatchIndexAndOutputTypeImpl(fwdIntNBitSplitEmbeddingWeightedBagsInstImpl,
6466 I->getIndices()->getElementType(),
6467 I->getOutputDType(), I);
6468}
6469
6470template <typename IndexTy, typename OutputTy>
6471void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingBagsInstImpl(
6472 const IntNBitSplitEmbeddingBagsInst *I) {
6473 auto out = getTensor(I->getDest());
6474 auto devWeights = getTensor(I->getDevWeights());
6475 auto uvmWeights = getTensor(I->getUvmWeights());
6476 auto weightsPlacements = getTensor(I->getWeightsPlacements());
6477 auto weightsTys = getTensor(I->getWeightsTys());
6478 auto dimOffsets = getTensor(I->getDimOffsets());
6479 auto indices = getTensor(I->getIndices());
6480 auto offsets = getTensor(I->getOffsets());
6481 auto weightsOffsets = getTensor(I->getWeightsOffsets());
6482 auto poolingMode = I->getPoolingMode();
6483 auto totalDims = I->getTotalDims();
6484 auto outputDType = I->getOutputDType();
6485
6486 fwdIntNBitSplitEmbeddingWeightedBagsImpl<IndexTy, OutputTy>(
6487 out, devWeights, uvmWeights, weightsPlacements, weightsTys, dimOffsets,
6488 indices, offsets, weightsOffsets, poolingMode,
6489 /* indiceWeights */ nullptr, totalDims, outputDType);
6490}
6491
6492template <typename IndexTy, typename OutputTy>
6493void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingWeightedBagsInstImpl(
6494 const IntNBitSplitEmbeddingWeightedBagsInst *I) {
6495 auto out = getTensor(I->getDest());
6496 auto devWeights = getTensor(I->getDevWeights());
6497 auto uvmWeights = getTensor(I->getUvmWeights());
6498 auto weightsPlacements = getTensor(I->getWeightsPlacements());
6499 auto weightsTys = getTensor(I->getWeightsTys());
6500 auto dimOffsets = getTensor(I->getDimOffsets());
6501 auto indices = getTensor(I->getIndices());
6502 auto offsets = getTensor(I->getOffsets());
6503 auto weightsOffsets = getTensor(I->getWeightsOffsets());
6504 auto poolingMode = I->getPoolingMode();
6505 auto totalDims = I->getTotalDims();
6506 auto outputDType = I->getOutputDType();
6507 auto indiceWeights = getTensor(I->getIndiceWeight());
6508
6509 fwdIntNBitSplitEmbeddingWeightedBagsImpl<IndexTy, OutputTy>(
6510 out, devWeights, uvmWeights, weightsPlacements, weightsTys, dimOffsets,
6511 indices, offsets, weightsOffsets, poolingMode, indiceWeights, totalDims,
6512 outputDType);
6513}
6514
6515template <typename IndexTy, typename OutputTy>
6516void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingWeightedBagsImpl(
6517 Tensor *out, Tensor *devWeights, Tensor *uvmWeights,
6518 Tensor *weightsPlacements, Tensor *weightsTys, Tensor *dimOffsets,
6519 Tensor *indices, Tensor *offsets, Tensor *weightsOffsets,
6520 int64_t poolingMode, Tensor *indiceWeights, int64_t totalDims,
6521 int64_t outputDType) {
6522 out->zero();
6523
6524 auto indicesH = indices->getHandle<IndexTy>();
6525 auto offsetsH = offsets->getHandle<IndexTy>();
6526 auto devWeightsH = devWeights->getHandle<uint8_t>();
6527 auto uvmWeightsH = uvmWeights->getHandle<uint8_t>();
6528 auto weightsPlacementH = weightsPlacements->getHandle<int32_t>();
6529 auto weightsTysH = weightsTys->getHandle<uint8_t>();
6530 auto dimOffsetsH = dimOffsets->getHandle<int32_t>();
6531 auto weightsOffsetsH = weightsOffsets->getHandle<int32_t>();
6532 llvm::Optional<Handle<float>> indiceWeightsH;
6533 if (indiceWeights) {
6534 indiceWeightsH = indiceWeights->getHandle<float>();
6535 }
6536 auto outH = out->getHandle<uint8_t>();
6537
6538 size_t numTables = dimOffsets->size() - 1;
6539 size_t numBatches = (offsets->size() - 1) / numTables;
6540 auto outputSparseType = static_cast<SplitEmbeddingSparseType>(outputDType);
6541 auto numTotalBytes = IntNBitSplitEmbeddingBagsHelper::unpaddedRowSizeInBytes(
6542 totalDims, outputSparseType);
6543 int64_t maxDevWeightsBytes = devWeights->getSizeInBytes();
6544 int64_t maxUvmWeightsBytes = uvmWeights->getSizeInBytes();
6545
6546 for (int32_t t = 0; t < numTables; t++) {
6547 const int32_t dimStart = dimOffsetsH.raw(t);
6548 const int32_t numDims = dimOffsetsH.raw(t + 1) - dimOffsetsH.raw(t);
6549 const auto placement = weightsPlacementH.raw(t);
6550 assert(placement != WeightsPlacement::DEVICE);
6551 auto weightsH = uvmWeightsH;
6552 auto maxWeightsBytes = maxUvmWeightsBytes;
6553 if (placement == WeightsPlacement::HOST) {
6554 weightsH = devWeightsH;
6555 maxWeightsBytes = maxDevWeightsBytes;
6556 }
6557 auto weightTy = static_cast<SplitEmbeddingSparseType>(weightsTysH.raw(t));
6558 assert(weightTy != SplitEmbeddingSparseType::EST_INT2 &&
6559 "Int2 sparse type isn't supported yet.");
6560 auto numDimBytes = IntNBitSplitEmbeddingBagsHelper::paddedRowSizeInBytes(
6561 numDims, weightTy);
6562
6563 for (int32_t b = 0; b < numBatches; b++) {
6564 IndexTy indicesStart = offsetsH.raw(t * numBatches + b);
6565 IndexTy indicesEnd = offsetsH.raw(t * numBatches + b + 1);
6566 for (int32_t d = 0; d < numDims; d++) {
6567 OutputTy sum = 0;
6568 for (IndexTy i = indicesStart; i < indicesEnd; i++) {
6569 int64_t idxRow =
6570 weightsOffsetsH.raw(t) + indicesH.raw(i) * numDimBytes;
6571 int64_t idxData =
6572 idxRow + IntNBitSplitEmbeddingBagsHelper::unpaddedRowSizeInBytes(
6573 d, weightTy);
6574 assert(idxData < maxWeightsBytes &&
6575 "Index shall be within weights boundary.");
6576 sum = IntNBitSplitEmbeddingBagsHelper::add(
6577 sum, &weightsH.raw(idxRow), &weightsH.raw(idxData), weightTy,
6578 d % 2, indiceWeights ? indiceWeightsH->raw(i) : 1.0);
6579 }
6580 int64_t idxOut =
6581 b * numTotalBytes +
6582 IntNBitSplitEmbeddingBagsHelper::unpaddedRowSizeInBytes(
6583 dimStart + d, outputSparseType);
6584 if (poolingMode == SplitEmbeddingPoolingMode::EP_MEAN &&
6585 indicesEnd > indicesStart) {
6586 OutputTy scale = (indicesEnd - indicesStart);
6587 IntNBitSplitEmbeddingBagsHelper::save(sum / scale, &outH.raw(idxOut),
6588 outputSparseType);
6589 } else {
6590 IntNBitSplitEmbeddingBagsHelper::save(sum, &outH.raw(idxOut),
6591 outputSparseType);
6592 }
6593 }
6594 }
6595 }
6596}
6597
6598#define DISPATCH_ARG_MIN_MAX(functionName, elemTy, elemTyIndex, ...) \
6599 switch (elemTy) { \
6600 case ElemKind::FloatTy: \
6601 if (elemTyIndex == ElemKind::Int64ITy) { \
6602 functionName<float, int64_t>(__VA_ARGS__); \
6603 } else if (elemTyIndex == ElemKind::Int32ITy) { \
6604 functionName<float, int32_t>(__VA_ARGS__); \
6605 } \
6606 break; \
6607 case ElemKind::Float16Ty: \
6608 if (elemTyIndex == ElemKind::Int64ITy) { \
6609 functionName<float16_t, int64_t>(__VA_ARGS__); \
6610 } else if (elemTyIndex == ElemKind::Int32ITy) { \
6611 functionName<float16_t, int32_t>(__VA_ARGS__); \
6612 } \
6613 break; \
6614 case ElemKind::Int8QTy: \
6615 if (elemTyIndex == ElemKind::Int64ITy) { \
6616 functionName<int8_t, int64_t>(__VA_ARGS__); \
6617 } else if (elemTyIndex == ElemKind::Int32ITy) { \
6618 functionName<int8_t, int32_t>(__VA_ARGS__); \
6619 } \
6620 break; \
6621 default: \
6622 llvm_unreachable("Type is not supported"); \
6623 }
6624
6625void BoundInterpreterFunction::fwdArgMaxInst(const ArgMaxInst *I) {
6626 auto inpT = getTensor(I->getSrc());
6627 auto outT = getTensor(I->getDest());
6628 size_t axis = I->getAxis();
6629 auto inpElemType = inpT->getElementType();
6630 auto outElemType = outT->getElementType();
6631 DISPATCH_ARG_MIN_MAX(fwdArgMax, inpElemType, outElemType, inpT, outT, axis);
6632}
6633
6634void BoundInterpreterFunction::fwdArgMinInst(const ArgMinInst *I) {
6635 auto inpT = getTensor(I->getSrc());
6636 auto outT = getTensor(I->getDest());
6637 size_t axis = I->getAxis();
6638 auto inpElemType = inpT->getElementType();
6639 auto outElemType = outT->getElementType();
6640 DISPATCH_ARG_MIN_MAX(fwdArgMin, inpElemType, outElemType, inpT, outT, axis);
6641}
6642#undef DISPATCH_ARG_MIN_MAX
6643
6644//===----------------------------------------------------------------------===//
6645// Tensor allocation operations
6646//===----------------------------------------------------------------------===//
6647
6648void BoundInterpreterFunction::fwdAllocActivationInst(
6649 const AllocActivationInst *I) {
6650 getOrCreateTensor(I);
6651}
6652
6653void BoundInterpreterFunction::fwdDeallocActivationInst(
6654 const DeallocActivationInst *I) {
6655 deleteTensor(I->getSrc());
6656}
6657
6658//===----------------------------------------------------------------------===//
6659// Debug instructions
6660//===----------------------------------------------------------------------===//
6661/// Prints a value of the instruction's operand.
6662/// In most cases it will be the name of the variable and the value of the
6663/// tensor.
6664void BoundInterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) {
6665 auto *V = I->getSrc();
6666 auto *T = getTensor(V);
6667 std::string format = I->getFormat();
6668 std::string filename = I->getFileName();
6669
6670 if (format == "console") {
6671 // Dump tensor in console.
6672 llvm::outs() << I->getName() << ": ";
6673 V->dump();
6674 llvm::outs() << "\n";
6675 dumpImpl(T);
6676 llvm::outs() << "\n";
6677 } else if (format == "bin") {
6678 TensorSerializationOptions opts;
6679 opts.withType = true;
6680 glow::dumpTensorToBinaryFile(*T, filename, opts);
6681 } else if (format == "txt") {
6682 TensorSerializationOptions opts;
6683 opts.withType = true;
6684 glow::dumpTensorToTextFile(*T, filename, opts);
6685 } else if (format == "rawbin") {
6686 TensorSerializationOptions opts;
6687 opts.withType = false;
6688 glow::dumpTensorToBinaryFile(*T, filename, opts);
6689 } else if (format == "rawtxt") {
6690 TensorSerializationOptions opts;
6691 opts.withType = false;
6692 glow::dumpTensorToTextFile(*T, filename, opts);
6693 } else {
6694 llvm_unreachable("DebugPrint format not supported!");
6695 }
6696}
6697
6698void BoundInterpreterFunction::fwdTraceEventInst(const TraceEventInst *I) {
6699 auto T = getTensor(I->getData());
6700 auto IH = T->getHandle<int64_t>();
6701 size_t index = I->getIndex();
6702 IH.raw(index) = std::chrono::duration_cast<std::chrono::microseconds>(
6703 std::chrono::steady_clock::now().time_since_epoch())
6704 .count();
6705}
6706
6707void BoundInterpreterFunction::fwdInstrumentInst(const InstrumentInst *I) {
6708 // The instrument instruction is not implemented on the Interpreter backend.
6709 // We cannot throw error though because the Interpreter can be potentially
6710 // used when constant folding parts of the graph while compiling for the
6711 // CPU backend with IR instrumentation.
6712}
6713
6714//===----------------------------------------------------------------------===//
6715// Instructions used by Quantization
6716//===----------------------------------------------------------------------===//
6717void BoundInterpreterFunction::fwdQuantizationProfileInst(
6718 const glow::QuantizationProfileInst *I) {
6719 auto inputTensor = getWeightHandle(I->getInputTensor());
6720 auto currentHistogram = getWeightHandle(I->getHistogram());
6721 auto computationInfo = getWeightHandle(I->getComputationInfo());
6722
6723 float &min = computationInfo.raw(0);
6724 float &max = computationInfo.raw(1);
6725
6726 // Update current histogram, min and max based on the inputTensor data.
6727 quantization::generateTensorHistogram(inputTensor, currentHistogram, min,
6728 max);
6729}
6730
6731/// Quantize floating point tensor. Scale and Offset are based on return type
6732/// of the instruction \p I.
6733void BoundInterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) {
6734 auto *srcTensor = getTensor(I->getSrc());
6735 auto *destTensor = getTensor(I->getDest());
6736 auto destTy = destTensor->getType();
6737 Tensor qTensor = quantization::quantizeTensor(
6738 *srcTensor, {destTy.getScale(), destTy.getOffset()},
6739 destTy.getElementType());
6740 destTensor->assign(&qTensor);
6741}
6742
6743/// Dequantize integer tensor. Scale and Offset are based
6744/// on the source tensor type.
6745void BoundInterpreterFunction::fwdDequantizeInst(
6746 const glow::DequantizeInst *I) {
6747 auto *srcTensor = getTensor(I->getSrc());
6748 auto *destTensor = getTensor(I->getDest());
6749 auto destTy = destTensor->getType();
6750 Tensor fTensor =
6751 quantization::dequantizeTensor(*srcTensor, destTy.getElementType());
6752 destTensor->assign(&fTensor);
6753}
6754
6755template <class eTy>
6756void BoundInterpreterFunction::fwdRescaleQuantizedInstImpl(
6757 Value *src, Value *dest, TensorQuantizationParams &srcQ,
6758 TensorQuantizationParams &destQ) {
6759
6760 auto srcH = getWeightHandle<eTy>(src);
6761 auto destH = getWeightHandle<eTy>(dest);
6762
6763 for (size_t i = 0, e = destH.size(); i < e; ++i) {
6764 float val = quantization::dequantize(srcH.raw(i), srcQ);
6765 destH.raw(i) = quantization::quantize(val, destQ);
6766 }
6767}
6768
6769void BoundInterpreterFunction::fwdRescaleQuantizedInst(
6770 const glow::RescaleQuantizedInst *I) {
6771 auto src = I->getSrc();
6772 auto dest = I->getDest();
6773 auto srcTy = src->getType();
6774 auto destTy = dest->getType();
6775
6776 TensorQuantizationParams srcQ{srcTy->getScale(), srcTy->getOffset()};
6777 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
6778
6779 dispatchQuantizedImpl(fwdRescaleQuantizedInstImpl, destTy->getElementType(),
6780 src, dest, srcQ, destQ);
6781}
6782
6783void BoundInterpreterFunction::fwdIntLookupTableInst(
6784 const IntLookupTableInst *I) {
6785 if (I->getSrc()->getElementType() == ElemKind::Int8QTy) {
6786 auto srcH = getWeightHandle<int8_t>(I->getSrc());
6787 auto destH = getWeightHandle<int8_t>(I->getDest());
6788 auto mappingH = getWeightHandle<int8_t>(I->getMapping());
6789 for (size_t i = 0, e = destH.size(); i < e; i++) {
6790 destH.raw(i) = mappingH.raw((int)srcH.raw(i) + 128);
6791 }
6792 } else if (I->getSrc()->getElementType() == ElemKind::Int16QTy) {
6793 auto srcH = getWeightHandle<int16_t>(I->getSrc());
6794 auto destH = getWeightHandle<int16_t>(I->getDest());
6795 auto mappingH = getWeightHandle<int16_t>(I->getMapping());
6796 for (size_t i = 0, e = destH.size(); i < e; i++) {
6797 destH.raw(i) = mappingH.raw((int)srcH.raw(i) + 32768);
6798 }
6799 } else {
6800 llvm_unreachable("Type not supported for IntLookupTable!");
6801 }
6802}
6803
6804void BoundInterpreterFunction::fwdLookupTableInst(const LookupTableInst *I) {
6805 llvm_unreachable("LookupTable instruction is not supported yet");
6806}
6807
6808void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) {
6809 Tensor *source = getTensor(I->getInput());
6810 Tensor *dest = getTensor(I->getResult());
6811 auto srcElType = source->getType().getElementType();
6812 auto destElType = dest->getType().getElementType();
6813 if (srcElType == destElType) {
6814 // This is a noop conversion.
6815 dest->copyRawFrom(source);
6816 return;
6817 }
6818
6819#define CONVERT(T_FROM, T_TO, DTY_FROM, DTY_TO) \
6820 if (srcElType == DTY_FROM && destElType == DTY_TO) { \
6821 dest->copyWithCast<T_TO, T_FROM>(source); \
6822 return; \
6823 }
6824 CONVERT(float, float16_t, ElemKind::FloatTy, ElemKind::Float16Ty)
6825 CONVERT(float, bfloat16_t, ElemKind::FloatTy, ElemKind::BFloat16Ty)
6826 CONVERT(float, bool, ElemKind::FloatTy, ElemKind::BoolTy)
6827 CONVERT(float, int32_t, ElemKind::FloatTy, ElemKind::Int32ITy)
6828 CONVERT(float, int64_t, ElemKind::FloatTy, ElemKind::Int64ITy)
6829 CONVERT(float16_t, float, ElemKind::Float16Ty, ElemKind::FloatTy)
6830 CONVERT(float16_t, bfloat16_t, ElemKind::Float16Ty, ElemKind::BFloat16Ty)
6831 CONVERT(float16_t, int32_t, ElemKind::Float16Ty, ElemKind::Int32ITy)
6832 CONVERT(float16_t, int64_t, ElemKind::Float16Ty, ElemKind::Int64ITy)
6833 CONVERT(bfloat16_t, float, ElemKind::BFloat16Ty, ElemKind::FloatTy)
6834 CONVERT(bfloat16_t, float16_t, ElemKind::BFloat16Ty, ElemKind::Float16Ty)
6835 CONVERT(bfloat16_t, int32_t, ElemKind::BFloat16Ty, ElemKind::Int32ITy)
6836 CONVERT(bfloat16_t, int64_t, ElemKind::BFloat16Ty, ElemKind::Int64ITy)
6837 CONVERT(bool, float, ElemKind::BoolTy, ElemKind::FloatTy)
6838 CONVERT(bool, bfloat16_t, ElemKind::BoolTy, ElemKind::BFloat16Ty)
6839 CONVERT(int32_t, float, ElemKind::Int32ITy, ElemKind::FloatTy)
6840 CONVERT(int32_t, float16_t, ElemKind::Int32ITy, ElemKind::Float16Ty)
6841 CONVERT(int32_t, bfloat16_t, ElemKind::Int32ITy, ElemKind::BFloat16Ty)
6842 CONVERT(int32_t, int64_t, ElemKind::Int32ITy, ElemKind::Int64ITy)
6843 CONVERT(int64_t, float, ElemKind::Int64ITy, ElemKind::FloatTy)
6844 CONVERT(int64_t, float16_t, ElemKind::Int64ITy, ElemKind::Float16Ty)
6845 CONVERT(int64_t, bfloat16_t, ElemKind::Int64ITy, ElemKind::BFloat16Ty)
6846 CONVERT(int64_t, int32_t, ElemKind::Int64ITy, ElemKind::Int32ITy)
6847 CONVERT(bool, int32_t, ElemKind::BoolTy, ElemKind::Int32ITy)
6848#undef CONVERT
6849
6850 if (srcElType == ElemKind::UInt8FusedQTy &&
6851 destElType == ElemKind::UInt8FusedFP16QTy) {
6852 Tensor result = source->getCopyConvertedToType(ElemKind::UInt8FusedFP16QTy);
6853 dest->assign(&result);
6854 return;
6855 }
6856
6857 if ((srcElType == ElemKind::UInt8FusedFP16QTy ||
6858 srcElType == ElemKind::UInt4FusedFP16QTy) &&
6859 destElType == ElemKind::UInt8FusedQTy) {
6860 Tensor result = source->getCopyConvertedToType(ElemKind::UInt8FusedQTy);
6861 dest->assign(&result);
6862 return;
6863 }
6864
6865 if (srcElType == ElemKind::UInt4FusedFP16QTy &&
6866 destElType == ElemKind::UInt4FusedQTy) {
6867 Tensor result = source->getCopyConvertedToType(ElemKind::UInt4FusedQTy);
6868 dest->assign(&result);
6869 return;
6870 }
6871
6872 llvm_unreachable("Type not supported");
6873}
6874
6875template <typename ElemTy>
6876void BoundInterpreterFunction::fwdBatchedPairwiseDotProductInstImpl(
6877 const BatchedPairwiseDotProductInst *I) {
6878 auto destT = getTensor(I->getDest());
6879 auto destH = destT->getHandle<ElemTy>();
6880
6881 dim_t batchCount = destT->getType().dims()[0];
6882
6883 // Gather all batched vector operands into an array so that they can be
6884 // indexed easily.
6885 std::vector<Value *> srcs;
6886 for (unsigned i = 1, e = I->getNumOperands(); i < e; ++i) {
6887 auto op = I->getOperand(i);
6888 srcs.emplace_back(op.first);
6889 }
6890
6891 // pairIdx is the total number of pairs (i, j) that have been processed.
6892 unsigned pairIdx = 0;
6893
6894 // For each src operand:
6895 for (unsigned i = 1, e = I->getNumInputs(); i < e; ++i) {
6896 auto vAH = getTensor(srcs[i])->getHandle<ElemTy>();
6897 dim_t vectorSize = getTensor(srcs[i])->getType().dims()[1];
6898
6899 // Compute the dot product of src[i] with every other vector with a
6900 // smaller index.
6901 for (unsigned j = 0; j < i; ++j) {
6902 auto vBH = getTensor(srcs[j])->getHandle<ElemTy>();
6903
6904 // Process all batches for a given pair (i, j).
6905 for (dim_t b = 0; b < batchCount; ++b) {
6906 ElemTy accum = 0;
6907
6908 for (dim_t k = 0; k < vectorSize; ++k) {
6909 accum += vAH.at({b, k}) * vBH.at({b, k});
6910 }
6911
6912 destH.at({b, pairIdx}) = accum;
6913 }
6914
6915 ++pairIdx;
6916 }
6917 }
6918}
6919
6920void BoundInterpreterFunction::fwdBatchedPairwiseDotProductInst(
6921 const BatchedPairwiseDotProductInst *I) {
6922 dispatchImpl(fwdBatchedPairwiseDotProductInstImpl,
6923 I->getDest()->getElementType(), I);
6924}
6925
6926template <typename ElemTy>
6927void BoundInterpreterFunction::fwdBatchedPairwiseDotProductGradInstImpl(
6928 const BatchedPairwiseDotProductGradInst *I) {
6929 auto destGradT = getTensor(I->getDestGrad());
6930 auto destGradH = destGradT->getHandle<ElemTy>();
6931
6932 dim_t batchCount = destGradT->getType().dims()[0];
6933
6934 // Gather all batched vector operands into arrays so that they can be
6935 // indexed easily. Operands 1 -> numInputs are gradients of inputs, and
6936 // operands numInputs + 1 -> numOperands - 1 are the corresponding original
6937 // inputs.
6938 std::vector<Value *> srcs, srcGrads;
6939 for (unsigned i = 0, e = I->getNumInputs(); i < e; ++i) {
6940 auto gradOp = I->getOperand(i + 1);
6941 auto inputOp = I->getOperand(i + 1 + e);
6942
6943 srcGrads.emplace_back(gradOp.first);
6944 srcs.emplace_back(inputOp.first);
6945 }
6946
6947 // Zero initialize all srcGrad tensors.
6948 for (auto &s : srcGrads) {
6949 getTensor(s)->zero();
6950 }
6951
6952 // pairIdx is the total number of pairs (i, j) that have been processed.
6953 unsigned pairIdx = 0;
6954
6955 // For each srcGrad operand:
6956 for (unsigned i = 0, e = I->getNumInputs(); i < e; ++i) {
6957 auto dvAH = getTensor(srcGrads[i])->getHandle<ElemTy>();
6958 dim_t vectorSize = getTensor(srcs[i])->getType().dims()[1];
6959
6960 // Accmulate into it the product of the gradient of all dot products that
6961 // src[i] contributed to and the corresponding vectors that src[i] was
6962 // dotted with.
6963 for (unsigned j = i + 1; j < e; ++j) {
6964 auto vBH = getTensor(srcs[j])->getHandle<ElemTy>();
6965
6966 // Process all batches for a given pair (i, j).
6967 for (dim_t b = 0; b < batchCount; ++b) {
6968 ElemTy grad = destGradH.at({b, pairIdx});
6969
6970 for (dim_t k = 0; k < vectorSize; ++k) {
6971 dvAH.at({b, k}) += grad * vBH.at({b, k});
6972 }
6973 }
6974
6975 ++pairIdx;
6976 }
6977 }
6978}
6979
6980void BoundInterpreterFunction::fwdBatchedPairwiseDotProductGradInst(
6981 const BatchedPairwiseDotProductGradInst *I) {
6982 dispatchImpl(fwdBatchedPairwiseDotProductGradInstImpl,
6983 I->getDestGrad()->getElementType(), I);
6984}
6985
6986template <typename ElemTy>
6987void BoundInterpreterFunction::fwdFlipInstImpl(const FlipInst *I) {
6988
6989 static_assert(max_tensor_dimensions == 6,
6990 "Loops below assume max_tensor_dimensions = 6.");
6991
6992 auto *src = I->getSrc();
6993 auto *dest = I->getDest();
6994
6995 // Get unowned handles of src and dest with dims expanded to maximum.
6996 ShapeVector eDims = expandDimsToMax(src->dims());
6997 auto eSrc = getTensor(src)->getUnowned(eDims);
6998 auto eDest = getTensor(dest)->getUnowned(eDims);
6999 auto srcH = eSrc.getHandle<ElemTy>();
7000 auto destH = eDest.getHandle<ElemTy>();
7001
7002#define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5) \
7003 for (dim_t idx0 = 0; idx0 < eDims[0]; idx0++) \
7004 for (dim_t idx1 = 0; idx1 < eDims[1]; idx1++) \
7005 for (dim_t idx2 = 0; idx2 < eDims[2]; idx2++) \
7006 for (dim_t idx3 = 0; idx3 < eDims[3]; idx3++) \
7007 for (dim_t idx4 = 0; idx4 < eDims[4]; idx4++) \
7008 for (dim_t idx5 = 0; idx5 < eDims[5]; idx5++) { \
7009 destH.at({_D0, _D1, _D2, _D3, _D4, _D5}) = \
7010 srcH.at({idx0, idx1, idx2, idx3, idx4, idx5}); \
7011 } \
7012 return;
7013
7014 switch (I->getAxis()) {
7015 case 0:
7016 LOOP_AXIS_CASE(eDims[0] - 1 - idx0, idx1, idx2, idx3, idx4, idx5);
7017 case 1:
7018 LOOP_AXIS_CASE(idx0, eDims[1] - 1 - idx1, idx2, idx3, idx4, idx5);
7019 case 2:
7020 LOOP_AXIS_CASE(idx0, idx1, eDims[2] - 1 - idx2, idx3, idx4, idx5);
7021 case 3:
7022 LOOP_AXIS_CASE(idx0, idx1, idx2, eDims[3] - 1 - idx3, idx4, idx5);
7023 case 4:
7024 LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, eDims[4] - 1 - idx4, idx5);
7025 case 5:
7026 LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, idx4, eDims[5] - 1 - idx5);
7027 default:
7028 llvm_unreachable("Axis should be less than max_tensor_dimensions.");
7029 }
7030}
7031
7032void BoundInterpreterFunction::fwdFlipInst(const FlipInst *I) {
7033 dispatchImpl(fwdFlipInstImpl, I->getSrc()->getElementType(), I);
7034}
7035
7036//===----------------------------------------------------------------------===//
7037// Instructions used by ObjectDetection
7038//===----------------------------------------------------------------------===//
7039static void maxMin(float lhs, float rhs, float &min, float &max) {
7040 if (lhs >= rhs) {
7041 min = rhs;
7042 max = lhs;
7043 } else {
7044 min = lhs;
7045 max = rhs;
7046 }
7047}
7048
7049using ClassBox = std::pair<float, dim_t>;
7050
7051struct Box {
7052 float classValue{0.0f};
7053 dim_t batchIndex{0};
7054 dim_t classIndex{0};
7055 dim_t boxIndex{0};
7056};
7057
7058template <typename ElemTy>
7059static bool doIOU(Handle<ElemTy> &boxes, dim_t batchIndex,
7060 dim_t selectedBoxIndex, dim_t candidateBoxIndex,
7061 int centerPointBox, float iouThreshold, bool isV4) {
7062 float sx[] = {0.0f, 0.0f, 0.0f, 0.0f};
7063 float cx[] = {0.0f, 0.0f, 0.0f, 0.0f};
7064
7065 if (isV4) {
7066 for (dim_t i = 0; i < 4; ++i) {
7067 sx[i] = boxes.at({selectedBoxIndex, i});
7068 cx[i] = boxes.at({candidateBoxIndex, i});
7069 }
7070 } else {
7071 for (dim_t i = 0; i < 4; ++i) {
7072 sx[i] = boxes.at({batchIndex, selectedBoxIndex, i});
7073 cx[i] = boxes.at({batchIndex, candidateBoxIndex, i});
7074 }
7075 }
7076
7077 float xSMin = 0.0f;
7078 float ySMin = 0.0f;
7079 float xSMax = 0.0f;
7080 float ySMax = 0.0f;
7081
7082 float xCMin = 0.0f;
7083 float yCMin = 0.0f;
7084 float xCMax = 0.0f;
7085 float yCMax = 0.0f;
7086
7087 // Standardizing coordinates so that (xmin, ymin) is upper left corner of a
7088 // box and (xmax, ymax) is lower right corner of the box.
7089 if (!centerPointBox) {
7090 // 0 means coordinates for diagonal ends of a box.
7091 // Coordinates can either be absolute or normalized.
7092 maxMin(sx[0], sx[2], xSMin, xSMax);
7093 maxMin(sx[1], sx[3], ySMin, ySMax);
7094
7095 maxMin(cx[0], cx[2], xCMin, xCMax);
7096 maxMin(cx[1], cx[3], yCMin, yCMax);
7097 } else {
7098 float halfWidthS = sx[2] / 2.0f;
7099 float halfHeightS = sx[3] / 2.0f;
7100 float halfWidthC = cx[2] / 2.0f;
7101 float halfHeightC = cx[3] / 2.0f;
7102
7103 xSMin = sx[0] - halfWidthS;
7104 ySMin = sx[1] - halfHeightS;
7105 xSMax = sx[0] + halfWidthS;
7106 ySMax = sx[1] + halfHeightS;
7107
7108 xCMin = cx[0] - halfWidthC;
7109 yCMin = cx[1] - halfHeightC;
7110 xCMax = cx[0] + halfWidthC;
7111 yCMax = cx[1] + halfHeightC;
7112 }
7113
7114 // finding upper left and lower right corner of a box formed by
7115 // intersection.
7116 float xMin = std::max(xSMin, xCMin);
7117 float yMin = std::max(ySMin, yCMin);
7118 float xMax = std::min(xSMax, xCMax);
7119 float yMax = std::min(ySMax, yCMax);
7120
7121 float intersectionArea =
7122 std::max(0.0f, xMax - xMin) * std::max(0.0f, yMax - yMin);
7123
7124 if (intersectionArea == 0.0f) {
7125 return false;
7126 }
7127
7128 float sArea = (xSMax - xSMin) * (ySMax - ySMin);
7129 float cArea = (xCMax - xCMin) * (yCMax - yCMin);
7130 float unionArea = sArea + cArea - intersectionArea;
7131
7132 return intersectionArea > iouThreshold * unionArea;
7133}
7134
7135template <typename T>
7136void BoundInterpreterFunction::fwdNonMaxSuppressionInstImpl(
7137 glow::NonMaxSuppressionInst const *I) {
7138
7139 auto boxes = I->getBoxes();
7140 auto scores = I->getScores();
7141 auto indices = I->getIndices();
7142 auto numDetected = I->getNumberOfSelectedIndices();
7143 float iouThreshold = I->getIouThreshold();
7144 dim_t maxBoxesPerClass = I->getMaxOutputBoxesPerClass();
7145 float scoreThreshold = I->getScoreThreshold();
7146 unsigned centerPointBox = I->getCenterPointBox();
7147 bool isV4 = I->getIsTFVersion4();
7148
7149 auto boxesH = getTensor(boxes)->getHandle<float>();
7150 auto scoresH = getTensor(scores)->getHandle<float>();
7151 auto indicesH = getTensor(indices)->getHandle<T>();
7152 auto numDetectedH = getTensor(numDetected)->getHandle<T>();
7153
7154 int boxesBoxDim = boxes->dims().size() - 2;
7155
7156 dim_t numBatches = 1;
7157 dim_t numClasses = 1;
7158 dim_t numBoxes = boxes->dims()[boxesBoxDim];
7159
7160 size_t maxOutputPerBatch = 0;
7161
7162 if (!isV4) {
7163 int boxesBatchDim = boxes->dims().size() - 3;
7164
7165 int scoresBatchDim = scores->dims().size() - 3;
7166 int scoresBoxDim = scores->dims().size() - 1;
7167 int scoresClassDim = scores->dims().size() - 2;
7168 assert(scores->dims()[scoresBoxDim] == boxes->dims()[boxesBoxDim] &&
7169 "Mismatch between number of scores and number of boxes.");
7170 assert(scores->dims()[scoresBatchDim] == boxes->dims()[boxesBatchDim] &&
7171 "Mismatch in batch dimension.");
7172 (void)boxesBatchDim;
7173 (void)scoresBoxDim;
7174 numBatches = scores->dims()[scoresBatchDim];
7175 numClasses = scores->dims()[scoresClassDim];
7176 numBoxes = boxes->dims()[boxesBoxDim];
7177 maxOutputPerBatch =
7178 indices->dims()[indices->dims().size() - 2] / numBatches;
7179 } else {
7180 maxOutputPerBatch =
7181 indices->dims()[indices->dims().size() - 1] / numBatches;
7182 }
7183
7184 auto cmpFunc = [](const ClassBox &a, const ClassBox &b) {
7185 return a.first < b.first;
7186 };
7187
7188 std::vector<ClassBox> selectedIndices(numBoxes);
7189 dim_t outPutBoxIndex = 0;
7190
7191 for (dim_t batchIndex = 0; batchIndex < numBatches; ++batchIndex) {
7192 Box minBox{scoresH.raw(batchIndex * numClasses * numBoxes), batchIndex, 0,
7193 0};
7194 int32_t detectedPerBatch = 0;
7195 for (dim_t classIndex = 0; classIndex < numClasses; ++classIndex) {
7196 selectedIndices.clear();
7197 size_t detectedPerClass = 0;
7198 std::priority_queue<ClassBox, std::vector<ClassBox>, decltype(cmpFunc)>
7199 queue(cmpFunc);
7200
7201 for (size_t boxIndex = 0; boxIndex < numBoxes; ++boxIndex) {
7202 float classValue = scoresH.raw(
7203 (batchIndex * numClasses + classIndex) * numBoxes + boxIndex);
7204 if (classValue > scoreThreshold) {
7205 queue.emplace(classValue, boxIndex);
7206 }
7207 }
7208
7209 float tScore = minBox.classValue;
7210 while (!queue.empty()) {
7211 auto priorBox = queue.top();
7212 queue.pop();
7213
7214 bool selected = true;
7215 for (auto &sBox : selectedIndices) {
7216 if (doIOU(boxesH, batchIndex, sBox.second, priorBox.second,
7217 centerPointBox, iouThreshold, isV4)) {
7218 selected = false;
7219 break;
7220 }
7221 }
7222
7223 if (selected) {
7224 selectedIndices.emplace_back(priorBox);
7225 if (isV4) {
7226 indicesH.at({outPutBoxIndex}) = priorBox.second;
7227 tScore = scoresH.at({priorBox.second});
7228 } else {
7229 indicesH.at({outPutBoxIndex, 0}) = batchIndex;
7230 indicesH.at({outPutBoxIndex, 1}) = classIndex;
7231 indicesH.at({outPutBoxIndex, 2}) = priorBox.second;
7232 tScore = scoresH.at({batchIndex, classIndex, priorBox.second});
7233 }
7234
7235 ++outPutBoxIndex;
7236 ++detectedPerClass;
7237 ++detectedPerBatch;
7238 }
7239 if (maxBoxesPerClass == detectedPerClass) {
7240 break;
7241 }
7242 }
7243
7244 if (tScore < minBox.classValue) {
7245 minBox.classValue = tScore;
7246 minBox.classIndex = classIndex;
7247 if (isV4) {
7248 minBox.boxIndex = indicesH.at({outPutBoxIndex - 1});
7249 } else {
7250 minBox.boxIndex = indicesH.at({outPutBoxIndex - 1, 2});
7251 }
7252 }
7253 }
7254
7255 for (size_t i = detectedPerBatch; i < maxOutputPerBatch; ++i) {
7256 if (isV4) {
7257 indicesH.at({outPutBoxIndex}) = minBox.boxIndex;
7258 } else {
7259 indicesH.at({outPutBoxIndex, 0}) = minBox.batchIndex;
7260 indicesH.at({outPutBoxIndex, 1}) = minBox.classIndex;
7261 indicesH.at({outPutBoxIndex, 2}) = minBox.boxIndex;
7262 }
7263
7264 ++outPutBoxIndex;
7265 }
7266 // For ONNX NMS it's not used, for TF Batch Dimension is 1.
7267 for (dim_t i = 0; i < maxBoxesPerClass; ++i) {
7268 numDetectedH.at({batchIndex * maxBoxesPerClass + i}) = detectedPerBatch;
7269 }
7270 }
7271}
7272
7273void BoundInterpreterFunction::fwdNonMaxSuppressionInst(
7274 glow::NonMaxSuppressionInst const *I) {
7275 switch (I->getBoxes()->getElementType()) {
7276 case ElemKind::FloatTy:
7277 if (I->getIndices()->getElementType() == ElemKind::Int32ITy) {
7278 fwdNonMaxSuppressionInstImpl<int32_t>(I);
7279 } else if (I->getIndices()->getElementType() == ElemKind::Int64ITy) {
7280 fwdNonMaxSuppressionInstImpl<int64_t>(I);
7281 } else {
7282 llvm_unreachable("Output type is not supported.");
7283 }
7284 break;
7285 default:
7286 llvm_unreachable("Type is not supported.");
7287 break;
7288 }
7289}
7290
7291//===----------------------------------------------------------------------===//
7292// TensorFlowLite NonMaxSuppression
7293//===----------------------------------------------------------------------===//
7294static int32_t partition(int32_t *arr, int32_t low, int32_t high,
7295 float *values) {
7296 float pivot = values[high];
7297 int32_t i = (low - 1);
7298 float swap_float;
7299 int32_t swap_int;
7300
7301 for (int32_t j = low; j <= high - 1; j++) {
7302 if (values[j] > pivot) {
7303 i++;
7304
7305 swap_float = values[i];
7306 values[i] = values[j];
7307 values[j] = swap_float;
7308
7309 swap_int = arr[i];
7310 arr[i] = arr[j];
7311 arr[j] = swap_int;
7312 }
7313 }
7314
7315 swap_float = values[i + 1];
7316 values[i + 1] = values[high];
7317 values[high] = swap_float;
7318
7319 swap_int = arr[i + 1];
7320 arr[i + 1] = arr[high];
7321 arr[high] = swap_int;
7322
7323 return (i + 1);
7324}
7325
7326static void partial_sort(int32_t *arr, int32_t i, int32_t j, int32_t k,
7327 float *values) {
7328 int32_t p;
7329 if (i < j) {
7330 p = partition(arr, i, j, values);
7331
7332 partial_sort(arr, i, p - 1, k, values);
7333
7334 if (p < k - 1)
7335 partial_sort(arr, p + 1, j, k, values);
7336 }
7337}
7338
7339static void iota(int32_t *first, int32_t *last, int32_t value) {
7340 while (first != last) {
7341 *first++ = value;
7342 value++;
7343 }
7344}
7345
7346static void decreasing_partial_arg_sort(float *values, int32_t num_values,
7347 int32_t num_to_sort, int32_t *indices,
7348 float *aux_values) {
7349 iota(indices, indices + num_values, 0);
7350
7351 memcpy(aux_values, values, sizeof(float) * num_values);
7352
7353 partial_sort(indices, 0, num_values - 1, num_to_sort, aux_values);
7354}
7355
7356static void select_detection_above_score_threshold(
7357 float *scores, int32_t num_scores, float threshold, float *keep_values,
7358 int32_t *keep_indices, int32_t *num_indices) {
7359 int32_t idx = 0;
7360 for (int32_t i = 0; i < num_scores; i++) {
7361 if (scores[i] >= threshold) {
7362 keep_indices[idx] = i;
7363 keep_values[idx] = scores[i];
7364 idx++;
7365 }
7366 }
7367 *num_indices = idx;
7368}
7369
7370/// Compute the IOU (Intersection Over Union) metric between two boxes. Each
7371/// of box1 and box2 is a vector with 4 floating-point values with the box
7372/// coordinates in the following format: [ymin, xmin, ymax, xmax].
7373static float tflite_compute_iou(float *box1, float *box2) {
7374
7375 // Compute the areas of the two boxes.
7376 float box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1]);
7377 float box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1]);
7378
7379 // If box coordinates are invalid we return 0.
7380 if (box1Area <= 0 || box2Area <= 0) {
7381 return 0.0f;
7382 }
7383
7384 // Determine the coordinates of the intersection rectangle.
7385 float iYmin = MAX(box1[0], box2[0]);
7386 float iXmin = MAX(box1[1], box2[1]);
7387 float iYmax = MIN(box1[2], box2[2]);
7388 float iXmax = MIN(box1[3], box2[3]);
7389
7390 // Compute the area of the intersection rectangle.
7391 float iArea = MAX(0.0f, iXmax - iXmin) * MAX(0.0f, iYmax - iYmin);
7392
7393 // Compute the area of the union (reunion) rectangle.
7394 float uArea = box1Area + box2Area - iArea;
7395
7396 // Compute the Intersection Over Union metric.
7397 return iArea / uArea;
7398}
7399
7400static void tflite_helper(float *boxesPtr, int32_t num_boxes,
7401 float nms_score_threshold, float nms_iou_treshold,
7402 float *class_scores, int32_t num_scores,
7403 int32_t *selected, int32_t *num_selected,
7404 int32_t max_detections, int32_t *keep_indices,
7405 float *keep_scores, int32_t *sorted_indices_helper) {
7406
7407 *num_selected = 0;
7408
7409 int32_t num_scores_kept;
7410 select_detection_above_score_threshold(class_scores, num_boxes,
7411 nms_score_threshold, keep_scores,
7412 keep_indices, &num_scores_kept);
7413
7414 decreasing_partial_arg_sort(keep_scores, num_scores_kept, num_scores_kept,
7415 sorted_indices_helper, (float *)selected);
7416
7417 int32_t num_boxes_kept = num_scores_kept;
7418 int32_t output_size = MIN(num_boxes_kept, max_detections);
7419
7420 int32_t num_active_candidate = num_boxes_kept;
7421
7422 uint8_t *active_box_candidate = (uint8_t *)keep_scores;
7423
7424 for (int32_t row = 0; row < num_boxes_kept; row++) {
7425 active_box_candidate[row] = 1;
7426 }
7427
7428 for (int32_t i = 0; i < num_boxes_kept; i++) {
7429 if (num_active_candidate == 0 || *num_selected >= output_size)
7430 break;
7431 if (active_box_candidate[i] == 1) {
7432 selected[*num_selected] = keep_indices[sorted_indices_helper[i]];
7433 (*num_selected)++;
7434 active_box_candidate[i] = 0;
7435 num_active_candidate--;
7436 } else {
7437 continue;
7438 }
7439
7440 for (int32_t j = i + 1; j < num_boxes_kept; ++j) {
7441 if (active_box_candidate[j] == 1) {
7442
7443 float *box1 = boxesPtr + 4 * keep_indices[sorted_indices_helper[i]];
7444 float *box2 = boxesPtr + 4 * keep_indices[sorted_indices_helper[j]];
7445 float iou = tflite_compute_iou(box1, box2);
7446
7447 if (iou > nms_iou_treshold) {
7448 active_box_candidate[j] = 0;
7449 num_active_candidate--;
7450 }
7451 }
7452 }
7453 }
7454}
7455
7456static void tflite_detection_post_process_f(
7457 float *boxes, float *scores, float *anchors, float *detectionBoxes,
7458 int32_t *detectionClasses, float *detectionScores, int32_t *numDetections,
7459 int8_t *scratch, int32_t numBoxes, int32_t numTotalClasses,
7460 int32_t numClasses, int32_t maxDetections, int32_t maxClassesPerDetection,
7461 int32_t maxDetectionsPerClass, float iouThreshold, float scoreThreshold,
7462 float xScaleInv, float yScaleInv, float hScaleInv, float wScaleInv,
7463 bool regularNMS) {
7464
7465 // Decode the box coordinates in-place using the anchors.
7466 for (int32_t i = 0; i < numBoxes; i++) {
7467
7468 float *box = &boxes[i * 4];
7469 float *anchor = &anchors[i * 4];
7470
7471 float ycenter = box[0] * yScaleInv * anchor[2] + anchor[0];
7472 float xcenter = box[1] * xScaleInv * anchor[3] + anchor[1];
7473
7474 float half_h = 0.5f * expf(box[2] * hScaleInv) * anchor[2];
7475 float half_w = 0.5f * expf(box[3] * wScaleInv) * anchor[3];
7476
7477 box[0] = ycenter - half_h;
7478 box[1] = xcenter - half_w;
7479 box[2] = ycenter + half_h;
7480 box[3] = xcenter + half_w;
7481 }
7482
7483 int32_t max_categories_per_anchor = maxClassesPerDetection;
7484 int32_t num_categories_per_anchor =
7485 MIN(max_categories_per_anchor, numClasses);
7486 int32_t label_offset = numTotalClasses - numClasses;
7487
7488 if (regularNMS) {
7489 int32_t num_detections_per_class = maxDetectionsPerClass;
7490
7491 float *class_scores = (float *)(scratch);
7492 scratch += numBoxes * sizeof(float);
7493
7494 int32_t *box_indices_after_regular_nms = (int32_t *)(scratch);
7495 scratch += (numBoxes + maxDetections) * sizeof(int32_t);
7496
7497 float *scores_after_regular_nms = (float *)(scratch);
7498 scratch += (numBoxes + maxDetections) * sizeof(float);
7499
7500 int32_t size_of_sorted_indices = 0;
7501
7502 int32_t *sorted_indices = (int32_t *)(scratch);
7503 scratch += (numBoxes + maxDetections) * sizeof(int32_t);
7504
7505 float *sorted_values = (float *)(scratch);
7506 scratch += MIN(numBoxes, maxDetectionsPerClass) * sizeof(float);
7507
7508 int32_t *selected = (int32_t *)scratch;
7509 scratch += numBoxes * sizeof(int32_t);
7510
7511 int32_t *keep_indices = (int32_t *)(scratch);
7512 scratch += numBoxes * sizeof(int32_t);
7513
7514 float *keep_scores = (float *)(scratch);
7515 scratch += numBoxes * sizeof(float);
7516
7517 int32_t *sorted_indices_helper = (int32_t *)scratch;
7518 scratch += numBoxes * sizeof(int32_t);
7519
7520 for (int32_t col = 0; col < numClasses; col++) {
7521 for (int32_t row = 0; row < numBoxes; row++) {
7522 class_scores[row] =
7523 *(scores + row * numTotalClasses + col + label_offset);
7524 }
7525
7526 int32_t num_selected;
7527 tflite_helper(boxes, numBoxes, scoreThreshold, iouThreshold, class_scores,
7528 numBoxes, selected, &num_selected, num_detections_per_class,
7529 keep_indices, keep_scores, sorted_indices_helper);
7530
7531 int32_t output_index = size_of_sorted_indices;
7532 for (int32_t i = 0; i < num_selected; i++) {
7533 int32_t selected_index = selected[i];
7534 box_indices_after_regular_nms[output_index] =
7535 (selected_index * numTotalClasses + col + label_offset);
7536 scores_after_regular_nms[output_index] = class_scores[selected_index];
7537 output_index++;
7538 }
7539
7540 int32_t num_indices_to_sort = MIN(output_index, maxDetections);
7541
7542 decreasing_partial_arg_sort(scores_after_regular_nms, output_index,
7543 num_indices_to_sort, sorted_indices,
7544 keep_scores);
7545
7546 for (int32_t row = 0; row < num_indices_to_sort; row++) {
7547 int32_t temp = sorted_indices[row];
7548 sorted_indices[row] = box_indices_after_regular_nms[temp];
7549 sorted_values[row] = scores_after_regular_nms[temp];
7550 }
7551
7552 for (int32_t row = 0; row < num_indices_to_sort; row++) {
7553 box_indices_after_regular_nms[row] = sorted_indices[row];
7554 scores_after_regular_nms[row] = sorted_values[row];
7555 }
7556
7557 size_of_sorted_indices = num_indices_to_sort;
7558 }
7559
7560 for (int32_t output_box_index = 0;
7561 output_box_index < size_of_sorted_indices; output_box_index++) {
7562
7563 int32_t anchor_index =
7564 box_indices_after_regular_nms[output_box_index] / numTotalClasses;
7565 int32_t class_index = box_indices_after_regular_nms[output_box_index] -
7566 anchor_index * numTotalClasses - label_offset;
7567 float selected_score = scores_after_regular_nms[output_box_index];
7568 float *box = boxes + anchor_index * 4;
7569
7570 *detectionBoxes++ = *box++;
7571 *detectionBoxes++ = *box++;
7572 *detectionBoxes++ = *box++;
7573 *detectionBoxes++ = *box++;
7574 *detectionClasses++ = class_index;
7575 *detectionScores++ = selected_score;
7576 }
7577
7578 *numDetections = size_of_sorted_indices;
7579 } else {
7580 float *max_scores = (float *)scratch;
7581 scratch += numBoxes * sizeof(float);
7582
7583 int32_t *sorted_classes_indices = (int32_t *)scratch;
7584 scratch += numBoxes * MIN(maxDetections, numClasses) * sizeof(int32_t);
7585
7586 int32_t *selected = (int32_t *)scratch;
7587 scratch += numBoxes * sizeof(int32_t);
7588
7589 int32_t *keep_indices = (int32_t *)(scratch);
7590 scratch += numBoxes * sizeof(int32_t);
7591
7592 float *keep_scores = (float *)(scratch);
7593 scratch += numBoxes * sizeof(float);
7594
7595 int32_t *sorted_indices_helper = (int32_t *)scratch;
7596 scratch += numBoxes * sizeof(int32_t);
7597
7598 for (int32_t row = 0; row < numBoxes; row++) {
7599 float *box_scores = scores + row * numTotalClasses + label_offset;
7600 int32_t *class_indices =
7601 sorted_classes_indices + row * num_categories_per_anchor;
7602
7603 decreasing_partial_arg_sort(box_scores, numClasses,
7604 num_categories_per_anchor, keep_indices,
7605 keep_scores);
7606
7607 for (int32_t i = 0; i < num_categories_per_anchor; i++) {
7608 class_indices[i] = keep_indices[i];
7609 }
7610
7611 max_scores[row] = box_scores[class_indices[0]];
7612 }
7613
7614 int32_t selected_size = 0;
7615 tflite_helper(boxes, numBoxes, scoreThreshold, iouThreshold, max_scores,
7616 numBoxes, selected, &selected_size, maxDetections,
7617 keep_indices, keep_scores, sorted_indices_helper);
7618
7619 int32_t num_detections = 0;
7620 for (int32_t i = 0; i < selected_size; i++) {
7621
7622 int32_t selected_index = selected[i];
7623 float *box = boxes + selected_index * 4;
7624 float *box_scores =
7625 scores + selected_index * numTotalClasses + label_offset;
7626 int32_t *class_indices =
7627 sorted_classes_indices + selected_index * num_categories_per_anchor;
7628
7629 for (int32_t col = 0; (col < num_categories_per_anchor) &&
7630 (num_detections <= selected_size);
7631 ++col) {
7632 *detectionBoxes++ = box[0];
7633 *detectionBoxes++ = box[1];
7634 *detectionBoxes++ = box[2];
7635 *detectionBoxes++ = box[3];
7636 *detectionClasses++ = class_indices[col];
7637 *detectionScores++ = box_scores[class_indices[col]];
7638 num_detections++;
7639 }
7640 }
7641
7642 *numDetections = selected_size;
7643 }
7644}
7645
7646void BoundInterpreterFunction::fwdTFLiteDetectionPostProcessInst(
7647 glow::TFLiteDetectionPostProcessInst const *I) {
7648 auto boxes = I->getBoxes();
7649 auto scores = I->getScores();
7650 auto anchors = I->getAnchors();
7651 auto detectionBoxes = I->getDetectionBoxes();
7652 auto detectionClasses = I->getDetectionClasses();
7653 auto detectionScores = I->getDetectionScores();
7654 auto numDetections = I->getNumDetections();
7655 auto scratch = I->getScratch();
7656
7657 // Get raw pointers.
7658 float *boxesPtr = (float *)getTensor(boxes)->getUnsafePtr();
7659 float *scoresPtr = (float *)getTensor(scores)->getUnsafePtr();
7660 float *anchorsPtr = (float *)getTensor(anchors)->getUnsafePtr();
7661 float *detectionBoxesPtr = (float *)getTensor(detectionBoxes)->getUnsafePtr();
7662 int32_t *detectionClassesPtr =
7663 (int32_t *)getTensor(detectionClasses)->getUnsafePtr();
7664 float *detectionScoresPtr =
7665 (float *)getTensor(detectionScores)->getUnsafePtr();
7666 int32_t *numDetectionsPtr =
7667 (int32_t *)getTensor(numDetections)->getUnsafePtr();
7668 int8_t *scratchPtr = (int8_t *)getTensor(scratch)->getUnsafePtr();
7669
7670 // Get parameters.
7671 int32_t numBoxes = boxes->dims()[1];
7672 int32_t numTotalClasses = scores->dims()[2];
7673 int32_t numClasses = I->getNumClasses();
7674 int32_t maxDetections = I->getMaxDetections();
7675 int32_t maxClassesPerDetection = I->getMaxClassesPerDetection();
7676 int32_t maxDetectionsPerClass = I->getMaxDetectionsPerClass();
7677 float iouThreshold = I->getIouThreshold();
7678 float scoreThreshold = I->getScoreThreshold();
7679 float xScaleInv = 1.0f / I->getXScale();
7680 float yScaleInv = 1.0f / I->getYScale();
7681 float hScaleInv = 1.0f / I->getHScale();
7682 float wScaleInv = 1.0f / I->getWScale();
7683 bool regularNMS = I->getRegularNMS();
7684
7685 // Compute TFLite NMS.
7686 tflite_detection_post_process_f(
7687 boxesPtr, scoresPtr, anchorsPtr, detectionBoxesPtr, detectionClassesPtr,
7688 detectionScoresPtr, numDetectionsPtr, scratchPtr, numBoxes,
7689 numTotalClasses, numClasses, maxDetections, maxClassesPerDetection,
7690 maxDetectionsPerClass, iouThreshold, scoreThreshold, xScaleInv, yScaleInv,
7691 hScaleInv, wScaleInv, regularNMS);
7692}
7693
7694void BoundInterpreterFunction::fwdAudioSpectrogramInstFloatImpl(
7695 glow::AudioSpectrogramInst const *I) {
7696
7697 auto spectrogram = I->getSpectrogram();
7698 auto input = I->getInput();
7699 auto window = I->getWindow();
7700 int64_t windowSize = I->getWindowSize();
7701 int64_t windowStride = I->getWindowStride();
7702
7703 auto spectrogramH = getTensor(spectrogram)->getHandle<float>();
7704 auto inputH = getTensor(input)->getHandle<float>();
7705 auto windowH = getTensor(window)->getHandle<float>();
7706
7707 // Compute window count.
7708 int64_t inputLength = input->size();
7709 int64_t windowCount =
7710 std::floor((inputLength - windowSize) / windowStride) + 1;
7711
7712 // Compute FFT length (next power of 2) and spectrogram length.
7713 dim_t fftLen = 1 << (dim_t)std::ceil(std::log2((double)windowSize));
7714 dim_t specLen = fftLen / 2 + 1;
7715
7716 // Allocate temporary buffers.
7717 auto winOut = std::make_unique<float[]>(windowSize);
7718 auto fftRealOut = std::make_unique<float[]>(specLen);
7719 auto fftImagOut = std::make_unique<float[]>(specLen);
7720
7721 // Compute the spectrogram.
7722 for (dim_t winIdx = 0; int64_t(winIdx) < windowCount; winIdx++) {
7723
7724 // Windowing.
7725 for (int64_t n = 0; n < windowSize; n++) {
7726 winOut[n] = inputH.raw(winIdx * windowStride + n) * windowH.raw(n);
7727 }
7728
7729 // Compute spectrum (perform FFT).
7730 for (dim_t k = 0; k < specLen; k++) {
7731 fftRealOut[k] = 0;
7732 fftImagOut[k] = 0;
7733 for (int n = 0; n < windowSize; n++) {
7734 fftRealOut[k] +=
7735 winOut[n] * cos(2.0 * M_PI * (double)(n * k) / (double)(fftLen));
7736 fftImagOut[k] -=
7737 winOut[n] * sin(2.0 * M_PI * (double)(n * k) / (double)(fftLen));
7738 }
7739 }
7740
7741 // Compute spectrum magnitude/power.
7742 if (I->getMagnitudeSquared()) {
7743 for (dim_t k = 0; k < specLen; k++) {
7744 spectrogramH.at({winIdx, k}) =
7745 fftRealOut[k] * fftRealOut[k] + fftImagOut[k] * fftImagOut[k];
7746 }
7747 } else {
7748 for (dim_t k = 0; k < specLen; k++) {
7749 spectrogramH.at({winIdx, k}) =
7750 sqrt(fftRealOut[k] * fftRealOut[k] + fftImagOut[k] * fftImagOut[k]);
7751 }
7752 }
7753 }
7754}
7755
7756void BoundInterpreterFunction::fwdAudioSpectrogramInst(
7757 glow::AudioSpectrogramInst const *I) {
7758 auto inputTy = I->getInput()->getElementType();
7759 auto spectrogramTy = I->getSpectrogram()->getElementType();
7760 if ((inputTy == ElemKind::FloatTy) && (spectrogramTy == ElemKind::FloatTy)) {
7761 fwdAudioSpectrogramInstFloatImpl(I);
7762 } else {
7763 llvm_unreachable("Type is not supported.");
7764 }
7765}
7766
7767void BoundInterpreterFunction::fwdMFCCInstFloatImpl(glow::MFCCInst const *I) {
7768
7769 auto coefficients = I->getCoefficients();
7770 auto spectrogram = I->getSpectrogram();
7771 auto melWeights = I->getMelWeights();
7772 auto melRanges = I->getMelRanges();
7773 auto dctMat = I->getDctMat();
7774 int64_t filterBankCount = I->getFilterBankCount();
7775 int64_t numCoefficients = I->getNumCoefficients();
7776
7777 auto coefficientsH = getTensor(coefficients)->getHandle<float>();
7778 auto spectrogramH = getTensor(spectrogram)->getHandle<float>();
7779 auto melWeightsH = getTensor(melWeights)->getHandle<float>();
7780 auto melRangesH = getTensor(melRanges)->getHandle<int32_t>();
7781 auto dctMatH = getTensor(dctMat)->getHandle<float>();
7782
7783 // Perform MFCC for all the windows.
7784 auto winNum = spectrogram->dims()[0];
7785 auto melBuff = std::make_unique<float[]>(filterBankCount);
7786 for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
7787
7788 // Apply Mel filter bank mapping. We use sqrt for the spectrogram since we
7789 // assume the spectrogram is a power value and not a magnitude.
7790 dim_t melBinCoeffIdx = 0;
7791 for (int64_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
7792 int32_t freqIdxStart = melRangesH.raw(2 * melIdx + 0);
7793 int32_t freqIdxStop = melRangesH.raw(2 * melIdx + 1);
7794 float melPwr = 0.0f;
7795 for (dim_t freqIdx = freqIdxStart; int32_t(freqIdx) <= freqIdxStop;
7796 freqIdx++) {
7797 melPwr += std::sqrt(spectrogramH.at({winIdx, freqIdx})) *
7798 melWeightsH.raw(melBinCoeffIdx++);
7799 }
7800 melBuff[melIdx] = melPwr;
7801 }
7802
7803 // Take logarithm in-place (avoid log(0)).
7804 for (int64_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
7805 float melPwr = melBuff[melIdx];
7806 melBuff[melIdx] = (melPwr == 0.0)
7807 ? logf(std::numeric_limits<float>::min())
7808 : logf(melPwr);
7809 }
7810
7811 // Compute DCT transform.
7812 for (dim_t k = 0; int64_t(k) < numCoefficients; k++) {
7813 float dctOut = 0.0f;
7814 for (dim_t n = 0; int64_t(n) < filterBankCount; n++) {
7815 dctOut += dctMatH.at({k, n}) * melBuff[n];
7816 }
7817 coefficientsH.at({winIdx, k}) = dctOut;
7818 }
7819 }
7820}
7821
7822void BoundInterpreterFunction::fwdMFCCInst(glow::MFCCInst const *I) {
7823 auto spectrogramTy = I->getSpectrogram()->getElementType();
7824 auto coefficientsTy = I->getCoefficients()->getElementType();
7825 if ((spectrogramTy == ElemKind::FloatTy) &&
7826 (coefficientsTy == ElemKind::FloatTy)) {
7827 fwdMFCCInstFloatImpl(I);
7828 } else {
7829 llvm_unreachable("Type is not supported.");
7830 }
7831}
7832
7833namespace {
7834/// Positions of the input values to be used for bilinear interpolation for
7835/// each sample point and the weights to use for each.
7836template <typename T> struct BinGrid {
7837 dim_t left;
7838 dim_t top;
7839 dim_t right;
7840 dim_t bottom;
7841 T leftW;
7842 T topW;
7843 T rightW;
7844 T bottomW;
7845};
7846} // namespace
7847
7848/// Function to calculate the xy coordinates of the resized image (grid)
7849/// Ref: https://arxiv.org/pdf/1703.06870.pdf and OnnxRuntime implementation
7850/// \p featureMapHeight and \p featureMapWidth are the dimensions of the
7851/// input feature map to the operator, \p outputHeight and \p outputWidth are
7852/// the dimensions of the operator output tensor, \p samplingRatioH and \p
7853/// samplingRatioW are the number of sampling points to use in each bin in the
7854/// height and width directions respectively (total sample points is
7855/// samplingRatioH * samplingRatioW), \p boxHeight and \p boxWidth are the
7856/// height and width of the RoI box, \p yRef and \p xRef are the adjustment to
7857/// be made for each sampling point, this is either the top left corer of the
7858/// box for RoiAlign or a vector to be added to center point after rotation
7859/// for RoiAlignRotated, \p rotated is true if the op is RoiAlignRotated, \p
7860/// theta is the rotation angle in the case of RoiAlignRotated and is unused
7861/// in RoiAlign, \p boxCenterH and \p boxCenterW are the center of the box
7862/// used for rotation in the case of RoiAlignRotated and unused in the case of
7863/// RoiAlign. \returns a vector of BinGrids, each one to be used to compute a
7864/// single sample point value.
7865template <typename T>
7866static std::vector<BinGrid<T>> getROIAlignInterpolationCoordinates(
7867 dim_t featureMapHeight, dim_t featureMapWidth, dim_t outputHeight,
7868 dim_t outputWidth, dim_t samplingRatioH, dim_t samplingRatioW, T boxHeight,
7869 T boxWidth, T yRef, T xRef, bool rotated, T theta, T boxCenterH,
7870 T boxCenterW) {
7871
7872 T sinTheta = T(0.0f);
7873 T cosTheta = T(0.0f);
7874 if (rotated) {
7875 sinTheta = T(std::sin(float(theta)));
7876 cosTheta = T(std::cos(float(theta)));
7877 }
7878
7879 std::vector<BinGrid<T>> binGrids;
7880
7881 // height and width of the each bin in the final output
7882 const T binH = boxHeight / T(outputHeight);
7883 const T binW = boxWidth / T(outputWidth);
7884 const T roiBinSizeH = binH / T(samplingRatioH);
7885 const T roiBinSizeW = binW / T(samplingRatioW);
7886 for (dim_t oh = 0; oh < outputHeight; oh++) {
7887 for (dim_t ow = 0; ow < outputWidth; ow++) {
7888 for (dim_t gh = 0; gh < samplingRatioH; gh++) {
7889 for (dim_t gw = 0; gw < samplingRatioW; gw++) {
7890 // x,y coordinates or vector w.r.t input dimensions
7891 T inY = yRef + (T(oh) * binH) + ((T(gh) + T(0.5f)) * roiBinSizeH);
7892 T inX = xRef + (T(ow) * binW) + ((T(gw) + T(0.5f)) * roiBinSizeW);
7893
7894 // If ROI is rotated, rotate by theta around the box center then
7895 // translate
7896 if (rotated) {
7897 T inYY = inY;
7898 T inXX = inX;
7899 inY = inYY * cosTheta - inXX * sinTheta + boxCenterH;
7900 inX = inXX * cosTheta + inYY * sinTheta + boxCenterW;
7901 }
7902
7903 // zero pad mal-formed boxes
7904 if (inY < T(-1) || inY > T(featureMapHeight)) {
7905 BinGrid<T> bg = BinGrid<T>{0, 0, 0, 0, 0, 0, 0, 0};
7906 binGrids.push_back(bg);
7907 continue;
7908 }
7909 if (inX < T(-1) || inX > T(featureMapWidth)) {
7910 BinGrid<T> bg = BinGrid<T>{0, 0, 0, 0, 0, 0, 0, 0};
7911 binGrids.push_back(bg);
7912 continue;
7913 }
7914
7915 // clip to input dimensions
7916 T y = std::min(std::max(inY, T(0)), T(featureMapHeight - 1));
7917 T x = std::min(std::max(inX, T(0)), T(featureMapWidth - 1));
7918
7919 // calc interpolation parameters
7920 const dim_t yl = dim_t(std::floor(float(y)));
7921 const dim_t xl = dim_t(std::floor(float(x)));
7922 const dim_t yh = std::min(yl + 1, featureMapHeight - 1);
7923 const dim_t xh = std::min(xl + 1, featureMapWidth - 1);
7924
7925 BinGrid<T> bg;
7926 bg.left = xl;
7927 bg.top = yl;
7928 bg.right = xh;
7929 bg.bottom = yh;
7930
7931 bg.rightW = x - T(xl);
7932 bg.bottomW = y - T(yl);
7933 bg.leftW = T(1.0) - bg.rightW;
7934 bg.topW = T(1.0) - bg.bottomW;
7935
7936 binGrids.push_back(bg);
7937 } // end of w
7938 } // end of h
7939 } // end of W
7940 } // end of H
7941
7942 return binGrids;
7943}
7944
7945// Implementation of ROIAlign as described in
7946// https://arxiv.org/pdf/1703.06870.pdf ROIAlign is similar to crop_and_resize
7947// + pooling with minor modifications in the crop_and_resize.
7948template <typename T>
7949void BoundInterpreterFunction::fwdROIAlignInstFloatImpl(
7950 glow::ROIAlignInst const *I) {
7951 auto featureMap = I->getFeatureMap();
7952 auto boxes = I->getBoxes();
7953 auto batchIndices = I->getBatchIndices();
7954 auto result = I->getResult();
7955
7956 auto boxesH = getTensor(boxes)->getHandle<T>();
7957 auto featureMapH = getTensor(featureMap)->getHandle<T>();
7958 auto resultH = getTensor(result)->getHandle<T>();
7959
7960 const bool rotated = I->getRotated();
7961 const PoolingMode mode = PoolingMode(I->getMode());
7962 const bool aligned = I->getAligned();
7963 const dim_t samplingRatio = I->getSamplingRatio();
7964 const T spatialScale = I->getSpatialScale();
7965
7966 const dim_t featureMapHeight = featureMapH.dims()[1];
7967 const dim_t featureMapWidth = featureMapH.dims()[2];
7968 const dim_t numBoxes = resultH.dims()[0];
7969 const dim_t outputHeight = resultH.dims()[1];
7970 const dim_t outputWidth = resultH.dims()[2];
7971 const dim_t depth = resultH.dims()[3];
7972
7973 const T offset = aligned ? T(0.5) : T(0);
7974
7975 bool useSeparateBatchIndexVector = true;
7976 dim_t boxesStartCol = 0;
7977 if (rotated || boxes->dims()[1] == 5) {
7978 boxesStartCol = 1;
7979 useSeparateBatchIndexVector = false;
7980 }
7981
7982 // Extract batch indices from batchIndices tensor if that is used (only used
7983 // by ONNX which may provide Int64ITy tensors.)
7984 std::vector<dim_t> batchIndicesExtracted;
7985 if (useSeparateBatchIndexVector) {
7986 Tensor *batchIndicesTensor = getTensor(batchIndices);
7987 auto batchIndicesElemKind = batchIndicesTensor->getElementType();
7988 for (dim_t b = 0; b < numBoxes; b++) {
7989 if (batchIndicesElemKind == ElemKind::Int32ITy) {
7990 batchIndicesExtracted.push_back(
7991 batchIndicesTensor->getHandle<int32_t>().at({b}));
7992 } else {
7993 batchIndicesExtracted.push_back(
7994 batchIndicesTensor->getHandle<int64_t>().at({b}));
7995 }
7996 }
7997 }
7998
7999 for (dim_t b = 0; b < numBoxes; b++) {
8000 dim_t batchIndex;
8001 if (useSeparateBatchIndexVector) {
8002 batchIndex = batchIndicesExtracted[b];
8003 } else {
8004 batchIndex = dim_t(float(boxesH.at({b, 0})));
8005 }
8006
8007 // Values used to determine sampling points during bilinear interpolation.
8008 // yRef and xRef have different interpreterations for rotated vs unrotated
8009 // cases (vector vs coordinates) but are used very similarly.
8010 T yRef;
8011 T xRef;
8012 T boxHeight;
8013 T boxWidth;
8014
8015 // Values only used in rotated case.
8016 T theta = T(0.0);
8017 T boxCenterH = T(0.0);
8018 T boxCenterW = T(0.0);
8019
8020 if (rotated) {
8021 // Do not round
8022 boxCenterW = boxesH.at({b, boxesStartCol + 0}) * spatialScale - offset;
8023 boxCenterH = boxesH.at({b, boxesStartCol + 1}) * spatialScale - offset;
8024 boxWidth = boxesH.at({b, boxesStartCol + 2}) * spatialScale;
8025 boxHeight = boxesH.at({b, boxesStartCol + 3}) * spatialScale;
8026 theta = boxesH.at({b, boxesStartCol + 4}) * T(M_PI) / T(180.0);
8027
8028 if (aligned) {
8029 assert(boxWidth >= T(0.0) && boxHeight >= T(0.0) &&
8030 "ROIs in ROIAlign must not have non-negative size!");
8031 } else { // backward compatibility
8032 // Force malformed ROIs to be 1x1
8033 boxHeight = std::max(boxHeight, T(1.0));
8034 boxWidth = std::max(boxWidth, T(1.0));
8035 }
8036
8037 // These are computed wrt the center of RoI (x, y).
8038 // Appropriate translation needs to be applied after.
8039 yRef = (T(-1.0) * boxHeight) / T(2.0);
8040 xRef = (T(-1.0) * boxWidth) / T(2.0);
8041 } else {
8042 llvm::SmallVector<T, 4> box = {
8043 boxesH.at({b, boxesStartCol + 0}) * spatialScale - offset,
8044 boxesH.at({b, boxesStartCol + 1}) * spatialScale - offset,
8045 boxesH.at({b, boxesStartCol + 2}) * spatialScale - offset,
8046 boxesH.at({b, boxesStartCol + 3}) * spatialScale - offset};
8047
8048 if (aligned) {
8049 CHECK_GE(box[3] - box[1], T(0.0)) << "Roi height cannot be negative.";
8050 CHECK_GE(box[2] - box[0], T(0.0)) << "Roi width cannot be negative.";
8051 } else {
8052 // Caffe2 backwards compatibility for mal-formed ROIs:
8053 // Force ROI size to be at least 1x1.
8054 box[2] = std::max(box[2], box[0] + T(1.0));
8055 box[3] = std::max(box[3], box[1] + T(1.0));
8056 }
8057
8058 yRef = box[1];
8059 xRef = box[0];
8060 boxHeight = (box[3] - box[1]);
8061 boxWidth = (box[2] - box[0]);
8062 }
8063
8064 const dim_t samplingRatioH =
8065 (samplingRatio > 0) ? samplingRatio
8066 : std::ceil(float(boxHeight) / outputHeight);
8067 const dim_t samplingRatioW = (samplingRatio > 0)
8068 ? samplingRatio
8069 : std::ceil(float(boxWidth) / outputWidth);
8070
8071 // get the xy coordinates in the resized image(grid)
8072 std::vector<BinGrid<T>> binGrids = getROIAlignInterpolationCoordinates<T>(
8073 featureMapHeight, featureMapWidth, outputHeight, outputWidth,
8074 samplingRatioH, samplingRatioW, boxHeight, boxWidth, yRef, xRef,
8075 rotated, theta, boxCenterH, boxCenterW);
8076
8077 uint64_t binCount = 0;
8078 for (dim_t oh = 0; oh < outputHeight; ++oh) {
8079 for (dim_t ow = 0; ow < outputWidth; ++ow) {
8080 for (dim_t d = 0; d < depth; ++d) {
8081 std::vector<T> values;
8082 for (dim_t gh = 0; gh < samplingRatioH; ++gh) {
8083 for (dim_t gw = 0; gw < samplingRatioW; ++gw) {
8084 BinGrid<T> bg = binGrids[binCount++];
8085 // The four values of the i/p image surrounding the point of
8086 // interest (POI) in the resized image
8087 const T topLeft =
8088 featureMapH.at({batchIndex, bg.top, bg.left, d});
8089 const T topRight =
8090 featureMapH.at({batchIndex, bg.top, bg.right, d});
8091 const T bottomLeft =
8092 featureMapH.at({batchIndex, bg.bottom, bg.left, d});
8093 const T bottomRight =
8094 featureMapH.at({batchIndex, bg.bottom, bg.right, d});
8095
8096 // bilinear interpolation
8097 const T value = (topLeft * (bg.topW * bg.leftW)) +
8098 (topRight * (bg.topW * bg.rightW)) +
8099 (bottomLeft * (bg.bottomW * bg.leftW)) +
8100 (bottomRight * (bg.bottomW * bg.rightW));
8101 // interpolation along vertical line
8102 values.push_back(value);
8103 } // end of w
8104 } // end of h
8105 // {Average or Max} pooling
8106 resultH.at({b, oh, ow, d}) =
8107 (mode == PoolingMode::AVG)
8108 ? std::accumulate(values.begin(), values.end(), T(0.0)) /
8109 T(values.size())
8110 : *std::max_element(values.begin(), values.end());
8111
8112 binCount = binCount - (samplingRatioH * samplingRatioW);
8113 } // end of d
8114 binCount = binCount + (samplingRatioH * samplingRatioW);
8115 } // end of W
8116 } // end of H
8117 } // end of b
8118}
8119
8120void BoundInterpreterFunction::fwdROIAlignInst(glow::ROIAlignInst const *I) {
8121 dispatchFloatingPointImpl(fwdROIAlignInstFloatImpl,
8122 I->getFeatureMap()->getElementType(), I);
8123}
8124
8125// Forward transform that maps proposal boxes to ground-truth boxes using
8126// bounding-box regression deltas.
8127// boxes: pixel coordinates of the bounding boxes
8128// size (M, 4), format [x1; y1; x2; y2], x2 >= x1, y2 >= y1
8129// deltas: bounding box translations and scales
8130// size (M, 4), format [dx; dy; dw; dh]
8131// dx, dy: scale-invariant translation of the center of the bounding box
8132// dw, dh: log-space scaling of the width and height of the bounding box
8133// weights: weights [wx, wy, ww, wh] for the deltas
8134// bboxXformClip: minimum bounding box width and height in log-space after
8135// transofmration
8136// correct_transform_coords: Correct bounding box transform coordates. Set to
8137// true to match the detectron code, set to false for backward
8138// compatibility
8139// return: pixel coordinates of the bounding boxes
8140// size (M, 4), format [x1; y1; x2; y2]
8141// see "Rich feature hierarchies for accurate object detection and semantic
8142// segmentation" Appendix C for more details
8143// reference: detectron/lib/utils/boxes.py bbox_transform()
8144template <typename T>
8145static void bbox_transform_upright(
8146 Handle<T> &boxesOut, const Handle<T> &boxes, const Handle<T> &deltas,
8147 dim_t startRowBoxesOut, dim_t startColBoxesOut, dim_t startRowBoxes,
8148 dim_t startColBoxes, dim_t startRowDeltas, dim_t startColDeltas, dim_t rows,
8149 llvm::ArrayRef<float> weights, const T &bboxXformClip, T scaleBeforeInv,
8150 const bool legacyPlusOne = false) {
8151
8152 if (boxes.dims()[0] == 0) {
8153 return;
8154 }
8155
8156 std::vector<T> widths(rows), heights(rows), ctrX(rows), ctrY(rows);
8157 for (dim_t i = 0; i < rows; i++) {
8158 widths[i] = boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) *
8159 scaleBeforeInv -
8160 boxes.at({startRowBoxes + i, startColBoxes}) * scaleBeforeInv +
8161 T(((legacyPlusOne) ? 1 : 0));
8162 heights[i] = boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) *
8163 scaleBeforeInv -
8164 boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) *
8165 scaleBeforeInv +
8166 T(((legacyPlusOne) ? 1 : 0));
8167
8168 ctrX[i] = boxes.at({startRowBoxes + i, startColBoxes}) * scaleBeforeInv +
8169 T(0.5) * widths[i];
8170 ctrY[i] = boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) *
8171 scaleBeforeInv +
8172 T(0.5) * heights[i];
8173 }
8174
8175 std::vector<T> dx(rows), dy(rows), dw(rows), dh(rows);
8176 for (dim_t i = 0; i < rows; i++) {
8177 dx[i] = deltas.at({startRowDeltas + i, startColDeltas}) / T(weights[0]);
8178 dy[i] = deltas.at({startRowDeltas + i, startColDeltas + dim_t(1)}) /
8179 T(weights[1]);
8180 dw[i] =
8181 std::min(deltas.at({startRowDeltas + i, startColDeltas + dim_t(2)}) /
8182 T(weights[2]),
8183 bboxXformClip);
8184 dh[i] =
8185 std::min(deltas.at({startRowDeltas + i, startColDeltas + dim_t(3)}) /
8186 T(weights[3]),
8187 bboxXformClip);
8188 }
8189
8190 std::vector<T> predCtrX(rows), predCtrY(rows), predW(rows), predH(rows);
8191 for (dim_t i = 0; i < rows; i++) {
8192 predCtrX[i] = dx[i] * widths[i] + ctrX[i];
8193 predCtrY[i] = dy[i] * heights[i] + ctrY[i];
8194 predW[i] = T(std::exp(float(dw[i]))) * widths[i];
8195 predH[i] = T(std::exp(float(dh[i]))) * heights[i];
8196 }
8197
8198 for (dim_t i = 0; i < rows; i++) {
8199 // x1
8200 boxesOut.at({startRowBoxesOut + i, startColBoxesOut}) =
8201 predCtrX[i] - T(0.5) * predW[i];
8202 // x2
8203 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(1)}) =
8204 predCtrY[i] - T(0.5) * predH[i];
8205 // y1
8206 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(2)}) =
8207 predCtrX[i] + T(0.5) * predW[i] - T(((legacyPlusOne) ? 1 : 0));
8208 // y2
8209 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(3)}) =
8210 predCtrY[i] + T(0.5) * predH[i] - T(((legacyPlusOne) ? 1 : 0));
8211 }
8212}
8213
8214// Like bbox_transform_upright, but works on rotated boxes.
8215// boxes: pixel coordinates of the bounding boxes
8216// size (M, 5), format [ctr_x; ctr_y; width; height; angle (in degrees)]
8217// deltas: bounding box translations and scales
8218// size (M, 5), format [dx; dy; dw; dh; da]
8219// dx, dy: scale-invariant translation of the center of the bounding box
8220// dw, dh: log-space scaling of the width and height of the bounding box
8221// da: delta for angle in radians
8222// return: pixel coordinates of the bounding boxes
8223// size (M, 5), format [ctr_x; ctr_y; width; height; angle (in degrees)]
8224template <typename T>
8225static void bbox_transform_rotated(
8226 Handle<T> &boxesOut, const Handle<T> &boxes, const Handle<T> &deltas,
8227 dim_t startRowBoxesOut, dim_t startColBoxesOut, dim_t startRowBoxes,
8228 dim_t startColBoxes, dim_t startRowDeltas, dim_t startColDeltas, dim_t rows,
8229 llvm::ArrayRef<float> weights, const T &bboxXformClip, T scaleBeforeInv,
8230 const bool angleBoundOn, ssize_t angleBoundLo, ssize_t angleBoundHi) {
8231
8232 if (boxes.dims()[0] == 0) {
8233 return;
8234 }
8235
8236 const T PI = 3.1415926535897931;
8237
8238 std::vector<T> dx(rows), dy(rows), dw(rows), dh(rows), da(rows);
8239 for (dim_t i = 0; i < rows; i++) {
8240 dx[i] = deltas.at({startRowDeltas + i, startColDeltas}) / T(weights[0]);
8241 dy[i] = deltas.at({startRowDeltas + i, startColDeltas + dim_t(1)}) /
8242 T(weights[1]);
8243 dw[i] =
8244 std::min(deltas.at({startRowDeltas + i, startColDeltas + dim_t(2)}) /
8245 T(weights[2]),
8246 bboxXformClip);
8247 dh[i] =
8248 std::min(deltas.at({startRowDeltas + i, startColDeltas + dim_t(3)}) /
8249 T(weights[3]),
8250 bboxXformClip);
8251 // Convert back to degrees
8252 da[i] = deltas.at({startRowDeltas + i, startColDeltas + dim_t(4)}) *
8253 T(180.0) / PI;
8254 }
8255
8256 for (dim_t i = 0; i < rows; i++) {
8257 // new ctr_x
8258 boxesOut.at({startRowBoxesOut + i, startColBoxesOut}) =
8259 dx[i] * boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) *
8260 scaleBeforeInv +
8261 boxes.at({startRowBoxes + i, startColBoxes}) * scaleBeforeInv;
8262 // new ctr_y
8263 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(1)}) =
8264 dy[i] * boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) *
8265 scaleBeforeInv +
8266 boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) *
8267 scaleBeforeInv;
8268 // new width
8269 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(2)}) =
8270 T(std::exp(float(dw[i]))) *
8271 boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) *
8272 scaleBeforeInv;
8273 // new height
8274 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(3)}) =
8275 T(std::exp(float(dh[i]))) *
8276 boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) *
8277 scaleBeforeInv;
8278 // new angle
8279 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(4)}) =
8280 da[i] + boxes.at({startRowBoxes + i, startColBoxes + dim_t(4)});
8281 }
8282
8283 if (angleBoundOn) {
8284 const ssize_t period = angleBoundHi - angleBoundLo;
8285
8286 for (dim_t i = 0; i < rows; i++) {
8287 if (ssize_t(boxesOut.at({startRowBoxesOut + i,
8288 startColBoxesOut + dim_t(4)})) < angleBoundLo) {
8289 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(4)}) +=
8290 T(period);
8291 } else if (ssize_t(boxesOut.at(
8292 {startRowBoxesOut + i, startColBoxesOut + dim_t(4)})) >
8293 angleBoundHi) {
8294 boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(4)}) -=
8295 T(period);
8296 }
8297 }
8298 }
8299}
8300
8301// Clip boxes to image boundaries
8302// boxes: pixel coordinates of bounding box, size (M * 4)
8303template <typename T>
8304void clip_boxes_upright(Handle<T> &boxes, dim_t startRowBoxes,
8305 dim_t startColBoxes, dim_t rows, int height, int width,
8306 T scaleAfter, bool legacyPlusOne = false,
8307 std::vector<dim_t> uprightRows = {}) {
8308 for (dim_t i = 0; i < rows; i++) {
8309 if (uprightRows.size() == rows && !uprightRows[i]) {
8310 continue;
8311 }
8312 // x1 >= 0 && x1 < width
8313 boxes.at({startRowBoxes + i, startColBoxes}) =
8314 scaleAfter *
8315 std::max(std::min(boxes.at({startRowBoxes + i, startColBoxes}),
8316 T(width - int(((legacyPlusOne) ? 1 : 0)))),
8317 T(0));
8318 // y1 >= 0 && y1 < height
8319 boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) =
8320 scaleAfter * std::max(std::min(boxes.at({startRowBoxes + i,
8321 startColBoxes + dim_t(1)}),
8322 T(height - ((legacyPlusOne) ? 1 : 0))),
8323 T(0));
8324
8325 // x2 >= 0 && x2 < width
8326 boxes.at({startRowBoxes + i, startColBoxes + 2}) =
8327 scaleAfter * std::max(std::min(boxes.at({startRowBoxes + i,
8328 startColBoxes + dim_t(2)}),
8329 T(width - ((legacyPlusOne) ? 1 : 0))),
8330 T(0));
8331 // y2 >= 0 && y2 < height
8332 boxes.at({startRowBoxes + i, startColBoxes + 3}) =
8333 scaleAfter * std::max(std::min(boxes.at({startRowBoxes + i,
8334 startColBoxes + dim_t(3)}),
8335 T(height - ((legacyPlusOne) ? 1 : 0))),
8336 T(0));
8337 }
8338}
8339
8340// Similar to clip_boxes_upright but handles rotated boxes with angle info.
8341// boxes: size (M, 5), format [ctr_x; ctr_y; width; height; angle (in
8342// degrees)]
8343//
8344// Clipping is only performed for boxes that are almost upright
8345// (within a given `angle_thresh` tolerance) to maintain backward
8346// compatibility for non-rotated boxes.
8347//
8348// We don't clip rotated boxes due to a couple of reasons:
8349// (1) There are potentially multiple ways to clip a rotated box to make it
8350// fit within the image.
8351// (2) It's tricky to make the entire rectangular box fit within the image and
8352// still be able to not leave out pixels of interest.
8353// Therefore, we rely on upstream ops like RoIAlignRotated safely handling
8354// this.
8355template <typename T>
8356void clip_boxes_rotated(Handle<T> &boxes, dim_t startRowBoxes,
8357 dim_t startColBoxes, dim_t rows, int imH, int imW,
8358 T scaleAfter, float angleThresh = 1.0,
8359 bool legacyPlusOne = false) {
8360 std::vector<dim_t> uprightRows(rows, 0);
8361 for (dim_t i = 0; i < rows; i++) {
8362 if (std::abs(float(boxes.at(
8363 {startRowBoxes + i, startColBoxes + dim_t(4)}))) <= angleThresh) {
8364 const T ctrX = boxes.at({startRowBoxes + i, startColBoxes});
8365 const T ctrY = boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)});
8366 const T width = boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)});
8367 const T height = boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)});
8368 boxes.at({startRowBoxes + i, startColBoxes}) =
8369 ctrX - (width - T(((legacyPlusOne) ? 1 : 0))) / T(2.0);
8370 boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) =
8371 ctrY - (height - T(((legacyPlusOne) ? 1 : 0))) / T(2.0);
8372 boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) =
8373 ctrX + (width - T(((legacyPlusOne) ? 1 : 0))) / T(2.0);
8374 boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) =
8375 ctrY + (height - T(((legacyPlusOne) ? 1 : 0))) / T(2.0);
8376 uprightRows[i] = 1;
8377 }
8378 }
8379 clip_boxes_upright(boxes, startRowBoxes, startColBoxes, rows, imH, imW,
8380 /* scaleAfter */ T(1.0), legacyPlusOne, uprightRows);
8381
8382 for (dim_t i = 0; i < rows; i++) {
8383 if (uprightRows[i] == 1) {
8384 const T x1 = boxes.at({startRowBoxes + i, startColBoxes});
8385 const T y1 = boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)});
8386 const T x2 = boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)});
8387 const T y2 = boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)});
8388 boxes.at({startRowBoxes + i, startColBoxes}) = (x1 + x2) / T(2.0);
8389 boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) =
8390 (y1 + y2) / T(2.0);
8391 boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) =
8392 x2 - x1 + T(((legacyPlusOne) ? 1 : 0));
8393 boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) =
8394 y2 - y1 + T(((legacyPlusOne) ? 1 : 0));
8395 }
8396
8397 for (dim_t j = 0; j < 4; j++) {
8398 boxes.at({startRowBoxes + i, startColBoxes + j}) *= scaleAfter;
8399 }
8400 }
8401}
8402
8403template <typename T>
8404void BoundInterpreterFunction::fwdBBoxTransformInstFloatImpl(
8405 glow::BBoxTransformInst const *I) {
8406 auto roiIn = I->getRois();
8407 auto deltaIn = I->getDeltas();
8408 auto imInfoIn = I->getImInfo();
8409
8410 auto boxOut = I->getBoxOut();
8411 auto roiBatchSplits = I->getRoiBatchSplits();
8412
8413 auto weights = I->getWeights();
8414 const dim_t boxDim = I->getRotated() ? 5 : 4;
8415 auto applyScale = I->getApplyScale();
8416 auto rotated = I->getRotated();
8417 auto angleBoundOn = I->getAngleBoundOn();
8418 auto angleBoundLo = I->getAngleBoundLo();
8419 auto angleBoundHi = I->getAngleBoundHi();
8420 auto angleThresh = I->getClipAngleThresh();
8421 auto legacyPlusOne = I->getLegacyPlusOne();
8422 const dim_t N = roiIn->dims()[0];
8423 const dim_t numClasses = deltaIn->dims()[1] / boxDim;
8424 const dim_t batchSize = imInfoIn->dims()[0];
8425
8426 auto roisH = getTensor(roiIn)->getHandle<T>();
8427 auto deltasH = getTensor(deltaIn)->getHandle<T>();
8428 auto roiBatchSplitsH = getTensor(roiBatchSplits)->getHandle<T>();
8429
8430 // Count the number of RoIs per batch
8431 std::vector<int> numRoisPerBatch(batchSize, 0);
8432 if (roiIn->dims()[1] == boxDim) {
8433 numRoisPerBatch[0] = N;
8434 } else {
8435 for (dim_t i = 0; i < N; ++i) {
8436 const int roiBatchId = roisH.at({i, 0});
8437 numRoisPerBatch[roiBatchId]++;
8438 }
8439 }
8440
8441 auto imInfoH = getTensor(imInfoIn)->getHandle<T>();
8442 auto boxOutH = getTensor(boxOut)->getHandle<T>();
8443 getTensor(boxOut)->zero();
8444
8445 // Default value for minimum bounding box width and height after bounding
8446 // box transformation (bbox_transform()) in log-space
8447 const T bboxXformClip = std::log(1000.0 / 16.0);
8448
8449 // We assume roiIn and deltaIn over multiple batches are grouped
8450 // together in increasing order as generated by GenerateProposalsOp
8451 dim_t offset = 0;
8452 for (dim_t i = 0; i < batchSize; ++i) {
8453 const dim_t numRois = numRoisPerBatch[i];
8454 const T scaleBefore = imInfoH.at({i, 2});
8455 const T scaleAfter = applyScale ? scaleBefore : T(1.0);
8456 dim_t imgH = dim_t(float(imInfoH.at({i, 0}) / scaleBefore) + 0.5);
8457 dim_t imgW = dim_t(float(imInfoH.at({i, 1}) / scaleBefore) + 0.5);
8458
8459 // Apply for the rectangle starting at (startRowRoi, startColRoi)
8460 // with height (Rows) of num_rois, and width (Cols) of boxDim.
8461 dim_t startRowRoi = offset;
8462 dim_t startColRoi = roiIn->dims()[1] != boxDim ? 1 : 0;
8463 dim_t rows = numRois;
8464 T scaleBeforeInv = T(1) / scaleBefore;
8465
8466 // scale before and after on the fly.
8467 // Do not apply scale for angle in rotated boxes
8468 for (dim_t k = 0; k < numClasses; k++) {
8469 dim_t startRowDelta = offset;
8470 dim_t startColDelta = k * boxDim;
8471 if (rotated) {
8472 bbox_transform_rotated<T>(boxOutH, roisH, deltasH, startRowDelta,
8473 startColDelta, startRowRoi, startColRoi,
8474 startRowDelta, startColDelta, rows, weights,
8475 bboxXformClip, scaleBeforeInv, angleBoundOn,
8476 angleBoundLo, angleBoundHi);
8477 clip_boxes_rotated<T>(boxOutH, startRowDelta, startColDelta, rows, imgH,
8478 imgW, scaleAfter, angleThresh, legacyPlusOne);
8479 } else {
8480 bbox_transform_upright<T>(boxOutH, roisH, deltasH, startRowDelta,
8481 startColDelta, startRowRoi, startColRoi,
8482 startRowDelta, startColDelta, rows, weights,
8483 bboxXformClip, scaleBeforeInv, legacyPlusOne);
8484 clip_boxes_upright<T>(boxOutH, startRowDelta, startColDelta, rows, imgH,
8485 imgW, scaleAfter, legacyPlusOne);
8486 }
8487 }
8488
8489 offset += rows;
8490 }
8491
8492 for (dim_t i = 0; i < batchSize; i++) {
8493 roiBatchSplitsH.at({i}) = numRoisPerBatch[i];
8494 }
8495}
8496
8497void BoundInterpreterFunction::fwdBBoxTransformInst(
8498 glow::BBoxTransformInst const *I) {
8499 dispatchFloatingPointImpl(fwdBBoxTransformInstFloatImpl,
8500 I->getRois()->getElementType(), I);
8501}
8502
8503void BoundInterpreterFunction::fwdExternalFunctionCallInst(
8504 glow::ExternalFunctionCallInst const *) {
8505 LOG(FATAL) << "ExternalFunctionCallInst is not supported yet";
8506}
8507