1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include <assert.h>
17#include <math.h>
18#include <stddef.h>
19#include <stdint.h>
20#include <stdio.h>
21#include <stdlib.h>
22#include <string.h>
23#include <sys/types.h>
24
25#include "libjit_defs.h"
26
27namespace {
28// Initialize the convolution output frame for slice \p N with the bias \p
29// biasW.
30void libjit_conv_init_output_with_bias(dim_t N, float *outW, const float *biasW,
31 const dim_t *outWdims,
32 const dim_t *biasWdims) {
33 // For each (x,y) step in the output tensor:
34 for (dim_t ax = 0; ax < outWdims[1]; ax++) {
35 for (dim_t ay = 0; ay < outWdims[2]; ay++) {
36 // For each output channel:
37 for (dim_t d = 0; d < outWdims[3]; d++) {
38 // Store the results to the output buffer.
39 float bias = biasW[d];
40 auto outIdx = libjit_getXYZW(outWdims, N, ax, ay, d);
41 outW[outIdx] = bias;
42 } // For each depth in the output.
43 } // For each Y in the output.
44 } // For each X in the output.
45}
46
47/// Generic template for quantized conv2d. The template allows choosing
48/// element type and bias type.
49template <typename ElemTy, typename BiasElemTy>
50void libjit_quantized_conv2d_generic(
51 ElemTy *outW, const ElemTy *inW, const ElemTy *filterW,
52 const BiasElemTy *biasW, const dim_t *outWdims, const dim_t *inWdims,
53 const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernelSizes,
54 const dim_t *strides, const dim_t *pads, dim_t group, int32_t outOffset,
55 int32_t inOffset, int32_t filterOffset, int32_t biasOffset, int32_t biasPre,
56 int32_t biasPost, int32_t biasScale, int32_t outPre, int32_t outPost,
57 int32_t outScale, unsigned depthUnroll, const dim_t *dilation,
58 int32_t actType, const int32_t *actArgs) {
59 dim_t inChannels = inWdims[3];
60 dim_t outChannels = outWdims[3];
61 dim_t inCperG = inChannels / group;
62 dim_t outCperG = outChannels / group;
63 dim_t pad_t = pads[0];
64 dim_t pad_l = pads[1];
65 dim_t stride_h = strides[0];
66 size_t stride_w = strides[1];
67 size_t kernel_h = kernelSizes[0];
68 size_t kernel_w = kernelSizes[1];
69 // For each input in the batch:
70 for (size_t n = 0; n < inWdims[0]; n++) {
71 // For each group of input channels:
72 for (size_t g = 0; g < group; g++) {
73
74 // For each output channel in the group. Process 'depthUnroll' output
75 // layers together.
76 for (size_t d = g * outCperG; d < (g + 1) * outCperG; d += depthUnroll) {
77 // For each convolution 'jump' in the input tensor:
78 ssize_t x = -(ssize_t)pad_t;
79 for (size_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
80 ssize_t y = -(ssize_t)pad_l;
81 for (size_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
82 int32_t sum[depthUnroll];
83
84 for (unsigned i = 0; i < depthUnroll; i++) {
85 // Scale the bias to match the scale of the matrix multiplication.
86 sum[i] = libjit_scale<int32_t>((int32_t)biasW[d + i] - biasOffset,
87 biasPre, biasPost, biasScale, 0);
88 }
89
90 // For each element in the convolution-filter:
91 for (size_t fx = 0; fx < kernel_h; fx++) {
92 for (size_t fy = 0; fy < kernel_w; fy++) {
93 ssize_t ox = x + fx * dilation[0];
94 ssize_t oy = y + fy * dilation[1];
95
96 // Ignore index access below zero (this is due to padding).
97 if (ox < 0 || oy < 0 || ox >= (ssize_t)inWdims[1] ||
98 oy >= (ssize_t)inWdims[2]) {
99 continue;
100 }
101
102 // Calculate the indices into the Filter and Input buffers.
103 size_t inIdx = libjit_getXYZW(inWdims, n, (size_t)ox,
104 (size_t)oy, g * inCperG);
105 size_t filterIdx = libjit_getXYZW(filterWdims, d, fx, fy, 0);
106 size_t sliceSize =
107 filterWdims[1] * filterWdims[2] * filterWdims[3];
108
109 // Perform the innermost loop of the convolution using 4 vector
110 // registers.
111 for (size_t fd = 0; fd < inCperG; fd++) {
112 int32_t in = inW[inIdx + fd] - inOffset;
113 for (unsigned i = 0; i < MIN(4, depthUnroll); i++) {
114 sum[i] += (filterW[filterIdx + (sliceSize * i) + fd] -
115 filterOffset) *
116 in;
117 }
118 }
119
120 // And perform the innermost loop again with 4 more registers.
121 if (depthUnroll > 4)
122 for (size_t fd = 0; fd < inCperG; fd++) {
123 int32_t in = inW[inIdx + fd] - inOffset;
124 for (unsigned i = 4; i < MIN(8, depthUnroll); i++) {
125 sum[i] += (filterW[filterIdx + (sliceSize * i) + fd] -
126 filterOffset) *
127 in;
128 }
129 }
130 }
131 }
132
133 for (unsigned i = 0; i < depthUnroll; i++) {
134 // Scale the result back to the expected destination scale.
135 int32_t scaledSum = libjit_scale<int32_t>(sum[i], outPre, outPost,
136 outScale, outOffset);
137 scaledSum =
138 libjit_activation_i32(scaledSum, outOffset, actType, actArgs);
139 outW[libjit_getXYZW(outWdims, n, ax, ay, d + i)] =
140 libjit_clip_i8(scaledSum);
141 }
142 } // W
143 } // H
144 } // C
145 } // G
146 } // N
147}
148
149/// Generic template for channelwise quantized conv2d. The template allows
150/// choosing the element type and bias type.
151template <typename ElemTy, typename BiasElemTy>
152void libjit_channelwise_quantized_conv2d_generic(
153 ElemTy *outW, const ElemTy *inW, const ElemTy *filterW,
154 const BiasElemTy *biasW, const dim_t *outWdims, const dim_t *inWdims,
155 const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernels,
156 const dim_t *strides, const dim_t *pads, dim_t group, const dim_t *dilation,
157 int32_t outOffset, int32_t inOffset, int32_t *filterOffsetsPtr,
158 int32_t *biasOffsetsPtr, const int32_t *biasPrePtr,
159 const int32_t *biasPostPtr, const int32_t *biasScalePtr,
160 const int32_t *outPrePtr, const int32_t *outPostPtr,
161 const int32_t *outScalePtr, int32_t actType, const int32_t *actArgs) {
162
163 dim_t inChannels = inWdims[3];
164 dim_t outChannels = outWdims[3];
165 dim_t inCperG = inChannels / group;
166 dim_t outCperG = outChannels / group;
167 dim_t pad_t = pads[0];
168 dim_t pad_l = pads[1];
169 dim_t stride_h = strides[0];
170 dim_t stride_w = strides[1];
171 dim_t kernel_h = kernels[0];
172 dim_t kernel_w = kernels[1];
173
174 // For each input in the batch:
175 for (dim_t n = 0; n < inWdims[0]; n++) {
176 // For each group of input channels:
177 for (dim_t g = 0; g < group; g++) {
178 // For each output channel in the group:
179 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
180
181 // Get channel wise quantization params.
182 int32_t filterOffset = filterOffsetsPtr[d];
183 int32_t biasOffset = biasOffsetsPtr[d];
184 int32_t biasPre = biasPrePtr[d];
185 int32_t biasPost = biasPostPtr[d];
186 int32_t biasScale = biasScalePtr[d];
187 int32_t outPre = outPrePtr[d];
188 int32_t outPost = outPostPtr[d];
189 int32_t outScale = outScalePtr[d];
190
191 // For each convolution 'jump' in the input tensor:
192 sdim_t x = -(sdim_t)pad_t;
193 for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
194 sdim_t y = -(sdim_t)pad_l;
195 for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
196
197 // Scale the bias to match the scale of the matrix multiplication.
198 int32_t sum =
199 libjit_scale<int32_t>((int32_t)biasW[d] - biasOffset, biasPre,
200 biasPost, biasScale, 0);
201
202 // For each element in the convolution-filter:
203 for (dim_t fx = 0; fx < kernel_h; fx++) {
204 for (dim_t fy = 0; fy < kernel_w; fy++) {
205 sdim_t ox = x + fx * dilation[0];
206 sdim_t oy = y + fy * dilation[1];
207
208 // Ignore access outside the input tensor (due to padding).
209 if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
210 oy >= (sdim_t)inWdims[2]) {
211 continue;
212 }
213
214 // Calculate the indices into the Filter and Input buffers.
215 dim_t inIdx = libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy,
216 g * inCperG);
217 dim_t filterIdx = libjit_getXYZW(filterWdims, d, fx, fy, 0);
218
219 // Accumulate along the filter depth.
220 for (dim_t fd = 0; fd < inCperG; fd++) {
221 sum += (filterW[filterIdx + fd] - filterOffset) *
222 (inW[inIdx + fd] - inOffset);
223 }
224 }
225 }
226
227 // Scale the result back to the expected destination scale.
228 int32_t scaledSum = libjit_scale<int32_t>(sum, outPre, outPost,
229 outScale, outOffset);
230 scaledSum =
231 libjit_activation_i32(scaledSum, outOffset, actType, actArgs);
232 outW[libjit_getXYZW(outWdims, n, ax, ay, d)] =
233 libjit_clip_i8(scaledSum);
234 } // W
235 } // H
236 } // C
237 } // G
238 } // N
239}
240
241/// Generic template for channelwise quantized conv3d. The template allows
242/// choosing the element type and bias type.
243template <typename ElemTy, typename BiasElemTy>
244void libjit_channelwise_quantized_conv3d_generic(
245 ElemTy *outW, const ElemTy *inW, const ElemTy *filterW,
246 const BiasElemTy *biasW, const dim_t *outWdims, const dim_t *inWdims,
247 const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernels,
248 const dim_t *strides, const dim_t *pads, dim_t group, const dim_t *dilation,
249 int32_t outOffset, int32_t inOffset, int32_t *filterOffsetsPtr,
250 int32_t *biasOffsetsPtr, const int32_t *biasPrePtr,
251 const int32_t *biasPostPtr, const int32_t *biasScalePtr,
252 const int32_t *outPrePtr, const int32_t *outPostPtr,
253 const int32_t *outScalePtr, int32_t actType, const int32_t *actArgs) {
254
255 dim_t inChannels = inWdims[4];
256 dim_t outChannels = outWdims[4];
257 dim_t inCperG = inChannels / group;
258 dim_t outCperG = outChannels / group;
259
260 dim_t pad_near = pads[0];
261 dim_t pad_top = pads[2];
262 dim_t pad_left = pads[4];
263
264 dim_t stride_t = strides[0];
265 dim_t stride_h = strides[1];
266 dim_t stride_w = strides[2];
267
268 dim_t kernel_t = kernels[0];
269 dim_t kernel_h = kernels[1];
270 dim_t kernel_w = kernels[2];
271
272 (void)dilation;
273
274 // For each input in the batch:
275 for (dim_t n = 0; n < inWdims[0]; n++) {
276 // For each group of input channels:
277 for (dim_t g = 0; g < group; g++) {
278 // For each output channel in the group:
279 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
280
281 // Get channel wise quantization params.
282 int32_t filterOffset = filterOffsetsPtr[d];
283 int32_t biasOffset = biasOffsetsPtr[d];
284 int32_t biasPre = biasPrePtr[d];
285 int32_t biasPost = biasPostPtr[d];
286 int32_t biasScale = biasScalePtr[d];
287 int32_t outPre = outPrePtr[d];
288 int32_t outPost = outPostPtr[d];
289 int32_t outScale = outScalePtr[d];
290
291 // For each convolution 'jump' in the input tensor:
292 sdim_t t = -sdim_t(pad_near);
293 for (dim_t at = 0; at < outWdims[1]; t += stride_t, at++) {
294 sdim_t x = -sdim_t(pad_top);
295 for (dim_t ax = 0; ax < outWdims[2]; x += stride_h, ax++) {
296 sdim_t y = -sdim_t(pad_left);
297 for (dim_t ay = 0; ay < outWdims[3]; y += stride_w, ay++) {
298
299 // Scale the bias to match the scale of the matrix multiplication.
300 int32_t sum =
301 libjit_scale<int32_t>((int32_t)biasW[d] - biasOffset, biasPre,
302 biasPost, biasScale, 0);
303
304 // For each element in the convolution-filter:
305 for (dim_t ft = 0; ft < kernel_t; ft++) {
306 for (dim_t fx = 0; fx < kernel_h; fx++) {
307 for (dim_t fy = 0; fy < kernel_w; fy++) {
308 sdim_t ot = t + ft;
309 sdim_t ox = x + fx;
310 sdim_t oy = y + fy;
311
312 // Ignore index access below zero (this is due to
313 // padding).
314 if (ot < 0 || ox < 0 || oy < 0 ||
315 ot >= (sdim_t)inWdims[1] || ox >= (sdim_t)inWdims[2] ||
316 oy >= (sdim_t)inWdims[3]) {
317 continue;
318 }
319
320 // Calculate the indices into the Filter and Input buffers.
321 dim_t inIdx =
322 libjit_getXYZWQ(inWdims, n, (dim_t)ot, (dim_t)ox,
323 (dim_t)oy, g * inCperG);
324 dim_t filterIdx =
325 libjit_getXYZWQ(filterWdims, d, ft, fx, fy, 0);
326
327 // Accumulate along the filter depth.
328 for (dim_t fd = 0; fd < inCperG; fd++) {
329 sum += (filterW[filterIdx + fd] - filterOffset) *
330 (inW[inIdx + fd] - inOffset);
331 }
332 }
333 }
334 }
335
336 // Scale the result back to the expected destination scale.
337 int32_t scaledSum = libjit_scale<int32_t>(sum, outPre, outPost,
338 outScale, outOffset);
339 scaledSum =
340 libjit_activation_i32(scaledSum, outOffset, actType, actArgs);
341 outW[libjit_getXYZWQ(outWdims, n, at, ax, ay, d)] =
342 libjit_clip_i8(scaledSum);
343 } // W
344 } // H
345 } // T
346 } // C
347 } // G
348 } // N
349}
350} // namespace
351
352extern "C" {
353void libjit_conv2d_f(float *outW, const float *inW, const float *filterW,
354 const float *biasW, const dim_t *outWdims,
355 const dim_t *inWdims, const dim_t *filterWdims,
356 const dim_t *biasWdims, const dim_t *kernelSizes,
357 const dim_t *strides, const dim_t *pads, dim_t group,
358 unsigned depthUnroll, const dim_t *dilation,
359 int32_t actType, const float *actArgs) {
360 dim_t inChannels = inWdims[3];
361 dim_t outChannels = outWdims[3];
362 dim_t inCperG = inChannels / group;
363 dim_t outCperG = outChannels / group;
364
365 // The output dims are calculated already from all of the pads,
366 // therefore we only need the top and left pads here to control the starting
367 // position.
368 dim_t pad_t = pads[0];
369 dim_t pad_l = pads[1];
370 dim_t stride_h = strides[0];
371 dim_t stride_w = strides[1];
372 dim_t kernel_h = kernelSizes[0];
373 dim_t kernel_w = kernelSizes[1];
374 // The size of the input-channel tile. High channel count allow for SIMD
375 // parallelism but create register pressure. Low channel count reduces the
376 // memory pressure and allows things to fit in cache, but require additional
377 // compute (horizontal add) to sum the values in the block. This value is a
378 // compromise between the two.
379 constexpr unsigned cbSize = 512;
380
381 // For each input in the batch:
382 for (dim_t n = 0; n < inWdims[0]; n++) {
383
384 // Initialize the output frame for the N'th slice with the bias.
385 // Later we will accumulate values into this slice.
386 libjit_conv_init_output_with_bias(n, outW, biasW, outWdims, biasWdims);
387
388 // For each group of input channels:
389 for (dim_t g = 0; g < group; g++) {
390
391 // Process the body of the loop in tiles of "channel-block".
392 for (dim_t cb = 0; cb < inCperG; cb += cbSize) {
393
394 // For each output channel in the group. Process 'depthUnroll' output
395 // layers together.
396 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d += depthUnroll) {
397
398 // For each element in the convolution-filter:
399 for (dim_t fx = 0; fx < kernel_h; fx++) {
400 for (dim_t fy = 0; fy < kernel_w; fy++) {
401
402 // Flag to signal whether this is the last iteration in which we
403 // finalize the accumulation and is time to apply the activation.
404 bool lastSumIter = (fx == (kernel_h - 1)) &&
405 (fy == (kernel_w - 1)) &&
406 ((cb + cbSize) >= inCperG);
407
408 // For each convolution 'jump' in the input tensor:
409 for (dim_t outx = 0; outx < outWdims[1]; outx++) {
410 for (dim_t outy = 0; outy < outWdims[2]; outy++) {
411
412 // Process 'depthUnroll' output pixels at once. Each scalar
413 // here represents the convolution sum for one (x,y) point in
414 // the output. We process the same pixel for different output
415 // channel (D) values. The compiler should perform scalar
416 // replacement of aggregates and split this tiny array to
417 // registers.
418 float sum[depthUnroll];
419 for (unsigned i = 0; i < depthUnroll; i++) {
420 sum[i] = 0;
421 }
422
423 // Calculate the specific input x,y that we process in this
424 // iteration.
425 sdim_t inx =
426 (sdim_t)outx * stride_h - pad_t + fx * dilation[0];
427 sdim_t iny =
428 (sdim_t)outy * stride_w - pad_l + fy * dilation[1];
429
430 // Ignore index access below zero (this is due to padding).
431 if (inx < 0 || iny < 0 || inx >= (sdim_t)inWdims[1] ||
432 iny >= (sdim_t)inWdims[2]) {
433 // If this is the last iteration and we skip it we apply
434 // the activation.
435 if (actType && lastSumIter) {
436 for (unsigned i = 0; i < depthUnroll; i++) {
437 dim_t outIdx =
438 libjit_getXYZW(outWdims, n, outx, outy, d + i);
439 outW[outIdx] =
440 libjit_activation_f(outW[outIdx], actType, actArgs);
441 }
442 }
443 continue;
444 }
445
446 // Calculate the indices into the Filter and Input buffers.
447 dim_t inIdx = libjit_getXYZW(inWdims, n, (dim_t)inx,
448 (dim_t)iny, g * inCperG);
449 dim_t filterIdx = libjit_getXYZW(filterWdims, d, fx, fy, 0);
450 dim_t sliceSize =
451 filterWdims[1] * filterWdims[2] * filterWdims[3];
452
453 // Perform the heart of the convolution, 4 elements at a time
454 // to reduce register pressure.
455 for (dim_t fd = cb, e = MIN(cb + cbSize, inCperG); fd < e;
456 fd++) {
457 float in = inW[inIdx + fd];
458 for (unsigned i = 0; i < MIN(4, depthUnroll); i++) {
459 sum[i] += filterW[filterIdx + (sliceSize * i) + fd] * in;
460 }
461 }
462
463 // And run the innermost loop again for the second group of
464 // depth slices:
465 if (depthUnroll > 4) {
466 for (dim_t fd = cb, e = MIN(cb + cbSize, inCperG); fd < e;
467 fd++) {
468 float in = inW[inIdx + fd];
469 for (unsigned i = 4; i < MIN(8, depthUnroll); i++) {
470 sum[i] +=
471 filterW[filterIdx + (sliceSize * i) + fd] * in;
472 }
473 }
474 }
475
476 // Store the results to the output buffer.
477 for (unsigned i = 0; i < depthUnroll; i++) {
478 dim_t outIdx =
479 libjit_getXYZW(outWdims, n, outx, outy, d + i);
480 float sumIter = outW[outIdx] + sum[i];
481 if (actType && lastSumIter) {
482 sumIter = libjit_activation_f(sumIter, actType, actArgs);
483 }
484 outW[outIdx] = sumIter;
485 }
486 }
487 }
488 } // For each Y in the filter.
489 } // For each X in the filter.
490 } // For each D (the depth, or the output channel).
491 } // For each block in the input channel.
492 } // For each group in the input channel.
493 } // For each N, the sample in the batch.
494}
495
496void libjit_conv2d_i8_i32(
497 int8_t *outW, const int8_t *inW, const int8_t *filterW,
498 const int32_t *biasW, const dim_t *outWdims, const dim_t *inWdims,
499 const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernelSizes,
500 const dim_t *strides, const dim_t *pads, dim_t group, int32_t outOffset,
501 int32_t inOffset, int32_t filterOffset, int32_t biasOffset, int32_t biasPre,
502 int32_t biasPost, int32_t biasScale, int32_t outPre, int32_t outPost,
503 int32_t outScale, unsigned depthUnroll, const dim_t *dilation,
504 int32_t actType, const int32_t *actArgs) {
505 libjit_quantized_conv2d_generic<int8_t, int32_t>(
506 outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims,
507 kernelSizes, strides, pads, group, outOffset, inOffset, filterOffset,
508 biasOffset, biasPre, biasPost, biasScale, outPre, outPost, outScale,
509 depthUnroll, dilation, actType, actArgs);
510}
511
512void libjit_conv2d_i8_i8(int8_t *outW, const int8_t *inW, const int8_t *filterW,
513 const int8_t *biasW, const dim_t *outWdims,
514 const dim_t *inWdims, const dim_t *filterWdims,
515 const dim_t *biasWdims, const dim_t *kernelSizes,
516 const dim_t *strides, const dim_t *pads, dim_t group,
517 int32_t outOffset, int32_t inOffset,
518 int32_t filterOffset, int32_t biasOffset,
519 int32_t biasPre, int32_t biasPost, int32_t biasScale,
520 int32_t outPre, int32_t outPost, int32_t outScale,
521 unsigned depthUnroll, const dim_t *dilation,
522 int32_t actType, const int32_t *actArgs) {
523 libjit_quantized_conv2d_generic<int8_t, int8_t>(
524 outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims,
525 kernelSizes, strides, pads, group, outOffset, inOffset, filterOffset,
526 biasOffset, biasPre, biasPost, biasScale, outPre, outPost, outScale,
527 depthUnroll, dilation, actType, actArgs);
528}
529
530void libjit_channelwise_quantized_conv2d_i8_i32(
531 int8_t *outW, const int8_t *inW, const int8_t *filterW,
532 const int32_t *biasW, const dim_t *outWdims, const dim_t *inWdims,
533 const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernels,
534 const dim_t *strides, const dim_t *pads, dim_t group, const dim_t *dilation,
535 int32_t outOffset, int32_t inOffset, int32_t *filterOffsetsPtr,
536 int32_t *biasOffsetsPtr, const int32_t *biasPrePtr,
537 const int32_t *biasPostPtr, const int32_t *biasScalePtr,
538 const int32_t *outPrePtr, const int32_t *outPostPtr,
539 const int32_t *outScalePtr, int32_t actType, const int32_t *actArgs) {
540 libjit_channelwise_quantized_conv2d_generic<int8_t, int32_t>(
541 outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims,
542 kernels, strides, pads, group, dilation, outOffset, inOffset,
543 filterOffsetsPtr, biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr,
544 outPrePtr, outPostPtr, outScalePtr, actType, actArgs);
545}
546
547void libjit_channelwise_quantized_conv2d_i8_i8(
548 int8_t *outW, const int8_t *inW, const int8_t *filterW, const int8_t *biasW,
549 const dim_t *outWdims, const dim_t *inWdims, const dim_t *filterWdims,
550 const dim_t *biasWdims, const dim_t *kernels, const dim_t *strides,
551 const dim_t *pads, dim_t group, const dim_t *dilation, int32_t outOffset,
552 int32_t inOffset, int32_t *filterOffsetsPtr, int32_t *biasOffsetsPtr,
553 const int32_t *biasPrePtr, const int32_t *biasPostPtr,
554 const int32_t *biasScalePtr, const int32_t *outPrePtr,
555 const int32_t *outPostPtr, const int32_t *outScalePtr, int32_t actType,
556 const int32_t *actArgs) {
557 libjit_channelwise_quantized_conv2d_generic<int8_t, int8_t>(
558 outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims,
559 kernels, strides, pads, group, dilation, outOffset, inOffset,
560 filterOffsetsPtr, biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr,
561 outPrePtr, outPostPtr, outScalePtr, actType, actArgs);
562}
563
564void libjit_channelwise_quantized_conv3d_i8_i32(
565 int8_t *outW, const int8_t *inW, const int8_t *filterW,
566 const int32_t *biasW, const dim_t *outWdims, const dim_t *inWdims,
567 const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernels,
568 const dim_t *strides, const dim_t *pads, dim_t group, const dim_t *dilation,
569 int32_t outOffset, int32_t inOffset, int32_t *filterOffsetsPtr,
570 int32_t *biasOffsetsPtr, const int32_t *biasPrePtr,
571 const int32_t *biasPostPtr, const int32_t *biasScalePtr,
572 const int32_t *outPrePtr, const int32_t *outPostPtr,
573 const int32_t *outScalePtr, int32_t actType, const int32_t *actArgs) {
574 libjit_channelwise_quantized_conv3d_generic<int8_t, int32_t>(
575 outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims,
576 kernels, strides, pads, group, dilation, outOffset, inOffset,
577 filterOffsetsPtr, biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr,
578 outPrePtr, outPostPtr, outScalePtr, actType, actArgs);
579}
580
581void libjit_channelwise_quantized_conv3d_i8_i8(
582 int8_t *outW, const int8_t *inW, const int8_t *filterW, const int8_t *biasW,
583 const dim_t *outWdims, const dim_t *inWdims, const dim_t *filterWdims,
584 const dim_t *biasWdims, const dim_t *kernels, const dim_t *strides,
585 const dim_t *pads, dim_t group, const dim_t *dilation, int32_t outOffset,
586 int32_t inOffset, int32_t *filterOffsetsPtr, int32_t *biasOffsetsPtr,
587 const int32_t *biasPrePtr, const int32_t *biasPostPtr,
588 const int32_t *biasScalePtr, const int32_t *outPrePtr,
589 const int32_t *outPostPtr, const int32_t *outScalePtr, int32_t actType,
590 const int32_t *actArgs) {
591 libjit_channelwise_quantized_conv3d_generic<int8_t, int8_t>(
592 outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims,
593 kernels, strides, pads, group, dilation, outOffset, inOffset,
594 filterOffsetsPtr, biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr,
595 outPrePtr, outPostPtr, outScalePtr, actType, actArgs);
596}
597
598void libjit_conv_transpose_f(float *outW, const float *inW,
599 const float *filterW, const float *biasW,
600 const dim_t *outWdims, const dim_t *inWdims,
601 const dim_t *filterWdims, const dim_t *biasWdims,
602 const dim_t *kernels, const dim_t *strides,
603 const dim_t *pads, dim_t group,
604 const dim_t *dilation) {
605 // NHWC format is assumed
606 dim_t p = sizeof(float);
607 memset(outW, 0, outWdims[0] * outWdims[1] * outWdims[2] * outWdims[3] * p);
608
609 dim_t pad_t = pads[0];
610 dim_t pad_l = pads[1];
611 dim_t stride_h = strides[0];
612 dim_t stride_w = strides[1];
613 dim_t kernel_h = kernels[0];
614 dim_t kernel_w = kernels[1];
615 dim_t outCperG = outWdims[3] / group;
616 dim_t inCperG = inWdims[3] / group;
617
618 // For each input in the batch:
619 for (dim_t n = 0; n < inWdims[0]; n++) {
620
621 // Initialize the outputs with the bias.
622 libjit_conv_init_output_with_bias(n, outW, biasW, outWdims, biasWdims);
623
624 // For each group of input channels:
625 for (dim_t g = 0; g < group; g++) {
626 for (dim_t d = g * inCperG; d < (g + 1) * inCperG; d++) {
627 ssize_t x = -(ssize_t)pad_t;
628 for (dim_t bx = 0; bx < inWdims[1]; bx++, x += stride_h) {
629 ssize_t y = -(ssize_t)pad_l;
630 for (dim_t by = 0; by < inWdims[2]; by++, y += stride_w) {
631 float grad = inW[libjit_getXYZW(inWdims, n, bx, by, d)];
632
633 for (dim_t kx = 0; kx < kernel_h; kx++) {
634 for (dim_t ky = 0; ky < kernel_w; ky++) {
635 ssize_t ax = x + kx * dilation[0];
636 ssize_t ay = y + ky * dilation[1];
637
638 if (ax < 0 || ay < 0 || ax >= (ssize_t)outWdims[1] ||
639 ay >= (ssize_t)outWdims[2]) {
640 continue;
641 }
642
643 for (dim_t c = 0; c < outCperG; c++) {
644 dim_t outIndex = libjit_getXYZW(
645 outWdims, n, (dim_t)ax, (dim_t)ay, (g * outCperG + c));
646 dim_t inIndex = libjit_getXYZW(filterWdims, c, kx, ky, d);
647 outW[outIndex] += filterW[inIndex] * grad;
648 }
649 }
650 }
651 } // W
652 } // H
653 } // C
654 } // G
655 } // N
656}
657
658void libjit_convolution_grad_f(float *inG, const float *outG, const float *inW,
659 float *filterG, float *biasG,
660 const float *filterW, const dim_t *outGdims,
661 const dim_t *inWdims, const dim_t *filterGdims,
662 const dim_t *kernels, const dim_t *strides,
663 const dim_t *pads, dim_t group,
664 const dim_t *dilation) {
665 // NHWC format is assumed
666 // Clear inG, filterG, and biasG
667 dim_t p = sizeof(float);
668 memset(inG, 0, inWdims[0] * inWdims[1] * inWdims[2] * inWdims[3] * p);
669 memset(filterG, 0,
670 filterGdims[0] * filterGdims[1] * filterGdims[2] * filterGdims[3] * p);
671 memset(biasG, 0, outGdims[3] * p);
672
673 dim_t pad_t = pads[0];
674 dim_t pad_l = pads[1];
675 dim_t stride_h = strides[0];
676 dim_t stride_w = strides[1];
677 dim_t kernel_h = kernels[0];
678 dim_t kernel_w = kernels[1];
679 dim_t inCperG = inWdims[3] / group;
680 dim_t outCperG = outGdims[3] / group;
681
682 // For each input in the batch:
683 for (dim_t n = 0; n < outGdims[0]; n++) {
684 // For each group of input channels:
685 for (dim_t g = 0; g < group; g++) {
686 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
687 ssize_t x = -(ssize_t)pad_t;
688 for (dim_t bx = 0; bx < outGdims[1]; bx++, x += stride_h) {
689 ssize_t y = -(ssize_t)pad_l;
690 for (dim_t by = 0; by < outGdims[2]; by++, y += stride_w) {
691 float grad = outG[libjit_getXYZW(outGdims, n, bx, by, d)];
692
693 for (dim_t kx = 0; kx < kernel_h; kx++) {
694 for (dim_t ky = 0; ky < kernel_w; ky++) {
695 ssize_t ax = x + kx * dilation[0];
696 ssize_t ay = y + ky * dilation[1];
697
698 if (ax < 0 || ay < 0 || ax >= (ssize_t)inWdims[1] ||
699 ay >= (ssize_t)inWdims[2]) {
700 continue;
701 }
702
703 for (dim_t c = 0; c < inCperG; c++) {
704 inG[libjit_getXYZW(inWdims, n, (dim_t)ax, (dim_t)ay,
705 g * inCperG + c)] +=
706 filterW[libjit_getXYZW(filterGdims, d, kx, ky, c)] * grad;
707 filterG[libjit_getXYZW(filterGdims, d, kx, ky, c)] +=
708 inW[libjit_getXYZW(inWdims, n, (dim_t)ax, (dim_t)ay,
709 g * inCperG + c)] *
710 grad;
711 }
712 }
713 }
714
715 biasG[d] += grad;
716 } // W
717 } // H
718 } // C
719 } // G
720 } // N
721}
722}
723