1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include <algorithm>
17#include <assert.h>
18#include <chrono>
19#include <cmath>
20#include <math.h>
21#include <numeric>
22#include <stddef.h>
23#include <stdint.h>
24#include <stdio.h>
25#include <stdlib.h>
26#include <string.h>
27#include <sys/types.h>
28
29#include "libjit_defs.h"
30
31namespace {
32
33template <class ElemTy>
34static void libjit_dump_tensor_console_impl(ElemTy *tensor, dim_t *dims,
35 dim_t numDims) {
36 // Check for 0-dimensional tensor.
37 if (!numDims) {
38 printf("[ Scalar containing: %.3f ]\n", (float)tensor[0]);
39 return;
40 }
41
42 // Output shape.
43 printf("shape: ( ");
44 for (size_t i = 0; i < numDims; ++i) {
45 printf("%zu ", (size_t)dims[i]);
46 }
47 printf(")\n");
48
49 ElemTy mx = tensor[0];
50 ElemTy mn = tensor[0];
51
52 size_t size = 1;
53 size_t sliceSize[numDims];
54 for (size_t i = 0; i < numDims; ++i) {
55 size *= dims[i];
56 }
57
58 for (ssize_t i = numDims - 1, curSliceSize = 1; i >= 0; --i) {
59 sliceSize[i] = curSliceSize;
60 curSliceSize *= dims[i];
61 }
62
63 for (size_t i = 0, e = size; i < e; i++) {
64 mx = MAX(mx, tensor[i]);
65 mn = MIN(mn, tensor[i]);
66 }
67
68 // Check for zero tensor.
69 if (mn == .0 && mx == .0) {
70 printf("[ Zero tensor ]\n");
71 return;
72 }
73
74 // Output max and min.
75 printf("max: %.3f min: %.3f\n", (float)mx, (float)mn);
76
77 const unsigned maxNumElem = 100;
78
79 printf("[");
80
81 for (size_t i = 0, e = MIN(maxNumElem, size); i < e; i++) {
82
83 // Print one open brace at the beginning of every row, slice, and tensor.
84 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
85 if (i % sliceSize[j] == 0) {
86 // This iteration of outer loop is a new row, slice or tensor.
87 printf("[");
88 }
89 }
90
91 // Print the value at the current index.
92 printf("%.3f", (float)tensor[i]);
93
94 // Print one closed brace at the end of every row, slice, or tensor.
95 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
96 size_t next_index = i + 1;
97 if (next_index % sliceSize[j] == 0u) {
98 printf("]");
99 }
100 }
101
102 printf(", ");
103
104 // Print one newline at the end of every row, slice, or tensor.
105 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
106 size_t next_index = i + 1;
107 if (next_index % sliceSize[j] == 0u) {
108 // Next iteration of outer loop will be a new row, slice or tensor.
109 printf("\n");
110 }
111 }
112 }
113
114 if (size > maxNumElem) {
115 printf("...");
116 }
117
118 printf("]\n");
119}
120
121template <class ElemTy>
122static void libjit_dump_tensor_txt_impl(ElemTy *tensor, size_t tensorElemSize,
123 const char *filename,
124 const char *header) {
125 FILE *fh = fopen(filename, "w");
126 if (!fh) {
127 printf("ERROR opening file: '%s'!\n"
128 "File name might be too long!\n",
129 filename);
130 return;
131 }
132 if (strlen(header)) {
133 fprintf(fh, "%s\n", header);
134 }
135 for (size_t idx = 0, end = tensorElemSize; idx < end; idx++) {
136 fprintf(fh, "%f, ", (double)tensor[idx]);
137 }
138 fclose(fh);
139}
140
141template <typename ElemTy>
142static dim_t get_element_ptr(const ElemTy *tensor, const dim_t *dims,
143 dim_t numDims, const dim_t *indices,
144 dim_t numIndices) {
145 dim_t index = 0;
146 dim_t subdimensionSize = 1;
147 for (dim_t i = numDims; i > 0; i--) {
148 dim_t curIndicesValue = (i <= numIndices) ? indices[i - 1] : 0;
149 index += subdimensionSize * curIndicesValue;
150 subdimensionSize *= dims[i - 1];
151 }
152 return index;
153}
154
155template <typename ElemTy>
156static void libjit_insert_tensor(ElemTy *tensor, ElemTy *slice, dim_t *offset,
157 dim_t *tensorDim, dim_t *sliceDim,
158 dim_t numDimsTensor, dim_t numDimsSlice,
159 dim_t offsetDim, dim_t count, dim_t axis) {
160 // Destination coordinates.
161 dim_t C[6];
162
163 // A local copy of the offsets buffer. We copy the buffer to make it clear
164 // to the optimizer that the inputs don't alias. This loop is optimized away.
165 dim_t offsets_cpy[6];
166 for (dim_t i = 0; i < numDimsSlice; i++) {
167 offsets_cpy[i] = offset[i];
168 }
169
170 if (numDimsSlice == 6) {
171 for (dim_t c = 0; c < count; c++)
172 for (dim_t x = 0; x < sliceDim[0]; x++)
173 for (dim_t y = 0; y < sliceDim[1]; y++)
174 for (dim_t z = 0; z < sliceDim[2]; z++)
175 for (dim_t w = 0; w < sliceDim[3]; w++)
176 for (dim_t q = 0; q < sliceDim[4]; q++)
177 for (dim_t r = 0; r < sliceDim[5]; r++) {
178 const dim_t countAxisOffset = c * sliceDim[axis];
179 C[0] =
180 x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
181 C[1] =
182 y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
183 C[2] =
184 z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
185 C[3] =
186 w + offsets_cpy[3] + ((axis == 3) ? countAxisOffset : 0);
187 C[4] =
188 q + offsets_cpy[4] + ((axis == 4) ? countAxisOffset : 0);
189 C[5] =
190 r + offsets_cpy[5] + ((axis == 5) ? countAxisOffset : 0);
191 tensor[libjit_getXYZWQR(tensorDim, C[0], C[1], C[2], C[3],
192 C[4], C[5])] =
193 slice[libjit_getXYZWQR(sliceDim, x, y, z, w, q, r)];
194 }
195 return;
196 }
197
198 if (numDimsSlice == 5) {
199 for (dim_t c = 0; c < count; c++)
200 for (dim_t x = 0; x < sliceDim[0]; x++)
201 for (dim_t y = 0; y < sliceDim[1]; y++)
202 for (dim_t z = 0; z < sliceDim[2]; z++)
203 for (dim_t w = 0; w < sliceDim[3]; w++)
204 for (dim_t q = 0; q < sliceDim[4]; q++) {
205 const dim_t countAxisOffset = c * sliceDim[axis];
206 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
207 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
208 C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
209 C[3] = w + offsets_cpy[3] + ((axis == 3) ? countAxisOffset : 0);
210 C[4] = q + offsets_cpy[4] + ((axis == 4) ? countAxisOffset : 0);
211 tensor[libjit_getXYZWQ(tensorDim, C[0], C[1], C[2], C[3],
212 C[4])] =
213 slice[libjit_getXYZWQ(sliceDim, x, y, z, w, q)];
214 }
215 return;
216 }
217
218 if (numDimsSlice == 4) {
219 for (dim_t c = 0; c < count; c++)
220 for (dim_t x = 0; x < sliceDim[0]; x++)
221 for (dim_t y = 0; y < sliceDim[1]; y++)
222 for (dim_t z = 0; z < sliceDim[2]; z++)
223 for (dim_t w = 0; w < sliceDim[3]; w++) {
224 const dim_t countAxisOffset = c * sliceDim[axis];
225 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
226 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
227 C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
228 C[3] = w + offsets_cpy[3] + ((axis == 3) ? countAxisOffset : 0);
229 tensor[libjit_getXYZW(tensorDim, C[0], C[1], C[2], C[3])] =
230 slice[libjit_getXYZW(sliceDim, x, y, z, w)];
231 }
232 return;
233 }
234
235 if (numDimsSlice == 3) {
236 for (dim_t c = 0; c < count; c++)
237 for (dim_t x = 0; x < sliceDim[0]; x++)
238 for (dim_t y = 0; y < sliceDim[1]; y++)
239 for (dim_t z = 0; z < sliceDim[2]; z++) {
240 const dim_t countAxisOffset = c * sliceDim[axis];
241 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
242 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
243 C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
244 tensor[libjit_getXYZ(tensorDim, C[0], C[1], C[2])] =
245 slice[libjit_getXYZ(sliceDim, x, y, z)];
246 }
247 return;
248 }
249
250 if (numDimsSlice == 2) {
251 for (dim_t c = 0; c < count; c++)
252 for (dim_t x = 0; x < sliceDim[0]; x++)
253 for (dim_t y = 0; y < sliceDim[1]; y++) {
254 const dim_t countAxisOffset = c * sliceDim[axis];
255 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
256 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
257 tensor[libjit_getXY(tensorDim, C[0], C[1])] =
258 slice[libjit_getXY(sliceDim, x, y)];
259 }
260 return;
261 }
262
263 if (numDimsSlice == 1) {
264 for (dim_t c = 0; c < count; c++)
265 for (dim_t x = 0; x < sliceDim[0]; x++) {
266 const dim_t countAxisOffset = c * sliceDim[axis];
267 tensor[x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0)] =
268 slice[x];
269 }
270 return;
271 }
272}
273
274template <typename ElemTy>
275static void libjit_extract_tensor(ElemTy *tensor, ElemTy *slice, dim_t *offset,
276 dim_t *tensorDim, dim_t *sliceDim,
277 dim_t numDimsTensor, dim_t numDimsSlice,
278 dim_t offsetDim) {
279 // Source coordinates.
280 dim_t C[5];
281
282 // A local copy of the offsets buffer. We copy the buffer to make it clear
283 // to the optimizer that the inputs don't alias. This loop is optimized away.
284 dim_t offsets_cpy[5];
285 for (dim_t i = 0; i < numDimsSlice; i++) {
286 offsets_cpy[i] = offset[i];
287 }
288
289 if (numDimsSlice == 5) {
290 for (dim_t x = 0; x < sliceDim[0]; x++)
291 for (dim_t y = 0; y < sliceDim[1]; y++)
292 for (dim_t z = 0; z < sliceDim[2]; z++)
293 for (dim_t w = 0; w < sliceDim[3]; w++)
294 for (dim_t q = 0; q < sliceDim[4]; q++) {
295 C[0] = x + offsets_cpy[0];
296 C[1] = y + offsets_cpy[1];
297 C[2] = z + offsets_cpy[2];
298 C[3] = w + offsets_cpy[3];
299 C[4] = q + offsets_cpy[4];
300 slice[libjit_getXYZWQ(sliceDim, x, y, z, w, q)] =
301 tensor[libjit_getXYZWQ(tensorDim, C[0], C[1], C[2], C[3],
302 C[4])];
303 }
304 return;
305 }
306
307 if (numDimsSlice == 4) {
308 for (dim_t x = 0; x < sliceDim[0]; x++)
309 for (dim_t y = 0; y < sliceDim[1]; y++)
310 for (dim_t z = 0; z < sliceDim[2]; z++)
311 for (dim_t w = 0; w < sliceDim[3]; w++) {
312 C[0] = x + offsets_cpy[0];
313 C[1] = y + offsets_cpy[1];
314 C[2] = z + offsets_cpy[2];
315 C[3] = w + offsets_cpy[3];
316 slice[libjit_getXYZW(sliceDim, x, y, z, w)] =
317 tensor[libjit_getXYZW(tensorDim, C[0], C[1], C[2], C[3])];
318 }
319 return;
320 }
321
322 if (numDimsSlice == 3) {
323 for (dim_t x = 0; x < sliceDim[0]; x++)
324 for (dim_t y = 0; y < sliceDim[1]; y++)
325 for (dim_t z = 0; z < sliceDim[2]; z++) {
326 C[0] = x + offsets_cpy[0];
327 C[1] = y + offsets_cpy[1];
328 C[2] = z + offsets_cpy[2];
329 slice[libjit_getXYZ(sliceDim, x, y, z)] =
330 tensor[libjit_getXYZ(tensorDim, C[0], C[1], C[2])];
331 }
332 return;
333 }
334
335 if (numDimsSlice == 2) {
336 for (dim_t x = 0; x < sliceDim[0]; x++)
337 for (dim_t y = 0; y < sliceDim[1]; y++) {
338 C[0] = x + offsets_cpy[0];
339 C[1] = y + offsets_cpy[1];
340 slice[libjit_getXY(sliceDim, x, y)] =
341 tensor[libjit_getXY(tensorDim, C[0], C[1])];
342 }
343 return;
344 }
345
346 if (numDimsSlice == 1) {
347 for (dim_t x = 0; x < sliceDim[0]; x++) {
348 slice[x] = tensor[x + offsets_cpy[0]];
349 }
350 return;
351 }
352}
353
354/// Helper struct for TopK
355template <typename T, typename TI> struct value_index {
356 TI index;
357 T value;
358};
359
360/// Helper function for TopK
361template <typename T, typename TI>
362static int value_index_sort(const void *va, const void *vb) {
363 value_index<T, TI> *a = (value_index<T, TI> *)va;
364 value_index<T, TI> *b = (value_index<T, TI> *)vb;
365 if (a->value != b->value)
366 return a->value > b->value ? -1 : 1;
367 return a->index < b->index ? -1 : 1;
368}
369
370/// Generic Top-K function. Here, \p scratch is some allocated buffer space, \p
371/// size is the size of the input, and \p n is the size of the last dimension of
372/// the input.
373template <typename T, typename TI>
374static void libjit_topk(T *values, TI *indices, const T *input, void *scratch,
375 dim_t k, dim_t n, dim_t size) {
376 dim_t in = 0;
377 dim_t out = 0;
378
379 // Initialize scratch with 0.
380 memset(scratch, 0, 2 * n * sizeof(TI));
381
382 value_index<T, TI> *buffer = (value_index<T, TI> *)scratch;
383
384 // Specialize TopK for the case where K is 1.
385 if (k == 1) {
386 while (in < size) {
387 // Find the largest value by iterating over the array instead of calling
388 // 'sort'.
389 value_index<T, TI> mx = {0, input[in]};
390 for (TI i = 1; i < TI(n); i++) {
391 if (input[i + in] > mx.value) {
392 mx = {i, input[i + in]};
393 }
394 }
395 indices[out] = mx.index;
396 values[out] = mx.value;
397 out++;
398 in += n;
399 }
400 return;
401 }
402
403 while (in < size) {
404 for (dim_t i = 0; i < n; i++) {
405 buffer[i].index = i;
406 buffer[i].value = input[in++];
407 }
408 qsort(buffer, n, sizeof(value_index<T, TI>), value_index_sort<T, TI>);
409 for (dim_t i = 0; i < k; i++) {
410 indices[out] = buffer[i].index;
411 values[out] = buffer[i].value;
412 out++;
413 }
414 }
415}
416
417template <typename T, typename IDX>
418static void libjit_gather(T *dest, const T *data, const IDX *indices,
419 dim_t numIndices, dim_t sliceSize, dim_t numSamples,
420 dim_t sampleSize) {
421 // The index of the slice that is being written.
422 dim_t outIdx = 0;
423
424 // For each sample in our batch:
425 for (dim_t sample = 0; sample < numSamples; sample++) {
426 dim_t sampleStart = sample * sampleSize;
427
428 // For each slice that we fetch:
429 for (dim_t i = 0; i < numIndices; i++) {
430 dim_t slice = indices[i];
431
432 // Copy the slice.
433 memcpy(dest + outIdx * sliceSize, data + sampleStart + slice * sliceSize,
434 sliceSize * sizeof(T));
435
436 // Point to the next location in the destination tensor.
437 outIdx++;
438 }
439 }
440}
441
442template <typename DataT, typename IndexT>
443static void
444libjit_gather_nd(DataT *out, const DataT *data, const IndexT *indicesPtr,
445 dim_t batchCount, dim_t inpSliceCount, dim_t outSliceCount,
446 dim_t sliceSize, dim_t indicesDimLast, dim_t *indicesDimProd) {
447 const char *dataPtr = (const char *)data;
448 char *outPtr = (char *)out;
449
450 for (dim_t batchIdx = 0; batchIdx < batchCount; ++batchIdx) {
451 for (dim_t outSliceIdx = 0; outSliceIdx < outSliceCount; ++outSliceIdx) {
452
453 // Compute input slice index.
454 dim_t inpSliceIdx = 0;
455 for (size_t idx = 0; idx < indicesDimLast; ++idx) {
456 inpSliceIdx += (*indicesPtr++) * indicesDimProd[idx];
457 }
458
459 // Copy data.
460 memcpy(outPtr, dataPtr + inpSliceIdx * sliceSize, sliceSize);
461 outPtr += sliceSize;
462 }
463
464 // Increment input pointer for next batch.
465 dataPtr += inpSliceCount * sliceSize;
466 }
467}
468
469template <typename T, typename U>
470static void libjit_gatherranges(T *output, U *lengths, const T *data,
471 const U *ranges, dim_t numExamples,
472 dim_t exampleSize) {
473 // Indices into the output and range buffers.
474 dim_t outputIdx = 0;
475 dim_t rangesIdx = 0;
476
477 // For each example:
478 for (dim_t example = 0; example < numExamples; ++example) {
479 // Keep track of the total length of the gathered ranges for the example.
480 U totalLen = 0;
481
482 // For each range:
483 for (dim_t range = 0; range < exampleSize; ++range) {
484 // Get the start and length of the range.
485 const U start = ranges[rangesIdx];
486 const U len = ranges[rangesIdx + 1];
487
488 // Copy the specified elements.
489 memcpy(output + outputIdx, data + start, len * sizeof(T));
490
491 // len elements were copied, so increment the output index by len.
492 outputIdx += len;
493
494 // Each range is of the form (start, len), so increment the ranges
495 // index by 2 to get to the next range.
496 rangesIdx += 2;
497
498 // Increment the total length for the example by len.
499 totalLen += len;
500 }
501
502 // Record the total length of gathered ranges for the current example in
503 // the lengths buffer.
504 lengths[example] = totalLen;
505 }
506}
507
508template <typename T, typename T2>
509static void libjit_scatterdatacopy(T *data, const dim_t *dataDims,
510 const T2 *indices, const T *slices,
511 dim_t numIndices, dim_t indexSize,
512 dim_t sliceSize) {
513 for (dim_t i = 0; i < numIndices; i++) {
514 dim_t destDataIdx = indices[i * indexSize];
515 for (dim_t j = 1; j < indexSize; j++) {
516 destDataIdx *= dataDims[j];
517 destDataIdx += indices[i * indexSize + j];
518 }
519 memcpy(data + destDataIdx * sliceSize, slices + i * sliceSize,
520 sliceSize * sizeof(T));
521 }
522}
523
524template <typename T, typename T2>
525static void libjit_scatterdataaddfloat(T *data, const dim_t *dataDims,
526 const T2 *indices, const T *slices,
527 dim_t numIndices, dim_t indexSize,
528 dim_t sliceSize) {
529 for (dim_t i = 0; i < numIndices; i++) {
530 dim_t destDataIdx = indices[i * indexSize];
531 for (dim_t j = 1; j < indexSize; j++) {
532 destDataIdx *= dataDims[j];
533 destDataIdx += indices[i * indexSize + j];
534 }
535 for (dim_t j = 0; j < sliceSize; j++) {
536 data[destDataIdx * sliceSize + j] += slices[i * sliceSize + j];
537 }
538 }
539}
540
541template <typename T, typename T2>
542static void libjit_scatterdataaddquantized(T *data, const dim_t *dataDims,
543 const T2 *indices, const T *slices,
544 dim_t numIndices, dim_t indexSize,
545 dim_t sliceSize, float dataScale,
546 int32_t dataOffset, float sliceScale,
547 int32_t sliceOffset) {
548
549 for (size_t i = 0; i < numIndices; i++) {
550 size_t destDataIdx = indices[i * indexSize];
551 for (size_t j = 1; j < indexSize; j++) {
552 destDataIdx *= dataDims[j];
553 destDataIdx += indices[i * indexSize + j];
554 }
555 for (size_t j = 0; j < sliceSize; j++) {
556 float lhs = (data[destDataIdx * sliceSize + j] - dataOffset) * dataScale;
557 float rhs = (slices[i * sliceSize + j] - sliceOffset) * sliceScale;
558 T result = libjit_clip_i8((lhs + rhs) / dataScale + dataOffset);
559 data[destDataIdx * sliceSize + j] = result;
560 }
561 }
562}
563
564template <typename T>
565static void libjit_transpose_generic(const T *inW, T *outW, const dim_t *idim,
566 const dim_t *odim, const dim_t *shuffle,
567 dim_t numDims) {
568 // Transpose 2d matrices one tile at a time. This access pattern ensures
569 // that the whole tile is kept in L1 cache. When scanning the whole row at
570 // once we invalidate many cache lines when we touch a single column.
571 const unsigned tileSize = 64;
572
573 // Source coordinate.
574 dim_t SC[6];
575
576 if (numDims == 6) {
577 for (dim_t x = 0; x < odim[0]; x++)
578 for (dim_t y = 0; y < odim[1]; y++)
579 for (dim_t z = 0; z < odim[2]; z++)
580 for (dim_t w = 0; w < odim[3]; w++)
581 for (dim_t q = 0; q < odim[4]; q++)
582 for (dim_t r = 0; r < odim[5]; r++) {
583 SC[shuffle[0]] = x;
584 SC[shuffle[1]] = y;
585 SC[shuffle[2]] = z;
586 SC[shuffle[3]] = w;
587 SC[shuffle[4]] = q;
588 SC[shuffle[5]] = r;
589 outW[libjit_getXYZWQR(odim, x, y, z, w, q, r)] =
590 inW[libjit_getXYZWQR(idim, SC[0], SC[1], SC[2], SC[3],
591 SC[4], SC[5])];
592 }
593 return;
594 }
595
596 if (numDims == 5) {
597 for (dim_t x = 0; x < odim[0]; x++)
598 for (dim_t y = 0; y < odim[1]; y++)
599 for (dim_t z = 0; z < odim[2]; z++)
600 for (dim_t w = 0; w < odim[3]; w++)
601 for (dim_t q = 0; q < odim[4]; q++) {
602 SC[shuffle[0]] = x;
603 SC[shuffle[1]] = y;
604 SC[shuffle[2]] = z;
605 SC[shuffle[3]] = w;
606 SC[shuffle[4]] = q;
607 outW[libjit_getXYZWQ(odim, x, y, z, w, q)] =
608 inW[libjit_getXYZWQ(idim, SC[0], SC[1], SC[2], SC[3], SC[4])];
609 }
610 return;
611 }
612 if (numDims == 4) {
613 for (dim_t x = 0; x < odim[0]; x++)
614 for (dim_t y = 0; y < odim[1]; y++)
615 for (dim_t z = 0; z < odim[2]; z++)
616 for (dim_t w = 0; w < odim[3]; w++) {
617 SC[shuffle[0]] = x;
618 SC[shuffle[1]] = y;
619 SC[shuffle[2]] = z;
620 SC[shuffle[3]] = w;
621 outW[libjit_getXYZW(odim, x, y, z, w)] =
622 inW[libjit_getXYZW(idim, SC[0], SC[1], SC[2], SC[3])];
623 }
624 return;
625 }
626 if (numDims == 3) {
627 for (dim_t x = 0; x < odim[0]; x++) {
628 // Process the tiles in the innermost two dimensions:
629 for (dim_t sy = 0; sy < odim[1]; sy += tileSize) {
630 for (dim_t sz = 0; sz < odim[2]; sz += tileSize) {
631 // Process the inner tile:
632 for (dim_t y = sy; y < MIN(sy + tileSize, odim[1]); y++) {
633 for (dim_t z = sz; z < MIN(sz + tileSize, odim[2]); z++) {
634 SC[shuffle[0]] = x;
635 SC[shuffle[1]] = y;
636 SC[shuffle[2]] = z;
637 outW[libjit_getXYZ(odim, x, y, z)] =
638 inW[libjit_getXYZ(idim, SC[0], SC[1], SC[2])];
639 }
640 }
641 }
642 }
643 }
644 return;
645 }
646
647 if (numDims == 2) {
648 // Process the tiles in the matrix:
649 for (dim_t sx = 0; sx < odim[0]; sx += tileSize) {
650 for (dim_t sy = 0; sy < odim[1]; sy += tileSize) {
651 // Process the inner tile:
652 for (dim_t x = sx; x < MIN(sx + tileSize, odim[0]); x++) {
653 for (dim_t y = sy; y < MIN(sy + tileSize, odim[1]); y++) {
654 SC[shuffle[0]] = x;
655 SC[shuffle[1]] = y;
656 outW[libjit_getXY(odim, x, y)] =
657 inW[libjit_getXY(idim, SC[0], SC[1])];
658 }
659 }
660 }
661 }
662 return;
663 }
664}
665
666template <typename T>
667static void libjit_flip_generic(const T *inW, T *outW, const dim_t *dims,
668 dim_t axis, dim_t numDims) {
669
670 // Product of outer dimensions excluding the flip dimension.
671 dim_t outerLen = 1;
672 for (dim_t idx = 0; idx < axis; idx++) {
673 outerLen *= dims[idx];
674 }
675
676 // Flip dimension.
677 dim_t len = dims[axis];
678
679 // Product of inner dimensions excluding the flip dimension.
680 dim_t innerLen = 1;
681 for (dim_t idx = axis + 1; idx < numDims; idx++) {
682 innerLen *= dims[idx];
683 }
684
685 // Flip axis such that input data is read linearly.
686 const T *inpPtr = inW;
687 T *outPtr = outW + (len - 1) * innerLen;
688 for (dim_t outerIdx = 0; outerIdx < outerLen; outerIdx++) {
689 for (dim_t idx = 0; idx < len; idx++) {
690 for (dim_t innerIdx = 0; innerIdx < innerLen; innerIdx++) {
691 *outPtr++ = *inpPtr++;
692 }
693 outPtr -= 2 * innerLen;
694 }
695 outPtr += 2 * len * innerLen;
696 }
697}
698
699template <typename ty>
700static void libjit_embedding_generic(ty *dest, ty *weights, int64_t *indices,
701 const dim_t *indDims, dim_t indSize,
702 dim_t num_embedding, dim_t embedding_dim,
703 int64_t padIdx, bool scale, bool sparse) {
704 dim_t indLen = 1;
705 for (dim_t idx = 0; idx < indSize; ++idx) {
706 indLen *= indDims[idx];
707 }
708
709 assert(!scale && "Currently only support scale_grad_by_freq == 'false'");
710 assert(!sparse && "Currently only support sparse == 'false'");
711 if (padIdx > -1) {
712 assert(static_cast<dim_t>(padIdx) <= num_embedding &&
713 "padIdx must be within num_embedding");
714 }
715 memset(dest, 0, indLen * embedding_dim * sizeof(ty));
716
717 for (int64_t i = 0; i < indLen; i++) {
718 int64_t index = indices[i];
719 if (index != padIdx) {
720 for (dim_t j = 0; j < embedding_dim; j++) {
721 dest[i * embedding_dim + j] = weights[index * embedding_dim + j];
722 }
723 }
724 }
725}
726
727template <typename inpT, typename outT>
728static void libjit_arg_max_generic(const inpT *inpW, outT *outW,
729 const dim_t *dims, size_t numDims,
730 size_t axis) {
731
732 // Product of outer dimensions excluding the axis dimension.
733 dim_t outerLen = 1;
734 for (dim_t idx = 0; idx < axis; ++idx) {
735 outerLen *= dims[idx];
736 }
737
738 // Axis dimension length.
739 dim_t axisLen = dims[axis];
740
741 // Product of inner dimensions excluding the axis dimension.
742 dim_t innerLen = 1;
743 for (dim_t idx = axis + 1; idx < numDims; ++idx) {
744 innerLen *= dims[idx];
745 }
746
747 // Traverse data such that output is written linearly.
748 const inpT *inpPtr = inpW;
749 outT *outPtr = outW;
750 for (dim_t outerIdx = 0; outerIdx < outerLen; ++outerIdx) {
751 for (dim_t innerIdx = 0; innerIdx < innerLen; ++innerIdx) {
752 inpT maxVal = std::numeric_limits<inpT>::lowest();
753 outT maxIdx = 0;
754 for (dim_t axisIdx = 0; axisIdx < axisLen; ++axisIdx) {
755 inpT inpVal = *inpPtr;
756 if (inpVal > maxVal) {
757 maxVal = inpVal;
758 maxIdx = axisIdx;
759 }
760 inpPtr += innerLen;
761 }
762 inpPtr = inpPtr - axisLen * innerLen + 1;
763 *outPtr++ = maxIdx;
764 }
765 inpPtr = inpPtr - innerLen + axisLen * innerLen;
766 }
767}
768
769template <typename inpT, typename outT>
770static void libjit_arg_min_generic(const inpT *inpW, outT *outW,
771 const dim_t *dims, size_t numDims,
772 size_t axis) {
773
774 // Product of outer dimensions excluding the axis dimension.
775 dim_t outerLen = 1;
776 for (dim_t idx = 0; idx < axis; ++idx) {
777 outerLen *= dims[idx];
778 }
779
780 // Axis dimension length.
781 dim_t axisLen = dims[axis];
782
783 // Product of inner dimensions excluding the axis dimension.
784 dim_t innerLen = 1;
785 for (dim_t idx = axis + 1; idx < numDims; ++idx) {
786 innerLen *= dims[idx];
787 }
788
789 // Traverse data such that output is written linearly.
790 const inpT *inpPtr = inpW;
791 outT *outPtr = outW;
792 for (dim_t outerIdx = 0; outerIdx < outerLen; ++outerIdx) {
793 for (dim_t innerIdx = 0; innerIdx < innerLen; ++innerIdx) {
794 inpT minVal = std::numeric_limits<inpT>::max();
795 outT minIdx = 0;
796 for (dim_t axisIdx = 0; axisIdx < axisLen; ++axisIdx) {
797 inpT inpVal = *inpPtr;
798 if (inpVal < minVal) {
799 minVal = inpVal;
800 minIdx = axisIdx;
801 }
802 inpPtr += innerLen;
803 }
804 inpPtr = inpPtr - axisLen * innerLen + 1;
805 *outPtr++ = minIdx;
806 }
807 inpPtr = inpPtr - innerLen + axisLen * innerLen;
808 }
809}
810
811template <typename T>
812static void libjit_max_pool_generic(const T *inW, T *outW, const dim_t *inWdims,
813 const dim_t *outWdims, dim_t *kernelSizes,
814 dim_t *strides, dim_t *pads, T defVal) {
815
816 size_t kernelH = kernelSizes[0];
817 size_t kernelW = kernelSizes[1];
818
819 size_t strideH = strides[0];
820 size_t strideW = strides[1];
821
822 size_t padT = pads[0];
823 size_t padL = pads[1];
824
825 // For each input in the batch.
826 for (size_t n = 0; n < inWdims[0]; n++) {
827
828 // For each output height.
829 ssize_t i_h_min = -(ssize_t)padT;
830 for (size_t o_h = 0; o_h < outWdims[1]; o_h++, i_h_min += strideH) {
831
832 // Effective kernel height limits.
833 ssize_t f_h_min = libjit_conv_flt_min(i_h_min);
834 ssize_t f_h_max = libjit_conv_flt_max(inWdims[1], kernelH, i_h_min);
835 ssize_t f_h_len = libjit_conv_flt_len(f_h_min, f_h_max);
836 const T *inpPtrH = inW + (i_h_min + f_h_min) * inWdims[2] * inWdims[3];
837
838 // For each output width.
839 ssize_t i_w_min = -(ssize_t)padL;
840 for (size_t o_w = 0; o_w < outWdims[2]; o_w++, i_w_min += strideW) {
841
842 // Effective kernel width limits.
843 ssize_t f_w_min = libjit_conv_flt_min(i_w_min);
844 ssize_t f_w_max = libjit_conv_flt_max(inWdims[2], kernelW, i_w_min);
845 ssize_t f_w_len = libjit_conv_flt_len(f_w_min, f_w_max);
846 const T *inpPtr = inpPtrH + (i_w_min + f_w_min) * inWdims[3];
847
848 // For each output channel.
849 for (size_t o_c = 0; o_c < outWdims[3]; o_c++) {
850
851 // Initialize max.
852 T max = std::numeric_limits<T>::lowest();
853
854 // For each kernel height.
855 for (size_t f_h = 0; f_h < f_h_len; f_h++) {
856
857 // For each kernel width.
858 for (size_t f_w = 0; f_w < f_w_len; f_w++) {
859
860 // Take maximum along the kernel width.
861 max = std::max(max, *inpPtr);
862 inpPtr += inWdims[3];
863 }
864
865 // Advance input pointer for next kernel height.
866 inpPtr = inpPtr - f_w_len * inWdims[3] + inWdims[2] * inWdims[3];
867 }
868
869 // Store max. If the effective pooling window size is empty then we
870 // return the default value.
871 if (f_h_len > 0 && f_w_len > 0) {
872 *outW++ = max;
873 } else {
874 *outW++ = defVal;
875 }
876
877 // Advance input pointer for next output channel.
878 inpPtr = inpPtr - f_h_len * inWdims[2] * inWdims[3] + 1;
879 }
880 }
881 }
882
883 // Advance input pointer for next batch.
884 inW += inWdims[1] * inWdims[2] * inWdims[3];
885 }
886}
887
888template <typename T, typename T2>
889static void
890libjit_max_pool_argmax_generic(const T *inW, T *outW, T2 *argmax,
891 const dim_t *inWdims, const dim_t *outWdims,
892 dim_t *kernels, dim_t *strides, dim_t *pads) {
893 dim_t pad_t = pads[0];
894 dim_t pad_l = pads[1];
895 dim_t stride_h = strides[0];
896 dim_t stride_w = strides[1];
897 dim_t kernel_h = kernels[0];
898 dim_t kernel_w = kernels[1];
899 // For each input in the batch:
900 for (dim_t n = 0; n < outWdims[0]; n++) {
901
902 // For each (x,y) step in the input/output tensor:
903 sdim_t x = -(sdim_t)pad_t;
904 for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
905 sdim_t y = -(sdim_t)pad_l;
906 for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
907
908 // For each channel in the output tensor:
909 for (dim_t z = 0; z < outWdims[3]; z++) {
910 int64_t argmaxNHWC = 0;
911 int first = 1;
912 T max = 0;
913
914 for (dim_t kx = 0; kx < kernel_h; kx++) {
915 for (dim_t ky = 0; ky < kernel_w; ky++) {
916 sdim_t ox = x + kx;
917 sdim_t oy = y + ky;
918
919 if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
920 oy >= (sdim_t)inWdims[2]) {
921 continue;
922 }
923 const dim_t flatIndex =
924 libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z);
925 T val = inW[flatIndex];
926 if (first || (val >= max)) {
927 first = 0;
928 max = val;
929 argmaxNHWC = flatIndex;
930 }
931 }
932 }
933
934 const dim_t flatIndex = libjit_getXYZW(outWdims, n, ax, ay, z);
935 outW[flatIndex] = max;
936 argmax[flatIndex] = argmaxNHWC;
937 } // C
938 } // W
939 } // H
940 } // N
941}
942
943template <typename T>
944void libjit_resizenearest_generic(T *dst, const T *src, const float *scale,
945 const dim_t *inWdims, const dim_t *outWdims) {
946
947 for (dim_t ob = 0; ob < outWdims[0]; ++ob) {
948 auto ib = std::min(dim_t(ob / (scale[0])), inWdims[0] - 1);
949 for (dim_t oh = 0; oh < outWdims[1]; ++oh) {
950 auto ih = std::min(dim_t(oh / (scale[1])), inWdims[1] - 1);
951 for (dim_t ow = 0; ow < outWdims[2]; ++ow) {
952 auto iw = std::min(dim_t(ow / (scale[2])), inWdims[2] - 1);
953 for (dim_t oc = 0; oc < outWdims[3]; ++oc) {
954 auto ic = std::min(dim_t(oc / (scale[3])), inWdims[3] - 1);
955 const dim_t inIndex = libjit_getXYZW(inWdims, ib, ih, iw, ic);
956 const dim_t outIndex = libjit_getXYZW(outWdims, ob, oh, ow, oc);
957 dst[outIndex] = src[inIndex];
958 }
959 }
960 }
961 }
962}
963
964template <typename T>
965static void
966libjit_resizebilinear_generic(T *dst, const T *src, const float *scale,
967 const dim_t *inWdims, const dim_t *outWdims) {
968 for (dim_t ob = 0; ob < outWdims[0]; ++ob) {
969 for (dim_t oh = 0; oh < outWdims[1]; ++oh) {
970 for (dim_t ow = 0; ow < outWdims[2]; ++ow) {
971 float ihf = oh / scale[1];
972 float iwf = ow / scale[2];
973 dim_t ih = dim_t(ihf);
974 dim_t iw = dim_t(iwf);
975
976 auto ih0 = std::min(ih, inWdims[1] - 1);
977 auto ih1 = std::min(ih + 1, inWdims[1] - 1);
978 auto iw0 = std::min(iw, inWdims[2] - 1);
979 auto iw1 = std::min(iw + 1, inWdims[2] - 1);
980
981 for (dim_t oc = 0; oc < outWdims[3]; ++oc) {
982 float v00 = src[libjit_getXYZW(inWdims, ob, ih0, iw0, oc)];
983 float v01 = src[libjit_getXYZW(inWdims, ob, ih0, iw1, oc)];
984 float v10 = src[libjit_getXYZW(inWdims, ob, ih1, iw0, oc)];
985 float v11 = src[libjit_getXYZW(inWdims, ob, ih1, iw1, oc)];
986
987 float hd = v00 + (v10 - v00) * (ihf - ih);
988 float hw = v01 + (v11 - v01) * (ihf - ih);
989 float result = hd + (hw - hd) * (iwf - iw);
990 dst[libjit_getXYZW(outWdims, ob, oh, ow, oc)] = result;
991 }
992 }
993 }
994 }
995}
996
997template <typename T>
998static void
999libjit_batchedadd_quantized(int8_t *dest, const int8_t *batch, const T *slice,
1000 dim_t numSlice, dim_t sliceSize, int32_t destOffset,
1001 int32_t batchOffset, int32_t sliceOffset,
1002 int32_t batchPre, int32_t batchPost,
1003 int32_t batchScale, int32_t slicePre,
1004 int32_t slicePost, int32_t sliceScale) {
1005 for (dim_t n = 0; n < numSlice; n++) {
1006 dim_t base = n * sliceSize;
1007 for (dim_t i = 0; i < sliceSize; i++) {
1008 int32_t b = batch[base + i] - batchOffset;
1009 int32_t s = slice[i] - sliceOffset;
1010 int32_t x = libjit_scale<int32_t>(b, batchPre, batchPost, batchScale, 0);
1011 int32_t y = libjit_scale<int32_t>(s, slicePre, slicePost, sliceScale, 0);
1012 dest[base + i] = libjit_clip_i8(x + y + destOffset);
1013 }
1014 }
1015}
1016
1017static void find_min_max_f(float *tensor, dim_t size, float &min, float &max) {
1018 min = tensor[0];
1019 max = tensor[0];
1020
1021 for (dim_t i = 1; i < size; ++i) {
1022 float tensorVal = tensor[i];
1023 if (tensorVal < min)
1024 min = tensorVal;
1025
1026 if (tensorVal > max)
1027 max = tensorVal;
1028
1029 // Sanity check for NaN and Infinity.
1030 assert(!std::isnan(tensor[i]) && "NaN value found!");
1031 assert(!std::isinf(tensor[i]) && "Infinity value found!");
1032 }
1033}
1034
1035static int check_all_zeros(float *arrayToCheck, dim_t size) {
1036 for (dim_t i = 0; i < size; ++i) {
1037 if (arrayToCheck[i] != 0) {
1038 return 0;
1039 }
1040 }
1041 return 1;
1042}
1043
1044/// Gen a bin number to insert \p value into the histogram which has \p nBins
1045/// with \p minValue and binWidth in histogram.
1046static dim_t get_bin(dim_t nBins, float binWidth, float minValue, float value) {
1047 dim_t result =
1048 binWidth == 0
1049 ? 0
1050 : MIN(static_cast<dim_t>((value - minValue) / binWidth), nBins - 1);
1051 return result;
1052}
1053
1054template <typename T>
1055static void libjit_space_to_depth_generic(const T *inPtr, T *outPtr,
1056 dim_t blockSize, const dim_t *inDims,
1057 const dim_t *outDims) {
1058 dim_t inHeight = inDims[1];
1059 dim_t inWidth = inDims[2];
1060 dim_t inDepth = inDims[3];
1061
1062 dim_t outBatch = outDims[0];
1063 dim_t outHeight = outDims[1];
1064 dim_t outWidth = outDims[2];
1065 dim_t outDepth = outDims[3];
1066
1067 for (dim_t b = 0; b < outBatch; ++b) {
1068 for (dim_t h = 0; h < outHeight; ++h) {
1069 for (dim_t w = 0; w < outWidth; ++w) {
1070 for (dim_t c = 0; c < outDepth; ++c) {
1071 // NHWC
1072 // c +
1073 // w * outDepth +
1074 // h * outDepth * outWidth +
1075 // b * outDepth * outWidth * outHeight
1076 dim_t outIndex = c + outDepth * (w + outWidth * (h + b * outHeight));
1077
1078 // Gets the block layer we are on
1079 dim_t blockDepthLayer = c / inDepth;
1080 // every multiple of block size we reset to 0 offset
1081 dim_t iw = w * blockSize + blockDepthLayer % blockSize;
1082 // every multiple of blockSize we start height traversal + 1
1083 dim_t ih = h * blockSize + blockDepthLayer / blockSize;
1084 // at every multiple of inDepth index in to input depths resets to 0
1085 dim_t id = c % inDepth;
1086
1087 dim_t inIndex = id + inDepth * (iw + inWidth * (ih + b * inHeight));
1088 outPtr[outIndex] = inPtr[inIndex];
1089 }
1090 }
1091 }
1092 }
1093}
1094
1095template <typename DstType, typename SrcType>
1096static void
1097libjit_copy_kernel_with_conversion(DstType *dstPtr, const SrcType *srcPtr,
1098 const dim_t *dims, dim_t numDims) {
1099 dim_t dimSize = 1;
1100 for (dim_t i = 0; i < numDims; ++i) {
1101 dimSize *= dims[i];
1102 }
1103
1104 for (dim_t i = 0; i < dimSize; ++i) {
1105 dstPtr[i] = DstType(srcPtr[i]);
1106 }
1107}
1108
1109/// The dimensions passed in here are pre-expanded in LLVMIRGen with 1s so that
1110/// we can iterate over the shape here, regardless of the shape of the tensor.
1111#define DEFINE_REDUCE_MINMAX_KERNEL(minmax) \
1112 template <typename T> \
1113 static void libjit_reduce##minmax(T *dest, const T *batch, size_t destSize, \
1114 const dim_t *destDims, \
1115 const dim_t *batchDims, T init) { \
1116 for (dim_t i = 0; i < destSize; i++) { \
1117 dest[i] = init; \
1118 } \
1119 \
1120 unsigned int axis[6]; \
1121 for (dim_t i = 0; i < 6; i++) { \
1122 axis[i] = (destDims[i] > 1); \
1123 } \
1124 \
1125 for (dim_t x = 0, dx = 0; x < batchDims[0]; x++, dx += axis[0]) { \
1126 for (dim_t y = 0, dy = 0; y < batchDims[1]; y++, dy += axis[1]) { \
1127 for (dim_t z = 0, dz = 0; z < batchDims[2]; z++, dz += axis[2]) { \
1128 for (dim_t w = 0, dw = 0; w < batchDims[3]; w++, dw += axis[3]) { \
1129 for (dim_t q = 0, dq = 0; q < batchDims[4]; q++, dq += axis[4]) { \
1130 for (dim_t r = 0, dr = 0; r < batchDims[5]; \
1131 r++, dr += axis[5]) { \
1132 T fdest = \
1133 dest[libjit_getXYZWQR(destDims, dx, dy, dz, dw, dq, dr)]; \
1134 T fnew = batch[libjit_getXYZWQR(batchDims, x, y, z, w, q, r)]; \
1135 dest[libjit_getXYZWQR(destDims, dx, dy, dz, dw, dq, dr)] = \
1136 std::minmax(fdest, fnew); \
1137 } \
1138 } \
1139 } \
1140 } \
1141 } \
1142 } \
1143 }
1144
1145// Define libjit_reducemax
1146DEFINE_REDUCE_MINMAX_KERNEL(max)
1147
1148// Define libjit_reducemin
1149DEFINE_REDUCE_MINMAX_KERNEL(min)
1150
1151#undef DEFINE_REDUCE_MINMAX_KERNEL
1152
1153template <typename T, typename T2>
1154static void libjit_cross_entropy_loss_generic(T *CE, T *P, T2 *labels,
1155 dim_t *dims) {
1156 CE[0] = 0.0;
1157 for (dim_t n = 0; n < dims[0]; ++n) {
1158 auto y = labels[n];
1159 auto p_n = P[libjit_getXY(dims, n, y)];
1160 CE[0] -= log(p_n);
1161 }
1162}
1163
1164template <typename T, typename T2>
1165static void libjit_sparse_lengths_sum_generic(T *dest, T *data, T2 *indices,
1166 int32_t *lengths, dim_t segments,
1167 dim_t lineSize) {
1168 memset(dest, 0, segments * lineSize * sizeof(float));
1169 dim_t curIndex = 0;
1170 for (dim_t i = 0; i < segments; i++) {
1171 for (int32_t j = 0; j < lengths[i]; j++) {
1172 dim_t line = indices[curIndex];
1173 for (dim_t k = 0; k < lineSize; k++) {
1174 dest[i * lineSize + k] += data[line * lineSize + k];
1175 }
1176 curIndex++;
1177 }
1178 }
1179}
1180
1181template <typename T, typename T2>
1182static void
1183libjit_sparse_lengths_weighted_sum_generic(T *dest, T *data, float *weights,
1184 T2 *indices, int32_t *lengths,
1185 dim_t segments, dim_t lineSize) {
1186 memset(dest, 0, segments * lineSize * sizeof(float));
1187 dim_t curIndex = 0;
1188 for (dim_t i = 0; i < segments; i++) {
1189 for (int32_t j = 0; j < lengths[i]; j++) {
1190 float weight = weights[curIndex];
1191 dim_t line = indices[curIndex];
1192 for (dim_t k = 0; k < lineSize; k++) {
1193 dest[i * lineSize + k] += weight * data[line * lineSize + k];
1194 }
1195 curIndex++;
1196 }
1197 }
1198}
1199
1200template <typename T, typename T2>
1201static void libjit_sparse_lengths_weighted_sum_grad_generic(
1202 const T *destGrad, T *dataGrad, T *weightsGrad, const T *data,
1203 const T *weights, const T2 *indices, const int32_t *lengths, dim_t segments,
1204 dim_t lineSize, dim_t dataGradRawSize) {
1205 // The data gradients not touched by this operation should
1206 // be 0, so set the entire buffer to 0 to start with.
1207 memset(dataGrad, 0, dataGradRawSize);
1208
1209 for (dim_t i = 0, curIndex = 0; i < segments; ++i) {
1210 for (int32_t j = 0; j < lengths[i]; ++j, ++curIndex) {
1211 // For each index in each segment:
1212 // 1) accumulate into the corresponding data gradient the product of
1213 // the gradient of the result it was added to and the weight that it
1214 // was multiplied by during the SparseLengthsWeightedSum operation.
1215 //
1216 // 2) accumulate into each weight gradient the reduced sum of the
1217 // elementwise product of the result slice that the corresponding
1218 // weight produced and the input slice that the weight was multiplied
1219 // with.
1220 float weightGrad = 0.0f;
1221 float weight = weights[curIndex];
1222 dim_t line = indices[curIndex];
1223 for (dim_t k = 0; k < lineSize; ++k) {
1224 dataGrad[line * lineSize + k] += weight * destGrad[i * lineSize + k];
1225 weightGrad += destGrad[i * lineSize + k] * data[line * lineSize + k];
1226 }
1227 weightsGrad[curIndex] = weightGrad;
1228 }
1229 }
1230}
1231
1232template <typename T, typename T2>
1233static void libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
1234 T *dest, uint8_t *data, T *scales, T *offsets, T *weights, T2 *indices,
1235 int32_t *lengths, dim_t segments, dim_t lineSize) {
1236 memset(dest, 0, segments * lineSize * sizeof(float));
1237 dim_t curIndex = 0;
1238 for (dim_t i = 0; i < segments; i++) {
1239 for (int32_t j = 0; j < lengths[i]; j++) {
1240 const float weight = weights[curIndex];
1241 const dim_t line = indices[curIndex];
1242 const float scale = scales[line];
1243 const float offset = offsets[line];
1244 for (dim_t k = 0; k < lineSize; k++) {
1245 const float fData = scale * data[line * lineSize + k] + offset;
1246 dest[i * lineSize + k] += weight * fData;
1247 }
1248 curIndex++;
1249 }
1250 }
1251}
1252
1253template <typename T, typename T2>
1254static void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
1255 T *dest, int8_t *data, T *weights, T2 *indices, int32_t *lengths,
1256 dim_t segments, dim_t inLineSize, dim_t outLineSize) {
1257 memset(dest, 0, segments * outLineSize * sizeof(float));
1258 dim_t curIndex = 0;
1259 for (dim_t i = 0; i < segments; i++) {
1260 for (int32_t j = 0, e = lengths[i]; j < e; j++) {
1261 const float weight = weights[curIndex];
1262 const dim_t line = indices[curIndex];
1263 const int8_t *currRowScaleOffsetPtr =
1264 data + ((line + 1) * inLineSize) - 2 * sizeof(float);
1265 float scale, offset;
1266 memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
1267 memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
1268 for (dim_t k = 0; k < outLineSize; k++) {
1269 const float fData =
1270 (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
1271 dest[i * outLineSize + k] += weight * fData;
1272 }
1273 curIndex++;
1274 }
1275 }
1276}
1277
1278template <typename T, typename T2>
1279static void libjit_sparse_to_dense_generic(T *dest, const T2 *indices,
1280 const T *values, dim_t numIndices,
1281 dim_t destSize, dim_t valueSize) {
1282 memset(dest, 0, destSize * sizeof(float));
1283
1284 for (dim_t i = 0, valuesOffset = 0; i < numIndices;
1285 ++i, valuesOffset += valueSize) {
1286 dim_t idx = indices[i];
1287 dim_t destOffset = idx * valueSize;
1288
1289 for (size_t j = 0; j < valueSize; ++j) {
1290 dest[destOffset + j] += values[valuesOffset + j];
1291 }
1292 }
1293}
1294
1295struct ClassBox {
1296 float score{0.0f};
1297 size_t index{0};
1298};
1299
1300struct Box {
1301 float v0{0.0f};
1302 float v1{0.0f};
1303 float v2{0.0f};
1304 float v3{0.0f};
1305};
1306
1307struct OutBox {
1308 float classValue{0.0f};
1309 size_t batchIndex{0};
1310 size_t classIndex{0};
1311 size_t boxIndex{0};
1312};
1313
1314static void maxMin(float lhs, float rhs, float &min, float &max) {
1315 if (lhs >= rhs) {
1316 min = rhs;
1317 max = lhs;
1318 } else {
1319 min = lhs;
1320 max = rhs;
1321 }
1322}
1323
1324static bool checkIOU(const Box &sb, const Box &cb, float iouThreshold,
1325 size_t centerPointBox) {
1326 float xSMin = 0.0f;
1327 float ySMin = 0.0f;
1328 float xSMax = 0.0f;
1329 float ySMax = 0.0f;
1330
1331 float xCMin = 0.0f;
1332 float yCMin = 0.0f;
1333 float xCMax = 0.0f;
1334 float yCMax = 0.0f;
1335
1336 // Standardizing coordinates so that (xmin, ymin) is upper left corner of a
1337 // box and (xmax, ymax) is lower right corner of the box.
1338 if (!centerPointBox) {
1339 // 0 means coordinates for diagonal ends of a box.
1340 // Coordinates can either be absolute or normalized.
1341 maxMin(sb.v0, sb.v2, xSMin, xSMax);
1342 maxMin(sb.v1, sb.v3, ySMin, ySMax);
1343
1344 maxMin(cb.v0, cb.v2, xCMin, xCMax);
1345 maxMin(cb.v1, cb.v3, yCMin, yCMax);
1346 } else {
1347 float halfWidthS = sb.v2 / 2.0f;
1348 float halfHeightS = sb.v3 / 2.0f;
1349 float halfWidthC = cb.v2 / 2.0f;
1350 float halfHeightC = cb.v3 / 2.0f;
1351
1352 xSMin = sb.v0 - halfWidthS;
1353 ySMin = sb.v1 - halfHeightS;
1354 xSMax = sb.v0 + halfWidthS;
1355 ySMax = sb.v1 + halfHeightS;
1356
1357 xCMin = cb.v0 - halfWidthC;
1358 yCMin = cb.v1 - halfHeightC;
1359 xCMax = cb.v0 + halfWidthC;
1360 yCMax = cb.v1 + halfHeightC;
1361 }
1362
1363 // finding upper left and lower right corner of a box formed by intersection.
1364 float xMin = MAX(xSMin, xCMin);
1365 float yMin = MAX(ySMin, yCMin);
1366 float xMax = MIN(xSMax, xCMax);
1367 float yMax = MIN(ySMax, yCMax);
1368
1369 float intersectionArea = MAX((0.0f), xMax - xMin) * MAX((0.0f), yMax - yMin);
1370
1371 if (intersectionArea == 0.0f) {
1372 return false;
1373 }
1374
1375 float sArea = (xSMax - xSMin) * (ySMax - ySMin);
1376 float cArea = (xCMax - xCMin) * (yCMax - yCMin);
1377 float unionArea = sArea + cArea - intersectionArea;
1378
1379 return intersectionArea > iouThreshold * unionArea;
1380}
1381
1382// ONNX
1383// Class/Score [BatchNum][ClassNum][BoxNum]
1384// Box [BatchNum][BoxNum][4]
1385// Result [BatchNum*MaxOutputPerBatch][3]
1386// V4
1387// Class/Score [BatchNum][BoxNum]
1388// Boxes [BatdhNum][BoxNum][4]
1389// Result [BatchNum*MaxOutputPerBatch]
1390// NumberOfIndicesDetected [BatchNum*MaxOutputPerBatch]
1391template <typename T>
1392static void
1393libjit_nms_generic(T *indices, T *numDetected, const float *boxTensor,
1394 const dim_t *boxTensorDims, dim_t boxTensorDimSize,
1395 const float *scoresTensor, const dim_t *scoresTensorDims,
1396 dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
1397 dim_t resultTensorDimSize, unsigned centerPointBox,
1398 unsigned maxOutputBoxesPerClass, float iouThreshold,
1399 float scoreThreshold, bool isV4) {
1400 int boxesBoxDim = boxTensorDimSize - 2;
1401
1402 size_t numBatches = 1;
1403 size_t numClasses = 1;
1404 size_t numBoxes = boxTensorDims[boxesBoxDim];
1405
1406 size_t maxOutputPerBatch = 0;
1407 if (!isV4) {
1408 int boxesBatchDim = boxTensorDimSize - 3;
1409 int scoresBatchDim = scoresTensorDimSize - 3;
1410
1411 int scoresBoxDim = scoresTensorDimSize - 1;
1412 int scoresClassDim = scoresTensorDimSize - 2;
1413
1414 assert(scoresTensorDims[scoresBoxDim] == boxTensorDims[boxesBoxDim] &&
1415 "Mismatch between number of scores and number of boxes.");
1416 assert(scoresTensorDims[scoresBatchDim] == boxTensorDims[boxesBatchDim] &&
1417 "Scores and Box Batch Dimensions don't match.");
1418 (void)boxesBatchDim;
1419 (void)scoresBoxDim;
1420 numBatches = scoresTensorDims[scoresBatchDim];
1421 numClasses = scoresTensorDims[scoresClassDim];
1422 numBoxes = boxTensorDims[boxesBoxDim];
1423 maxOutputPerBatch = resultTensorDims[resultTensorDimSize - 2] / numBatches;
1424 } else {
1425 maxOutputPerBatch = resultTensorDims[resultTensorDimSize - 1] / numBatches;
1426 }
1427
1428 static_assert(sizeof(Box) == 4 * sizeof(float),
1429 "Can't reinterpret raw float data as a Box.");
1430 const Box *boxes = reinterpret_cast<const Box *>(boxTensor);
1431
1432 auto cmpFunc = [](const ClassBox &cb1, const ClassBox &cb2) -> bool {
1433 return cb1.score > cb2.score;
1434 };
1435
1436 size_t outPutBoxIndex = 0;
1437 for (size_t batchIndex = 0; batchIndex < numBatches; ++batchIndex) {
1438 int32_t detectedPerBatch = 0;
1439 OutBox minBox{scoresTensor[batchIndex * numClasses], batchIndex, 0, 0};
1440 for (size_t classIndex = 0; classIndex < numClasses; ++classIndex) {
1441 ClassBox selectedIndices[numBoxes];
1442 ClassBox potentialBoxes[numBoxes];
1443 size_t indexPBoxes = 0;
1444 const float *currClass =
1445 &scoresTensor[(batchIndex * numClasses + classIndex) * numBoxes];
1446 for (size_t boxIndex = 0; boxIndex < numBoxes; ++boxIndex) {
1447 float classScore = currClass[boxIndex];
1448 if (classScore > scoreThreshold) {
1449 ClassBox &b = potentialBoxes[indexPBoxes++];
1450 b.score = classScore;
1451 b.index = boxIndex;
1452 }
1453 }
1454
1455 std::sort(potentialBoxes, potentialBoxes + indexPBoxes, cmpFunc);
1456
1457 size_t indexSBoxes = 0;
1458 size_t detectedPerClass = 0;
1459 float tScore = minBox.classValue;
1460 for (unsigned int i = 0; i < indexPBoxes; ++i) {
1461 ClassBox &pbI = potentialBoxes[i];
1462 const Box &potentialBox = boxes[batchIndex * numBoxes + pbI.index];
1463 bool selected = true;
1464 for (unsigned int j = 0; j < indexSBoxes && selected; ++j) {
1465 ClassBox &sbI = selectedIndices[j];
1466 const Box &selectedBox = boxes[batchIndex * numBoxes + sbI.index];
1467 selected = !checkIOU(selectedBox, potentialBox, iouThreshold,
1468 centerPointBox);
1469 }
1470
1471 if (selected) {
1472 selectedIndices[indexSBoxes++] = pbI;
1473 if (isV4) {
1474 indices[outPutBoxIndex] = pbI.index;
1475 } else {
1476 indices[outPutBoxIndex * 3 + 0] = batchIndex;
1477 indices[outPutBoxIndex * 3 + 1] = classIndex;
1478 indices[outPutBoxIndex * 3 + 2] = pbI.index;
1479 }
1480
1481 tScore = pbI.score;
1482 ++outPutBoxIndex;
1483 ++detectedPerClass;
1484 ++detectedPerBatch;
1485 }
1486
1487 if (detectedPerClass == maxOutputBoxesPerClass) {
1488 break;
1489 }
1490 }
1491
1492 if (tScore < minBox.classValue) {
1493 minBox.classValue = tScore;
1494 if (isV4) {
1495 minBox.boxIndex = indices[outPutBoxIndex - 1];
1496 } else {
1497 minBox.boxIndex = indices[(outPutBoxIndex - 1) * 3 + 2];
1498 }
1499 minBox.classIndex = classIndex;
1500 }
1501 }
1502
1503 // Filling the rest of the class with minimum value.
1504 for (size_t i = detectedPerBatch; i < maxOutputPerBatch; ++i) {
1505 if (isV4) {
1506 indices[outPutBoxIndex] = minBox.boxIndex;
1507 } else {
1508 indices[outPutBoxIndex * 3 + 0] = minBox.batchIndex;
1509 indices[outPutBoxIndex * 3 + 1] = minBox.classIndex;
1510 indices[outPutBoxIndex * 3 + 2] = minBox.boxIndex;
1511 }
1512
1513 ++outPutBoxIndex;
1514 }
1515 // For ONNX NMS it's not used, for TF Batch Dimension is 1.
1516 for (size_t i = 0; i < maxOutputBoxesPerClass; ++i) {
1517 numDetected[batchIndex * maxOutputBoxesPerClass + i] = detectedPerBatch;
1518 }
1519 }
1520}
1521
1522template <typename T, typename T2>
1523void libjit_softmax_grad_generic(T *inG, T *outW, const T2 *selectedW,
1524 const dim_t *idim, const dim_t *selectdim) {
1525 for (dim_t n = 0; n < idim[0]; n++) {
1526 for (dim_t i = 0; i < idim[1]; i++) {
1527 float delta = (selectedW[libjit_getXY(selectdim, n, 0)] == T2(i));
1528 inG[libjit_getXY(idim, n, i)] = outW[libjit_getXY(idim, n, i)] - delta;
1529 }
1530 }
1531}
1532
1533template <typename T, typename T2>
1534void libjit_max_pool_argmax_grad_generic(T *inG, const T *outG,
1535 const T2 *argmax, const dim_t *inGdims,
1536 const dim_t *outWdims) {
1537 // NHWC format is assumed
1538 for (dim_t n = 0; n < outWdims[0]; n++) {
1539 for (dim_t z = 0; z < outWdims[3]; z++) {
1540 // Clear inG
1541 for (dim_t x = 0; x < inGdims[1]; x++) {
1542 for (dim_t y = 0; y < inGdims[2]; y++) {
1543 inG[libjit_getXYZW(inGdims, n, x, y, z)] = 0.0;
1544 }
1545 }
1546
1547 for (dim_t ax = 0; ax < outWdims[1]; ax++) {
1548 for (dim_t ay = 0; ay < outWdims[2]; ay++) {
1549 // Reuse precomputed linear index of max element from argmax.
1550 const dim_t flatIndex = libjit_getXYZW(outWdims, n, ax, ay, z);
1551 float df = outG[flatIndex];
1552 inG[argmax[flatIndex]] += df;
1553 } // W
1554 } // H
1555 } // C
1556 } // N
1557}
1558} // namespace
1559
1560extern "C" {
1561
1562/// Macro to define a mini-kernel for data-parallel operations. The body of the
1563/// kernel is auto-generated by the macro.
1564/// \p name the name of the kernel
1565/// \p type the type of the tensor elements and of the return value
1566/// \p body the operation to be performed
1567#define DEFINE_DATA_PARALLEL_KERNEL(name, type, body) \
1568 type name(dim_t idx, const type *LHS, const type *RHS, const type *op3) { \
1569 return body; \
1570 }
1571
1572/// Macro to define a mini-kernel for data-parallel operations. The body of the
1573/// kernel is not auto-generated by the macro.
1574/// \p name the name of the kernel
1575#define DEFINE_DATA_PARALLEL_KERNEL_FUNC(name) \
1576 float name(dim_t idx, const float *LHS, const float *RHS, const float *op3)
1577
1578/// Macro to define a mini-kernel for data-parallel operations with immediate
1579/// operands.
1580/// \p name the name of the kernel
1581/// \p type the type of the tensor elements and of the return value
1582/// \p body the operation to be performed
1583#define DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(name, type, body) \
1584 type name(dim_t idx, type val, const type *LHS, const type *RHS) { \
1585 return body; \
1586 }
1587
1588/// Macro to define a mini-kernel for data-parallel arithmetic quantized
1589/// operations. The body of the kernel is auto-generated by the macro.
1590/// \p name the name of the kernel
1591/// \p type the type of the tensor elements
1592/// \p body the operation to be performed
1593#define DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(name, type, body) \
1594 type name(dim_t idx, const type *LHS, const type *RHS, int32_t destOffset, \
1595 int32_t lhsOffset, int32_t rhsOffset, int32_t lhsPre, \
1596 int32_t lhsPost, int32_t lhsScale, int32_t rhsPre, \
1597 int32_t rhsPost, int32_t rhsScale) { \
1598 int32_t lhs = libjit_scale<int32_t>(LHS[idx] - lhsOffset, lhsPre, lhsPost, \
1599 lhsScale, 0); \
1600 int32_t rhs = libjit_scale<int32_t>(RHS[idx] - rhsOffset, rhsPre, rhsPost, \
1601 rhsScale, 0); \
1602 return libjit_clip_i8((body) + destOffset); \
1603 }
1604
1605/// Macro to define a mini-kernel for data-parallel multiplicative quantized
1606/// operations. The body of the kernel is auto-generated by the macro.
1607/// \p name the name of the kernel
1608/// \p type the type of the tensor elements
1609/// \p body the operation to be performed
1610#define DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(name, body) \
1611 int8_t name(dim_t idx, const int8_t *LHS, const int8_t *RHS, \
1612 int32_t destOffset, int32_t lhsOffset, int32_t rhsOffset, \
1613 int32_t pre, int32_t post, int32_t scale) { \
1614 int32_t lhs = LHS[idx] - lhsOffset; \
1615 int32_t rhs = RHS[idx] - rhsOffset; \
1616 return libjit_clip_i8( \
1617 libjit_scale<int32_t>((body), pre, post, scale, destOffset)); \
1618 }
1619
1620/// Define mini-kernels for all data parallel operations. They are invoked from
1621/// the generated kernels for sequences of data parallel operations.
1622DEFINE_DATA_PARALLEL_KERNEL(libjit_element_max_kernel_f, float,
1623 MAX(LHS[idx], RHS[idx]))
1624DEFINE_DATA_PARALLEL_KERNEL(libjit_element_min_kernel_f, float,
1625 MIN(LHS[idx], RHS[idx]))
1626DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_f, float, LHS[idx])
1627DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_u, int64_t, LHS[idx])
1628DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i8, int8_t, LHS[idx])
1629DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i16, int16_t, LHS[idx])
1630DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i32, int32_t, LHS[idx])
1631DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_b, int8_t, LHS[idx])
1632DEFINE_DATA_PARALLEL_KERNEL(libjit_element_add_kernel_f, float,
1633 LHS[idx] + RHS[idx])
1634DEFINE_DATA_PARALLEL_KERNEL(libjit_element_add_kernel_i32, int32_t,
1635 LHS[idx] + RHS[idx])
1636DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sub_kernel_f, float,
1637 LHS[idx] - RHS[idx])
1638DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_f, float,
1639 LHS[idx] / RHS[idx])
1640DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_u, int64_t,
1641 LHS[idx] / RHS[idx])
1642DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_i32, int32_t,
1643 LHS[idx] / RHS[idx])
1644DEFINE_DATA_PARALLEL_KERNEL(libjit_element_mul_kernel_f, float,
1645 LHS[idx] * RHS[idx])
1646DEFINE_DATA_PARALLEL_KERNEL(libjit_element_mul_kernel_i32, int32_t,
1647 LHS[idx] * RHS[idx])
1648DEFINE_DATA_PARALLEL_KERNEL(libjit_element_pow_kernel_f, float,
1649 pow(LHS[idx], RHS[idx]))
1650DEFINE_DATA_PARALLEL_KERNEL(libjit_element_log_kernel_f, float, log(LHS[idx]))
1651DEFINE_DATA_PARALLEL_KERNEL(libjit_element_exp_kernel_f, float, exp(LHS[idx]))
1652DEFINE_DATA_PARALLEL_KERNEL(libjit_element_abs_kernel_f, float,
1653 std::abs(LHS[idx]))
1654DEFINE_DATA_PARALLEL_KERNEL(libjit_element_neg_kernel_f, float, -LHS[idx])
1655DEFINE_DATA_PARALLEL_KERNEL(libjit_element_floor_kernel_f, float,
1656 std::floor(LHS[idx]))
1657DEFINE_DATA_PARALLEL_KERNEL(libjit_element_ceil_kernel_f, float,
1658 std::ceil(LHS[idx]))
1659// Rounding mode required by ONNX, Numpy, TensorFlow is round to even which
1660// rounds to nearest even integer those values with fractional part 0.5.
1661DEFINE_DATA_PARALLEL_KERNEL(libjit_element_round_kernel_f, float,
1662 std::nearbyintf(LHS[idx]))
1663DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sqrt_kernel_f, float,
1664 std::sqrt(LHS[idx]))
1665DEFINE_DATA_PARALLEL_KERNEL(libjit_element_erf_kernel_f, float,
1666 std::erf(LHS[idx]))
1667DEFINE_DATA_PARALLEL_KERNEL(libjit_element_rsqrt_kernel_f, float,
1668 1 / std::sqrt(LHS[idx]))
1669DEFINE_DATA_PARALLEL_KERNEL(libjit_element_reciprocal_kernel_f, float,
1670 1 / LHS[idx])
1671DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sin_kernel_f, float,
1672 std::sin(LHS[idx]))
1673DEFINE_DATA_PARALLEL_KERNEL(libjit_element_cos_kernel_f, float,
1674 std::cos(LHS[idx]))
1675DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_add_kernel_i8, int8_t,
1676 lhs + rhs)
1677DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_sub_kernel_i8, int8_t,
1678 lhs - rhs)
1679DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_max_kernel_i8, int8_t,
1680 MAX(lhs, rhs))
1681DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_min_kernel_i8, int8_t,
1682 MIN(lhs, rhs))
1683DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(libjit_element_mul_kernel_i8, lhs *rhs)
1684DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(libjit_element_div_kernel_i8, lhs / rhs)
1685
1686DEFINE_DATA_PARALLEL_KERNEL(libjit_element_add_kernel_u, size_t,
1687 LHS[idx] + RHS[idx])
1688DEFINE_DATA_PARALLEL_KERNEL(libjit_element_mul_kernel_u, size_t,
1689 LHS[idx] * RHS[idx])
1690
1691/// This is a variable used by Glow backends to determine the actual type used
1692/// for size_t, dim_t and int variables when libjit was compiled.
1693size_t libjit_sizeTVar;
1694dim_t libjit_dimTVar;
1695int libjit_intVar;
1696
1697/// Specialize the Modulo kernel into two functions based on the
1698/// value of SignFollowDivisor.
1699int64_t libjit_element_modulo_kernel_sign_follow_u(dim_t idx,
1700 const int64_t divisor,
1701 const int64_t *input) {
1702 int64_t res = input[idx] % divisor;
1703 if (res && ((res > 0) != (divisor > 0))) {
1704 res += divisor;
1705 }
1706 return res;
1707}
1708
1709int64_t libjit_element_modulo_kernel_no_sign_follow_u(dim_t idx,
1710 const int64_t divisor,
1711 const int64_t *input) {
1712 return input[idx] % divisor;
1713}
1714
1715int32_t libjit_element_modulo_kernel_sign_follow_i32(dim_t idx,
1716 const int64_t divisor,
1717 const int32_t *input) {
1718 int32_t res = input[idx] % divisor;
1719 if (res && ((res > 0) != (divisor > 0))) {
1720 res += divisor;
1721 }
1722 return res;
1723}
1724
1725int32_t libjit_element_modulo_kernel_no_sign_follow_i32(dim_t idx,
1726 const int64_t divisor,
1727 const int32_t *input) {
1728 return input[idx] % divisor;
1729}
1730
1731//===----------------------------------------------------------------------===//
1732// Logical operations
1733//===----------------------------------------------------------------------===//
1734int8_t libjit_element_not_kernel_b(dim_t idx, const bool *input) {
1735 return !input[idx];
1736}
1737
1738int8_t libjit_element_and_kernel_b(dim_t idx, const bool *LHS,
1739 const bool *RHS) {
1740 return LHS[idx] && RHS[idx];
1741}
1742
1743int8_t libjit_element_or_kernel_b(dim_t idx, const bool *LHS, const bool *RHS) {
1744 return LHS[idx] || RHS[idx];
1745}
1746
1747int8_t libjit_element_xor_kernel_b(dim_t idx, const bool *LHS,
1748 const bool *RHS) {
1749 return LHS[idx] ^ RHS[idx];
1750}
1751
1752//===----------------------------------------------------------------------===//
1753// Compare operations
1754//===----------------------------------------------------------------------===//
1755#define DEFINE_CMP_KERNEL_QUANTIZED(name, type, cmp) \
1756 int8_t name(dim_t idx, const type *LHS, const type *RHS, int32_t lhsOffset, \
1757 int32_t rhsOffset, int32_t pre, int32_t post, int32_t scale) { \
1758 int32_t lhs = LHS[idx] - lhsOffset; \
1759 int32_t rhs = RHS[idx] - rhsOffset; \
1760 return (libjit_scale<int32_t>(lhs, pre, post, scale, 0) cmp rhs) ? 1 : 0; \
1761 }
1762DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_eq_kernel_i8, int8_t, ==)
1763DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_neq_kernel_i8, int8_t, !=)
1764DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_lt_kernel_i8, int8_t, <)
1765DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_lte_kernel_i8, int8_t, <=)
1766#undef DEFINE_CMP_KERNEL_QUANTIZED
1767
1768#define DEFINE_CMP_KERNEL_NON_QUANTIZED(name, type, cmp) \
1769 int8_t name(dim_t idx, const type *LHS, const type *RHS) { \
1770 return (LHS[idx] cmp RHS[idx]) ? 1 : 0; \
1771 }
1772
1773DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_f, float, ==)
1774DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_i32, int32_t, ==)
1775DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_u, size_t, ==)
1776
1777DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_f, float, !=)
1778DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_i32, int32_t, !=)
1779DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_u, size_t, !=)
1780
1781DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_f, float, <)
1782DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_i32, int32_t, <)
1783DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_u, size_t, <)
1784
1785DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_f, float, <=)
1786DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_i32, int32_t, <=)
1787DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_u, size_t, <=)
1788#undef DEFINE_CMP_KERNEL_NON_QUANTIZED
1789
1790int8_t libjit_element_is_nan_kernel_f(dim_t idx, const float *input) {
1791 return std::isnan(input[idx]) ? 1 : 0;
1792}
1793
1794// Tanh cannot be vectorized by LLVM yet. Therefore we use the following
1795// formula instead: 1 - 2 / (exp(x * 2) + 1), which is also used by Caffe2 and
1796// provides a good accuracy.
1797// Once LLVM supports the vectorization of tanh, we can replace this
1798// approximation by a direct tanh call.
1799// When the LIBJIT compile option "-ffast-math" is enabled the intermediate
1800// computation expf(x) for Tanh operator is not handled properly for very
1801// large positive values which results in NaN values for the Tanh output.
1802// Therefore when the "-ffast-math" is enabled we compute the Tanh such that
1803// we avoid computing large values for the "expf" function.
1804#ifdef FFAST_MATH
1805DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f) {
1806 float inpVal = LHS[idx];
1807 float tanhVal = -1 + 2 / (expf(-2 * std::abs(inpVal)) + 1);
1808 return std::copysignf(tanhVal, inpVal);
1809}
1810#else
1811DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f) {
1812 return 1 - 2 / (expf(LHS[idx] * 2) + 1);
1813}
1814#endif // FFAST_MATH
1815
1816int8_t libjit_intlookuptable_kernel_i8(dim_t idx, const int8_t *src,
1817 const int8_t *mapping) {
1818 return mapping[src[idx] + 128];
1819}
1820
1821int16_t libjit_intlookuptable_kernel_i16(dim_t idx, const int16_t *src,
1822 const int16_t *mapping) {
1823 return mapping[src[idx] + 32768];
1824}
1825
1826float libjit_elementselect_kernel_f(dim_t idx, const int8_t *cond,
1827 const float *LHS, const float *RHS) {
1828 return (cond[idx] != 0) ? LHS[idx] : RHS[idx];
1829}
1830
1831int8_t libjit_elementselect_kernel_i8(dim_t idx, const int8_t *cond,
1832 const int8_t *LHS, const int8_t *RHS,
1833 int32_t destOffset, int32_t lhsOffset,
1834 int32_t rhsOffset, int32_t lhsPre,
1835 int32_t lhsPost, int32_t lhsScale,
1836 int32_t rhsPre, int32_t rhsPost,
1837 int32_t rhsScale) {
1838 return (cond[idx] != 0)
1839 ? libjit_clip_i8(libjit_scale<int32_t>(
1840 LHS[idx] - lhsOffset, lhsPre, lhsPost, lhsScale, destOffset))
1841 : libjit_clip_i8(libjit_scale<int32_t>(RHS[idx] - rhsOffset,
1842 rhsPre, rhsPost, rhsScale,
1843 destOffset));
1844}
1845
1846float libjit_element_relu_f(dim_t idx, const float *src) {
1847 float srcVal = src[idx];
1848 return MAX(srcVal, 0);
1849}
1850
1851int8_t libjit_element_relu_i8(dim_t idx, const int8_t *src, int8_t srcOffset,
1852 int8_t destOffset, int32_t destPre,
1853 int32_t destPost, int32_t destScale) {
1854 int32_t reluVal = MAX(src[idx], srcOffset);
1855 int32_t scaledVal = libjit_scale<int32_t>(reluVal - srcOffset, destPre,
1856 destPost, destScale, destOffset);
1857 return libjit_clip_i8(scaledVal);
1858}
1859
1860float libjit_element_clip_f(dim_t idx, const float *src, float min, float max) {
1861 float srcVal = src[idx];
1862 return MIN(MAX(srcVal, min), max);
1863}
1864
1865int8_t libjit_element_clip_i8(dim_t idx, const int8_t *src, int8_t clipMin,
1866 int8_t clipMax, int8_t srcOffset,
1867 int8_t destOffset, int32_t destPre,
1868 int32_t destPost, int32_t destScale) {
1869 int32_t clipVal = MIN(MAX(src[idx], clipMin), clipMax);
1870 int32_t scaledVal = libjit_scale<int32_t>(clipVal - srcOffset, destPre,
1871 destPost, destScale, destOffset);
1872 return libjit_clip_i8(scaledVal);
1873}
1874
1875float libjit_element_leaky_relu_f(dim_t idx, const float *src, float alpha) {
1876 float srcVal = src[idx];
1877 return (srcVal >= 0) ? srcVal : alpha * srcVal;
1878}
1879
1880int8_t libjit_element_leaky_relu_i8(dim_t idx, const int8_t *src,
1881 int8_t srcOffset, int8_t destOffset,
1882 int32_t posPre, int32_t posPost,
1883 int32_t posScale, int32_t negPre,
1884 int32_t negPost, int32_t negScale) {
1885 int32_t srcVal = src[idx];
1886 int32_t scaledVal =
1887 (srcVal >= srcOffset)
1888 ? libjit_scale<int32_t>(srcVal - srcOffset, posPre, posPost, posScale,
1889 destOffset)
1890 : libjit_scale<int32_t>(srcVal - srcOffset, negPre, negPost, negScale,
1891 destOffset);
1892 return libjit_clip_i8(scaledVal);
1893}
1894
1895// When the LIBJIT compile option "-ffast-math" is enabled the intermediate
1896// computation expf(x) for Sigmoid operator is not handled properly for very
1897// large positive values which results in NaN values for the Sigmoid output.
1898// Therefore when the "-ffast-math" is enabled we compute the Sigmoid such that
1899// we avoid computing large values for the "expf" function.
1900#ifdef FFAST_MATH
1901DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f) {
1902 float inpVal = LHS[idx];
1903 float sigmoidVal = 1 / (1 + expf(-std::abs(inpVal)));
1904 return (float)(std::signbit(inpVal)) + std::copysignf(sigmoidVal, inpVal);
1905}
1906#else
1907DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f) {
1908 float e = expf(-LHS[idx]);
1909 return 1 / (e + 1);
1910}
1911#endif // FFAST_MATH
1912
1913DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_f, float, val)
1914DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_u, int64_t,
1915 val)
1916DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_i8, int8_t,
1917 val)
1918DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_i32, int32_t,
1919 val)
1920DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_b, int8_t, val)
1921
1922#undef DEFINE_DATA_PARALLEL_KERNEL
1923#undef DEFINE_DATA_PARALLEL_KERNEL_FUNC
1924#undef DEFINE_DATA_PARALLEL_KERNEL_FUNC
1925#undef DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND
1926
1927void libjit_batchedadd_f(float *dest, const float *batch, const float *slice,
1928 dim_t numSlice, dim_t sliceSize) {
1929 // For each layer in the batch:
1930 for (dim_t n = 0; n < numSlice; n++) {
1931 dim_t base = n * sliceSize;
1932 // For each element in the slice.
1933 for (dim_t i = 0; i < sliceSize; i++) {
1934 dest[base + i] = batch[base + i] + slice[i];
1935 }
1936 }
1937}
1938
1939void libjit_batchedadd_i8(int8_t *dest, const int8_t *batch,
1940 const int8_t *slice, dim_t numSlice, dim_t sliceSize,
1941 int32_t destOffset, int32_t batchOffset,
1942 int32_t sliceOffset, int32_t batchPre,
1943 int32_t batchPost, int32_t batchScale,
1944 int32_t slicePre, int32_t slicePost,
1945 int32_t sliceScale) {
1946 libjit_batchedadd_quantized(dest, batch, slice, numSlice, sliceSize,
1947 destOffset, batchOffset, sliceOffset, batchPre,
1948 batchPost, batchScale, slicePre, slicePost,
1949 sliceScale);
1950}
1951
1952void libjit_batchedadd_i32_i8(int8_t *dest, const int8_t *batch,
1953 const int32_t *slice, dim_t numSlice,
1954 dim_t sliceSize, int32_t destOffset,
1955 int32_t batchOffset, int32_t sliceOffset,
1956 int32_t batchPre, int32_t batchPost,
1957 int32_t batchScale, int32_t slicePre,
1958 int32_t slicePost, int32_t sliceScale) {
1959 libjit_batchedadd_quantized(dest, batch, slice, numSlice, sliceSize,
1960 destOffset, batchOffset, sliceOffset, batchPre,
1961 batchPost, batchScale, slicePre, slicePost,
1962 sliceScale);
1963}
1964
1965// /// The dimensions passed in here are pre-expanded in LLVMIRGen with 1s so
1966// that
1967// /// we can iterate over the shape here, regardless of the shape of the
1968// tensor.
1969#define DEFINE_BATCHEDREDUCE_KERNEL_FLOAT(name, type, init, op) \
1970 void libjit_##name(type *dest, const type *batch, dim_t destSize, \
1971 const dim_t *destDims, const dim_t *batchDims, \
1972 dim_t axis) { \
1973 for (dim_t i = 0; i < destSize; i++) \
1974 dest[i] = init; \
1975 for (dim_t x = 0; x < batchDims[0]; x++) \
1976 for (dim_t y = 0; y < batchDims[1]; y++) \
1977 for (dim_t z = 0; z < batchDims[2]; z++) \
1978 for (dim_t w = 0; w < batchDims[3]; w++) \
1979 for (dim_t q = 0; q < batchDims[4]; q++) \
1980 for (dim_t r = 0; r < batchDims[5]; r++) { \
1981 dim_t I[] = {x, y, z, w, q, r}; \
1982 I[axis] = 0; \
1983 dest[libjit_getXYZWQR(destDims, I[0], I[1], I[2], I[3], I[4], \
1984 I[5])] = \
1985 dest[libjit_getXYZWQR(destDims, I[0], I[1], I[2], I[3], \
1986 I[4], I[5])] op \
1987 batch[libjit_getXYZWQR(batchDims, x, y, z, w, q, r)]; \
1988 } \
1989 }
1990
1991DEFINE_BATCHEDREDUCE_KERNEL_FLOAT(batchedreduceadd_f, float, 0.0, +)
1992DEFINE_BATCHEDREDUCE_KERNEL_FLOAT(batchedreduceprod_f, float, 1.0, *)
1993#undef DEFINE_BATCHEDREDUCE_KERNEL_FLOAT
1994
1995/// Macro to reducemin/max wrapper kernels.
1996#define DEFINE_REDUCE_MINMAX(func, suffix, type, init) \
1997 void func##_##suffix(type *dest, const type *batch, size_t destSize, \
1998 const dim_t *destDims, const dim_t *batchDims) { \
1999 func(dest, batch, destSize, destDims, batchDims, init); \
2000 }
2001
2002/// Define reducemin wrapper kernels for float, int32_t and int64_t
2003DEFINE_REDUCE_MINMAX(libjit_reducemin, f, float,
2004 std::numeric_limits<float>::infinity());
2005DEFINE_REDUCE_MINMAX(libjit_reducemin, u, int64_t,
2006 std::numeric_limits<int64_t>::max());
2007DEFINE_REDUCE_MINMAX(libjit_reducemin, i32, int32_t,
2008 std::numeric_limits<int32_t>::max());
2009
2010/// Define reducemax wrapper kernels for float, int32_t and int64_t
2011DEFINE_REDUCE_MINMAX(libjit_reducemax, f, float,
2012 (-std::numeric_limits<float>::infinity()));
2013DEFINE_REDUCE_MINMAX(libjit_reducemax, u, int64_t,
2014 std::numeric_limits<int64_t>::min());
2015DEFINE_REDUCE_MINMAX(libjit_reducemax, i32, int32_t,
2016 std::numeric_limits<int32_t>::min());
2017
2018#undef DEF_REDUCE_MINMAX_WRAPPER_F
2019
2020/// Same as the non-quantized version, the dimensions here are pre-expanded in
2021/// LLVMIRGen. However, for quantization, we must accumulate in the inner-most
2022/// loop with higher precision (int32_t) and then clip the result back into the
2023/// dest tensor. Thus we add max_tensor_dimensions different cases for this to
2024/// ensure the axis is used as the inner-most loop.
2025void libjit_batchedreduceadd_i8(int8_t *dest, const int8_t *batch,
2026 const dim_t *destDims, const dim_t *batchDims,
2027 int32_t destOffset, int32_t batchOffset,
2028 int32_t batchPre, int32_t batchPost,
2029 int32_t batchScale, dim_t axis) {
2030 switch (axis) {
2031#define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS) \
2032 case _D5_AXIS: \
2033 for (dim_t i##_D0 = 0; i##_D0 < batchDims[_D0]; i##_D0++) \
2034 for (dim_t i##_D1 = 0; i##_D1 < batchDims[_D1]; i##_D1++) \
2035 for (dim_t i##_D2 = 0; i##_D2 < batchDims[_D2]; i##_D2++) \
2036 for (dim_t i##_D3 = 0; i##_D3 < batchDims[_D3]; i##_D3++) \
2037 for (dim_t i##_D4 = 0; i##_D4 < batchDims[_D4]; i##_D4++) { \
2038 int32_t sum = 0.0; \
2039 for (dim_t i##_D5_AXIS = 0; i##_D5_AXIS < batchDims[_D5_AXIS]; \
2040 i##_D5_AXIS++) { \
2041 sum += batch[libjit_getXYZWQR(batchDims, i0, i1, i2, i3, i4, \
2042 i5)] - \
2043 batchOffset; \
2044 } \
2045 dim_t i##_D5_AXIS = 0; \
2046 int32_t res = libjit_scale<int32_t>(sum, batchPre, batchPost, \
2047 batchScale, destOffset); \
2048 dest[libjit_getXYZWQR(destDims, i0, i1, i2, i3, i4, i5)] = \
2049 libjit_clip_i8(res); \
2050 } \
2051 return;
2052
2053 // Each loop order, with the inner-most dimension/index equal to the axis.
2054 LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0);
2055 LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1);
2056 LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2);
2057 LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3);
2058 LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4);
2059 LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5);
2060#undef LOOP_AXIS_CASE
2061 }
2062}
2063
2064void libjit_cross_entropy_loss_f_u(float *CE, float *P, size_t *labels,
2065 dim_t *dims) {
2066 libjit_cross_entropy_loss_generic(CE, P, labels, dims);
2067}
2068
2069void libjit_cross_entropy_loss_f_i32(float *CE, float *P, int32_t *labels,
2070 dim_t *dims) {
2071 libjit_cross_entropy_loss_generic(CE, P, labels, dims);
2072}
2073
2074//===----------------------------------------------------------------------===//
2075// Gather
2076//===----------------------------------------------------------------------===//
2077void libjit_gather64_f(float *dest, const float *data, const int64_t *indices,
2078 dim_t numIndices, dim_t sliceSize, dim_t numSamples,
2079 dim_t sampleSize) {
2080 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
2081 sampleSize);
2082}
2083
2084void libjit_gather64_i8(int8_t *dest, const int8_t *data,
2085 const int64_t *indices, dim_t numIndices,
2086 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
2087 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
2088 sampleSize);
2089}
2090
2091void libjit_gather64_u(int64_t *dest, const int64_t *data,
2092 const int64_t *indices, dim_t numIndices,
2093 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
2094 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
2095 sampleSize);
2096}
2097
2098void libjit_gather32_f(float *dest, const float *data, const int32_t *indices,
2099 dim_t numIndices, dim_t sliceSize, dim_t numSamples,
2100 dim_t sampleSize) {
2101 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
2102 sampleSize);
2103}
2104
2105void libjit_gather32_i8(int8_t *dest, const int8_t *data,
2106 const int32_t *indices, dim_t numIndices,
2107 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
2108 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
2109 sampleSize);
2110}
2111
2112void libjit_gather32_u(int64_t *dest, const int64_t *data,
2113 const int32_t *indices, dim_t numIndices,
2114 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
2115 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
2116 sampleSize);
2117}
2118
2119void libjit_gather32_i32(int32_t *dest, const int32_t *data,
2120 const int32_t *indices, dim_t numIndices,
2121 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
2122 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
2123 sampleSize);
2124}
2125
2126//===----------------------------------------------------------------------===//
2127// Gather ND
2128//===----------------------------------------------------------------------===//
2129void libjit_gather_nd_f_u(float *dest, const float *data,
2130 const int64_t *indices, dim_t batchCount,
2131 dim_t inpSliceCount, dim_t outSliceCount,
2132 dim_t sliceSize, dim_t indicesDimLast,
2133 dim_t *indicesDimProd) {
2134 libjit_gather_nd(dest, data, indices, batchCount, inpSliceCount,
2135 outSliceCount, sliceSize, indicesDimLast, indicesDimProd);
2136}
2137
2138void libjit_gather_nd_i8_u(int8_t *dest, const int8_t *data,
2139 const int64_t *indices, dim_t batchCount,
2140 dim_t inpSliceCount, dim_t outSliceCount,
2141 dim_t sliceSize, dim_t indicesDimLast,
2142 dim_t *indicesDimProd) {
2143 libjit_gather_nd(dest, data, indices, batchCount, inpSliceCount,
2144 outSliceCount, sliceSize, indicesDimLast, indicesDimProd);
2145}
2146
2147void libjit_gather_nd_i32_u(int32_t *dest, const int32_t *data,
2148 const int64_t *indices, dim_t batchCount,
2149 dim_t inpSliceCount, dim_t outSliceCount,
2150 dim_t sliceSize, dim_t indicesDimLast,
2151 dim_t *indicesDimProd) {
2152 libjit_gather_nd(dest, data, indices, batchCount, inpSliceCount,
2153 outSliceCount, sliceSize, indicesDimLast, indicesDimProd);
2154}
2155
2156void libjit_gather_nd_u_u(int64_t *dest, const int64_t *data,
2157 const int64_t *indices, dim_t batchCount,
2158 dim_t inpSliceCount, dim_t outSliceCount,
2159 dim_t sliceSize, dim_t indicesDimLast,
2160 dim_t *indicesDimProd) {
2161 libjit_gather_nd(dest, data, indices, batchCount, inpSliceCount,
2162 outSliceCount, sliceSize, indicesDimLast, indicesDimProd);
2163}
2164
2165void libjit_gather_nd_f_i32(float *dest, const float *data,
2166 const int32_t *indices, dim_t batchCount,
2167 dim_t inpSliceCount, dim_t outSliceCount,
2168 dim_t sliceSize, dim_t indicesDimLast,
2169 dim_t *indicesDimProd) {
2170 libjit_gather_nd(dest, data, indices, batchCount, inpSliceCount,
2171 outSliceCount, sliceSize, indicesDimLast, indicesDimProd);
2172}
2173
2174void libjit_gather_nd_i8_i32(int8_t *dest, const int8_t *data,
2175 const int32_t *indices, dim_t batchCount,
2176 dim_t inpSliceCount, dim_t outSliceCount,
2177 dim_t sliceSize, dim_t indicesDimLast,
2178 dim_t *indicesDimProd) {
2179 libjit_gather_nd(dest, data, indices, batchCount, inpSliceCount,
2180 outSliceCount, sliceSize, indicesDimLast, indicesDimProd);
2181}
2182
2183void libjit_gather_nd_i32_i32(int32_t *dest, const int32_t *data,
2184 const int32_t *indices, dim_t batchCount,
2185 dim_t inpSliceCount, dim_t outSliceCount,
2186 dim_t sliceSize, dim_t indicesDimLast,
2187 dim_t *indicesDimProd) {
2188 libjit_gather_nd(dest, data, indices, batchCount, inpSliceCount,
2189 outSliceCount, sliceSize, indicesDimLast, indicesDimProd);
2190}
2191
2192void libjit_gather_nd_u_i32(int64_t *dest, const int64_t *data,
2193 const int32_t *indices, dim_t batchCount,
2194 dim_t inpSliceCount, dim_t outSliceCount,
2195 dim_t sliceSize, dim_t indicesDimLast,
2196 dim_t *indicesDimProd) {
2197 libjit_gather_nd(dest, data, indices, batchCount, inpSliceCount,
2198 outSliceCount, sliceSize, indicesDimLast, indicesDimProd);
2199}
2200
2201//===----------------------------------------------------------------------===//
2202// Gather Ranges
2203//===----------------------------------------------------------------------===//
2204void libjit_gatherranges64_f(float *output, int64_t *lengths, const float *data,
2205 const int64_t *ranges, dim_t numExamples,
2206 dim_t exampleSize) {
2207 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
2208}
2209
2210void libjit_gatherranges64_i8(int8_t *output, int64_t *lengths,
2211 const int8_t *data, const int64_t *ranges,
2212 dim_t numExamples, dim_t exampleSize) {
2213 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
2214}
2215
2216void libjit_gatherranges64_u(int64_t *output, int64_t *lengths,
2217 const int64_t *data, const int64_t *ranges,
2218 dim_t numExamples, dim_t exampleSize) {
2219 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
2220}
2221
2222void libjit_gatherranges32_f(float *output, int32_t *lengths, const float *data,
2223 const int32_t *ranges, dim_t numExamples,
2224 dim_t exampleSize) {
2225 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
2226}
2227
2228void libjit_gatherranges32_i8(int8_t *output, int32_t *lengths,
2229 const int8_t *data, const int32_t *ranges,
2230 dim_t numExamples, dim_t exampleSize) {
2231 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
2232}
2233
2234void libjit_gatherranges32_u(uint64_t *output, int32_t *lengths,
2235 const uint64_t *data, const int32_t *ranges,
2236 dim_t numExamples, dim_t exampleSize) {
2237 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
2238}
2239
2240void libjit_gatherranges32_i32(int32_t *output, int32_t *lengths,
2241 const int32_t *data, const int32_t *ranges,
2242 dim_t numExamples, dim_t exampleSize) {
2243 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
2244}
2245
2246void libjit_lengths_range_fill_i32(const int32_t *lengths, int32_t *output,
2247 const dim_t lengthsSize) {
2248 dim_t curIdx = 0;
2249 for (dim_t i = 0, e = lengthsSize; i < e; i++) {
2250 for (int32_t j = 0, f = lengths[i]; j < f; j++) {
2251 output[curIdx++] = j;
2252 }
2253 }
2254}
2255
2256void libjit_scatterdata_f_i32(float *data, const dim_t *dataDims,
2257 const int32_t *indices, const float *slices,
2258 dim_t numIndices, dim_t indexSize,
2259 dim_t sliceSize, bool isCumulative) {
2260 if (isCumulative) {
2261 libjit_scatterdataaddfloat(data, dataDims, indices, slices, numIndices,
2262 indexSize, sliceSize);
2263 } else {
2264 libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
2265 indexSize, sliceSize);
2266 }
2267}
2268
2269void libjit_scatterdata_i8_u(int8_t *data, const dim_t *dataDims,
2270 const int64_t *indices, const int8_t *slices,
2271 dim_t numIndices, dim_t indexSize, dim_t sliceSize,
2272 bool isCumulative, float dataScale,
2273 int32_t dataOffset, float sliceScale,
2274 int32_t sliceOffset) {
2275 if (isCumulative) {
2276 libjit_scatterdataaddquantized(data, dataDims, indices, slices, numIndices,
2277 indexSize, sliceSize, dataScale, dataOffset,
2278 sliceScale, sliceOffset);
2279 } else {
2280 libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
2281 indexSize, sliceSize);
2282 }
2283}
2284
2285void libjit_scatterdata_i8_i32(int8_t *data, const dim_t *dataDims,
2286 const int32_t *indices, const int8_t *slices,
2287 dim_t numIndices, dim_t indexSize,
2288 dim_t sliceSize, bool isCumulative,
2289 float dataScale, int32_t dataOffset,
2290 float sliceScale, int32_t sliceOffset) {
2291 if (isCumulative) {
2292 libjit_scatterdataaddquantized(data, dataDims, indices, slices, numIndices,
2293 indexSize, sliceSize, dataScale, dataOffset,
2294 sliceScale, sliceOffset);
2295 } else {
2296 libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
2297 indexSize, sliceSize);
2298 }
2299}
2300
2301void libjit_lengths_to_ranges_i32(int32_t *ranges, const int32_t *lengths,
2302 dim_t size) {
2303 int32_t offset = 0;
2304 for (dim_t i = 0; i < size; i++) {
2305 auto length = lengths[i];
2306 ranges[i * 2] = offset;
2307 ranges[i * 2 + 1] = length;
2308 offset += length;
2309 }
2310}
2311
2312void libjit_sparse_lengths_sum_f_u(float *dest, float *data, size_t *indices,
2313 int32_t *lengths, dim_t segments,
2314 dim_t lineSize) {
2315 libjit_sparse_lengths_sum_generic(dest, data, indices, lengths, segments,
2316 lineSize);
2317}
2318
2319void libjit_sparse_lengths_sum_f_i32(float *dest, float *data, int32_t *indices,
2320 int32_t *lengths, dim_t segments,
2321 dim_t lineSize) {
2322 libjit_sparse_lengths_sum_generic(dest, data, indices, lengths, segments,
2323 lineSize);
2324}
2325
2326void libjit_sparse_lengths_weighted_sum_f_u(float *dest, float *data,
2327 float *weights, size_t *indices,
2328 int32_t *lengths, dim_t segments,
2329 dim_t lineSize) {
2330 libjit_sparse_lengths_weighted_sum_generic(dest, data, weights, indices,
2331 lengths, segments, lineSize);
2332}
2333
2334void libjit_sparse_lengths_weighted_sum_f_i32(float *dest, float *data,
2335 float *weights, int32_t *indices,
2336 int32_t *lengths, dim_t segments,
2337 dim_t lineSize) {
2338 libjit_sparse_lengths_weighted_sum_generic(dest, data, weights, indices,
2339 lengths, segments, lineSize);
2340}
2341
2342void libjit_embedding_f(float *dest, float *weights, int64_t *indices,
2343 const dim_t *indDims, dim_t indSize, dim_t numEmbedding,
2344 dim_t embeddingDim, int64_t padIdx, bool scale,
2345 bool sparse) {
2346 libjit_embedding_generic(dest, weights, indices, indDims, indSize,
2347 numEmbedding, embeddingDim, padIdx, scale, sparse);
2348}
2349
2350void libjit_embedding_bag_f(float *dest, float *data, float *weights,
2351 int32_t *indices, int32_t *offsets, dim_t segments,
2352 dim_t lineSize, dim_t totalLength,
2353 bool hasEndOffset) {
2354 if (hasEndOffset) {
2355 --segments;
2356 }
2357 memset(dest, 0, segments * lineSize * sizeof(float));
2358 dim_t curIndex = 0;
2359 for (dim_t i = 0; i < segments; i++) {
2360 int32_t start = offsets[i];
2361 int32_t end =
2362 !hasEndOffset && i == segments - 1 ? totalLength : offsets[i + 1];
2363 for (int32_t j = start; j < end; j++) {
2364 float weight = weights[curIndex];
2365 dim_t line = indices[curIndex];
2366 for (dim_t k = 0; k < lineSize; k++) {
2367 dest[i * lineSize + k] += weight * data[line * lineSize + k];
2368 }
2369 curIndex++;
2370 }
2371 }
2372}
2373
2374void libjit_sparse_lengths_weighted_sum_grad_f_u(
2375 const float *destGrad, float *dataGrad, float *weightsGrad,
2376 const float *data, const float *weights, const size_t *indices,
2377 const int32_t *lengths, dim_t segments, dim_t lineSize,
2378 dim_t dataGradRawSize) {
2379 libjit_sparse_lengths_weighted_sum_grad_generic(
2380 destGrad, dataGrad, weightsGrad, data, weights, indices, lengths,
2381 segments, lineSize, dataGradRawSize);
2382}
2383
2384void libjit_sparse_lengths_weighted_sum_grad_f_i32(
2385 const float *destGrad, float *dataGrad, float *weightsGrad,
2386 const float *data, const float *weights, const int32_t *indices,
2387 const int32_t *lengths, dim_t segments, dim_t lineSize,
2388 dim_t dataGradRawSize) {
2389 libjit_sparse_lengths_weighted_sum_grad_generic(
2390 destGrad, dataGrad, weightsGrad, data, weights, indices, lengths,
2391 segments, lineSize, dataGradRawSize);
2392}
2393
2394void libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_u(
2395 float *dest, uint8_t *data, float *scales, float *offsets, float *weights,
2396 size_t *indices, int32_t *lengths, dim_t segments, dim_t lineSize) {
2397 libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2398 dest, data, scales, offsets, weights, indices, lengths, segments,
2399 lineSize);
2400}
2401
2402void libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(
2403 float *dest, uint8_t *data, float *scales, float *offsets, float *weights,
2404 int32_t *indices, int32_t *lengths, dim_t segments, dim_t lineSize) {
2405 libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2406 dest, data, scales, offsets, weights, indices, lengths, segments,
2407 lineSize);
2408}
2409
2410void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_u(
2411 float *dest, int8_t *data, float *weights, size_t *indices,
2412 int32_t *lengths, dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2413 libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2414 dest, data, weights, indices, lengths, segments, inLineSize, outLineSize);
2415}
2416
2417void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(
2418 float *dest, int8_t *data, float *weights, int32_t *indices,
2419 int32_t *lengths, dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2420 libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2421 dest, data, weights, indices, lengths, segments, inLineSize, outLineSize);
2422}
2423
2424void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f(
2425 float *dest, int8_t *data, float *weights, dim_t *indices, int32_t *lengths,
2426 dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2427 memset(dest, 0, segments * outLineSize * sizeof(float));
2428 dim_t curIndex = 0;
2429 for (dim_t i = 0; i < segments; i++) {
2430 for (int32_t j = 0, e = lengths[i]; j < e; j++) {
2431 const float weight = weights[curIndex];
2432 const dim_t line = indices[curIndex];
2433 const int8_t *currRowScaleOffsetPtr =
2434 data + ((line + 1) * inLineSize) - 2 * sizeof(float);
2435 float scale, offset;
2436 memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
2437 memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
2438 for (dim_t k = 0; k < outLineSize; k++) {
2439 const float fData =
2440 (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
2441 dest[i * outLineSize + k] += weight * fData;
2442 }
2443 curIndex++;
2444 }
2445 }
2446}
2447
2448void libjit_embedding_bag_byte_rowwise_offsets_f(
2449 float *dest, int8_t *data, float *weights, int32_t *indices,
2450 int32_t *offsets, dim_t segments, dim_t numIndices, dim_t inLineSize,
2451 dim_t outLineSize, bool hasEndOffset) {
2452 if (hasEndOffset) {
2453 --segments;
2454 }
2455 memset(dest, 0, segments * outLineSize * sizeof(float));
2456 for (dim_t i = 0; i < segments; i++) {
2457 dim_t start = offsets[i];
2458 dim_t end =
2459 !hasEndOffset && i == segments - 1 ? numIndices : offsets[i + 1];
2460 for (dim_t j = start; j < end; j++) {
2461 const float weight = weights[j];
2462 const dim_t line = indices[j];
2463 const int8_t *currRowScaleOffsetPtr =
2464 data + ((line + 1) * inLineSize) - 2 * sizeof(float);
2465 float scale, offset;
2466 memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
2467 memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
2468 for (dim_t k = 0; k < outLineSize; k++) {
2469 const float fData =
2470 (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
2471 dest[i * outLineSize + k] += weight * fData;
2472 }
2473 }
2474 }
2475}
2476
2477void libjit_sparse_to_dense_f_u(float *dest, const size_t *indices,
2478 const float *values, dim_t numIndices,
2479 dim_t destSize, dim_t valueSize) {
2480 libjit_sparse_to_dense_generic(dest, indices, values, numIndices, destSize,
2481 valueSize);
2482}
2483
2484void libjit_sparse_to_dense_f_i32(float *dest, const int32_t *indices,
2485 const float *values, dim_t numIndices,
2486 dim_t destSize, dim_t valueSize) {
2487 libjit_sparse_to_dense_generic(dest, indices, values, numIndices, destSize,
2488 valueSize);
2489}
2490
2491void libjit_lengths_sum_f(float *dest, const float *data,
2492 const int32_t *lengths, dim_t destSize,
2493 dim_t lengthsSize, dim_t sliceSize) {
2494 memset(dest, 0, destSize * sizeof(float));
2495
2496 dim_t offsetOut = 0;
2497 dim_t offsetIn = 0;
2498
2499 for (dim_t i = 0; i < lengthsSize; ++i) {
2500 for (int32_t j = 0; j < lengths[i]; ++j) {
2501 for (dim_t k = 0; k < sliceSize; ++k) {
2502 dest[offsetOut + k] += data[offsetIn + k];
2503 }
2504 offsetIn += sliceSize;
2505 }
2506 offsetOut += sliceSize;
2507 }
2508}
2509
2510void libjit_local_response_normalization_f(
2511 float *outW, const float *inW, float *scaleCache, const dim_t *outWdims,
2512 const dim_t *inWdims, dim_t halfWindow, float alpha, float beta, float k) {
2513 dim_t window = 2 * halfWindow + 1;
2514 float normedAlpha = alpha / window;
2515
2516 for (dim_t n = 0; n < inWdims[0]; n++) {
2517 for (dim_t h = 0; h < inWdims[1]; h++) {
2518 for (dim_t w = 0; w < inWdims[2]; w++) {
2519 for (dim_t c = 0; c < inWdims[3]; c++) {
2520 float m2 = 0.0;
2521 for (dim_t i = (c >= halfWindow ? c - halfWindow : 0);
2522 i <= MIN(c + halfWindow, inWdims[3] - 1); i++) {
2523 float val = inW[libjit_getXYZW(inWdims, n, h, w, i)];
2524 m2 += val * val;
2525 }
2526
2527 float scale = k + normedAlpha * m2;
2528 scaleCache[libjit_getXYZW(inWdims, n, h, w, c)] = scale;
2529 float normFactor = pow(scale, -beta);
2530 outW[libjit_getXYZW(outWdims, n, h, w, c)] =
2531 inW[libjit_getXYZW(inWdims, n, h, w, c)] * normFactor;
2532 } // C
2533 } // W
2534 } // H
2535 } // N
2536}
2537
2538void libjit_local_response_normalization_grad_f(
2539 float *inG, const float *outG, const float *inW, const float *outW,
2540 const float *scaleCache, const dim_t *outWdims, dim_t halfWindow,
2541 float alpha, float beta) {
2542 dim_t window = 2 * halfWindow + 1;
2543 float normedAlpha = alpha / window;
2544 float coeff = 2 * normedAlpha * beta;
2545
2546 for (dim_t n = 0; n < outWdims[0]; n++) {
2547 for (dim_t h = 0; h < outWdims[1]; h++) {
2548 for (dim_t w = 0; w < outWdims[2]; w++) {
2549 // Prepare right half of sliding window based at c = 0
2550 float sum = 0.0;
2551 for (dim_t i = 0; i < MIN(halfWindow, outWdims[3]); i++) {
2552 float outg = outG[libjit_getXYZW(outWdims, n, h, w, i)];
2553 float outw = outW[libjit_getXYZW(outWdims, n, h, w, i)];
2554 float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, i)];
2555 sum += outg * (outw / scale);
2556 }
2557
2558 for (dim_t c = 0; c < outWdims[3]; c++) {
2559 if (c > halfWindow) {
2560 dim_t j = c - halfWindow - 1;
2561 float outg = outG[libjit_getXYZW(outWdims, n, h, w, j)];
2562 float outw = outW[libjit_getXYZW(outWdims, n, h, w, j)];
2563 float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, j)];
2564 sum -= outg * (outw / scale);
2565 }
2566
2567 dim_t j = c + halfWindow;
2568 if (j < outWdims[3]) {
2569 float outg = outG[libjit_getXYZW(outWdims, n, h, w, j)];
2570 float outw = outW[libjit_getXYZW(outWdims, n, h, w, j)];
2571 float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, j)];
2572 sum += outg * (outw / scale);
2573 }
2574
2575 float outg = outG[libjit_getXYZW(outWdims, n, h, w, c)];
2576 float inw = inW[libjit_getXYZW(outWdims, n, h, w, c)];
2577 float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, c)];
2578 inG[libjit_getXYZW(outWdims, n, h, w, c)] =
2579 outg * pow(scale, -beta) - coeff * inw * sum;
2580 }
2581 } // W
2582 } // H
2583 } // N
2584}
2585
2586void libjit_max_pool_i8(const int8_t *inW, int8_t *outW, const dim_t *inWdims,
2587 const dim_t *outWdims, dim_t *kernelSizes,
2588 dim_t *strides, dim_t *pads, int32_t outOffset) {
2589 libjit_max_pool_generic(inW, outW, inWdims, outWdims, kernelSizes, strides,
2590 pads, static_cast<int8_t>(outOffset));
2591}
2592
2593void libjit_max_pool_f(const float *inW, float *outW, const dim_t *inWdims,
2594 const dim_t *outWdims, dim_t *kernelSizes,
2595 dim_t *strides, dim_t *pads) {
2596 libjit_max_pool_generic(inW, outW, inWdims, outWdims, kernelSizes, strides,
2597 pads, static_cast<float>(0));
2598}
2599
2600void libjit_max_pool_argmax_i8_u(const int8_t *inW, int8_t *outW,
2601 int64_t *argmax, const dim_t *inWdims,
2602 const dim_t *outWdims, dim_t *kernels,
2603 dim_t *strides, dim_t *pads) {
2604 libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2605 strides, pads);
2606}
2607
2608void libjit_max_pool_argmax_f_u(const float *inW, float *outW, int64_t *argmax,
2609 const dim_t *inWdims, const dim_t *outWdims,
2610 dim_t *kernels, dim_t *strides, dim_t *pads) {
2611 libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2612 strides, pads);
2613}
2614
2615void libjit_max_pool_argmax_i8_i32(const int8_t *inW, int8_t *outW,
2616 int32_t *argmax, const dim_t *inWdims,
2617 const dim_t *outWdims, dim_t *kernels,
2618 dim_t *strides, dim_t *pads) {
2619 libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2620 strides, pads);
2621}
2622
2623void libjit_max_pool_argmax_f_i32(const float *inW, float *outW,
2624 int32_t *argmax, const dim_t *inWdims,
2625 const dim_t *outWdims, dim_t *kernels,
2626 dim_t *strides, dim_t *pads) {
2627 libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2628 strides, pads);
2629}
2630
2631void libjit_arg_max_i8_u(const int8_t *inW, int64_t *outW, const dim_t *inWdims,
2632 size_t inWNumDims, size_t axis) {
2633 libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2634}
2635
2636void libjit_arg_max_i8_i32(const int8_t *inW, int32_t *outW,
2637 const dim_t *inWdims, size_t inWNumDims,
2638 size_t axis) {
2639 libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2640}
2641
2642void libjit_arg_max_f_u(const float *inW, int64_t *outW, const dim_t *inWdims,
2643 size_t inWNumDims, size_t axis) {
2644 libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2645}
2646
2647void libjit_arg_max_f_i32(const float *inW, int32_t *outW, const dim_t *inWdims,
2648 size_t inWNumDims, size_t axis) {
2649 libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2650}
2651
2652void libjit_arg_min_i8_u(const int8_t *inW, int64_t *outW, const dim_t *inWdims,
2653 size_t inWNumDims, size_t axis) {
2654 libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2655}
2656
2657void libjit_arg_min_i8_i32(const int8_t *inW, int32_t *outW,
2658 const dim_t *inWdims, size_t inWNumDims,
2659 size_t axis) {
2660 libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2661}
2662
2663void libjit_arg_min_f_u(const float *inW, int64_t *outW, const dim_t *inWdims,
2664 size_t inWNumDims, size_t axis) {
2665 libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2666}
2667
2668void libjit_arg_min_f_i32(const float *inW, int32_t *outW, const dim_t *inWdims,
2669 size_t inWNumDims, size_t axis) {
2670 libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2671}
2672
2673void libjit_max_pool_argmax_grad_f_u(float *inG, const float *outG,
2674 const int64_t *argmax,
2675 const dim_t *inGdims,
2676 const dim_t *outWdims) {
2677 libjit_max_pool_argmax_grad_generic(inG, outG, argmax, inGdims, outWdims);
2678}
2679
2680void libjit_max_pool_argmax_grad_f_i32(float *inG, const float *outG,
2681 const int32_t *argmax,
2682 const dim_t *inGdims,
2683 const dim_t *outWdims) {
2684 libjit_max_pool_argmax_grad_generic(inG, outG, argmax, inGdims, outWdims);
2685}
2686
2687void libjit_resizenearest_f(float *dst, const float *src, const float *scale,
2688 const dim_t *inWdims, const dim_t *outWdims) {
2689 libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2690}
2691
2692void libjit_resizenearest_i8(int8_t *dst, const int8_t *src, const float *scale,
2693 const dim_t *inWdims, const dim_t *outWdims) {
2694 libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2695}
2696
2697void libjit_resizenearest_i32(int32_t *dst, const int32_t *src,
2698 const float *scale, const dim_t *inWdims,
2699 const dim_t *outWdims) {
2700 libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2701}
2702
2703void libjit_resizenearest_u(int64_t *dst, const int64_t *src,
2704 const float *scale, const dim_t *inWdims,
2705 const dim_t *outWdims) {
2706 libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2707}
2708
2709void libjit_resizebilinear_f(float *dst, const float *src, const float *scale,
2710 const dim_t *inWdims, const dim_t *outWdims) {
2711 libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2712}
2713
2714void libjit_resizebilinear_i8(int8_t *dst, const int8_t *src,
2715 const float *scale, const dim_t *inWdims,
2716 const dim_t *outWdims) {
2717 libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2718}
2719
2720void libjit_resizebilinear_i32(int32_t *dst, const int32_t *src,
2721 const float *scale, const dim_t *inWdims,
2722 const dim_t *outWdims) {
2723 libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2724}
2725
2726void libjit_resizebilinear_u(int64_t *dst, const int64_t *src,
2727 const float *scale, const dim_t *inWdims,
2728 const dim_t *outWdims) {
2729 libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2730}
2731
2732void libjit_avg_pool_f(const float *inW, float *outW, const dim_t *inWdims,
2733 const dim_t *outWdims, dim_t *kernelSizes,
2734 dim_t *strides, dim_t *pads, bool countIncludePads) {
2735
2736 size_t kernelH = kernelSizes[0];
2737 size_t kernelW = kernelSizes[1];
2738
2739 size_t strideH = strides[0];
2740 size_t strideW = strides[1];
2741
2742 size_t padT = pads[0];
2743 size_t padL = pads[1];
2744
2745 // For each input in the batch.
2746 for (size_t n = 0; n < inWdims[0]; n++) {
2747
2748 // For each output height.
2749 ssize_t i_h_min = -(ssize_t)padT;
2750 for (size_t o_h = 0; o_h < outWdims[1]; o_h++, i_h_min += strideH) {
2751
2752 // Effective kernel height limits.
2753 ssize_t f_h_min = libjit_conv_flt_min(i_h_min);
2754 ssize_t f_h_max = libjit_conv_flt_max(inWdims[1], kernelH, i_h_min);
2755 ssize_t f_h_len = libjit_conv_flt_len(f_h_min, f_h_max);
2756 const float *inpPtrH =
2757 inW + (i_h_min + f_h_min) * inWdims[2] * inWdims[3];
2758
2759 // For each output width.
2760 ssize_t i_w_min = -(ssize_t)padL;
2761 for (size_t o_w = 0; o_w < outWdims[2]; o_w++, i_w_min += strideW) {
2762
2763 // Effective kernel width limits.
2764 ssize_t f_w_min = libjit_conv_flt_min(i_w_min);
2765 ssize_t f_w_max = libjit_conv_flt_max(inWdims[2], kernelW, i_w_min);
2766 ssize_t f_w_len = libjit_conv_flt_len(f_w_min, f_w_max);
2767 const float *inpPtr = inpPtrH + (i_w_min + f_w_min) * inWdims[3];
2768
2769 // For each output channel.
2770 for (size_t o_c = 0; o_c < outWdims[3]; o_c++) {
2771
2772 // Initialize sum.
2773 float sum = 0;
2774
2775 // For each kernel height.
2776 for (size_t f_h = 0; f_h < f_h_len; f_h++) {
2777
2778 // For each kernel width.
2779 for (size_t f_w = 0; f_w < f_w_len; f_w++) {
2780
2781 // Accumulate along the kernel width.
2782 sum += (*inpPtr);
2783 inpPtr += inWdims[3];
2784 }
2785
2786 // Advance input pointer for next kernel height.
2787 inpPtr = inpPtr - f_w_len * inWdims[3] + inWdims[2] * inWdims[3];
2788 }
2789
2790 // Normalize and store.
2791 float area =
2792 countIncludePads ? (kernelH * kernelW) : (f_h_len * f_w_len);
2793 *outW++ = (area == 0) ? 0 : sum / area;
2794
2795 // Advance input pointer for next output channel.
2796 inpPtr = inpPtr - f_h_len * inWdims[2] * inWdims[3] + 1;
2797 }
2798 }
2799 }
2800
2801 // Advance input pointer for next batch.
2802 inW += inWdims[1] * inWdims[2] * inWdims[3];
2803 }
2804}
2805
2806void libjit_avg_pool_i8(const int8_t *inW, int8_t *outW, const dim_t *inWdims,
2807 const dim_t *outWdims, dim_t *kernelSizes,
2808 dim_t *strides, dim_t *pads, bool countIncludePads,
2809 int32_t outOffset, int32_t inOffset, int32_t outPre,
2810 int32_t outPost, int32_t outScale) {
2811
2812 size_t kernelH = kernelSizes[0];
2813 size_t kernelW = kernelSizes[1];
2814
2815 size_t strideH = strides[0];
2816 size_t strideW = strides[1];
2817
2818 size_t padT = pads[0];
2819 size_t padL = pads[1];
2820
2821 // For each input in the batch.
2822 for (size_t n = 0; n < inWdims[0]; n++) {
2823
2824 // For each output height.
2825 ssize_t i_h_min = -(ssize_t)padT;
2826 for (size_t o_h = 0; o_h < outWdims[1]; o_h++, i_h_min += strideH) {
2827
2828 // Effective kernel height limits.
2829 ssize_t f_h_min = libjit_conv_flt_min(i_h_min);
2830 ssize_t f_h_max = libjit_conv_flt_max(inWdims[1], kernelH, i_h_min);
2831 ssize_t f_h_len = libjit_conv_flt_len(f_h_min, f_h_max);
2832 const int8_t *inpPtrH =
2833 inW + (i_h_min + f_h_min) * inWdims[2] * inWdims[3];
2834
2835 // For each output width.
2836 ssize_t i_w_min = -(ssize_t)padL;
2837 for (size_t o_w = 0; o_w < outWdims[2]; o_w++, i_w_min += strideW) {
2838
2839 // Effective kernel width limits.
2840 ssize_t f_w_min = libjit_conv_flt_min(i_w_min);
2841 ssize_t f_w_max = libjit_conv_flt_max(inWdims[2], kernelW, i_w_min);
2842 ssize_t f_w_len = libjit_conv_flt_len(f_w_min, f_w_max);
2843 const int8_t *inpPtr = inpPtrH + (i_w_min + f_w_min) * inWdims[3];
2844
2845 // For each output channel.
2846 for (size_t o_c = 0; o_c < outWdims[3]; o_c++) {
2847
2848 // Initialize sum.
2849 int32_t sum = 0;
2850
2851 // For each kernel height.
2852 for (size_t f_h = 0; f_h < f_h_len; f_h++) {
2853
2854 // For each kernel width.
2855 for (size_t f_w = 0; f_w < f_w_len; f_w++) {
2856
2857 // Accumulate along the kernel width.
2858 sum += (*inpPtr) - inOffset;
2859 inpPtr += inWdims[3];
2860 }
2861
2862 // Advance input pointer for next kernel height.
2863 inpPtr = inpPtr - f_w_len * inWdims[3] + inWdims[2] * inWdims[3];
2864 }
2865
2866 // Normalize and store.
2867 if (countIncludePads) {
2868 sum = libjit_scale<int32_t>(sum, outPre, outPost, outScale,
2869 outOffset);
2870 *outW++ = libjit_clip_i8(sum);
2871 } else {
2872 int32_t area = f_h_len * f_w_len;
2873 if (area == 0) {
2874 *outW++ = outOffset;
2875 } else {
2876 sum = libjit_scale<int32_t>(sum, outPre, outPost, outScale, 0);
2877 sum = libjit_div_round_i32(sum, area) + outOffset;
2878 *outW++ = libjit_clip_i8(sum);
2879 }
2880 }
2881
2882 // Advance input pointer for next output channel.
2883 inpPtr = inpPtr - f_h_len * inWdims[2] * inWdims[3] + 1;
2884 }
2885 }
2886 }
2887
2888 // Advance input pointer for next batch.
2889 inW += inWdims[1] * inWdims[2] * inWdims[3];
2890 }
2891}
2892
2893void libjit_adaptive_avg_pool_f(const float *inW, float *outW,
2894 const dim_t *inWdims, const dim_t *outWdims) {
2895// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
2896#define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
2897#define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
2898
2899 // For each input in the batch:
2900 for (dim_t n = 0; n < outWdims[0]; n++) {
2901 // For each layer in the output tensor:
2902 for (dim_t z = 0; z < inWdims[3]; z++) {
2903 // For each value in the output tensor:
2904 for (dim_t ax = 0; ax < outWdims[1]; ax++) {
2905
2906 dim_t x = START_IND(ax, outWdims[1], inWdims[1]);
2907 dim_t kH = END_IND(ax, outWdims[1], inWdims[1]) - x;
2908
2909 for (dim_t ay = 0; ay < outWdims[2]; ay++) {
2910
2911 dim_t y = START_IND(ay, outWdims[2], inWdims[2]);
2912 dim_t kW = END_IND(ay, outWdims[2], inWdims[2]) - y;
2913
2914 float sum = 0;
2915 for (dim_t fx = 0; fx < kH; fx++) {
2916 for (dim_t fy = 0; fy < kW; fy++) {
2917 dim_t ox = x + fx;
2918 dim_t oy = y + fy;
2919
2920 sum += inW[libjit_getXYZW(inWdims, n, ox, oy, z)];
2921 }
2922 }
2923 outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = (sum / kW / kH);
2924 } // W
2925 } // H
2926 } // C
2927 } // N
2928#undef START_IND
2929#undef END_IND
2930}
2931
2932void libjit_avg_pool_grad_f(float *inG, const float *outG, const dim_t *inGdims,
2933 const dim_t *outWdims, dim_t *kernels,
2934 dim_t *strides, dim_t *pads,
2935 bool countIncludePads) {
2936 dim_t pad_t = pads[0];
2937 dim_t pad_l = pads[1];
2938 dim_t stride_h = strides[0];
2939 dim_t stride_w = strides[1];
2940 dim_t kernel_h = kernels[0];
2941 dim_t kernel_w = kernels[1];
2942 float rawKernelArea = kernel_h * kernel_w;
2943
2944 // NHWC format is assumed
2945 for (dim_t n = 0; n < outWdims[0]; n++) {
2946 for (dim_t z = 0; z < outWdims[3]; z++) {
2947 // Clear inG
2948 for (dim_t x = 0; x < inGdims[1]; x++) {
2949 for (dim_t y = 0; y < inGdims[2]; y++) {
2950 inG[libjit_getXYZW(inGdims, n, x, y, z)] = 0.0;
2951 }
2952 }
2953
2954 sdim_t x = -(sdim_t)pad_t;
2955 for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
2956 sdim_t y = -(sdim_t)pad_l;
2957 for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
2958 float kernelArea = rawKernelArea;
2959
2960 if (!countIncludePads) {
2961 sdim_t pad_x = (-x > 0 ? -x : 0) +
2962 ((x + sdim_t(kernel_h) - sdim_t(inGdims[1])) > 0
2963 ? (x + sdim_t(kernel_h) - sdim_t(inGdims[1]))
2964 : 0);
2965 sdim_t pad_y = (-y > 0 ? -y : 0) +
2966 ((y + sdim_t(kernel_w) - sdim_t(inGdims[2])) > 0
2967 ? (y + sdim_t(kernel_w) - sdim_t(inGdims[2]))
2968 : 0);
2969 kernelArea = rawKernelArea - pad_x * kernel_w - pad_y * kernel_h +
2970 pad_x * pad_y;
2971 }
2972
2973 assert(kernelArea != 0 && "KernelArea shouldn't be 0");
2974 float df = outG[libjit_getXYZW(outWdims, n, ax, ay, z)] / kernelArea;
2975 for (dim_t kx = 0; kx < kernel_h; kx++) {
2976 for (dim_t ky = 0; ky < kernel_w; ky++) {
2977 sdim_t ox = x + kx;
2978 sdim_t oy = y + ky;
2979 if (ox < 0 || oy < 0 || ox >= (sdim_t)inGdims[1] ||
2980 oy >= (sdim_t)inGdims[2]) {
2981 continue;
2982 }
2983 inG[libjit_getXYZW(inGdims, n, (dim_t)ox, (dim_t)oy, z)] += df;
2984 }
2985 }
2986 } // W
2987 } // H
2988 } // C
2989 } // N
2990}
2991
2992int8_t libjit_element_quantize_kernel_i8(dim_t idx, const float *inW,
2993 float scale, int32_t offset) {
2994 int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset);
2995 return libjit_clip_i8(result);
2996}
2997
2998int16_t libjit_element_quantize_kernel_i16(dim_t idx, const float *inW,
2999 float scale, int32_t offset) {
3000 int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset);
3001 return libjit_clip_i16(result);
3002}
3003
3004int32_t libjit_element_quantize_kernel_i32(dim_t idx, const float *inW,
3005 float scale, int32_t offset) {
3006 int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset);
3007 return result;
3008}
3009
3010float libjit_element_dequantize_kernel_i8(dim_t idx, const int8_t *inW,
3011 float scale, int32_t offset) {
3012 return scale * (inW[idx] - offset);
3013}
3014
3015float libjit_element_dequantize_kernel_i16(dim_t idx, const int16_t *inW,
3016 float scale, int32_t offset) {
3017 return scale * (inW[idx] - offset);
3018}
3019
3020float libjit_element_dequantize_kernel_i32(dim_t idx, const int32_t *inW,
3021 float scale, int32_t offset) {
3022 return scale * (inW[idx] - offset);
3023}
3024
3025int8_t libjit_element_rescale_kernel_i8(dim_t idx, const int8_t *inW,
3026 int32_t outOffset, int32_t inOffset,
3027 int32_t pre, int32_t post,
3028 int32_t scale) {
3029 int32_t s =
3030 libjit_scale<int32_t>(inW[idx] - inOffset, pre, post, scale, outOffset);
3031 return libjit_clip_i8(s);
3032}
3033
3034int16_t libjit_element_rescale_kernel_i16(dim_t idx, const int16_t *inW,
3035 int32_t outOffset, int32_t inOffset,
3036 int32_t pre, int32_t post,
3037 int32_t scale) {
3038 int32_t s =
3039 libjit_scale<int64_t>(inW[idx] - inOffset, pre, post, scale, outOffset);
3040 return libjit_clip_i16(s);
3041}
3042
3043int32_t libjit_element_rescale_kernel_i32(dim_t idx, const int32_t *inW,
3044 int32_t outOffset, int32_t inOffset,
3045 int32_t pre, int32_t post,
3046 int32_t scale) {
3047 int32_t s =
3048 libjit_scale<int64_t>(inW[idx] - inOffset, pre, post, scale, outOffset);
3049 return s;
3050}
3051
3052void libjit_softmax_f(const float *inW, float *outW, const dim_t *idim,
3053 const dim_t *odim) {
3054 for (dim_t n = 0; n < idim[0]; n++) {
3055 float max = inW[libjit_getXY(idim, n, 0)];
3056
3057 // Find Max.
3058 for (dim_t i = 1; i < idim[1]; i++) {
3059 max = MAX(max, inW[libjit_getXY(idim, n, i)]);
3060 }
3061
3062 float sum = 0;
3063
3064 // Compute exp.
3065 for (dim_t i = 0; i < idim[1]; i++) {
3066 float e = expf(inW[libjit_getXY(idim, n, i)] - max);
3067 sum += e;
3068 outW[libjit_getXY(odim, n, i)] = e;
3069 }
3070
3071 // Normalize the output.
3072 for (dim_t i = 0; i < idim[1]; i++) {
3073 outW[libjit_getXY(odim, n, i)] = outW[libjit_getXY(odim, n, i)] / sum;
3074 }
3075 } // N
3076}
3077
3078void libjit_softmax_i8(const int8_t *inW, int8_t *outW, const dim_t *dims,
3079 const uint32_t *expData, int32_t outputOffset,
3080 uint32_t invScale, uint32_t integerPart,
3081 uint32_t invScalePoint) {
3082 for (int j = 0; j < dims[0]; j++) {
3083 uint32_t sum = 0;
3084 int8_t max = std::numeric_limits<int8_t>::min();
3085 uint32_t division;
3086 int point, size;
3087
3088 // Find max value.
3089 for (uint32_t i = 0; i < dims[1]; i++) {
3090 max = MAX(max, *inW);
3091 inW++;
3092 }
3093 inW -= dims[1];
3094
3095 // Compute the sum of exponentials.
3096 for (int i = 0; i < dims[1]; i++) {
3097 sum += (expData[*inW++ + 255 - max] >> (integerPart - 1));
3098 }
3099 inW -= dims[1];
3100
3101 // Compute 1 / outputScale * 1 / sum, where sum is computed above
3102 // align point for both operands.
3103 if ((32 - integerPart) >= (32 - invScalePoint)) {
3104 division = ((uint64_t)invScale * (1 << (32 - invScalePoint))) /
3105 (sum >> (invScalePoint - integerPart));
3106 size = (32 - invScalePoint);
3107 } else {
3108 division = ((uint64_t)(invScale >> (integerPart - invScalePoint))) *
3109 (1 << (32 - integerPart)) / sum;
3110 size = (32 - integerPart);
3111 }
3112
3113 point = size + 31;
3114 // Multiply with exp and bring the result into the right range.
3115 for (int i = 0; i < dims[1]; i++) {
3116 uint32_t index = *inW++ + 255 - max;
3117 uint64_t mul = (uint64_t)division * (uint64_t)expData[index];
3118 int32_t res = (int32_t)(mul >> point) + outputOffset;
3119 *outW++ = MAX(MIN(res, 127), -128);
3120 }
3121 }
3122}
3123
3124void libjit_softmax_grad_f_u(float *inG, float *outW, const size_t *selectedW,
3125 const dim_t *idim, const dim_t *selectdim) {
3126 libjit_softmax_grad_generic(inG, outW, selectedW, idim, selectdim);
3127}
3128
3129void libjit_softmax_grad_f_i32(float *inG, float *outW,
3130 const int32_t *selectedW, const dim_t *idim,
3131 const dim_t *selectdim) {
3132 libjit_softmax_grad_generic(inG, outW, selectedW, idim, selectdim);
3133}
3134
3135void libjit_topk_f_u(float *values, size_t *indices, const float *input,
3136 void *scratch, dim_t k, dim_t n, dim_t size) {
3137 libjit_topk(values, indices, input, scratch, k, n, size);
3138}
3139
3140void libjit_topk_f_i32(float *values, int32_t *indices, const float *input,
3141 void *scratch, dim_t k, dim_t n, dim_t size) {
3142 libjit_topk(values, indices, input, scratch, k, n, size);
3143}
3144
3145void libjit_topk_i8_u(int8_t *values, size_t *indices, const int8_t *input,
3146 void *scratch, dim_t k, dim_t n, dim_t size) {
3147 libjit_topk(values, indices, input, scratch, k, n, size);
3148}
3149
3150void libjit_topk_i8_i32(int8_t *values, int32_t *indices, const int8_t *input,
3151 void *scratch, dim_t k, dim_t n, dim_t size) {
3152 libjit_topk(values, indices, input, scratch, k, n, size);
3153}
3154
3155void libjit_transpose_i8(const int8_t *inW, int8_t *outW, const dim_t *idim,
3156 const dim_t *odim, const dim_t *shuffle,
3157 dim_t numDims) {
3158 libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
3159}
3160
3161void libjit_transpose_f(const float *inW, float *outW, const dim_t *idim,
3162 const dim_t *odim, const dim_t *shuffle,
3163 dim_t numDims) {
3164 libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
3165}
3166
3167void libjit_transpose_u(const int64_t *inW, int64_t *outW, const dim_t *idim,
3168 const dim_t *odim, const dim_t *shuffle,
3169 dim_t numDims) {
3170 libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
3171}
3172
3173void libjit_transpose_b(const bool *inW, bool *outW, const dim_t *idim,
3174 const dim_t *odim, const dim_t *shuffle,
3175 dim_t numDims) {
3176 libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
3177}
3178
3179void libjit_flip_i8(const int8_t *inW, int8_t *outW, const dim_t *dims,
3180 dim_t axis, dim_t numDims) {
3181 libjit_flip_generic(inW, outW, dims, axis, numDims);
3182}
3183
3184void libjit_flip_i16(const int16_t *inW, int16_t *outW, const dim_t *dims,
3185 dim_t axis, dim_t numDims) {
3186 libjit_flip_generic(inW, outW, dims, axis, numDims);
3187}
3188
3189void libjit_flip_i32(const int32_t *inW, int32_t *outW, const dim_t *dims,
3190 dim_t axis, dim_t numDims) {
3191 libjit_flip_generic(inW, outW, dims, axis, numDims);
3192}
3193
3194void libjit_flip_u(const int64_t *inW, int64_t *outW, const dim_t *dims,
3195 dim_t axis, dim_t numDims) {
3196 libjit_flip_generic(inW, outW, dims, axis, numDims);
3197}
3198
3199void libjit_flip_f(const float *inW, float *outW, const dim_t *dims, dim_t axis,
3200 dim_t numDims) {
3201 libjit_flip_generic(inW, outW, dims, axis, numDims);
3202}
3203
3204void libjit_flip_b(const bool *inW, bool *outW, const dim_t *dims, dim_t axis,
3205 dim_t numDims) {
3206 libjit_flip_generic(inW, outW, dims, axis, numDims);
3207}
3208
3209void libjit_insert_tensor_f(float *tensor, float *slice, dim_t *offset,
3210 dim_t *tensorDim, dim_t *sliceDim,
3211 dim_t numDimsTensor, dim_t numDimsSlice,
3212 dim_t offsetDim, dim_t count, dim_t axis) {
3213 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
3214 numDimsTensor, numDimsSlice, offsetDim, count, axis);
3215}
3216
3217void libjit_insert_tensor_i32(int32_t *tensor, int32_t *slice, dim_t *offset,
3218 dim_t *tensorDim, dim_t *sliceDim,
3219 dim_t numDimsTensor, dim_t numDimsSlice,
3220 dim_t offsetDim, dim_t count, dim_t axis) {
3221 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
3222 numDimsTensor, numDimsSlice, offsetDim, count, axis);
3223}
3224
3225void libjit_extract_tensor_f(float *tensor, float *slice, dim_t *offset,
3226 dim_t *tensorDim, dim_t *sliceDim,
3227 dim_t numDimsTensor, dim_t numDimsSlice,
3228 dim_t offsetDim) {
3229 libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
3230 numDimsTensor, numDimsSlice, offsetDim);
3231}
3232
3233void libjit_extract_tensor_i8(int8_t *tensor, int8_t *slice, dim_t *offset,
3234 dim_t *tensorDim, dim_t *sliceDim,
3235 dim_t numDimsTensor, dim_t numDimsSlice,
3236 dim_t offsetDim) {
3237 libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
3238 numDimsTensor, numDimsSlice, offsetDim);
3239}
3240
3241void libjit_extract_tensor_i32(int32_t *tensor, int32_t *slice, dim_t *offset,
3242 dim_t *tensorDim, dim_t *sliceDim,
3243 dim_t numDimsTensor, dim_t numDimsSlice,
3244 dim_t offsetDim) {
3245 libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
3246 numDimsTensor, numDimsSlice, offsetDim);
3247}
3248
3249void libjit_insert_tensor_u(int64_t *tensor, int64_t *slice, dim_t *offset,
3250 dim_t *tensorDim, dim_t *sliceDim,
3251 dim_t numDimsTensor, dim_t numDimsSlice,
3252 dim_t offsetDim, dim_t count, dim_t axis) {
3253 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
3254 numDimsTensor, numDimsSlice, offsetDim, count, axis);
3255}
3256
3257void libjit_extract_tensor_u(int64_t *tensor, int64_t *slice, dim_t *offset,
3258 dim_t *tensorDim, dim_t *sliceDim,
3259 dim_t numDimsTensor, dim_t numDimsSlice,
3260 dim_t offsetDim) {
3261 libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
3262 numDimsTensor, numDimsSlice, offsetDim);
3263}
3264
3265void libjit_insert_tensor_i8(int8_t *tensor, int8_t *slice, dim_t *offset,
3266 dim_t *tensorDim, dim_t *sliceDim,
3267 dim_t numDimsTensor, dim_t numDimsSlice,
3268 dim_t offsetDim, dim_t count, dim_t axis) {
3269 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
3270 numDimsTensor, numDimsSlice, offsetDim, count, axis);
3271}
3272
3273void libjit_insert_tensor_b(int8_t *tensor, int8_t *slice, dim_t *offset,
3274 dim_t *tensorDim, dim_t *sliceDim,
3275 dim_t numDimsTensor, dim_t numDimsSlice,
3276 dim_t offsetDim, dim_t count, dim_t axis) {
3277 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
3278 numDimsTensor, numDimsSlice, offsetDim, count, axis);
3279}
3280
3281void libjit_space_to_depth_f(const float *inTensor, float *outTensor,
3282 dim_t blockSize, const dim_t *inDims,
3283 const dim_t *outDims) {
3284 libjit_space_to_depth_generic(inTensor, outTensor, blockSize, inDims,
3285 outDims);
3286}
3287
3288void libjit_space_to_depth_i8(const int8_t *inTensor, int8_t *outTensor,
3289 dim_t blockSize, const dim_t *inDims,
3290 const dim_t *outDims) {
3291 libjit_space_to_depth_generic(inTensor, outTensor, blockSize, inDims,
3292 outDims);
3293}
3294
3295/// Function to dump a tensor in text format in the console.
3296__attribute__((noinline)) void libjit_dump_tensor_console(uint8_t *tensor,
3297 dim_t *tensorDim,
3298 dim_t numDimsTensor,
3299 dim_t elemKind,
3300 const char *name) {
3301 printf("%s\n", name);
3302 /// This definition should match the defintion in Glow.
3303 enum class ElemKind : unsigned char {
3304 FloatTy, // 32-bit float type (float)
3305 Float16Ty, // 16-bit float type (half, fp16)
3306 BFloat16Ty, // 16-bit float type (bfloat16)
3307 Int8QTy, // 8-bit quantized type (int8_t)
3308 UInt8QTy, // unsigned 8-bit quantized type (uint8_t)
3309 Int16QTy, // 16-bit quantized type (int16_t)
3310 Int32QTy, // 32-bit quantized type (int32_t)
3311 Int32ITy, // 32-bit index type (int32_t)
3312 Int64ITy, // 64-bit index type (int64_t)
3313 UInt8FusedQTy, // 8-bit quantized type with fused scale/offset (uint8_t)
3314 BoolTy, // Bool type (bool)
3315 };
3316 // Dump the content of a tensor.
3317 switch ((ElemKind)elemKind) {
3318 case ElemKind::FloatTy:
3319 libjit_dump_tensor_console_impl((float *)tensor, tensorDim, numDimsTensor);
3320 break;
3321 case ElemKind::Int64ITy:
3322 libjit_dump_tensor_console_impl((dim_t *)tensor, tensorDim, numDimsTensor);
3323 break;
3324 case ElemKind::Int8QTy:
3325 libjit_dump_tensor_console_impl((int8_t *)tensor, tensorDim, numDimsTensor);
3326 break;
3327 case ElemKind::Int32QTy:
3328 libjit_dump_tensor_console_impl((int32_t *)tensor, tensorDim,
3329 numDimsTensor);
3330 break;
3331 default:
3332 printf("Dumping this type of payload is not supported: %zu\n",
3333 (size_t)elemKind);
3334 break;
3335 }
3336 puts("");
3337}
3338
3339/// Function to dump a tensor in binary format in a file using the raw tensor
3340/// data pointer \p tensor, the tensor data size \p tensorSize (in bytes) and
3341/// the file name \p filename. A text header \p header will also be dumped.
3342__attribute__((noinline)) void libjit_dump_tensor_bin(uint8_t *tensor,
3343 size_t tensorSize,
3344 const char *filename,
3345 const char *header) {
3346 FILE *fh = fopen(filename, "wb");
3347 if (!fh) {
3348 printf("ERROR opening file: '%s'!\n"
3349 "File name might be too long!\n",
3350 filename);
3351 return;
3352 }
3353 // Dump header.
3354 fprintf(fh, "%s", header);
3355 // Dump tensor data.
3356 size_t size = fwrite(tensor, 1, tensorSize, fh);
3357 assert((size == tensorSize) && "Error dumping tensor to file!");
3358 (void)size;
3359 fclose(fh);
3360}
3361
3362/// Functions to dump a tensor in text format in a file using the raw tensor
3363/// data pointer \p tensor, the tensor data size \p tensorElemSize (number of
3364/// elements) and the file name \p filename. A text header \p header will also
3365/// be dumped.
3366#define DEFINE_DUMP_TENSOR_TXT_KERNEL(type, suffix) \
3367 __attribute__((noinline)) void libjit_dump_tensor_txt_##suffix( \
3368 uint8_t *tensor, size_t tensorElemSize, const char *filename, \
3369 const char *header) { \
3370 libjit_dump_tensor_txt_impl((type *)tensor, tensorElemSize, filename, \
3371 header); \
3372 }
3373DEFINE_DUMP_TENSOR_TXT_KERNEL(float, f)
3374DEFINE_DUMP_TENSOR_TXT_KERNEL(int8_t, i8)
3375DEFINE_DUMP_TENSOR_TXT_KERNEL(int16_t, i16)
3376DEFINE_DUMP_TENSOR_TXT_KERNEL(int32_t, i32)
3377DEFINE_DUMP_TENSOR_TXT_KERNEL(int64_t, u)
3378DEFINE_DUMP_TENSOR_TXT_KERNEL(bool, b)
3379#undef DEFINE_DUMP_TENSOR_TXT_KERNEL
3380
3381void libjit_write_timestamp(uint64_t *tensor, dim_t offset) {
3382 // We are using C++ timer here to a avoid issues with gettimeofday
3383 // Issue #2397 covers migrating this to a libc approach but if you have issues
3384 // with a lack of C++ symbols at runtime check there first.
3385 uint64_t ts = std::chrono::duration_cast<std::chrono::microseconds>(
3386 std::chrono::steady_clock::now().time_since_epoch())
3387 .count();
3388 memcpy(tensor + offset, &ts, sizeof(uint64_t));
3389}
3390
3391/// Copies a kernel with type conversion
3392void libjit_convertTo_f_b(float *dstPtr, const bool *srcPtr, const dim_t *dims,
3393 dim_t numDims) {
3394 libjit_copy_kernel_with_conversion<float, bool>(dstPtr, srcPtr, dims,
3395 numDims);
3396}
3397
3398void libjit_convertTo_b_f(bool *dstPtr, const float *srcPtr, const dim_t *dims,
3399 dim_t numDims) {
3400 libjit_copy_kernel_with_conversion<bool, float>(dstPtr, srcPtr, dims,
3401 numDims);
3402}
3403
3404void libjit_convertTo_f_i32(float *dstPtr, const int32_t *srcPtr,
3405 const dim_t *dims, dim_t numDims) {
3406 libjit_copy_kernel_with_conversion<float, int32_t>(dstPtr, srcPtr, dims,
3407 numDims);
3408}
3409
3410void libjit_convertTo_i32_u(int32_t *dstPtr, const int64_t *srcPtr,
3411 const dim_t *dims, dim_t numDims) {
3412 libjit_copy_kernel_with_conversion<int32_t, int64_t>(dstPtr, srcPtr, dims,
3413 numDims);
3414}
3415
3416void libjit_convertTo_u_i32(int64_t *dstPtr, const int32_t *srcPtr,
3417 const dim_t *dims, dim_t numDims) {
3418 libjit_copy_kernel_with_conversion<int64_t, int32_t>(dstPtr, srcPtr, dims,
3419 numDims);
3420}
3421
3422void libjit_convertTo_i32_b(int32_t *dstPtr, const bool *srcPtr,
3423 const dim_t *dims, dim_t numDims) {
3424 libjit_copy_kernel_with_conversion<int32_t, bool>(dstPtr, srcPtr, dims,
3425 numDims);
3426}
3427
3428void libjit_convertTo_i32_f(int32_t *dstPtr, const float *srcPtr,
3429 const dim_t *dims, dim_t numDims) {
3430 libjit_copy_kernel_with_conversion<int32_t, float>(dstPtr, srcPtr, dims,
3431 numDims);
3432}
3433
3434/// Update min/max values \p compInfo and histogram \p existingHistogram with
3435/// data collected from tensor \p inputTensor.
3436/// Note: code ported from Profile.cpp: generateTensorHistogram
3437__attribute__((noinline)) void
3438libjit_quantization_profile(float *inputTensor, dim_t tensorSize,
3439 float *compInfo, float *existingHistogram,
3440 dim_t *histDim) {
3441 dim_t nBins = histDim[0];
3442
3443 // Min/max computed from previous runs. If this is the first run, compInfo is
3444 // expected to be initialized as following:
3445 // compInfo[0]: std::numeric_limits<float>::max()
3446 // compInfo[1]: std::numeric_limits<float>::lowest()
3447 float min = compInfo[0];
3448 float max = compInfo[1];
3449
3450 // Min/max value for entire current input tensor.
3451 float minInput;
3452 float maxInput;
3453 find_min_max_f(inputTensor, tensorSize, minInput, maxInput);
3454
3455 // Update the global min/max.
3456 float newMin = MIN(minInput, min);
3457 float newMax = MAX(maxInput, max);
3458 compInfo[0] = newMin;
3459 compInfo[1] = newMax;
3460
3461 // If input histogram is empty then return.
3462 if (nBins == 0) {
3463 return;
3464 }
3465
3466 // Initial profile.
3467 if (check_all_zeros(existingHistogram, nBins) == 1) {
3468 min = minInput;
3469 max = maxInput;
3470 }
3471
3472 // If the min/max range changes, there is the need to rescale the histogram.
3473 if (newMin < min || newMax > max) {
3474 float destBinWidth = (newMax - newMin) / nBins;
3475 float srcBinWidth = (max - min) / nBins;
3476 float scaledHistogram[nBins];
3477 for (dim_t i = 0; i < nBins; ++i) {
3478 scaledHistogram[i] = 0.0f;
3479 }
3480
3481 for (dim_t i = 0; i < nBins; ++i) {
3482 if (existingHistogram[i] == 0)
3483 continue;
3484
3485 float srcBinBegin = min + srcBinWidth * i;
3486 dim_t destBin = (srcBinBegin - newMin) / destBinWidth;
3487 float destBinEnd = newMin + destBinWidth * (destBin + 1);
3488
3489 float srcBinEnd = srcBinBegin + srcBinWidth;
3490 dim_t destBinToVerify = (srcBinEnd - newMin) / destBinWidth;
3491 // Make sure that destination bin is mapped at most to 2 final bins, based
3492 // on that redistribute percentage is calculated.
3493 assert(destBinToVerify <= destBin + 2);
3494 (void)destBinToVerify;
3495
3496 // Calculate how much we need to redistribute.
3497 uint64_t dstBinCnt = static_cast<uint64_t>(
3498 MIN(static_cast<float>(round((destBinEnd - srcBinBegin) /
3499 srcBinWidth * existingHistogram[i])),
3500 existingHistogram[i]));
3501
3502 dim_t newBin = get_bin(nBins, destBinWidth, newMin, srcBinBegin);
3503 scaledHistogram[newBin] += dstBinCnt;
3504
3505 if (dstBinCnt < existingHistogram[i]) {
3506 dim_t newBin =
3507 get_bin(nBins, destBinWidth, newMin, srcBinBegin + destBinWidth);
3508 scaledHistogram[newBin] += existingHistogram[i] - dstBinCnt;
3509 }
3510 }
3511
3512 // Copy scaled histogram back to the existing histogram.
3513 for (dim_t i = 0, e = nBins; i < e; ++i) {
3514 existingHistogram[i] = scaledHistogram[i];
3515 }
3516
3517 // Update global min and max.
3518 min = newMin;
3519 max = newMax;
3520 }
3521
3522 // Update the histogram with the values of the current input tensor.
3523 float binWidth = (max - min) / nBins;
3524 for (dim_t i = 0, e = tensorSize; i < e; ++i) {
3525 dim_t newBin = get_bin(nBins, binWidth, min, inputTensor[i]);
3526 existingHistogram[newBin]++;
3527 }
3528}
3529
3530__attribute__((noinline)) void
3531libjit_nms_u(uint64_t *indices, uint64_t *numDetected, const float *boxTensor,
3532 const dim_t *boxTensorDims, dim_t boxTensorDimSize,
3533 const float *scoresTensor, const dim_t *scoresTensorDims,
3534 dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
3535 dim_t resultTensorDimSize, unsigned centerPointBox,
3536 unsigned maxOutputBoxesPerClass, float iouThreshold,
3537 float scoreThreshold, bool isV4) {
3538 libjit_nms_generic(indices, numDetected, boxTensor, boxTensorDims,
3539 boxTensorDimSize, scoresTensor, scoresTensorDims,
3540 scoresTensorDimSize, resultTensorDims, resultTensorDimSize,
3541 centerPointBox, maxOutputBoxesPerClass, iouThreshold,
3542 scoreThreshold, isV4);
3543}
3544
3545__attribute__((noinline)) void
3546libjit_nms_i32(int32_t *indices, int32_t *numDetected, const float *boxTensor,
3547 const dim_t *boxTensorDims, dim_t boxTensorDimSize,
3548 const float *scoresTensor, const dim_t *scoresTensorDims,
3549 dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
3550 dim_t resultTensorDimSize, unsigned centerPointBox,
3551 unsigned maxOutputBoxesPerClass, float iouThreshold,
3552 float scoreThreshold, bool isV4) {
3553 libjit_nms_generic(indices, numDetected, boxTensor, boxTensorDims,
3554 boxTensorDimSize, scoresTensor, scoresTensorDims,
3555 scoresTensorDimSize, resultTensorDims, resultTensorDimSize,
3556 centerPointBox, maxOutputBoxesPerClass, iouThreshold,
3557 scoreThreshold, isV4);
3558}
3559
3560/// FFT Radix2 DIT (Decimation In Time) implementation for Complex data.
3561/// The \p input and \p output buffers have 2 * \p fftLength float
3562/// samples corresponding to \p fftLength complex samples with real and
3563/// imaginary parts interleaved: real[0], imag[0], real[1], imag[1], ...
3564/// The lookup tables \p twiddleFactors and \p bitReverseIndices are
3565/// generated at compile time. The boolean flag \p inPlace decides whether
3566/// the FFT computation is done in-place (that is in the \p input buffer
3567/// without writing in the \p output buffer) or out-of-place (written in
3568/// the \p output buffer).
3569void libjit_fft_complex_f(float *output, float *input,
3570 const float *twiddleFactors,
3571 const int32_t *bitReverseIndices, unsigned fftLength,
3572 bool inPlace) {
3573
3574 // Bit Reverse Reordering.
3575 if (inPlace) {
3576 for (dim_t idx = 0; idx < fftLength; idx++) {
3577 int32_t bitRevIdx = bitReverseIndices[idx];
3578 if (int32_t(idx) < bitRevIdx) {
3579 // Swap complex pair.
3580 float real = input[2 * idx + 0];
3581 float imag = input[2 * idx + 1];
3582 input[2 * idx + 0] = input[2 * bitRevIdx + 0];
3583 input[2 * idx + 1] = input[2 * bitRevIdx + 1];
3584 input[2 * bitRevIdx + 0] = real;
3585 input[2 * bitRevIdx + 1] = imag;
3586 }
3587 }
3588 } else {
3589 for (dim_t idx = 0; idx < fftLength; idx++) {
3590 int32_t bitRevIdx = bitReverseIndices[idx];
3591 output[2 * idx + 0] = input[2 * bitRevIdx + 0];
3592 output[2 * idx + 1] = input[2 * bitRevIdx + 1];
3593 }
3594 }
3595
3596 // FFT output pointer.
3597 float *bitRevOut = inPlace ? input : output;
3598
3599 // Number of FFT stages.
3600 dim_t stageNum = std::log2((double)fftLength);
3601
3602 // Number of radix2 butterfly groups for 1st stage.
3603 dim_t groupNum = fftLength / 2;
3604
3605 // Number of radix2 butterflies per group for 1st stage.
3606 dim_t groupButterNum = 1;
3607
3608 // Stage loop.
3609 for (dim_t stageIdx = 0; stageIdx < stageNum; stageIdx++) {
3610
3611 // Butterfly input/output pointers.
3612 float *inp1Ptr = bitRevOut + 0 * groupButterNum;
3613 float *inp2Ptr = bitRevOut + 2 * groupButterNum;
3614
3615 // Butterfly group loop.
3616 for (dim_t groupIdx = 0; groupIdx < groupNum; groupIdx++) {
3617
3618 // Twiddle factors pointer.
3619 const float *twPtr = twiddleFactors;
3620
3621 // Butterfly loop within group.
3622 for (dim_t groupButterIdx = 0; groupButterIdx < groupButterNum;
3623 groupButterIdx++) {
3624
3625 // Radix 2 butterfly.
3626 float inp0_re = *inp1Ptr++;
3627 float inp0_im = *inp1Ptr--;
3628 float inp1_re = *inp2Ptr++;
3629 float inp1_im = *inp2Ptr--;
3630
3631 float tw_re = *twPtr++;
3632 float tw_im = *twPtr--;
3633 twPtr += (2 * groupNum);
3634
3635 float inp1_tw_mult_re = inp1_re * tw_re - inp1_im * tw_im;
3636 float inp1_tw_mult_im = inp1_re * tw_im + inp1_im * tw_re;
3637
3638 *inp1Ptr++ = inp0_re + inp1_tw_mult_re;
3639 *inp1Ptr++ = inp0_im + inp1_tw_mult_im;
3640 *inp2Ptr++ = inp0_re - inp1_tw_mult_re;
3641 *inp2Ptr++ = inp0_im - inp1_tw_mult_im;
3642 }
3643
3644 inp1Ptr += 2 * groupButterNum;
3645 inp2Ptr += 2 * groupButterNum;
3646 }
3647
3648 // Update parameters for next stage.
3649 groupNum >>= 1;
3650 groupButterNum <<= 1;
3651 }
3652}
3653
3654/// FFT Radix2 DIT (Decimation In Time) implementation for Real data.
3655/// The implementation uses a fftLength/2 FFT for Complex data followed
3656/// by a step to map the complex FFT to the real FFT by using a set of
3657/// of complex weights \p complexToRealWeights A[k] defined as:
3658/// A[k] = 1/2 * (1 - j*exp(-j*2*pi*k/N)) for k = 0 .. N/4-1
3659/// The \p input buffer has \p fftLength float values corresponding
3660/// to \p fftLength real samples. Since the FFT of a real signal
3661/// has conjugate symmetry, the \p output buffer only contains
3662/// 2 * (fftLength/2+1) = fftLength + 2 float values corresponding
3663/// to fftLength/2+1 complex samples with real and imaginary parts
3664/// interleaved: real[0], imag[0], real[1], imag[1], ...
3665/// The lookup tables \p twiddleFactors and \p bitReverseIndices are
3666/// generated at compile time as if they were generated for a N/2
3667/// complex FFT. The boolean flag \p inPlace decides whether the FFT
3668/// computation is done in-place (that is in the \p input buffer
3669/// without writing in the \p output buffer) or out-of-place (written in
3670/// the \p output buffer).
3671void libjit_fft_real_f(float *output, float *input, const float *twiddleFactors,
3672 const int32_t *bitReverseIndices,
3673 const float *complexToRealWeights, unsigned fftLength,
3674 bool inPlace) {
3675
3676 // Perform N/2 complex FFT (in-place or out-of-place).
3677 // G[k] with k = 0 .. N/2-1.
3678 libjit_fft_complex_f(output, input, twiddleFactors, bitReverseIndices,
3679 fftLength / 2, inPlace);
3680
3681 // Complex to Real FFT mapping (in-place).
3682 // X[k] = G[k] * A[k] + conj(G[N/2-k]) * (1 - A[k])
3683 // for k = 0 .. N/2 with the convention G[N/2] = G[0].
3684 // Particular cases:
3685 // real(X[0]) = real(G[0]) + imag(G[0])
3686 // imag(X[0]) = 0
3687 // real(X[N/2]) = real(G[0]) - imag(G[0])
3688 // imag(X[N/2]) = 0
3689 // X[N/4] = conj(G[N/4])
3690
3691 const float *Ak = complexToRealWeights + 2;
3692 float *ptr = inPlace ? input : output;
3693 float *ptr0 = &ptr[0];
3694 float *ptr1 = &ptr[2 * fftLength / 2 + 1];
3695 float inp0_re = *ptr0++;
3696 float inp0_im = *ptr0--;
3697 *ptr0++ = inp0_re + inp0_im;
3698 *ptr0++ = 0;
3699 *ptr1-- = 0;
3700 *ptr1-- = inp0_re - inp0_im;
3701
3702 for (dim_t k = 1; k < fftLength / 4; k++) {
3703
3704 float inp0_re = *ptr0++;
3705 float inp0_im = *ptr0--;
3706 float inp1_im = *ptr1--;
3707 float inp1_re = *ptr1++;
3708
3709 float Ak_re = *Ak++;
3710 float Ak_im = *Ak++;
3711
3712 float dif_re = inp0_re - inp1_re;
3713 float sum_im = inp0_im + inp1_im;
3714 float prod0 = dif_re * Ak_re - sum_im * Ak_im;
3715 float prod1 = dif_re * Ak_im + sum_im * Ak_re;
3716
3717 *ptr0++ = +prod0 + inp1_re;
3718 *ptr0++ = +prod1 - inp1_im;
3719 *ptr1-- = +prod1 - inp0_im;
3720 *ptr1-- = -prod0 + inp0_re;
3721 }
3722
3723 if (fftLength >= 4) {
3724 *ptr1 = -*ptr1;
3725 }
3726}
3727
3728/// Compute the spectrogram for the given 1D mono audio signal \p input.
3729/// The input windows are weighted using the \p window function and the
3730/// FFT LUTs \p twiddleFactors and \p bitReverseIndices are computed at
3731/// compile-time. More details in Graph.h about the AudioSpectrogram node.
3732void libjit_audio_spectrogram_f(
3733 void *winOutScratch, void *fftOutScratch, float *spectrogram,
3734 const float *input, const float *window, const float *twiddleFactors,
3735 const int32_t *bitReverseIndices, const float *complexToRealWeights,
3736 const dim_t *spectrogramDims, const dim_t inputLength,
3737 const dim_t windowSize, const dim_t windowStride,
3738 const bool magnitudeSquared) {
3739
3740 dim_t winNum = spectrogramDims[0];
3741 dim_t specLen = spectrogramDims[1];
3742 dim_t fftLen = (specLen - 1) * 2;
3743
3744 // Scratch buffers.
3745 float *winOut = (float *)winOutScratch;
3746 float *fftOut = (float *)fftOutScratch;
3747 memset(winOut, 0, fftLen * sizeof(float));
3748
3749 // Compute the spectrogram.
3750 for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
3751
3752 // Windowing.
3753 for (dim_t n = 0; n < windowSize; n++) {
3754 winOut[n] = input[winIdx * windowStride + n] * window[n];
3755 }
3756
3757 // Compute spectrum (perform FFT for real data).
3758 libjit_fft_real_f(fftOut, winOut, twiddleFactors, bitReverseIndices,
3759 complexToRealWeights, fftLen, false /* inPlace */);
3760
3761 // Compute spectrum magnitude/power.
3762 for (dim_t k = 0; k < specLen; k++) {
3763 float real = fftOut[2 * k + 0];
3764 float imag = fftOut[2 * k + 1];
3765 float power = real * real + imag * imag;
3766 if (magnitudeSquared) {
3767 *spectrogram++ = power;
3768 } else {
3769 *spectrogram++ = std::sqrt(power);
3770 }
3771 }
3772 }
3773}
3774
3775/// Compute the MFCC (Mel Frequency Cepstral Coefficient) for the given
3776/// \p spectrogram power. The lookup tables \p melWeights, \p melRanges
3777/// and \p dctMat are computed at compile-time. More details in Graph.h
3778/// about the MFCC node.
3779void libjit_mfcc_f(void *scratch, float *coefficients, const float *spectrogram,
3780 const float *melWeights, const int32_t *melRanges,
3781 const float *dctMat, const dim_t *coefficientsDims,
3782 const dim_t *spectrogramDims, const dim_t filterBankCount) {
3783
3784 // Scratch buffer.
3785 float *melBuff = (float *)scratch;
3786
3787 // Perform MFCC for all the windows.
3788 dim_t winNum = spectrogramDims[0];
3789 dim_t winSize = spectrogramDims[1];
3790 dim_t numCoefficients = coefficientsDims[1];
3791 for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
3792
3793 // Pointers backup for this window.
3794 const float *melWeightsPtr = melWeights;
3795 const int32_t *melRangesPtr = melRanges;
3796 const float *dctMatPtr = dctMat;
3797
3798 // Apply Mel filter bank mapping. We use sqrt for the spectrogram since we
3799 // assume the spectrogram is a power value and not a magnitude.
3800 for (dim_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
3801
3802 int32_t freqIdxStart = *melRangesPtr++;
3803 int32_t freqIdxStop = *melRangesPtr++;
3804
3805 // Compute Mel Power.
3806 float melPwr = 0.0f;
3807 for (int32_t freqIdx = freqIdxStart; freqIdx <= freqIdxStop; freqIdx++) {
3808 melPwr += std::sqrt(spectrogram[freqIdx]) * (*melWeightsPtr++);
3809 }
3810
3811 // Take logarithm in-place (avoid log(0)).
3812 melBuff[melIdx] = (melPwr == 0.0)
3813 ? logf(std::numeric_limits<float>::min())
3814 : logf(melPwr);
3815 }
3816
3817 // Compute DCT transform.
3818 for (dim_t k = 0; k < numCoefficients; k++) {
3819 float dctOut = 0.0f;
3820 for (dim_t n = 0; n < filterBankCount; n++) {
3821 dctOut += (*dctMatPtr++) * melBuff[n];
3822 }
3823 *coefficients++ = dctOut;
3824 }
3825
3826 // Go to next spectrogram window.
3827 spectrogram += winSize;
3828 }
3829}
3830
3831//===----------------------------------------------------------------------===//
3832// TFLiteDetectionPostProcess
3833//===----------------------------------------------------------------------===//
3834static int32_t partition(int32_t *arr, int32_t low, int32_t high,
3835 float *values) {
3836 float pivot = values[high];
3837 int32_t i = (low - 1);
3838 float swap_float;
3839 int32_t swap_int;
3840
3841 for (int32_t j = low; j <= high - 1; j++) {
3842 if (values[j] > pivot) {
3843 i++;
3844
3845 swap_float = values[i];
3846 values[i] = values[j];
3847 values[j] = swap_float;
3848
3849 swap_int = arr[i];
3850 arr[i] = arr[j];
3851 arr[j] = swap_int;
3852 }
3853 }
3854
3855 swap_float = values[i + 1];
3856 values[i + 1] = values[high];
3857 values[high] = swap_float;
3858
3859 swap_int = arr[i + 1];
3860 arr[i + 1] = arr[high];
3861 arr[high] = swap_int;
3862
3863 return (i + 1);
3864}
3865
3866static void partial_sort(int32_t *arr, int32_t i, int32_t j, int32_t k,
3867 float *values) {
3868 int32_t p;
3869 if (i < j) {
3870 p = partition(arr, i, j, values);
3871
3872 partial_sort(arr, i, p - 1, k, values);
3873
3874 if (p < k - 1)
3875 partial_sort(arr, p + 1, j, k, values);
3876 }
3877}
3878
3879static void iota(int32_t *first, int32_t *last, int32_t value) {
3880 while (first != last) {
3881 *first++ = value;
3882 value++;
3883 }
3884}
3885
3886static void decreasing_partial_arg_sort(float *values, int32_t num_values,
3887 int32_t num_to_sort, int32_t *indices,
3888 float *aux_values) {
3889 iota(indices, indices + num_values, 0);
3890
3891 memcpy(aux_values, values, sizeof(float) * num_values);
3892
3893 partial_sort(indices, 0, num_values - 1, num_to_sort, aux_values);
3894}
3895
3896static void select_detection_above_score_threshold(
3897 float *scores, int32_t num_scores, float threshold, float *keep_values,
3898 int32_t *keep_indices, int32_t *num_indices) {
3899 int32_t idx = 0;
3900 for (int32_t i = 0; i < num_scores; i++) {
3901 if (scores[i] >= threshold) {
3902 keep_indices[idx] = i;
3903 keep_values[idx] = scores[i];
3904 idx++;
3905 }
3906 }
3907 *num_indices = idx;
3908}
3909
3910/// Compute the IOU (Intersection Over Union) metric between two boxes. Each
3911/// of box1 and box2 is a vector with 4 floating-point values with the box
3912/// coordinates in the following format: [ymin, xmin, ymax, xmax].
3913static float tflite_compute_iou(float *box1, float *box2) {
3914
3915 // Compute the areas of the two boxes.
3916 float box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1]);
3917 float box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1]);
3918
3919 // If box coordinates are invalid we return 0.
3920 if (box1Area <= 0 || box2Area <= 0) {
3921 return 0.0f;
3922 }
3923
3924 // Determine the coordinates of the intersection rectangle.
3925 float iYmin = MAX(box1[0], box2[0]);
3926 float iXmin = MAX(box1[1], box2[1]);
3927 float iYmax = MIN(box1[2], box2[2]);
3928 float iXmax = MIN(box1[3], box2[3]);
3929
3930 // Compute the area of the intersection rectangle.
3931 float iArea = MAX(0.0f, iXmax - iXmin) * MAX(0.0f, iYmax - iYmin);
3932
3933 // Compute the area of the union (reunion) rectangle.
3934 float uArea = box1Area + box2Area - iArea;
3935
3936 // Compute the Intersection Over Union metric.
3937 return iArea / uArea;
3938}
3939
3940static void tflite_helper(float *boxesPtr, int32_t num_boxes,
3941 float nms_score_threshold, float nms_iou_treshold,
3942 float *class_scores, int32_t num_scores,
3943 int32_t *selected, int32_t *num_selected,
3944 int32_t max_detections, int32_t *keep_indices,
3945 float *keep_scores, int32_t *sorted_indices_helper) {
3946
3947 *num_selected = 0;
3948
3949 int32_t num_scores_kept;
3950 select_detection_above_score_threshold(class_scores, num_boxes,
3951 nms_score_threshold, keep_scores,
3952 keep_indices, &num_scores_kept);
3953
3954 decreasing_partial_arg_sort(keep_scores, num_scores_kept, num_scores_kept,
3955 sorted_indices_helper, (float *)selected);
3956
3957 int32_t num_boxes_kept = num_scores_kept;
3958 int32_t output_size = MIN(num_boxes_kept, max_detections);
3959
3960 int32_t num_active_candidate = num_boxes_kept;
3961
3962 uint8_t *active_box_candidate = (uint8_t *)keep_scores;
3963
3964 for (int32_t row = 0; row < num_boxes_kept; row++) {
3965 active_box_candidate[row] = 1;
3966 }
3967
3968 for (int32_t i = 0; i < num_boxes_kept; i++) {
3969 if (num_active_candidate == 0 || *num_selected >= output_size)
3970 break;
3971 if (active_box_candidate[i] == 1) {
3972 selected[*num_selected] = keep_indices[sorted_indices_helper[i]];
3973 (*num_selected)++;
3974 active_box_candidate[i] = 0;
3975 num_active_candidate--;
3976 } else {
3977 continue;
3978 }
3979
3980 for (int32_t j = i + 1; j < num_boxes_kept; ++j) {
3981 if (active_box_candidate[j] == 1) {
3982
3983 float *box1 = boxesPtr + 4 * keep_indices[sorted_indices_helper[i]];
3984 float *box2 = boxesPtr + 4 * keep_indices[sorted_indices_helper[j]];
3985 float iou = tflite_compute_iou(box1, box2);
3986
3987 if (iou > nms_iou_treshold) {
3988 active_box_candidate[j] = 0;
3989 num_active_candidate--;
3990 }
3991 }
3992 }
3993 }
3994}
3995
3996void libjit_tflite_detection_post_process_f(
3997 float *boxes, float *scores, float *anchors, float *detectionBoxes,
3998 int32_t *detectionClasses, float *detectionScores, int32_t *numDetections,
3999 int8_t *scratch, int32_t numBoxes, int32_t numTotalClasses,
4000 int32_t numClasses, int32_t maxDetections, int32_t maxClassesPerDetection,
4001 int32_t maxDetectionsPerClass, float iouThreshold, float scoreThreshold,
4002 float xScaleInv, float yScaleInv, float hScaleInv, float wScaleInv,
4003 bool regularNMS) {
4004
4005 // Decode the box coordinates in-place using the anchors.
4006 for (int32_t i = 0; i < numBoxes; i++) {
4007
4008 float *box = &boxes[i * 4];
4009 float *anchor = &anchors[i * 4];
4010
4011 float ycenter = box[0] * yScaleInv * anchor[2] + anchor[0];
4012 float xcenter = box[1] * xScaleInv * anchor[3] + anchor[1];
4013
4014 float half_h = 0.5f * expf(box[2] * hScaleInv) * anchor[2];
4015 float half_w = 0.5f * expf(box[3] * wScaleInv) * anchor[3];
4016
4017 box[0] = ycenter - half_h;
4018 box[1] = xcenter - half_w;
4019 box[2] = ycenter + half_h;
4020 box[3] = xcenter + half_w;
4021 }
4022
4023 int32_t max_categories_per_anchor = maxClassesPerDetection;
4024 int32_t num_categories_per_anchor =
4025 MIN(max_categories_per_anchor, numClasses);
4026 int32_t label_offset = numTotalClasses - numClasses;
4027
4028 if (regularNMS) {
4029 int32_t num_detections_per_class = maxDetectionsPerClass;
4030
4031 float *class_scores = (float *)(scratch);
4032 scratch += numBoxes * sizeof(float);
4033
4034 int32_t *box_indices_after_regular_nms = (int32_t *)(scratch);
4035 scratch += (numBoxes + maxDetections) * sizeof(int32_t);
4036
4037 float *scores_after_regular_nms = (float *)(scratch);
4038 scratch += (numBoxes + maxDetections) * sizeof(float);
4039
4040 int32_t size_of_sorted_indices = 0;
4041
4042 int32_t *sorted_indices = (int32_t *)(scratch);
4043 scratch += (numBoxes + maxDetections) * sizeof(int32_t);
4044
4045 float *sorted_values = (float *)(scratch);
4046 scratch += MIN(numBoxes, maxDetectionsPerClass) * sizeof(float);
4047
4048 int32_t *selected = (int32_t *)scratch;
4049 scratch += numBoxes * sizeof(int32_t);
4050
4051 int32_t *keep_indices = (int32_t *)(scratch);
4052 scratch += numBoxes * sizeof(int32_t);
4053
4054 float *keep_scores = (float *)(scratch);
4055 scratch += numBoxes * sizeof(float);
4056
4057 int32_t *sorted_indices_helper = (int32_t *)scratch;
4058 scratch += numBoxes * sizeof(int32_t);
4059
4060 for (int32_t col = 0; col < numClasses; col++) {
4061 for (int32_t row = 0; row < numBoxes; row++) {
4062 class_scores[row] =
4063 *(scores + row * numTotalClasses + col + label_offset);
4064 }
4065
4066 int32_t num_selected;
4067 tflite_helper(boxes, numBoxes, scoreThreshold, iouThreshold, class_scores,
4068 numBoxes, selected, &num_selected, num_detections_per_class,
4069 keep_indices, keep_scores, sorted_indices_helper);
4070
4071 int32_t output_index = size_of_sorted_indices;
4072 for (int32_t i = 0; i < num_selected; i++) {
4073 int32_t selected_index = selected[i];
4074 box_indices_after_regular_nms[output_index] =
4075 (selected_index * numTotalClasses + col + label_offset);
4076 scores_after_regular_nms[output_index] = class_scores[selected_index];
4077 output_index++;
4078 }
4079
4080 int32_t num_indices_to_sort = MIN(output_index, maxDetections);
4081
4082 decreasing_partial_arg_sort(scores_after_regular_nms, output_index,
4083 num_indices_to_sort, sorted_indices,
4084 keep_scores);
4085
4086 for (int32_t row = 0; row < num_indices_to_sort; row++) {
4087 int32_t temp = sorted_indices[row];
4088 sorted_indices[row] = box_indices_after_regular_nms[temp];
4089 sorted_values[row] = scores_after_regular_nms[temp];
4090 }
4091
4092 for (int32_t row = 0; row < num_indices_to_sort; row++) {
4093 box_indices_after_regular_nms[row] = sorted_indices[row];
4094 scores_after_regular_nms[row] = sorted_values[row];
4095 }
4096
4097 size_of_sorted_indices = num_indices_to_sort;
4098 }
4099
4100 for (int32_t output_box_index = 0;
4101 output_box_index < size_of_sorted_indices; output_box_index++) {
4102
4103 int32_t anchor_index =
4104 box_indices_after_regular_nms[output_box_index] / numTotalClasses;
4105 int32_t class_index = box_indices_after_regular_nms[output_box_index] -
4106 anchor_index * numTotalClasses - label_offset;
4107 float selected_score = scores_after_regular_nms[output_box_index];
4108 float *box = boxes + anchor_index * 4;
4109
4110 *detectionBoxes++ = *box++;
4111 *detectionBoxes++ = *box++;
4112 *detectionBoxes++ = *box++;
4113 *detectionBoxes++ = *box++;
4114 *detectionClasses++ = class_index;
4115 *detectionScores++ = selected_score;
4116 }
4117
4118 *numDetections = size_of_sorted_indices;
4119 } else {
4120 float *max_scores = (float *)scratch;
4121 scratch += numBoxes * sizeof(float);
4122
4123 int32_t *sorted_classes_indices = (int32_t *)scratch;
4124 scratch += numBoxes * MIN(maxDetections, numClasses) * sizeof(int32_t);
4125
4126 int32_t *selected = (int32_t *)scratch;
4127 scratch += numBoxes * sizeof(int32_t);
4128
4129 int32_t *keep_indices = (int32_t *)(scratch);
4130 scratch += numBoxes * sizeof(int32_t);
4131
4132 float *keep_scores = (float *)(scratch);
4133 scratch += numBoxes * sizeof(float);
4134
4135 int32_t *sorted_indices_helper = (int32_t *)scratch;
4136 scratch += numBoxes * sizeof(int32_t);
4137
4138 for (int32_t row = 0; row < numBoxes; row++) {
4139 float *box_scores = scores + row * numTotalClasses + label_offset;
4140 int32_t *class_indices =
4141 sorted_classes_indices + row * num_categories_per_anchor;
4142
4143 decreasing_partial_arg_sort(box_scores, numClasses,
4144 num_categories_per_anchor, keep_indices,
4145 keep_scores);
4146
4147 for (int32_t i = 0; i < num_categories_per_anchor; i++) {
4148 class_indices[i] = keep_indices[i];
4149 }
4150
4151 max_scores[row] = box_scores[class_indices[0]];
4152 }
4153
4154 int32_t selected_size = 0;
4155 tflite_helper(boxes, numBoxes, scoreThreshold, iouThreshold, max_scores,
4156 numBoxes, selected, &selected_size, maxDetections,
4157 keep_indices, keep_scores, sorted_indices_helper);
4158
4159 int32_t num_detections = 0;
4160 for (int32_t i = 0; i < selected_size; i++) {
4161
4162 int32_t selected_index = selected[i];
4163 float *box = boxes + selected_index * 4;
4164 float *box_scores =
4165 scores + selected_index * numTotalClasses + label_offset;
4166 int32_t *class_indices =
4167 sorted_classes_indices + selected_index * num_categories_per_anchor;
4168
4169 for (int32_t col = 0; (col < num_categories_per_anchor) &&
4170 (num_detections <= selected_size);
4171 ++col) {
4172 *detectionBoxes++ = box[0];
4173 *detectionBoxes++ = box[1];
4174 *detectionBoxes++ = box[2];
4175 *detectionBoxes++ = box[3];
4176 *detectionClasses++ = class_indices[col];
4177 *detectionScores++ = box_scores[class_indices[col]];
4178 num_detections++;
4179 }
4180 }
4181
4182 *numDetections = selected_size;
4183 }
4184}
4185
4186//===----------------------------------------------------------------------===//
4187// Instrumentation Callbacks
4188//===----------------------------------------------------------------------===//
4189#ifdef GLOW_LIBJIT_EXTERNAL_FUNCTIONS
4190/// Glow IR instrumentation external callbacks.
4191void glow_instrument_before(int id, int kind, int opInp, int opOut,
4192 uint8_t **opAddr, int *opSize);
4193void glow_instrument_after(int id, int kind, int opInp, int opOut,
4194 uint8_t **opAddr, int *opSize);
4195
4196__attribute__((noinline)) void libjit_instrument_before(int id, int kind,
4197 int opInp, int opOut,
4198 uint8_t **opAddr,
4199 int *opSize) {
4200 glow_instrument_before(id, kind, opInp, opOut, opAddr, opSize);
4201}
4202
4203__attribute__((noinline)) void libjit_instrument_after(int id, int kind,
4204 int opInp, int opOut,
4205 uint8_t **opAddr,
4206 int *opSize) {
4207 glow_instrument_after(id, kind, opInp, opOut, opAddr, opSize);
4208}
4209#endif
4210
4211} // extern "C"
4212