1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include <assert.h> |
17 | #include <math.h> |
18 | #include <stddef.h> |
19 | #include <stdint.h> |
20 | #include <stdio.h> |
21 | #include <stdlib.h> |
22 | #include <string.h> |
23 | #include <sys/types.h> |
24 | |
25 | #include "libjit_defs.h" |
26 | |
27 | namespace { |
28 | // Initialize the convolution output frame for slice \p N with the bias \p |
29 | // biasW. |
30 | void libjit_conv_init_output_with_bias(dim_t N, float *outW, const float *biasW, |
31 | const dim_t *outWdims, |
32 | const dim_t *biasWdims) { |
33 | // For each (x,y) step in the output tensor: |
34 | for (dim_t ax = 0; ax < outWdims[1]; ax++) { |
35 | for (dim_t ay = 0; ay < outWdims[2]; ay++) { |
36 | // For each output channel: |
37 | for (dim_t d = 0; d < outWdims[3]; d++) { |
38 | // Store the results to the output buffer. |
39 | float bias = biasW[d]; |
40 | auto outIdx = libjit_getXYZW(outWdims, N, ax, ay, d); |
41 | outW[outIdx] = bias; |
42 | } // For each depth in the output. |
43 | } // For each Y in the output. |
44 | } // For each X in the output. |
45 | } |
46 | |
47 | /// Generic template for quantized conv2d. The template allows choosing |
48 | /// element type and bias type. |
49 | template <typename ElemTy, typename BiasElemTy> |
50 | void libjit_quantized_conv2d_generic( |
51 | ElemTy *outW, const ElemTy *inW, const ElemTy *filterW, |
52 | const BiasElemTy *biasW, const dim_t *outWdims, const dim_t *inWdims, |
53 | const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernelSizes, |
54 | const dim_t *strides, const dim_t *pads, dim_t group, int32_t outOffset, |
55 | int32_t inOffset, int32_t filterOffset, int32_t biasOffset, int32_t biasPre, |
56 | int32_t biasPost, int32_t biasScale, int32_t outPre, int32_t outPost, |
57 | int32_t outScale, unsigned depthUnroll, const dim_t *dilation, |
58 | int32_t actType, const int32_t *actArgs) { |
59 | dim_t inChannels = inWdims[3]; |
60 | dim_t outChannels = outWdims[3]; |
61 | dim_t inCperG = inChannels / group; |
62 | dim_t outCperG = outChannels / group; |
63 | dim_t pad_t = pads[0]; |
64 | dim_t pad_l = pads[1]; |
65 | dim_t stride_h = strides[0]; |
66 | size_t stride_w = strides[1]; |
67 | size_t kernel_h = kernelSizes[0]; |
68 | size_t kernel_w = kernelSizes[1]; |
69 | // For each input in the batch: |
70 | for (size_t n = 0; n < inWdims[0]; n++) { |
71 | // For each group of input channels: |
72 | for (size_t g = 0; g < group; g++) { |
73 | |
74 | // For each output channel in the group. Process 'depthUnroll' output |
75 | // layers together. |
76 | for (size_t d = g * outCperG; d < (g + 1) * outCperG; d += depthUnroll) { |
77 | // For each convolution 'jump' in the input tensor: |
78 | ssize_t x = -(ssize_t)pad_t; |
79 | for (size_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) { |
80 | ssize_t y = -(ssize_t)pad_l; |
81 | for (size_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) { |
82 | int32_t sum[depthUnroll]; |
83 | |
84 | for (unsigned i = 0; i < depthUnroll; i++) { |
85 | // Scale the bias to match the scale of the matrix multiplication. |
86 | sum[i] = libjit_scale<int32_t>((int32_t)biasW[d + i] - biasOffset, |
87 | biasPre, biasPost, biasScale, 0); |
88 | } |
89 | |
90 | // For each element in the convolution-filter: |
91 | for (size_t fx = 0; fx < kernel_h; fx++) { |
92 | for (size_t fy = 0; fy < kernel_w; fy++) { |
93 | ssize_t ox = x + fx * dilation[0]; |
94 | ssize_t oy = y + fy * dilation[1]; |
95 | |
96 | // Ignore index access below zero (this is due to padding). |
97 | if (ox < 0 || oy < 0 || ox >= (ssize_t)inWdims[1] || |
98 | oy >= (ssize_t)inWdims[2]) { |
99 | continue; |
100 | } |
101 | |
102 | // Calculate the indices into the Filter and Input buffers. |
103 | size_t inIdx = libjit_getXYZW(inWdims, n, (size_t)ox, |
104 | (size_t)oy, g * inCperG); |
105 | size_t filterIdx = libjit_getXYZW(filterWdims, d, fx, fy, 0); |
106 | size_t sliceSize = |
107 | filterWdims[1] * filterWdims[2] * filterWdims[3]; |
108 | |
109 | // Perform the innermost loop of the convolution using 4 vector |
110 | // registers. |
111 | for (size_t fd = 0; fd < inCperG; fd++) { |
112 | int32_t in = inW[inIdx + fd] - inOffset; |
113 | for (unsigned i = 0; i < MIN(4, depthUnroll); i++) { |
114 | sum[i] += (filterW[filterIdx + (sliceSize * i) + fd] - |
115 | filterOffset) * |
116 | in; |
117 | } |
118 | } |
119 | |
120 | // And perform the innermost loop again with 4 more registers. |
121 | if (depthUnroll > 4) |
122 | for (size_t fd = 0; fd < inCperG; fd++) { |
123 | int32_t in = inW[inIdx + fd] - inOffset; |
124 | for (unsigned i = 4; i < MIN(8, depthUnroll); i++) { |
125 | sum[i] += (filterW[filterIdx + (sliceSize * i) + fd] - |
126 | filterOffset) * |
127 | in; |
128 | } |
129 | } |
130 | } |
131 | } |
132 | |
133 | for (unsigned i = 0; i < depthUnroll; i++) { |
134 | // Scale the result back to the expected destination scale. |
135 | int32_t scaledSum = libjit_scale<int32_t>(sum[i], outPre, outPost, |
136 | outScale, outOffset); |
137 | scaledSum = |
138 | libjit_activation_i32(scaledSum, outOffset, actType, actArgs); |
139 | outW[libjit_getXYZW(outWdims, n, ax, ay, d + i)] = |
140 | libjit_clip_i8(scaledSum); |
141 | } |
142 | } // W |
143 | } // H |
144 | } // C |
145 | } // G |
146 | } // N |
147 | } |
148 | |
149 | /// Generic template for channelwise quantized conv2d. The template allows |
150 | /// choosing the element type and bias type. |
151 | template <typename ElemTy, typename BiasElemTy> |
152 | void libjit_channelwise_quantized_conv2d_generic( |
153 | ElemTy *outW, const ElemTy *inW, const ElemTy *filterW, |
154 | const BiasElemTy *biasW, const dim_t *outWdims, const dim_t *inWdims, |
155 | const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernels, |
156 | const dim_t *strides, const dim_t *pads, dim_t group, const dim_t *dilation, |
157 | int32_t outOffset, int32_t inOffset, int32_t *filterOffsetsPtr, |
158 | int32_t *biasOffsetsPtr, const int32_t *biasPrePtr, |
159 | const int32_t *biasPostPtr, const int32_t *biasScalePtr, |
160 | const int32_t *outPrePtr, const int32_t *outPostPtr, |
161 | const int32_t *outScalePtr, int32_t actType, const int32_t *actArgs) { |
162 | |
163 | dim_t inChannels = inWdims[3]; |
164 | dim_t outChannels = outWdims[3]; |
165 | dim_t inCperG = inChannels / group; |
166 | dim_t outCperG = outChannels / group; |
167 | dim_t pad_t = pads[0]; |
168 | dim_t pad_l = pads[1]; |
169 | dim_t stride_h = strides[0]; |
170 | dim_t stride_w = strides[1]; |
171 | dim_t kernel_h = kernels[0]; |
172 | dim_t kernel_w = kernels[1]; |
173 | |
174 | // For each input in the batch: |
175 | for (dim_t n = 0; n < inWdims[0]; n++) { |
176 | // For each group of input channels: |
177 | for (dim_t g = 0; g < group; g++) { |
178 | // For each output channel in the group: |
179 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) { |
180 | |
181 | // Get channel wise quantization params. |
182 | int32_t filterOffset = filterOffsetsPtr[d]; |
183 | int32_t biasOffset = biasOffsetsPtr[d]; |
184 | int32_t biasPre = biasPrePtr[d]; |
185 | int32_t biasPost = biasPostPtr[d]; |
186 | int32_t biasScale = biasScalePtr[d]; |
187 | int32_t outPre = outPrePtr[d]; |
188 | int32_t outPost = outPostPtr[d]; |
189 | int32_t outScale = outScalePtr[d]; |
190 | |
191 | // For each convolution 'jump' in the input tensor: |
192 | sdim_t x = -(sdim_t)pad_t; |
193 | for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) { |
194 | sdim_t y = -(sdim_t)pad_l; |
195 | for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) { |
196 | |
197 | // Scale the bias to match the scale of the matrix multiplication. |
198 | int32_t sum = |
199 | libjit_scale<int32_t>((int32_t)biasW[d] - biasOffset, biasPre, |
200 | biasPost, biasScale, 0); |
201 | |
202 | // For each element in the convolution-filter: |
203 | for (dim_t fx = 0; fx < kernel_h; fx++) { |
204 | for (dim_t fy = 0; fy < kernel_w; fy++) { |
205 | sdim_t ox = x + fx * dilation[0]; |
206 | sdim_t oy = y + fy * dilation[1]; |
207 | |
208 | // Ignore access outside the input tensor (due to padding). |
209 | if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] || |
210 | oy >= (sdim_t)inWdims[2]) { |
211 | continue; |
212 | } |
213 | |
214 | // Calculate the indices into the Filter and Input buffers. |
215 | dim_t inIdx = libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, |
216 | g * inCperG); |
217 | dim_t filterIdx = libjit_getXYZW(filterWdims, d, fx, fy, 0); |
218 | |
219 | // Accumulate along the filter depth. |
220 | for (dim_t fd = 0; fd < inCperG; fd++) { |
221 | sum += (filterW[filterIdx + fd] - filterOffset) * |
222 | (inW[inIdx + fd] - inOffset); |
223 | } |
224 | } |
225 | } |
226 | |
227 | // Scale the result back to the expected destination scale. |
228 | int32_t scaledSum = libjit_scale<int32_t>(sum, outPre, outPost, |
229 | outScale, outOffset); |
230 | scaledSum = |
231 | libjit_activation_i32(scaledSum, outOffset, actType, actArgs); |
232 | outW[libjit_getXYZW(outWdims, n, ax, ay, d)] = |
233 | libjit_clip_i8(scaledSum); |
234 | } // W |
235 | } // H |
236 | } // C |
237 | } // G |
238 | } // N |
239 | } |
240 | |
241 | /// Generic template for channelwise quantized conv3d. The template allows |
242 | /// choosing the element type and bias type. |
243 | template <typename ElemTy, typename BiasElemTy> |
244 | void libjit_channelwise_quantized_conv3d_generic( |
245 | ElemTy *outW, const ElemTy *inW, const ElemTy *filterW, |
246 | const BiasElemTy *biasW, const dim_t *outWdims, const dim_t *inWdims, |
247 | const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernels, |
248 | const dim_t *strides, const dim_t *pads, dim_t group, const dim_t *dilation, |
249 | int32_t outOffset, int32_t inOffset, int32_t *filterOffsetsPtr, |
250 | int32_t *biasOffsetsPtr, const int32_t *biasPrePtr, |
251 | const int32_t *biasPostPtr, const int32_t *biasScalePtr, |
252 | const int32_t *outPrePtr, const int32_t *outPostPtr, |
253 | const int32_t *outScalePtr, int32_t actType, const int32_t *actArgs) { |
254 | |
255 | dim_t inChannels = inWdims[4]; |
256 | dim_t outChannels = outWdims[4]; |
257 | dim_t inCperG = inChannels / group; |
258 | dim_t outCperG = outChannels / group; |
259 | |
260 | dim_t pad_near = pads[0]; |
261 | dim_t pad_top = pads[2]; |
262 | dim_t pad_left = pads[4]; |
263 | |
264 | dim_t stride_t = strides[0]; |
265 | dim_t stride_h = strides[1]; |
266 | dim_t stride_w = strides[2]; |
267 | |
268 | dim_t kernel_t = kernels[0]; |
269 | dim_t kernel_h = kernels[1]; |
270 | dim_t kernel_w = kernels[2]; |
271 | |
272 | (void)dilation; |
273 | |
274 | // For each input in the batch: |
275 | for (dim_t n = 0; n < inWdims[0]; n++) { |
276 | // For each group of input channels: |
277 | for (dim_t g = 0; g < group; g++) { |
278 | // For each output channel in the group: |
279 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) { |
280 | |
281 | // Get channel wise quantization params. |
282 | int32_t filterOffset = filterOffsetsPtr[d]; |
283 | int32_t biasOffset = biasOffsetsPtr[d]; |
284 | int32_t biasPre = biasPrePtr[d]; |
285 | int32_t biasPost = biasPostPtr[d]; |
286 | int32_t biasScale = biasScalePtr[d]; |
287 | int32_t outPre = outPrePtr[d]; |
288 | int32_t outPost = outPostPtr[d]; |
289 | int32_t outScale = outScalePtr[d]; |
290 | |
291 | // For each convolution 'jump' in the input tensor: |
292 | sdim_t t = -sdim_t(pad_near); |
293 | for (dim_t at = 0; at < outWdims[1]; t += stride_t, at++) { |
294 | sdim_t x = -sdim_t(pad_top); |
295 | for (dim_t ax = 0; ax < outWdims[2]; x += stride_h, ax++) { |
296 | sdim_t y = -sdim_t(pad_left); |
297 | for (dim_t ay = 0; ay < outWdims[3]; y += stride_w, ay++) { |
298 | |
299 | // Scale the bias to match the scale of the matrix multiplication. |
300 | int32_t sum = |
301 | libjit_scale<int32_t>((int32_t)biasW[d] - biasOffset, biasPre, |
302 | biasPost, biasScale, 0); |
303 | |
304 | // For each element in the convolution-filter: |
305 | for (dim_t ft = 0; ft < kernel_t; ft++) { |
306 | for (dim_t fx = 0; fx < kernel_h; fx++) { |
307 | for (dim_t fy = 0; fy < kernel_w; fy++) { |
308 | sdim_t ot = t + ft; |
309 | sdim_t ox = x + fx; |
310 | sdim_t oy = y + fy; |
311 | |
312 | // Ignore index access below zero (this is due to |
313 | // padding). |
314 | if (ot < 0 || ox < 0 || oy < 0 || |
315 | ot >= (sdim_t)inWdims[1] || ox >= (sdim_t)inWdims[2] || |
316 | oy >= (sdim_t)inWdims[3]) { |
317 | continue; |
318 | } |
319 | |
320 | // Calculate the indices into the Filter and Input buffers. |
321 | dim_t inIdx = |
322 | libjit_getXYZWQ(inWdims, n, (dim_t)ot, (dim_t)ox, |
323 | (dim_t)oy, g * inCperG); |
324 | dim_t filterIdx = |
325 | libjit_getXYZWQ(filterWdims, d, ft, fx, fy, 0); |
326 | |
327 | // Accumulate along the filter depth. |
328 | for (dim_t fd = 0; fd < inCperG; fd++) { |
329 | sum += (filterW[filterIdx + fd] - filterOffset) * |
330 | (inW[inIdx + fd] - inOffset); |
331 | } |
332 | } |
333 | } |
334 | } |
335 | |
336 | // Scale the result back to the expected destination scale. |
337 | int32_t scaledSum = libjit_scale<int32_t>(sum, outPre, outPost, |
338 | outScale, outOffset); |
339 | scaledSum = |
340 | libjit_activation_i32(scaledSum, outOffset, actType, actArgs); |
341 | outW[libjit_getXYZWQ(outWdims, n, at, ax, ay, d)] = |
342 | libjit_clip_i8(scaledSum); |
343 | } // W |
344 | } // H |
345 | } // T |
346 | } // C |
347 | } // G |
348 | } // N |
349 | } |
350 | } // namespace |
351 | |
352 | extern "C" { |
353 | void libjit_conv2d_f(float *outW, const float *inW, const float *filterW, |
354 | const float *biasW, const dim_t *outWdims, |
355 | const dim_t *inWdims, const dim_t *filterWdims, |
356 | const dim_t *biasWdims, const dim_t *kernelSizes, |
357 | const dim_t *strides, const dim_t *pads, dim_t group, |
358 | unsigned depthUnroll, const dim_t *dilation, |
359 | int32_t actType, const float *actArgs) { |
360 | dim_t inChannels = inWdims[3]; |
361 | dim_t outChannels = outWdims[3]; |
362 | dim_t inCperG = inChannels / group; |
363 | dim_t outCperG = outChannels / group; |
364 | |
365 | // The output dims are calculated already from all of the pads, |
366 | // therefore we only need the top and left pads here to control the starting |
367 | // position. |
368 | dim_t pad_t = pads[0]; |
369 | dim_t pad_l = pads[1]; |
370 | dim_t stride_h = strides[0]; |
371 | dim_t stride_w = strides[1]; |
372 | dim_t kernel_h = kernelSizes[0]; |
373 | dim_t kernel_w = kernelSizes[1]; |
374 | // The size of the input-channel tile. High channel count allow for SIMD |
375 | // parallelism but create register pressure. Low channel count reduces the |
376 | // memory pressure and allows things to fit in cache, but require additional |
377 | // compute (horizontal add) to sum the values in the block. This value is a |
378 | // compromise between the two. |
379 | constexpr unsigned cbSize = 512; |
380 | |
381 | // For each input in the batch: |
382 | for (dim_t n = 0; n < inWdims[0]; n++) { |
383 | |
384 | // Initialize the output frame for the N'th slice with the bias. |
385 | // Later we will accumulate values into this slice. |
386 | libjit_conv_init_output_with_bias(n, outW, biasW, outWdims, biasWdims); |
387 | |
388 | // For each group of input channels: |
389 | for (dim_t g = 0; g < group; g++) { |
390 | |
391 | // Process the body of the loop in tiles of "channel-block". |
392 | for (dim_t cb = 0; cb < inCperG; cb += cbSize) { |
393 | |
394 | // For each output channel in the group. Process 'depthUnroll' output |
395 | // layers together. |
396 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d += depthUnroll) { |
397 | |
398 | // For each element in the convolution-filter: |
399 | for (dim_t fx = 0; fx < kernel_h; fx++) { |
400 | for (dim_t fy = 0; fy < kernel_w; fy++) { |
401 | |
402 | // Flag to signal whether this is the last iteration in which we |
403 | // finalize the accumulation and is time to apply the activation. |
404 | bool lastSumIter = (fx == (kernel_h - 1)) && |
405 | (fy == (kernel_w - 1)) && |
406 | ((cb + cbSize) >= inCperG); |
407 | |
408 | // For each convolution 'jump' in the input tensor: |
409 | for (dim_t outx = 0; outx < outWdims[1]; outx++) { |
410 | for (dim_t outy = 0; outy < outWdims[2]; outy++) { |
411 | |
412 | // Process 'depthUnroll' output pixels at once. Each scalar |
413 | // here represents the convolution sum for one (x,y) point in |
414 | // the output. We process the same pixel for different output |
415 | // channel (D) values. The compiler should perform scalar |
416 | // replacement of aggregates and split this tiny array to |
417 | // registers. |
418 | float sum[depthUnroll]; |
419 | for (unsigned i = 0; i < depthUnroll; i++) { |
420 | sum[i] = 0; |
421 | } |
422 | |
423 | // Calculate the specific input x,y that we process in this |
424 | // iteration. |
425 | sdim_t inx = |
426 | (sdim_t)outx * stride_h - pad_t + fx * dilation[0]; |
427 | sdim_t iny = |
428 | (sdim_t)outy * stride_w - pad_l + fy * dilation[1]; |
429 | |
430 | // Ignore index access below zero (this is due to padding). |
431 | if (inx < 0 || iny < 0 || inx >= (sdim_t)inWdims[1] || |
432 | iny >= (sdim_t)inWdims[2]) { |
433 | // If this is the last iteration and we skip it we apply |
434 | // the activation. |
435 | if (actType && lastSumIter) { |
436 | for (unsigned i = 0; i < depthUnroll; i++) { |
437 | dim_t outIdx = |
438 | libjit_getXYZW(outWdims, n, outx, outy, d + i); |
439 | outW[outIdx] = |
440 | libjit_activation_f(outW[outIdx], actType, actArgs); |
441 | } |
442 | } |
443 | continue; |
444 | } |
445 | |
446 | // Calculate the indices into the Filter and Input buffers. |
447 | dim_t inIdx = libjit_getXYZW(inWdims, n, (dim_t)inx, |
448 | (dim_t)iny, g * inCperG); |
449 | dim_t filterIdx = libjit_getXYZW(filterWdims, d, fx, fy, 0); |
450 | dim_t sliceSize = |
451 | filterWdims[1] * filterWdims[2] * filterWdims[3]; |
452 | |
453 | // Perform the heart of the convolution, 4 elements at a time |
454 | // to reduce register pressure. |
455 | for (dim_t fd = cb, e = MIN(cb + cbSize, inCperG); fd < e; |
456 | fd++) { |
457 | float in = inW[inIdx + fd]; |
458 | for (unsigned i = 0; i < MIN(4, depthUnroll); i++) { |
459 | sum[i] += filterW[filterIdx + (sliceSize * i) + fd] * in; |
460 | } |
461 | } |
462 | |
463 | // And run the innermost loop again for the second group of |
464 | // depth slices: |
465 | if (depthUnroll > 4) { |
466 | for (dim_t fd = cb, e = MIN(cb + cbSize, inCperG); fd < e; |
467 | fd++) { |
468 | float in = inW[inIdx + fd]; |
469 | for (unsigned i = 4; i < MIN(8, depthUnroll); i++) { |
470 | sum[i] += |
471 | filterW[filterIdx + (sliceSize * i) + fd] * in; |
472 | } |
473 | } |
474 | } |
475 | |
476 | // Store the results to the output buffer. |
477 | for (unsigned i = 0; i < depthUnroll; i++) { |
478 | dim_t outIdx = |
479 | libjit_getXYZW(outWdims, n, outx, outy, d + i); |
480 | float sumIter = outW[outIdx] + sum[i]; |
481 | if (actType && lastSumIter) { |
482 | sumIter = libjit_activation_f(sumIter, actType, actArgs); |
483 | } |
484 | outW[outIdx] = sumIter; |
485 | } |
486 | } |
487 | } |
488 | } // For each Y in the filter. |
489 | } // For each X in the filter. |
490 | } // For each D (the depth, or the output channel). |
491 | } // For each block in the input channel. |
492 | } // For each group in the input channel. |
493 | } // For each N, the sample in the batch. |
494 | } |
495 | |
496 | void libjit_conv2d_i8_i32( |
497 | int8_t *outW, const int8_t *inW, const int8_t *filterW, |
498 | const int32_t *biasW, const dim_t *outWdims, const dim_t *inWdims, |
499 | const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernelSizes, |
500 | const dim_t *strides, const dim_t *pads, dim_t group, int32_t outOffset, |
501 | int32_t inOffset, int32_t filterOffset, int32_t biasOffset, int32_t biasPre, |
502 | int32_t biasPost, int32_t biasScale, int32_t outPre, int32_t outPost, |
503 | int32_t outScale, unsigned depthUnroll, const dim_t *dilation, |
504 | int32_t actType, const int32_t *actArgs) { |
505 | libjit_quantized_conv2d_generic<int8_t, int32_t>( |
506 | outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims, |
507 | kernelSizes, strides, pads, group, outOffset, inOffset, filterOffset, |
508 | biasOffset, biasPre, biasPost, biasScale, outPre, outPost, outScale, |
509 | depthUnroll, dilation, actType, actArgs); |
510 | } |
511 | |
512 | void libjit_conv2d_i8_i8(int8_t *outW, const int8_t *inW, const int8_t *filterW, |
513 | const int8_t *biasW, const dim_t *outWdims, |
514 | const dim_t *inWdims, const dim_t *filterWdims, |
515 | const dim_t *biasWdims, const dim_t *kernelSizes, |
516 | const dim_t *strides, const dim_t *pads, dim_t group, |
517 | int32_t outOffset, int32_t inOffset, |
518 | int32_t filterOffset, int32_t biasOffset, |
519 | int32_t biasPre, int32_t biasPost, int32_t biasScale, |
520 | int32_t outPre, int32_t outPost, int32_t outScale, |
521 | unsigned depthUnroll, const dim_t *dilation, |
522 | int32_t actType, const int32_t *actArgs) { |
523 | libjit_quantized_conv2d_generic<int8_t, int8_t>( |
524 | outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims, |
525 | kernelSizes, strides, pads, group, outOffset, inOffset, filterOffset, |
526 | biasOffset, biasPre, biasPost, biasScale, outPre, outPost, outScale, |
527 | depthUnroll, dilation, actType, actArgs); |
528 | } |
529 | |
530 | void libjit_channelwise_quantized_conv2d_i8_i32( |
531 | int8_t *outW, const int8_t *inW, const int8_t *filterW, |
532 | const int32_t *biasW, const dim_t *outWdims, const dim_t *inWdims, |
533 | const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernels, |
534 | const dim_t *strides, const dim_t *pads, dim_t group, const dim_t *dilation, |
535 | int32_t outOffset, int32_t inOffset, int32_t *filterOffsetsPtr, |
536 | int32_t *biasOffsetsPtr, const int32_t *biasPrePtr, |
537 | const int32_t *biasPostPtr, const int32_t *biasScalePtr, |
538 | const int32_t *outPrePtr, const int32_t *outPostPtr, |
539 | const int32_t *outScalePtr, int32_t actType, const int32_t *actArgs) { |
540 | libjit_channelwise_quantized_conv2d_generic<int8_t, int32_t>( |
541 | outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims, |
542 | kernels, strides, pads, group, dilation, outOffset, inOffset, |
543 | filterOffsetsPtr, biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr, |
544 | outPrePtr, outPostPtr, outScalePtr, actType, actArgs); |
545 | } |
546 | |
547 | void libjit_channelwise_quantized_conv2d_i8_i8( |
548 | int8_t *outW, const int8_t *inW, const int8_t *filterW, const int8_t *biasW, |
549 | const dim_t *outWdims, const dim_t *inWdims, const dim_t *filterWdims, |
550 | const dim_t *biasWdims, const dim_t *kernels, const dim_t *strides, |
551 | const dim_t *pads, dim_t group, const dim_t *dilation, int32_t outOffset, |
552 | int32_t inOffset, int32_t *filterOffsetsPtr, int32_t *biasOffsetsPtr, |
553 | const int32_t *biasPrePtr, const int32_t *biasPostPtr, |
554 | const int32_t *biasScalePtr, const int32_t *outPrePtr, |
555 | const int32_t *outPostPtr, const int32_t *outScalePtr, int32_t actType, |
556 | const int32_t *actArgs) { |
557 | libjit_channelwise_quantized_conv2d_generic<int8_t, int8_t>( |
558 | outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims, |
559 | kernels, strides, pads, group, dilation, outOffset, inOffset, |
560 | filterOffsetsPtr, biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr, |
561 | outPrePtr, outPostPtr, outScalePtr, actType, actArgs); |
562 | } |
563 | |
564 | void libjit_channelwise_quantized_conv3d_i8_i32( |
565 | int8_t *outW, const int8_t *inW, const int8_t *filterW, |
566 | const int32_t *biasW, const dim_t *outWdims, const dim_t *inWdims, |
567 | const dim_t *filterWdims, const dim_t *biasWdims, const dim_t *kernels, |
568 | const dim_t *strides, const dim_t *pads, dim_t group, const dim_t *dilation, |
569 | int32_t outOffset, int32_t inOffset, int32_t *filterOffsetsPtr, |
570 | int32_t *biasOffsetsPtr, const int32_t *biasPrePtr, |
571 | const int32_t *biasPostPtr, const int32_t *biasScalePtr, |
572 | const int32_t *outPrePtr, const int32_t *outPostPtr, |
573 | const int32_t *outScalePtr, int32_t actType, const int32_t *actArgs) { |
574 | libjit_channelwise_quantized_conv3d_generic<int8_t, int32_t>( |
575 | outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims, |
576 | kernels, strides, pads, group, dilation, outOffset, inOffset, |
577 | filterOffsetsPtr, biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr, |
578 | outPrePtr, outPostPtr, outScalePtr, actType, actArgs); |
579 | } |
580 | |
581 | void libjit_channelwise_quantized_conv3d_i8_i8( |
582 | int8_t *outW, const int8_t *inW, const int8_t *filterW, const int8_t *biasW, |
583 | const dim_t *outWdims, const dim_t *inWdims, const dim_t *filterWdims, |
584 | const dim_t *biasWdims, const dim_t *kernels, const dim_t *strides, |
585 | const dim_t *pads, dim_t group, const dim_t *dilation, int32_t outOffset, |
586 | int32_t inOffset, int32_t *filterOffsetsPtr, int32_t *biasOffsetsPtr, |
587 | const int32_t *biasPrePtr, const int32_t *biasPostPtr, |
588 | const int32_t *biasScalePtr, const int32_t *outPrePtr, |
589 | const int32_t *outPostPtr, const int32_t *outScalePtr, int32_t actType, |
590 | const int32_t *actArgs) { |
591 | libjit_channelwise_quantized_conv3d_generic<int8_t, int8_t>( |
592 | outW, inW, filterW, biasW, outWdims, inWdims, filterWdims, biasWdims, |
593 | kernels, strides, pads, group, dilation, outOffset, inOffset, |
594 | filterOffsetsPtr, biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr, |
595 | outPrePtr, outPostPtr, outScalePtr, actType, actArgs); |
596 | } |
597 | |
598 | void libjit_conv_transpose_f(float *outW, const float *inW, |
599 | const float *filterW, const float *biasW, |
600 | const dim_t *outWdims, const dim_t *inWdims, |
601 | const dim_t *filterWdims, const dim_t *biasWdims, |
602 | const dim_t *kernels, const dim_t *strides, |
603 | const dim_t *pads, dim_t group, |
604 | const dim_t *dilation) { |
605 | // NHWC format is assumed |
606 | dim_t p = sizeof(float); |
607 | memset(outW, 0, outWdims[0] * outWdims[1] * outWdims[2] * outWdims[3] * p); |
608 | |
609 | dim_t pad_t = pads[0]; |
610 | dim_t pad_l = pads[1]; |
611 | dim_t stride_h = strides[0]; |
612 | dim_t stride_w = strides[1]; |
613 | dim_t kernel_h = kernels[0]; |
614 | dim_t kernel_w = kernels[1]; |
615 | dim_t outCperG = outWdims[3] / group; |
616 | dim_t inCperG = inWdims[3] / group; |
617 | |
618 | // For each input in the batch: |
619 | for (dim_t n = 0; n < inWdims[0]; n++) { |
620 | |
621 | // Initialize the outputs with the bias. |
622 | libjit_conv_init_output_with_bias(n, outW, biasW, outWdims, biasWdims); |
623 | |
624 | // For each group of input channels: |
625 | for (dim_t g = 0; g < group; g++) { |
626 | for (dim_t d = g * inCperG; d < (g + 1) * inCperG; d++) { |
627 | ssize_t x = -(ssize_t)pad_t; |
628 | for (dim_t bx = 0; bx < inWdims[1]; bx++, x += stride_h) { |
629 | ssize_t y = -(ssize_t)pad_l; |
630 | for (dim_t by = 0; by < inWdims[2]; by++, y += stride_w) { |
631 | float grad = inW[libjit_getXYZW(inWdims, n, bx, by, d)]; |
632 | |
633 | for (dim_t kx = 0; kx < kernel_h; kx++) { |
634 | for (dim_t ky = 0; ky < kernel_w; ky++) { |
635 | ssize_t ax = x + kx * dilation[0]; |
636 | ssize_t ay = y + ky * dilation[1]; |
637 | |
638 | if (ax < 0 || ay < 0 || ax >= (ssize_t)outWdims[1] || |
639 | ay >= (ssize_t)outWdims[2]) { |
640 | continue; |
641 | } |
642 | |
643 | for (dim_t c = 0; c < outCperG; c++) { |
644 | dim_t outIndex = libjit_getXYZW( |
645 | outWdims, n, (dim_t)ax, (dim_t)ay, (g * outCperG + c)); |
646 | dim_t inIndex = libjit_getXYZW(filterWdims, c, kx, ky, d); |
647 | outW[outIndex] += filterW[inIndex] * grad; |
648 | } |
649 | } |
650 | } |
651 | } // W |
652 | } // H |
653 | } // C |
654 | } // G |
655 | } // N |
656 | } |
657 | |
658 | void libjit_convolution_grad_f(float *inG, const float *outG, const float *inW, |
659 | float *filterG, float *biasG, |
660 | const float *filterW, const dim_t *outGdims, |
661 | const dim_t *inWdims, const dim_t *filterGdims, |
662 | const dim_t *kernels, const dim_t *strides, |
663 | const dim_t *pads, dim_t group, |
664 | const dim_t *dilation) { |
665 | // NHWC format is assumed |
666 | // Clear inG, filterG, and biasG |
667 | dim_t p = sizeof(float); |
668 | memset(inG, 0, inWdims[0] * inWdims[1] * inWdims[2] * inWdims[3] * p); |
669 | memset(filterG, 0, |
670 | filterGdims[0] * filterGdims[1] * filterGdims[2] * filterGdims[3] * p); |
671 | memset(biasG, 0, outGdims[3] * p); |
672 | |
673 | dim_t pad_t = pads[0]; |
674 | dim_t pad_l = pads[1]; |
675 | dim_t stride_h = strides[0]; |
676 | dim_t stride_w = strides[1]; |
677 | dim_t kernel_h = kernels[0]; |
678 | dim_t kernel_w = kernels[1]; |
679 | dim_t inCperG = inWdims[3] / group; |
680 | dim_t outCperG = outGdims[3] / group; |
681 | |
682 | // For each input in the batch: |
683 | for (dim_t n = 0; n < outGdims[0]; n++) { |
684 | // For each group of input channels: |
685 | for (dim_t g = 0; g < group; g++) { |
686 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) { |
687 | ssize_t x = -(ssize_t)pad_t; |
688 | for (dim_t bx = 0; bx < outGdims[1]; bx++, x += stride_h) { |
689 | ssize_t y = -(ssize_t)pad_l; |
690 | for (dim_t by = 0; by < outGdims[2]; by++, y += stride_w) { |
691 | float grad = outG[libjit_getXYZW(outGdims, n, bx, by, d)]; |
692 | |
693 | for (dim_t kx = 0; kx < kernel_h; kx++) { |
694 | for (dim_t ky = 0; ky < kernel_w; ky++) { |
695 | ssize_t ax = x + kx * dilation[0]; |
696 | ssize_t ay = y + ky * dilation[1]; |
697 | |
698 | if (ax < 0 || ay < 0 || ax >= (ssize_t)inWdims[1] || |
699 | ay >= (ssize_t)inWdims[2]) { |
700 | continue; |
701 | } |
702 | |
703 | for (dim_t c = 0; c < inCperG; c++) { |
704 | inG[libjit_getXYZW(inWdims, n, (dim_t)ax, (dim_t)ay, |
705 | g * inCperG + c)] += |
706 | filterW[libjit_getXYZW(filterGdims, d, kx, ky, c)] * grad; |
707 | filterG[libjit_getXYZW(filterGdims, d, kx, ky, c)] += |
708 | inW[libjit_getXYZW(inWdims, n, (dim_t)ax, (dim_t)ay, |
709 | g * inCperG + c)] * |
710 | grad; |
711 | } |
712 | } |
713 | } |
714 | |
715 | biasG[d] += grad; |
716 | } // W |
717 | } // H |
718 | } // C |
719 | } // G |
720 | } // N |
721 | } |
722 | } |
723 | |