1#include <stdbool.h>
2#include <stdint.h>
3#include <stddef.h>
4#include <string.h>
5
6#include <fxdiv.h>
7
8#include <nnpack.h>
9#include <nnpack/macros.h>
10#include <nnpack/utils.h>
11#include <nnpack/system.h>
12
13#include <nnpack/hwinfo.h>
14#include <nnpack/activations.h>
15#include <nnpack/validation.h>
16
17
18struct NNP_CACHE_ALIGN kernel_transform_context {
19 nnp_transform_2d_with_offset transform_function;
20 const float* kernel;
21 void* kernel_transform;
22
23 size_t tuple_size;
24 size_t input_channels;
25 size_t input_channels_block_size;
26 size_t output_channels;
27 struct nnp_size kernel_size;
28};
29
30static void compute_kernel_transform(
31 const struct kernel_transform_context context[restrict static 1],
32 size_t output_channels_subblock_start, size_t input_channels_block_offset,
33 size_t output_channels_subblock_size, size_t input_channels_block_increment)
34{
35 const size_t tuple_size = context->tuple_size;
36 const size_t input_channels = context->input_channels;
37 const size_t input_channels_block_size = context->input_channels_block_size;
38 const size_t output_channels = context->output_channels;
39 const struct nnp_size kernel_size = context->kernel_size;
40
41 const float (*kernel)[input_channels][kernel_size.width * kernel_size.height] =
42 (const float(*)[input_channels][kernel_size.width * kernel_size.height]) context->kernel;
43 void* kernel_transform = context->kernel_transform;
44 nnp_transform_2d_with_offset transform_function = context->transform_function;
45
46 for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) {
47 const size_t output_channel = output_channels_subblock_start + output_channels_subblock_offset;
48 transform_function(
49 kernel[output_channel][input_channels_block_offset],
50 kernel_transform +
51 (output_channels_subblock_start * input_channels_block_size + input_channels_block_offset * output_channels_subblock_size + output_channels_subblock_offset) * tuple_size,
52 kernel_size.width,
53 input_channels_block_size * output_channels * tuple_size,
54 kernel_size.height, kernel_size.width, 0, 0);
55 }
56}
57
58struct NNP_CACHE_ALIGN input_transform_context {
59 const float* input;
60 void* input_transform;
61 nnp_transform_2d_with_offset transform_function;
62
63 const size_t tuple_size;
64 const size_t tiles_count;
65 const struct fxdiv_divisor_size_t tiles_x_count;
66 const size_t input_channels_block_start;
67 const size_t input_channels_block_size;
68 const struct nnp_size input_size;
69 const size_t input_padding_left;
70 const size_t input_padding_top;
71 const struct nnp_size input_tile;
72 const struct nnp_size input_tile_step;
73};
74
75static void compute_input_transform(
76 const struct input_transform_context context[restrict static 1],
77 size_t input_channels_block_offset, size_t tiles_subblock_start,
78 size_t input_channels_block_range, size_t tiles_subblock_size)
79{
80 const size_t tuple_size = context->tuple_size;
81 const size_t tiles_count = context->tiles_count;
82 const struct fxdiv_divisor_size_t tiles_x_count = context->tiles_x_count;
83 const size_t input_channels_block_start = context->input_channels_block_start;
84 const size_t input_channels_block_size = context->input_channels_block_size;
85 const struct nnp_size input_size = context->input_size;
86 const size_t input_padding_left = context->input_padding_left;
87 const size_t input_padding_top = context->input_padding_top;
88 const struct nnp_size input_tile = context->input_tile;
89 const struct nnp_size input_tile_step = context->input_tile_step;
90
91 const float (*input)[input_size.height][input_size.width] =
92 (const float(*)[input_size.height][input_size.width]) context->input;
93 void* input_transform = context->input_transform;
94 nnp_transform_2d_with_offset transform_function = context->transform_function;
95
96 const size_t input_channel = input_channels_block_start + input_channels_block_offset;
97 for (size_t tiles_subblock_offset = 0; tiles_subblock_offset < tiles_subblock_size; tiles_subblock_offset += 1) {
98 const size_t tile = tiles_subblock_start + tiles_subblock_offset;
99 const struct fxdiv_result_size_t tile_xy = fxdiv_divide_size_t(tile, tiles_x_count);
100 const size_t tile_x = tile_xy.remainder;
101 const size_t tile_y = tile_xy.quotient;
102
103 const size_t output_x = tile_x * input_tile_step.width;
104 const size_t output_y = tile_y * input_tile_step.height;
105
106 const size_t input_x = min(doz(output_x, input_padding_left), input_size.width);
107 const size_t input_y = min(doz(output_y, input_padding_top), input_size.height);
108
109 const size_t row_offset = doz(input_padding_top, output_y);
110 const size_t row_count = min(input_size.height - input_y, input_tile.height - row_offset);
111 const size_t column_offset = doz(input_padding_left, output_x);
112 const size_t column_count = min(input_size.width - input_x, input_tile.width - column_offset);
113
114 transform_function(
115 &input[input_channel][input_y][input_x],
116 input_transform + (tiles_subblock_start * input_channels_block_size + input_channels_block_offset * tiles_subblock_size + tiles_subblock_offset) * tuple_size,
117 input_size.width,
118 input_channels_block_size * tiles_count * tuple_size,
119 row_count, column_count, row_offset, column_offset);
120 }
121}
122
123struct NNP_CACHE_ALIGN output_transform_context {
124 nnp_transform_2d_with_bias transform_function;
125 float* output;
126 const void* output_transform;
127 const float* bias;
128
129 size_t tuple_size;
130 size_t tiles_count;
131 struct fxdiv_divisor_size_t tiles_x_count;
132 struct fxdiv_divisor_size_t tiles_block_max;
133 size_t output_channels;
134 struct nnp_size output_size;
135 struct nnp_size output_tile;
136};
137
138static void compute_output_transform(
139 const struct output_transform_context context[restrict static 1],
140 size_t output_channels_subblock_start, size_t tiles_subblock_start,
141 size_t output_channels_subblock_size, size_t tiles_subblock_size)
142{
143 const size_t tuple_size = context->tuple_size;
144 const size_t tiles_count = context->tiles_count;
145 const struct fxdiv_divisor_size_t tiles_x_count = context->tiles_x_count;
146 const struct fxdiv_divisor_size_t tiles_block_max = context->tiles_block_max;
147 const size_t output_channels = context->output_channels;
148 const struct nnp_size output_size = context->output_size;
149 const struct nnp_size output_tile = context->output_tile;
150
151 const size_t tiles_block_start = fxdiv_round_down_size_t(tiles_subblock_start, tiles_block_max);
152 const size_t tiles_block_size = min(tiles_count - tiles_block_start, tiles_block_max.value);
153
154 float (*output)[output_size.height][output_size.width] =
155 (float(*)[output_size.height][output_size.width]) context->output;
156 const void* output_transform = context->output_transform;
157 const float* bias = context->bias;
158 nnp_transform_2d_with_bias transform_function = context->transform_function;
159
160 for (size_t tiles_subblock_offset = 0; tiles_subblock_offset < tiles_subblock_size; tiles_subblock_offset += 1) {
161 const size_t tile = tiles_subblock_start + tiles_subblock_offset;
162 const struct fxdiv_result_size_t tile_xy = fxdiv_divide_size_t(tile, tiles_x_count);
163 const size_t tile_x = tile_xy.remainder;
164 const size_t tile_y = tile_xy.quotient;
165
166 const size_t output_x = tile_x * output_tile.width;
167 const size_t output_y = tile_y * output_tile.height;
168
169 for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) {
170 const size_t output_channel = output_channels_subblock_start + output_channels_subblock_offset;
171 transform_function(
172 output_transform +
173 (tiles_block_start * output_channels + output_channels_subblock_start * tiles_block_size + ((tiles_subblock_start - tiles_block_start) + tiles_subblock_offset) * output_channels_subblock_size + output_channels_subblock_offset) * tuple_size,
174 &output[output_channel][output_y][output_x],
175 &bias[output_channel],
176 tiles_count * output_channels * tuple_size,
177 output_size.width,
178 min(output_tile.height, output_size.height - output_y),
179 min(output_tile.width, output_size.width - output_x));
180 }
181 }
182}
183
184struct NNP_CACHE_ALIGN tuple_multiplication_context {
185 size_t tuple_elements;
186 size_t tuple_size;
187 size_t tiles_subblock_max;
188 size_t input_channels_block_size;
189 size_t input_channels_block_start;
190 size_t output_channels;
191 size_t output_channels_subblock_max;
192 size_t output_channels_block_start;
193
194 const void* input_transform;
195 const void* kernel_transform;
196 void* output_transform;
197
198 nnp_fast_tuple_gemm_function fast_gemm;
199 nnp_full_tuple_gemm_function full_gemm;
200};
201
202static void compute_tuple_multiplication(
203 const struct tuple_multiplication_context context[restrict static 1],
204 size_t tiles_block_start, size_t output_channels_subblock_start,
205 size_t tiles_block_size, size_t output_channels_subblock_size)
206{
207 const size_t tuple_elements = context->tuple_elements;
208 const size_t tuple_size = context->tuple_size;
209 const size_t tiles_subblock_max = context->tiles_subblock_max;
210 const size_t input_channels_block_size = context->input_channels_block_size;
211 const size_t input_channels_block_start = context->input_channels_block_start;
212 const size_t output_channels = context->output_channels;
213 const size_t output_channels_subblock_max = context->output_channels_subblock_max;
214 const size_t output_channels_block_start = context->output_channels_block_start;
215
216 const void* input_transform = context->input_transform +
217 tiles_block_start * input_channels_block_size * tuple_size;
218 const void* kernel_transform = context->kernel_transform +
219 (output_channels_block_start + output_channels_subblock_start) * input_channels_block_size * tuple_size;
220 void* output_transform = context->output_transform +
221 (tiles_block_start * output_channels + (output_channels_block_start + output_channels_subblock_start) * tiles_block_size) * tuple_size;
222
223 if (output_channels_subblock_size == output_channels_subblock_max) {
224 const nnp_fast_tuple_gemm_function fast_gemm = context->fast_gemm;
225 while (tiles_block_size >= tiles_subblock_max) {
226 tiles_block_size -= tiles_subblock_max;
227
228 fast_gemm(
229 input_channels_block_size, input_channels_block_start,
230 input_transform, kernel_transform, output_transform,
231 output_channels_subblock_size * tuple_elements);
232
233 input_transform += tiles_subblock_max * input_channels_block_size * tuple_size;
234 output_transform += tiles_subblock_max * output_channels_subblock_size * tuple_size;
235 }
236 }
237
238 const nnp_full_tuple_gemm_function full_gemm = context->full_gemm;
239 while (tiles_block_size != 0) {
240 const size_t tiles_subblock_size = min(tiles_block_size, tiles_subblock_max);
241 tiles_block_size -= tiles_subblock_size;
242
243 full_gemm(
244 tiles_subblock_size, output_channels_subblock_size,
245 input_channels_block_size, input_channels_block_start,
246 input_transform, kernel_transform, output_transform,
247 output_channels_subblock_size * tuple_elements);
248
249 input_transform += tiles_subblock_max * input_channels_block_size * tuple_size;
250 output_transform += tiles_subblock_max * output_channels_subblock_size * tuple_size;
251 }
252}
253
254struct NNP_CACHE_ALIGN kernel_packing_context {
255 const float* kernel;
256 float* packed_kernel;
257
258 size_t reduction_size;
259 size_t reduction_block_start;
260 size_t reduction_block_size;
261};
262
263static void compute_kernel_packing(
264 const struct kernel_packing_context context[restrict static 1],
265 size_t output_channels_subblock_start, size_t reduction_block_offset,
266 size_t output_channels_subblock_size, size_t reduction_block_range)
267{
268 const size_t reduction_size = context->reduction_size;
269 const size_t reduction_block_start = context->reduction_block_start;
270 const size_t reduction_block_size = context->reduction_block_size;
271
272 const float* kernel = context->kernel +
273 output_channels_subblock_start * reduction_size + reduction_block_offset;
274 float* packed_kernel = context->packed_kernel +
275 output_channels_subblock_start * reduction_block_size + reduction_block_offset * output_channels_subblock_size;
276
277 for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) {
278 packed_kernel[output_channels_subblock_offset] = kernel[output_channels_subblock_offset * reduction_size];
279 }
280}
281
282struct NNP_CACHE_ALIGN input_packing_context {
283 const float* input;
284 float* packed_input;
285
286 size_t simd_width;
287 size_t reduction_block_start;
288 size_t reduction_block_size;
289 size_t output_image_block_start;
290 struct nnp_size input_size;
291 size_t input_padding_top;
292 size_t input_padding_left;
293 struct fxdiv_divisor_size_t kernel_elements;
294 struct fxdiv_divisor_size_t kernel_width;
295 struct fxdiv_divisor_size_t output_width;
296 struct nnp_size output_subsampling;
297};
298
299static void compute_input_packing(
300 const struct input_packing_context context[restrict static 1],
301 size_t reduction_block_offset, size_t output_image_subblock_start,
302 size_t reduction_block_range, size_t output_image_subblock_size)
303{
304 const size_t simd_width = context->simd_width;
305 const size_t reduction_block_start = context->reduction_block_start;
306 const size_t reduction_block_size = context->reduction_block_size;
307 const size_t output_image_block_start = context->output_image_block_start;
308 const struct nnp_size input_size = context->input_size;
309 const size_t input_padding_top = context->input_padding_top;
310 const size_t input_padding_left = context->input_padding_left;
311 const struct fxdiv_divisor_size_t kernel_elements = context->kernel_elements;
312 const struct fxdiv_divisor_size_t kernel_width = context->kernel_width;
313 const struct fxdiv_divisor_size_t output_width = context->output_width;
314 const struct nnp_size output_subsampling = context->output_subsampling;
315
316 const float (*input)[input_size.height][input_size.width] =
317 (const float(*)[input_size.height][input_size.width]) context->input;
318 float* packed_input = context->packed_input;
319
320 const size_t output_image_subblock_stride = round_up_by_power_of_2(output_image_subblock_size, simd_width);
321
322 const size_t reduction_index = reduction_block_start + reduction_block_offset;
323 const struct fxdiv_result_size_t reduction_index_divmod = fxdiv_divide_size_t(reduction_index, kernel_elements);
324 const size_t input_channel = reduction_index_divmod.quotient;
325 const struct fxdiv_result_size_t kernel_xy = fxdiv_divide_size_t(reduction_index_divmod.remainder, kernel_width);
326 const size_t kernel_y = kernel_xy.quotient;
327 const size_t kernel_x = kernel_xy.remainder;
328
329 for (size_t output_image_subblock_offset = 0; output_image_subblock_offset < output_image_subblock_size; output_image_subblock_offset += 1) {
330 const size_t output_image_index = output_image_block_start + output_image_subblock_start + output_image_subblock_offset;
331 const struct fxdiv_result_size_t output_xy = fxdiv_divide_size_t(output_image_index, output_width);
332 const size_t output_y = output_xy.quotient;
333 const size_t output_x = output_xy.remainder;
334
335 const size_t input_y = output_y * output_subsampling.height + kernel_y - input_padding_top;
336 const size_t input_x = output_x * output_subsampling.width + kernel_x - input_padding_left;
337
338 const size_t packed_index = output_image_subblock_start * reduction_block_size +
339 reduction_block_offset * output_image_subblock_stride + output_image_subblock_offset;
340 if ((input_x < input_size.width) && (input_y < input_size.height)) {
341 packed_input[packed_index] = input[input_channel][input_y][input_x];
342 } else {
343 packed_input[packed_index] = 0.0f;
344 }
345 }
346}
347
348struct NNP_CACHE_ALIGN matrix_multiplication_context {
349 const float* packed_kernel;
350 const float* packed_input;
351 float* output;
352
353 size_t reduction_block_start;
354 size_t reduction_block_size;
355 size_t output_image_size;
356 size_t output_image_block_start;
357 size_t output_image_subblock_max;
358 size_t output_channels_subblock_max;
359};
360
361static void compute_matrix_multiplication(
362 const struct matrix_multiplication_context context[restrict static 1],
363 size_t output_channels_block_start, size_t output_image_subblock_start,
364 size_t output_channels_block_size, size_t output_image_subblock_size)
365{
366 const size_t reduction_block_start = context->reduction_block_start;
367 const size_t reduction_block_size = context->reduction_block_size;
368 const size_t output_image_size = context->output_image_size;
369 const size_t output_image_block_start = context->output_image_block_start;
370 const size_t output_image_subblock_max = context->output_image_subblock_max;
371 const size_t output_channels_subblock_max = context->output_channels_subblock_max;
372
373 const float* packed_kernel = context->packed_kernel +
374 output_channels_block_start * reduction_block_size;
375 const float* packed_input = context->packed_input +
376 output_image_subblock_start * reduction_block_size;
377 float* output = context->output +
378 output_channels_block_start * output_image_size + output_image_block_start + output_image_subblock_start;
379
380 if (output_image_subblock_size == output_image_subblock_max) {
381 const nnp_fast_sgemm_function fast_gemm = nnp_hwinfo.sgemm.only_mr_x_nr;
382 while (output_channels_block_size >= output_channels_subblock_max) {
383 output_channels_block_size -= output_channels_subblock_max;
384
385 fast_gemm(
386 reduction_block_size, reduction_block_start,
387 packed_kernel, packed_input, output,
388 output_image_size);
389
390 packed_kernel += reduction_block_size * output_channels_subblock_max;
391 output += output_image_size * output_channels_subblock_max;
392 }
393 }
394
395 const nnp_full_sgemm_function full_gemm = nnp_hwinfo.sgemm.upto_mr_x_nr;
396 while (output_channels_block_size != 0) {
397 const size_t output_channels_subblock_size = min(output_channels_block_size, output_channels_subblock_max);
398 output_channels_block_size -= output_channels_subblock_size;
399
400 full_gemm(
401 output_channels_subblock_size, output_image_subblock_size,
402 reduction_block_size, reduction_block_start,
403 packed_kernel, packed_input, output,
404 output_image_size);
405
406 packed_kernel += reduction_block_size * output_channels_subblock_max;
407 output += output_image_size * output_channels_subblock_max;
408 }
409}
410
411struct NNP_CACHE_ALIGN direct_convolution_context {
412 const float* input;
413 const float* kernel;
414 float* output;
415
416 size_t image_elements;
417 size_t input_channels;
418 size_t input_channels_block_max;
419 size_t output_channels_block_max;
420
421 nnp_fast_conv_function fast_conv;
422 nnp_full_conv_function full_conv;
423};
424
425static void compute_direct_convolution(
426 const struct direct_convolution_context context[restrict static 1],
427 size_t output_channels_block_start, size_t output_channels_block_size)
428{
429 const size_t image_elements = context->image_elements;
430 const size_t input_channels = context->input_channels;
431 const size_t input_channels_block_max = context->input_channels_block_max;
432 const size_t output_channels_block_max = context->output_channels_block_max;
433
434 const float* input = context->input;
435 const float* kernel = context->kernel + output_channels_block_start * input_channels;
436 float* output = context->output + output_channels_block_start * image_elements;
437
438 memset(output, 0, sizeof(float) * output_channels_block_size * image_elements);
439
440 size_t input_channels_unprocessed = input_channels;
441 if (output_channels_block_size == output_channels_block_max) {
442 const nnp_fast_conv_function fast_conv = context->fast_conv;
443 while (input_channels_unprocessed >= input_channels_block_max) {
444 input_channels_unprocessed -= input_channels_block_max;
445
446 fast_conv(
447 input_channels, image_elements,
448 input, kernel, output);
449
450 input += input_channels_block_max * image_elements;
451 kernel += input_channels_block_max;
452 }
453 }
454
455 const nnp_full_conv_function full_conv = context->full_conv;
456 while (input_channels_unprocessed != 0) {
457 const size_t input_channels_block_size = min(input_channels_unprocessed, input_channels_block_max);
458 input_channels_unprocessed -= input_channels_block_size;
459
460 full_conv(
461 input_channels_block_size, output_channels_block_size,
462 input_channels, image_elements,
463 input, kernel, output);
464
465 input += input_channels_block_max * image_elements;
466 kernel += input_channels_block_max;
467 }
468}
469
470static enum nnp_status compute_fast_convolution_inference(
471 const bool fourier_transform,
472 const enum nnp_convolution_transform_strategy transform_strategy,
473 const size_t transform_element_size,
474 const size_t input_channels,
475 const size_t output_channels,
476 const struct nnp_size tile_size,
477 const struct nnp_size input_size,
478 const struct nnp_padding input_padding,
479 const struct nnp_size kernel_size,
480 const struct nnp_size output_size,
481 const struct nnp_size output_subsampling,
482 const float* input,
483 const float* kernel,
484 const float* bias,
485 float* output,
486 void* workspace_buffer,
487 size_t* workspace_size,
488 const nnp_transform_2d_with_offset input_transform_function,
489 const nnp_transform_2d_with_offset kernel_transform_function,
490 const nnp_transform_2d_with_bias output_transform_function,
491 pthreadpool_t threadpool,
492 struct nnp_profile* profile)
493{
494 void* memory_block = NULL;
495 size_t memory_size = 0;
496 const size_t simd_width = nnp_hwinfo.simd_width;
497 const size_t tuple_elements = (fourier_transform ? simd_width * 2 : simd_width);
498 const size_t tuple_size = tuple_elements * transform_element_size;
499 const size_t tile_elements = tile_size.height * tile_size.width;
500 const size_t tuple_count = tile_elements / tuple_elements;
501
502 const struct nnp_size output_tile_size = {
503 .width = (tile_size.width - kernel_size.width) / output_subsampling.width + 1,
504 .height = (tile_size.height - kernel_size.height) / output_subsampling.height + 1
505 };
506 const struct nnp_size tile_step = {
507 .width = tile_size.width - kernel_size.width + 1,
508 .height = tile_size.height - kernel_size.height + 1
509 };
510
511 const size_t tiles_y_count = divide_round_up(output_size.height, output_tile_size.height);
512 const size_t tiles_x_count = divide_round_up(output_size.width, output_tile_size.width);
513 const size_t tiles_count = tiles_x_count * tiles_y_count;
514
515 /* Calculate cache blocking parameters */
516 const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / tuple_size;
517 const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / tuple_size;
518 const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / tuple_size;
519
520 const size_t tiles_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.mr : nnp_hwinfo.sxgemm.mr);
521 const size_t output_channels_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.nr : nnp_hwinfo.sxgemm.nr);
522
523 const size_t input_channels_block_max =
524 round_down(cache_elements_l1 / (tiles_subblock_max + output_channels_subblock_max), 2);
525 const size_t tiles_block_max =
526 round_down(cache_elements_l2 / input_channels_block_max, tiles_subblock_max);
527 const size_t output_channels_block_max =
528 round_down(cache_elements_l3 / input_channels_block_max, output_channels_subblock_max);
529
530 const size_t transform_tile_size = tile_elements * transform_element_size;
531 const size_t input_transform_size = tiles_count * min(input_channels, input_channels_block_max) * transform_tile_size;
532 const size_t output_transform_size = tiles_count * output_channels * transform_tile_size;
533 switch (transform_strategy) {
534 case nnp_convolution_transform_strategy_compute:
535 case nnp_convolution_transform_strategy_reuse:
536 {
537 memory_size = input_transform_size + output_transform_size;
538 const size_t kernel_transform_size = output_channels * min(input_channels, input_channels_block_max) * transform_tile_size;
539 if (transform_strategy == nnp_convolution_transform_strategy_compute) {
540 memory_size += kernel_transform_size;
541 }
542 if (workspace_buffer == NULL) {
543 if (workspace_size == NULL) {
544 memory_block = allocate_memory(memory_size);
545 if (memory_block == NULL) {
546 return nnp_status_out_of_memory;
547 }
548 } else {
549 *workspace_size = memory_size;
550 return nnp_status_success;
551 }
552 } else {
553 if (*workspace_size < memory_size) {
554 return nnp_status_insufficient_buffer;
555 }
556 memory_block = workspace_buffer;
557 }
558
559 void* input_transform = memory_block;
560 void* output_transform = memory_block + input_transform_size;
561 void* kernel_transform = memory_block + input_transform_size + output_transform_size;
562
563 for (size_t input_channels_block_start = 0; input_channels_block_start < input_channels; input_channels_block_start += input_channels_block_max) {
564 const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max);
565
566 if (transform_strategy == nnp_convolution_transform_strategy_compute) {
567 NNP_KERNEL_TRANSFORM_START(profile)
568 struct kernel_transform_context kernel_transform_context = {
569 .transform_function = kernel_transform_function,
570 .kernel = kernel + input_channels_block_start * kernel_size.height * kernel_size.width,
571 .kernel_transform = kernel_transform,
572 .tuple_size = tuple_size,
573 .input_channels = input_channels,
574 .input_channels_block_size = input_channels_block_size,
575 .output_channels = output_channels,
576 .kernel_size = kernel_size,
577 };
578 pthreadpool_parallelize_2d_tile_2d(threadpool,
579 (pthreadpool_task_2d_tile_2d_t) compute_kernel_transform,
580 &kernel_transform_context,
581 output_channels, input_channels_block_size,
582 output_channels_subblock_max, 1,
583 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
584 NNP_KERNEL_TRANSFORM_END(profile)
585 } else {
586 kernel_transform = (void*) kernel + input_channels_block_start * output_channels * transform_tile_size;
587 }
588
589 NNP_INPUT_TRANSFORM_START(profile)
590 struct input_transform_context input_transform_context = {
591 .input = input,
592 .input_transform = input_transform,
593 .transform_function = input_transform_function,
594 .tuple_size = tuple_size,
595 .tiles_count = tiles_count,
596 .tiles_x_count = fxdiv_init_size_t(tiles_x_count),
597 .input_channels_block_start = input_channels_block_start,
598 .input_channels_block_size = input_channels_block_size,
599 .input_size = input_size,
600 .input_padding_left = input_padding.left,
601 .input_padding_top = input_padding.top,
602 .input_tile = tile_size,
603 .input_tile_step = tile_step,
604 };
605 pthreadpool_parallelize_2d_tile_2d(threadpool,
606 (pthreadpool_task_2d_tile_2d_t) compute_input_transform,
607 &input_transform_context,
608 input_channels_block_size, tiles_count,
609 1, tiles_subblock_max,
610 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
611 NNP_INPUT_TRANSFORM_END(profile)
612
613 NNP_BLOCK_MULTIPLICATION_START(profile)
614 for (size_t tuple_index = 0; tuple_index < tuple_count; tuple_index += 1) {
615 nnp_full_tuple_gemm_function full_gemm_function;
616 nnp_fast_tuple_gemm_function fast_gemm_function;
617 if (fourier_transform) {
618 if (tuple_index < NNP_COMPLEX_TUPLE_INDEX) {
619 fast_gemm_function = nnp_hwinfo.cxgemm.s4cX_conjb_only_mr_x_nr;
620 full_gemm_function = nnp_hwinfo.cxgemm.s4cX_conjb_upto_mr_x_nr;
621 } else {
622 fast_gemm_function = nnp_hwinfo.cxgemm.cX_conjb_only_mr_x_nr;
623 full_gemm_function = nnp_hwinfo.cxgemm.cX_conjb_upto_mr_x_nr;
624 }
625 } else {
626 if NNP_LIKELY(transform_element_size == sizeof(float)) {
627 fast_gemm_function = nnp_hwinfo.sxgemm.only_mr_x_nr;
628 full_gemm_function = nnp_hwinfo.sxgemm.upto_mr_x_nr;
629 } else {
630 #if NNP_BACKEND_ARM
631 fast_gemm_function = nnp_hwinfo.hxgemm.only_mr_x_nr;
632 full_gemm_function = nnp_hwinfo.hxgemm.upto_mr_x_nr;
633 #endif /* NNP_BACKEND_ARM */
634 }
635 }
636 for (size_t output_channels_block_start = 0; output_channels_block_start < output_channels; output_channels_block_start += output_channels_block_max) {
637 const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max);
638 struct tuple_multiplication_context tuple_multiplication_context = {
639 .tuple_elements = tuple_elements,
640 .tuple_size = tuple_size,
641 .tiles_subblock_max = tiles_subblock_max,
642 .input_channels_block_start = input_channels_block_start,
643 .input_channels_block_size = input_channels_block_size,
644 .output_channels = output_channels,
645 .output_channels_subblock_max = output_channels_subblock_max,
646 .output_channels_block_start = output_channels_block_start,
647 .input_transform = input_transform +
648 tuple_index * tiles_count * input_channels_block_size * tuple_size,
649 .kernel_transform = kernel_transform +
650 tuple_index * output_channels * input_channels_block_size * tuple_size,
651 .output_transform = output_transform +
652 tuple_index * tiles_count * output_channels * tuple_size,
653 .fast_gemm = fast_gemm_function,
654 .full_gemm = full_gemm_function,
655 };
656 pthreadpool_parallelize_2d_tile_2d(threadpool,
657 (pthreadpool_task_2d_tile_2d_t) compute_tuple_multiplication,
658 &tuple_multiplication_context,
659 tiles_count, output_channels_block_size,
660 tiles_block_max, output_channels_subblock_max,
661 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
662 }
663 }
664 NNP_BLOCK_MULTIPLICATION_END(profile)
665 }
666 NNP_OUTPUT_TRANSFORM_START(profile)
667 struct output_transform_context output_transform_context = {
668 .transform_function = output_transform_function,
669 .output = output,
670 .output_transform = output_transform,
671 .bias = bias,
672 .tuple_size = tuple_size,
673 .tiles_count = tiles_count,
674 .tiles_x_count = fxdiv_init_size_t(tiles_x_count),
675 .tiles_block_max = fxdiv_init_size_t(tiles_block_max),
676 .output_channels = output_channels,
677 .output_size = output_size,
678 .output_tile = output_tile_size,
679 };
680 pthreadpool_parallelize_2d_tile_2d(threadpool,
681 (pthreadpool_task_2d_tile_2d_t) compute_output_transform,
682 &output_transform_context,
683 output_channels, tiles_count,
684 output_channels_subblock_max, tiles_subblock_max,
685 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
686 NNP_OUTPUT_TRANSFORM_END(profile)
687 break;
688 }
689 case nnp_convolution_transform_strategy_precompute:
690 {
691 const size_t kernel_transform_size = output_channels * input_channels * transform_tile_size;
692 if (workspace_buffer == NULL) {
693 *workspace_size = kernel_transform_size;
694 return nnp_status_success;
695 } else {
696 if (*workspace_size < kernel_transform_size) {
697 return nnp_status_insufficient_buffer;
698 }
699 memory_block = workspace_buffer;
700 }
701
702 for (size_t input_channels_block_start = 0; input_channels_block_start < input_channels; input_channels_block_start += input_channels_block_max) {
703 const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max);
704
705 NNP_KERNEL_TRANSFORM_START(profile)
706 struct kernel_transform_context kernel_transform_context = {
707 .transform_function = kernel_transform_function,
708 .kernel = kernel + input_channels_block_start * kernel_size.height * kernel_size.width,
709 .kernel_transform = (void*) workspace_buffer + input_channels_block_start * output_channels * transform_tile_size,
710 .tuple_size = tuple_size,
711 .input_channels = input_channels,
712 .input_channels_block_size = input_channels_block_size,
713 .output_channels = output_channels,
714 .kernel_size = kernel_size,
715 };
716 pthreadpool_parallelize_2d_tile_2d(threadpool,
717 (pthreadpool_task_2d_tile_2d_t) compute_kernel_transform,
718 &kernel_transform_context,
719 output_channels, input_channels_block_size,
720 output_channels_subblock_max, 1,
721 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
722 NNP_KERNEL_TRANSFORM_END(profile)
723 }
724 break;
725 }
726 default:
727 return nnp_status_invalid_transform_strategy;
728 }
729
730 if (memory_block != workspace_buffer) {
731 release_memory(memory_block, memory_size);
732 }
733 return nnp_status_success;
734}
735
736static enum nnp_status compute_gemm_convolution_inference(
737 const enum nnp_convolution_transform_strategy transform_strategy,
738 const size_t input_channels,
739 const size_t output_channels,
740 const struct nnp_size input_size,
741 const struct nnp_padding input_padding,
742 const struct nnp_size kernel_size,
743 const struct nnp_size output_size,
744 const struct nnp_size output_subsampling,
745 const float* input,
746 const float* kernel,
747 const float* bias,
748 float* output,
749 void* workspace_buffer,
750 size_t* workspace_size,
751 enum nnp_activation activation,
752 pthreadpool_t threadpool,
753 struct nnp_profile* profile)
754{
755 enum nnp_status status = nnp_status_success;
756 void* memory_block = NULL;
757 size_t memory_size = 0;
758 const size_t simd_width = nnp_hwinfo.simd_width;
759
760 /* Calculate cache blocking parameters */
761 const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / sizeof(float);
762 const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / sizeof(float);
763 const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / sizeof(float);
764
765 const size_t output_channels_subblock_max = nnp_hwinfo.sgemm.mr;
766 const size_t output_image_subblock_max = nnp_hwinfo.sgemm.nr;
767
768 const size_t reduction_size = input_channels * kernel_size.height * kernel_size.width;
769 const size_t output_image_size = output_size.height * output_size.width;
770 const size_t reduction_block_max =
771 round_down(cache_elements_l1 / (output_channels_subblock_max + output_image_subblock_max), 2);
772 const size_t output_channels_block_max =
773 round_down(cache_elements_l2 / reduction_block_max, output_channels_subblock_max);
774 const size_t output_image_block_max =
775 round_down(cache_elements_l3 / reduction_block_max, output_image_subblock_max);
776
777 switch (transform_strategy) {
778 case nnp_convolution_transform_strategy_compute:
779 case nnp_convolution_transform_strategy_reuse:
780 {
781 const size_t packed_kernel_size = output_channels *
782 min(reduction_block_max, reduction_size) * sizeof(float);
783 const size_t packed_input_size = min(output_image_block_max, round_up(output_image_size, simd_width)) *
784 min(reduction_block_max, reduction_size) * sizeof(float);
785 memory_size = packed_kernel_size + packed_input_size;
786 if (workspace_buffer == NULL) {
787 if (workspace_size == NULL) {
788 memory_block = allocate_memory(memory_size);
789 if (memory_block == NULL) {
790 return nnp_status_out_of_memory;
791 }
792 } else {
793 *workspace_size = memory_size;
794 return nnp_status_success;
795 }
796 } else {
797 if (*workspace_size < memory_size) {
798 return nnp_status_insufficient_buffer;
799 }
800 memory_block = workspace_buffer;
801 }
802
803 float* packed_input = memory_block;
804 float* packed_kernel = memory_block + packed_input_size;
805
806 for (size_t reduction_block_start = 0; reduction_block_start < reduction_size; reduction_block_start += reduction_block_max) {
807 const size_t reduction_block_size = min(reduction_size - reduction_block_start, reduction_block_max);
808
809 if (transform_strategy == nnp_convolution_transform_strategy_compute) {
810 /* Pack kernel into memory block */
811 NNP_KERNEL_TRANSFORM_START(profile)
812 struct kernel_packing_context kernel_packing_context = {
813 .kernel = kernel + reduction_block_start,
814 .packed_kernel = packed_kernel,
815 .reduction_size = reduction_size,
816 .reduction_block_start = reduction_block_start,
817 .reduction_block_size = reduction_block_size,
818 };
819 pthreadpool_parallelize_2d_tile_2d(threadpool,
820 (pthreadpool_task_2d_tile_2d_t) compute_kernel_packing,
821 &kernel_packing_context,
822 output_channels, reduction_block_size,
823 output_channels_subblock_max, 1,
824 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
825 NNP_KERNEL_TRANSFORM_END(profile)
826 } else {
827 packed_kernel = (void*) kernel + output_channels * reduction_block_start * sizeof(float);
828 }
829
830 const struct fxdiv_divisor_size_t kernel_elements_divisor = fxdiv_init_size_t(kernel_size.height * kernel_size.width);
831 const struct fxdiv_divisor_size_t kernel_width_divisor = fxdiv_init_size_t(kernel_size.width);
832 const struct fxdiv_divisor_size_t output_width_divisor = fxdiv_init_size_t(output_size.width);
833 for (size_t output_image_block_start = 0; output_image_block_start < output_image_size; output_image_block_start += output_image_block_max) {
834 const size_t output_image_block_size = min(output_image_size - output_image_block_start, output_image_block_max);
835
836 /* Pack image into L3 block */
837 NNP_INPUT_TRANSFORM_START(profile)
838 struct input_packing_context input_packing_context = {
839 .input = input,
840 .packed_input = packed_input,
841 .simd_width = simd_width,
842 .reduction_block_start = reduction_block_start,
843 .reduction_block_size = reduction_block_size,
844 .output_image_block_start = output_image_block_start,
845 .input_size = input_size,
846 .input_padding_top = input_padding.top,
847 .input_padding_left = input_padding.left,
848 .kernel_elements = kernel_elements_divisor,
849 .kernel_width = kernel_width_divisor,
850 .output_width = output_width_divisor,
851 .output_subsampling = output_subsampling,
852 };
853 pthreadpool_parallelize_2d_tile_2d(threadpool,
854 (pthreadpool_task_2d_tile_2d_t) compute_input_packing,
855 &input_packing_context,
856 reduction_block_size, output_image_block_size,
857 1, output_image_subblock_max,
858 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
859 NNP_INPUT_TRANSFORM_END(profile)
860
861 NNP_BLOCK_MULTIPLICATION_START(profile)
862 struct matrix_multiplication_context matrix_multiplication_context = {
863 .packed_kernel = packed_kernel,
864 .packed_input = packed_input,
865 .output = output,
866 .reduction_block_start = reduction_block_start,
867 .reduction_block_size = reduction_block_size,
868 .output_image_size = output_image_size,
869 .output_image_block_start = output_image_block_start,
870 .output_image_subblock_max = output_image_subblock_max,
871 .output_channels_subblock_max = output_channels_subblock_max,
872 };
873 pthreadpool_parallelize_2d_tile_2d(threadpool,
874 (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication,
875 &matrix_multiplication_context,
876 output_channels, output_image_block_size,
877 output_channels_block_max, output_image_subblock_max,
878 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
879 NNP_BLOCK_MULTIPLICATION_END(profile)
880 }
881 }
882 /* Add bias */
883 NNP_OUTPUT_TRANSFORM_START(profile)
884 switch (activation) {
885 case nnp_activation_identity:
886 for (size_t output_channel = 0; output_channel < output_channels; output_channel += 1) {
887 const float bias_value = bias[output_channel];
888 for (size_t index = 0; index < output_image_size; index += 1) {
889 output[output_channel * output_image_size + index] += bias_value;
890 }
891 }
892 break;
893 case nnp_activation_relu:
894 for (size_t output_channel = 0; output_channel < output_channels; output_channel += 1) {
895 const float bias_value = bias[output_channel];
896 for (size_t index = 0; index < output_image_size; index += 1) {
897 output[output_channel * output_image_size + index] =
898 relu(output[output_channel * output_image_size + index] + bias_value, 0.0f);
899 }
900 }
901 break;
902 default:
903 NNP_UNREACHABLE;
904 }
905 NNP_OUTPUT_TRANSFORM_END(profile)
906 break;
907 }
908 case nnp_convolution_transform_strategy_precompute:
909 {
910 const size_t packed_kernel_size = output_channels * reduction_size * sizeof(float);
911 if (workspace_buffer == NULL) {
912 *workspace_size = packed_kernel_size;
913 return nnp_status_success;
914 } else {
915 if (*workspace_size < packed_kernel_size) {
916 return nnp_status_insufficient_buffer;
917 }
918 memory_block = workspace_buffer;
919 }
920
921 for (size_t reduction_block_start = 0; reduction_block_start < reduction_size; reduction_block_start += reduction_block_max) {
922 const size_t reduction_block_size = min(reduction_size - reduction_block_start, reduction_block_max);
923
924 /* Pack kernel into memory block */
925 NNP_KERNEL_TRANSFORM_START(profile)
926 struct kernel_packing_context kernel_packing_context = {
927 .kernel = kernel + reduction_block_start,
928 .packed_kernel = (void*) workspace_buffer + output_channels * reduction_block_start * sizeof(float),
929 .reduction_size = reduction_size,
930 .reduction_block_start = reduction_block_start,
931 .reduction_block_size = reduction_block_size,
932 };
933 pthreadpool_parallelize_2d_tile_2d(threadpool,
934 (pthreadpool_task_2d_tile_2d_t) compute_kernel_packing,
935 &kernel_packing_context,
936 output_channels, reduction_block_size,
937 output_channels_subblock_max, 1,
938 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
939 NNP_KERNEL_TRANSFORM_END(profile)
940 }
941 break;
942 }
943 default:
944 return nnp_status_invalid_transform_strategy;
945 }
946
947 if (memory_block != workspace_buffer) {
948 release_memory(memory_block, memory_size);
949 }
950 return status;
951}
952
953static enum nnp_status compute_direct_convolution_inference(
954 const size_t input_channels,
955 const size_t output_channels,
956 const struct nnp_size image_size,
957 const struct nnp_size kernel_size,
958 const float* input,
959 const float* kernel,
960 const float* bias,
961 float* output,
962 void* workspace_buffer,
963 size_t* workspace_size,
964 enum nnp_activation activation,
965 pthreadpool_t threadpool,
966 struct nnp_profile* profile)
967{
968 const size_t image_elements = image_size.height * image_size.width;
969
970 if (workspace_buffer == NULL && workspace_size != NULL) {
971 *workspace_size = 0;
972 return nnp_status_success;
973 }
974
975 NNP_BLOCK_MULTIPLICATION_START(profile)
976 struct direct_convolution_context direct_convolution_context = {
977 .input = input,
978 .kernel = kernel,
979 .output = output,
980 .image_elements = image_elements,
981 .input_channels = input_channels,
982 .input_channels_block_max = nnp_hwinfo.conv1x1.mr,
983 .output_channels_block_max = nnp_hwinfo.conv1x1.nr,
984 .fast_conv = nnp_hwinfo.conv1x1.only_mr_x_nr,
985 .full_conv = nnp_hwinfo.conv1x1.upto_mr_x_nr,
986 };
987 pthreadpool_parallelize_1d_tile_1d(threadpool,
988 (pthreadpool_task_1d_tile_1d_t) compute_direct_convolution,
989 &direct_convolution_context,
990 output_channels, nnp_hwinfo.conv1x1.nr,
991 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
992 NNP_BLOCK_MULTIPLICATION_END(profile)
993
994 /* Add bias */
995 NNP_OUTPUT_TRANSFORM_START(profile)
996 switch (activation) {
997 case nnp_activation_identity:
998 for (size_t output_channel = 0; output_channel < output_channels; output_channel += 1) {
999 const float bias_value = bias[output_channel];
1000 for (size_t index = 0; index < image_elements; index += 1) {
1001 output[output_channel * image_elements + index] += bias_value;
1002 }
1003 }
1004 break;
1005 case nnp_activation_relu:
1006 for (size_t output_channel = 0; output_channel < output_channels; output_channel += 1) {
1007 const float bias_value = bias[output_channel];
1008 for (size_t index = 0; index < image_elements; index += 1) {
1009 output[output_channel * image_elements + index] =
1010 relu(output[output_channel * image_elements + index] + bias_value, 0.0f);
1011 }
1012 }
1013 break;
1014 default:
1015 NNP_UNREACHABLE;
1016 }
1017 NNP_OUTPUT_TRANSFORM_END(profile)
1018
1019 return nnp_status_success;
1020}
1021
1022static inline enum nnp_convolution_algorithm select_algorithm(
1023 struct nnp_size kernel_size,
1024 struct nnp_size output_subsampling,
1025 struct nnp_size output_size)
1026{
1027 if (max(output_subsampling.height, output_subsampling.width) == 1) {
1028 /* Stride-1 convolution: consider fast convolution algorithm and direct 1x1 */
1029 if (max(kernel_size.height, kernel_size.width) == 1) {
1030 return nnp_convolution_algorithm_direct;
1031 } else if (kernel_size.height == 3 && kernel_size.width == 3) {
1032 return nnp_convolution_algorithm_wt8x8;
1033 } else if (min(kernel_size.height, kernel_size.width) >= 2) {
1034 /* Consider FFT-based fast convolution */
1035 if (max(kernel_size.height, kernel_size.width) <= 8) {
1036 /* Decide between FFT 8x8 and FFT 16x16 */
1037 const size_t tile_count_8x8 =
1038 divide_round_up(output_size.height, 8 - kernel_size.height + 1) *
1039 divide_round_up(output_size.width, 8 - kernel_size.width + 1);
1040 const size_t tile_count_16x16 =
1041 divide_round_up(output_size.height, 16 - kernel_size.height + 1) *
1042 divide_round_up(output_size.width, 16 - kernel_size.width + 1);
1043 if (tile_count_8x8 <= 4 * tile_count_16x16) {
1044 /* 8x8 tiles are more efficient */
1045 return nnp_convolution_algorithm_ft8x8;
1046 } else {
1047 return nnp_convolution_algorithm_ft16x16;
1048 }
1049 } else if (max(kernel_size.height, kernel_size.width) <= 16) {
1050 return nnp_convolution_algorithm_ft16x16;
1051 }
1052 }
1053 }
1054
1055 /* Fall-back algorithm */
1056 return nnp_convolution_algorithm_implicit_gemm;
1057}
1058
1059enum nnp_status nnp_convolution_inference(
1060 enum nnp_convolution_algorithm algorithm,
1061 enum nnp_convolution_transform_strategy transform_strategy,
1062 size_t input_channels,
1063 size_t output_channels,
1064 struct nnp_size input_size,
1065 struct nnp_padding input_padding,
1066 struct nnp_size kernel_size,
1067 struct nnp_size output_subsampling,
1068 const float* input,
1069 const float* kernel,
1070 const float* bias,
1071 float* output,
1072 void* workspace_buffer,
1073 size_t* workspace_size,
1074 enum nnp_activation activation,
1075 const void* activation_parameters,
1076 pthreadpool_t threadpool,
1077 struct nnp_profile* profile)
1078{
1079 NNP_TOTAL_START(profile)
1080
1081 /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */
1082 enum nnp_status status = validate_convolution_arguments(
1083 1, input_channels, output_channels,
1084 input_size, input_padding, kernel_size, output_subsampling,
1085 activation, activation_parameters);
1086 if (status != nnp_status_success) {
1087 goto cleanup;
1088 }
1089
1090 if (activation_parameters != NULL) {
1091 status = nnp_status_unsupported_activation_parameters;
1092 goto cleanup;
1093 }
1094
1095 const struct nnp_size output_size = {
1096 .width = (input_padding.left + input_size.width + input_padding.right - kernel_size.width) / output_subsampling.width + 1,
1097 .height = (input_padding.top + input_size.height + input_padding.bottom - kernel_size.height) / output_subsampling.height + 1
1098 };
1099
1100 if (algorithm == nnp_convolution_algorithm_auto) {
1101 algorithm = select_algorithm(kernel_size, output_subsampling, output_size);
1102 }
1103
1104 struct nnp_size tile_size;
1105 size_t transform_element_size;
1106 bool fourier_transform;
1107 nnp_transform_2d_with_offset input_transform_function = NULL;
1108 nnp_transform_2d_with_offset kernel_transform_function = NULL;
1109 nnp_transform_2d_with_bias output_transform_function = NULL;
1110 switch (algorithm) {
1111 case nnp_convolution_algorithm_wt8x8_fp16:
1112 #if NNP_BACKEND_ARM
1113 if (kernel_size.height != 3 || kernel_size.width != 3) {
1114 status = nnp_status_unsupported_algorithm;
1115 goto cleanup;
1116 }
1117 if (max(output_subsampling.height, output_subsampling.width) > 1) {
1118 status = nnp_status_unsupported_algorithm;
1119 goto cleanup;
1120 }
1121 tile_size = (struct nnp_size) { .height = 8, .width = 8 };
1122 transform_element_size = sizeof(uint16_t);
1123 fourier_transform = false;
1124
1125 input_transform_function = nnp_hwinfo.transforms.iwt_f6x6_3x3_fp16_with_offset;
1126 kernel_transform_function = nnp_hwinfo.transforms.kwt_f6x6_3x3_fp16;
1127 switch (activation) {
1128 case nnp_activation_identity:
1129 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_fp16_with_bias;
1130 break;
1131 case nnp_activation_relu:
1132 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_fp16_with_bias_with_relu;
1133 break;
1134 default:
1135 NNP_UNREACHABLE;
1136 }
1137 if (input_transform_function != NULL && kernel_transform_function != NULL && output_transform_function != NULL) {
1138 break;
1139 }
1140 #endif
1141 /*
1142 * Fallthrough otherwise. The rationale here is that only some backends have fp16 storage natively implemented
1143 * (e.g. ARM NEON + VFP_FP16 currently), while configuration is (currently) fairly platform-independent.
1144 * Thus silently falling back to the baseline Winograd implementation is reasonable.
1145 */
1146 case nnp_convolution_algorithm_wt8x8:
1147 if (kernel_size.height != 3 || kernel_size.width != 3) {
1148 status = nnp_status_unsupported_algorithm;
1149 goto cleanup;
1150 }
1151 tile_size = (struct nnp_size) { .height = 8, .width = 8 };
1152 transform_element_size = sizeof(float);
1153 fourier_transform = false;
1154
1155 input_transform_function = nnp_hwinfo.transforms.iwt_f6x6_3x3_with_offset_and_stream;
1156 kernel_transform_function = nnp_hwinfo.transforms.kwt_f6x6_3x3;
1157 switch (activation) {
1158 case nnp_activation_identity:
1159 if (output_subsampling.height == 1 && output_subsampling.width == 1) {
1160 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_with_bias;
1161 } else if (output_subsampling.height == 2 && output_subsampling.width == 2) {
1162 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3s2_with_bias;
1163 }
1164 break;
1165 case nnp_activation_relu:
1166 if (output_subsampling.height == 1 && output_subsampling.width == 1) {
1167 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_with_bias_with_relu;
1168 } else if (output_subsampling.height == 2 && output_subsampling.width == 2) {
1169 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3s2_with_bias_with_relu;
1170 }
1171 break;
1172 default:
1173 NNP_UNREACHABLE;
1174 }
1175 break;
1176 case nnp_convolution_algorithm_ft8x8:
1177 if (max(kernel_size.height, kernel_size.width) > 8) {
1178 status = nnp_status_unsupported_algorithm;
1179 goto cleanup;
1180 }
1181 if (max(output_subsampling.height, output_subsampling.width) > 1) {
1182 status = nnp_status_unsupported_algorithm;
1183 goto cleanup;
1184 }
1185 tile_size = (struct nnp_size) { .height = 8, .width = 8 };
1186 transform_element_size = sizeof(float);
1187 fourier_transform = true;
1188
1189 input_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream;
1190 kernel_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream;
1191 switch (activation) {
1192 case nnp_activation_identity:
1193 output_transform_function = nnp_hwinfo.transforms.ifft8x8_with_bias;
1194 break;
1195 case nnp_activation_relu:
1196 output_transform_function = nnp_hwinfo.transforms.ifft8x8_with_bias_with_relu;
1197 break;
1198 default:
1199 NNP_UNREACHABLE;
1200 }
1201 break;
1202 case nnp_convolution_algorithm_ft16x16:
1203 if (max(kernel_size.height, kernel_size.width) > 16) {
1204 status = nnp_status_unsupported_algorithm;
1205 goto cleanup;
1206 }
1207 if (max(output_subsampling.height, output_subsampling.width) > 1) {
1208 status = nnp_status_unsupported_algorithm;
1209 goto cleanup;
1210 }
1211 tile_size = (struct nnp_size) { .height = 16, .width = 16 };
1212 transform_element_size = sizeof(float);
1213 fourier_transform = true;
1214
1215 input_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream;
1216 kernel_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream;
1217 switch (activation) {
1218 case nnp_activation_identity:
1219 output_transform_function = nnp_hwinfo.transforms.ifft16x16_with_bias;
1220 break;
1221 case nnp_activation_relu:
1222 output_transform_function = nnp_hwinfo.transforms.ifft16x16_with_bias_with_relu;
1223 break;
1224 default:
1225 NNP_UNREACHABLE;
1226 }
1227 break;
1228 case nnp_convolution_algorithm_implicit_gemm:
1229 break;
1230 case nnp_convolution_algorithm_direct:
1231 if (max(kernel_size.height, kernel_size.width) > 1) {
1232 status = nnp_status_unsupported_algorithm;
1233 goto cleanup;
1234 }
1235 if (max(output_subsampling.height, output_subsampling.width) > 1) {
1236 status = nnp_status_unsupported_algorithm;
1237 goto cleanup;
1238 }
1239 break;
1240 case nnp_convolution_algorithm_auto:
1241 NNP_UNREACHABLE;
1242 default:
1243 status = nnp_status_invalid_algorithm;
1244 goto cleanup;
1245 }
1246
1247 switch (algorithm) {
1248 case nnp_convolution_algorithm_wt8x8:
1249 case nnp_convolution_algorithm_wt8x8_fp16:
1250 case nnp_convolution_algorithm_ft8x8:
1251 case nnp_convolution_algorithm_ft16x16:
1252 if (input_transform_function == NULL || kernel_transform_function == NULL || output_transform_function == NULL) {
1253 status = nnp_status_unsupported_algorithm;
1254 goto cleanup;
1255 }
1256 status = compute_fast_convolution_inference(
1257 fourier_transform, transform_strategy, transform_element_size,
1258 input_channels, output_channels,
1259 tile_size, input_size, input_padding, kernel_size, output_size, output_subsampling,
1260 input, kernel, bias, output, workspace_buffer, workspace_size,
1261 input_transform_function, kernel_transform_function, output_transform_function,
1262 threadpool, profile);
1263 break;
1264 case nnp_convolution_algorithm_implicit_gemm:
1265 status = compute_gemm_convolution_inference(
1266 transform_strategy,
1267 input_channels, output_channels,
1268 input_size, input_padding, kernel_size, output_size, output_subsampling,
1269 input, kernel, bias, output, workspace_buffer, workspace_size,
1270 activation,
1271 threadpool, profile);
1272 break;
1273 case nnp_convolution_algorithm_direct:
1274 if (transform_strategy != nnp_convolution_transform_strategy_compute) {
1275 status = nnp_status_unsupported_transform_strategy;
1276 goto cleanup;
1277 }
1278 status = compute_direct_convolution_inference(
1279 input_channels, output_channels, input_size, kernel_size,
1280 input, kernel, bias, output, workspace_buffer, workspace_size,
1281 activation,
1282 threadpool, profile);
1283 break;
1284 case nnp_convolution_algorithm_auto:
1285 NNP_UNREACHABLE;
1286 }
1287
1288cleanup:
1289 NNP_TOTAL_END(profile)
1290 return status;
1291}
1292