1#include <stdbool.h>
2#include <stdint.h>
3#include <stddef.h>
4
5#include <nnpack.h>
6#include <nnpack/macros.h>
7#include <nnpack/utils.h>
8#include <nnpack/system.h>
9
10#include <nnpack/hwinfo.h>
11#include <nnpack/validation.h>
12
13
14struct NNP_CACHE_ALIGN kernel_transform_context {
15 nnp_transform_2d_with_offset transform_function;
16 const float* kernel;
17 float* kernel_transform;
18
19 size_t tuple_elements;
20 size_t output_channels;
21 size_t input_channels;
22 size_t input_channels_block_max;
23 struct nnp_size kernel_size;
24};
25
26static void compute_kernel_transform(
27 const struct kernel_transform_context context[restrict static 1],
28 size_t input_channel, size_t output_channels_subblock_start,
29 size_t input_channel_range, size_t output_channels_subblock_size)
30{
31 const size_t tuple_elements = context->tuple_elements;
32 const size_t output_channels = context->output_channels;
33 const size_t input_channels = context->input_channels;
34 const size_t input_channels_block_max = context->input_channels_block_max;
35 const struct nnp_size kernel_size = context->kernel_size;
36
37 const float (*kernel)[input_channels][kernel_size.width * kernel_size.height] =
38 (const float(*)[input_channels][kernel_size.width * kernel_size.height]) context->kernel;
39 float* kernel_transform = context->kernel_transform;
40 nnp_transform_2d_with_offset transform_function = context->transform_function;
41
42 const size_t input_channels_block_start = round_down(input_channel, input_channels_block_max);
43 const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max);
44 const size_t input_channels_block_offset = input_channel - input_channels_block_start;
45
46 for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) {
47 const size_t output_channel = output_channels_subblock_start + output_channels_subblock_offset;
48 transform_function(
49 kernel[output_channel][input_channel],
50 kernel_transform +
51 (input_channels_block_start * output_channels + output_channels_subblock_start * input_channels_block_size + input_channels_block_offset * output_channels_subblock_size + output_channels_subblock_offset) * tuple_elements,
52 kernel_size.width,
53 output_channels * input_channels * tuple_elements * sizeof(float),
54 kernel_size.height, kernel_size.width, 0, 0);
55 }
56}
57
58struct NNP_CACHE_ALIGN input_transform_context {
59 nnp_transform_2d_with_offset transform_function;
60 const float* input;
61 float* input_transform;
62
63 size_t tuple_elements;
64 size_t batch_size;
65 size_t input_channels;
66 size_t input_channels_block_max;
67 struct nnp_size input_size;
68 size_t row_offset;
69 size_t row_count;
70 size_t column_offset;
71 size_t column_count;
72};
73
74static void compute_input_transform(
75 const struct input_transform_context context[restrict static 1],
76 size_t input_channel, size_t batch_subblock_start,
77 size_t input_channel_range, size_t batch_subblock_size)
78{
79 const size_t tuple_elements = context->tuple_elements;
80 const size_t batch_size = context->batch_size;
81 const size_t input_channels = context->input_channels;
82 const size_t input_channels_block_max = context->input_channels_block_max;
83 const struct nnp_size input_size = context->input_size;
84 const size_t row_offset = context->row_offset;
85 const size_t row_count = context->row_count;
86 const size_t column_offset = context->column_offset;
87 const size_t column_count = context->column_count;
88
89 const float (*input)[input_channels][input_size.width * input_size.height] =
90 (const float(*)[input_channels][input_size.width * input_size.height]) context->input;
91 float* input_transform = context->input_transform;
92 nnp_transform_2d_with_offset transform_function = context->transform_function;
93
94 const size_t input_channels_block_start = round_down(input_channel, input_channels_block_max);
95 const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max);
96 const size_t input_channels_block_offset = input_channel - input_channels_block_start;
97
98 for (size_t batch_subblock_offset = 0; batch_subblock_offset < batch_subblock_size; batch_subblock_offset += 1) {
99 const size_t sample = batch_subblock_start + batch_subblock_offset;
100 transform_function(
101 input[sample][input_channel],
102 input_transform +
103 (input_channels_block_start * batch_size + batch_subblock_start * input_channels_block_size + input_channels_block_offset * batch_subblock_size + batch_subblock_offset) * tuple_elements,
104 input_size.width,
105 batch_size * input_channels * tuple_elements * sizeof(float),
106 row_count, column_count, row_offset, column_offset);
107 }
108}
109
110struct NNP_CACHE_ALIGN output_transform_context {
111 nnp_transform_2d_with_bias transform_function;
112 float* output;
113 const float* output_transform;
114 const float* bias;
115
116 size_t tuple_elements;
117 size_t output_channels;
118 size_t batch_size;
119 size_t batch_block_max;
120 struct nnp_size output_size;
121 size_t row_offset;
122 size_t row_count;
123 size_t column_offset;
124 size_t column_count;
125};
126
127static void compute_output_transform(
128 const struct output_transform_context context[restrict static 1],
129 size_t sample, size_t output_channels_subblock_start,
130 size_t sample_range, size_t output_channels_subblock_size)
131{
132 const size_t tuple_elements = context->tuple_elements;
133 const size_t batch_size = context->batch_size;
134 const size_t output_channels = context->output_channels;
135 const size_t batch_block_max = context->batch_block_max;
136 const struct nnp_size output_size = context->output_size;
137 const size_t row_offset = context->row_offset;
138 const size_t row_count = context->row_count;
139 const size_t column_offset = context->column_offset;
140 const size_t column_count = context->column_count;
141
142 float (*output)[output_channels][output_size.width * output_size.height] =
143 (float(*)[output_channels][output_size.width * output_size.height]) context->output;
144 const float* output_transform = context->output_transform;
145 const float* bias = context->bias;
146 nnp_transform_2d_with_bias transform_function = context->transform_function;
147
148 const size_t batch_block_start = round_down(sample, batch_block_max);
149 const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max);
150 const size_t batch_block_offset = sample - batch_block_start;
151
152 for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) {
153 const size_t output_channel = output_channels_subblock_start + output_channels_subblock_offset;
154 transform_function(
155 output_transform +
156 (batch_block_start * output_channels + output_channels_subblock_start * batch_block_size + batch_block_offset * output_channels_subblock_size + output_channels_subblock_offset) * tuple_elements,
157 output[sample][output_channel],
158 &bias[output_channel],
159 batch_size * output_channels * tuple_elements * sizeof(float),
160 output_size.width,
161 row_count, column_count);
162 }
163}
164
165struct NNP_CACHE_ALIGN matrix_multiplication_context {
166 size_t tuple_elements;
167 size_t batch_block_size;
168 size_t input_channels_block_start;
169 size_t input_channels_block_size;
170 size_t batch_subblock_max;
171 size_t output_channels_subblock_max;
172
173 const float* input_transform;
174 const float* kernel_transform;
175 float* output_transform;
176
177 nnp_fast_tuple_gemm_function fast_gemm;
178 nnp_full_tuple_gemm_function full_gemm;
179};
180
181static void compute_matrix_multiplication(
182 const struct matrix_multiplication_context context[restrict static 1],
183 size_t output_channels_block_start, size_t batch_subblock_start,
184 size_t output_channels_block_size, size_t batch_subblock_size)
185{
186 const size_t tuple_elements = context->tuple_elements;
187 const size_t batch_block_size = context->batch_block_size;
188 const size_t input_channels_block_start = context->input_channels_block_start;
189 const size_t input_channels_block_size = context->input_channels_block_size;
190 const size_t batch_subblock_max = context->batch_subblock_max;
191 const size_t output_channels_subblock_max = context->output_channels_subblock_max;
192
193 const float* input_transform = context->input_transform +
194 (batch_subblock_start * input_channels_block_size * tuple_elements);
195 const float* kernel_transform = context->kernel_transform +
196 (output_channels_block_start * input_channels_block_size * tuple_elements);
197 float* output_transform = context->output_transform +
198 (output_channels_block_start * batch_block_size * tuple_elements);
199
200 if (batch_subblock_size == batch_subblock_max) {
201 const nnp_fast_tuple_gemm_function fast_gemm = context->fast_gemm;
202 while (output_channels_block_size >= output_channels_subblock_max) {
203 output_channels_block_size -= output_channels_subblock_max;
204
205 fast_gemm(
206 input_channels_block_size, input_channels_block_start,
207 input_transform,
208 kernel_transform,
209 output_transform + (batch_subblock_start * output_channels_subblock_max * tuple_elements),
210 output_channels_subblock_max * tuple_elements);
211
212 kernel_transform += input_channels_block_size * output_channels_subblock_max * tuple_elements;
213 output_transform += batch_block_size * output_channels_subblock_max * tuple_elements;
214 }
215 }
216
217 const nnp_full_tuple_gemm_function full_gemm = context->full_gemm;
218 while (output_channels_block_size != 0) {
219 const size_t output_channels_subblock_size = min(output_channels_block_size, output_channels_subblock_max);
220 output_channels_block_size -= output_channels_subblock_size;
221
222 full_gemm(
223 batch_subblock_size, output_channels_subblock_size,
224 input_channels_block_size, input_channels_block_start,
225 input_transform,
226 kernel_transform,
227 output_transform + (batch_subblock_start * output_channels_subblock_size * tuple_elements),
228 output_channels_subblock_size * tuple_elements);
229
230 kernel_transform += input_channels_block_size * output_channels_subblock_max * tuple_elements;
231 output_transform += batch_block_size * output_channels_subblock_max * tuple_elements;
232 }
233}
234
235static enum nnp_status compute_fast_convolution_output(
236 bool fourier_transform,
237 size_t batch_size,
238 size_t input_channels,
239 size_t output_channels,
240 struct nnp_size tile_size,
241 struct nnp_size input_size,
242 struct nnp_padding input_padding,
243 struct nnp_size kernel_size,
244 struct nnp_size output_size,
245 const float* input,
246 const float* kernel,
247 const float* bias,
248 float* output,
249 void* workspace_buffer,
250 size_t* workspace_size,
251 const nnp_transform_2d_with_offset input_transform_function,
252 const nnp_transform_2d_with_offset kernel_transform_function,
253 const nnp_transform_2d_with_bias output_transform_function,
254 pthreadpool_t threadpool,
255 struct nnp_profile* profile)
256{
257 void* memory_block = NULL;
258 const size_t simd_width = nnp_hwinfo.simd_width;
259 const size_t tuple_elements = (fourier_transform ? simd_width * 2 : simd_width);
260 const size_t tile_elements = tile_size.height * tile_size.width;
261 const size_t tuple_count = tile_elements / tuple_elements;
262
263 const struct nnp_size output_tile_size = {
264 .height = tile_size.height - kernel_size.height + 1,
265 .width = tile_size.width - kernel_size.width + 1
266 };
267
268 /* Calculate cache blocking parameters */
269 const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / (tuple_elements * sizeof(float));
270 const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / (tuple_elements * sizeof(float));
271 const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / (tuple_elements * sizeof(float));
272
273 const size_t batch_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.mr : nnp_hwinfo.sxgemm.mr);
274 const size_t output_channels_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.nr : nnp_hwinfo.sxgemm.nr);
275
276 const size_t input_channels_block_max =
277 round_down(cache_elements_l1 / (batch_subblock_max + output_channels_subblock_max), 2);
278 const size_t batch_block_max =
279 round_down(cache_elements_l3 / input_channels_block_max, batch_subblock_max);
280 const size_t output_channels_block_max =
281 round_down(cache_elements_l2 / input_channels_block_max, output_channels_subblock_max);
282
283 /* Calculate memory footprint and allocate memory */
284 const size_t kernel_transform_size = output_channels * input_channels * tile_elements * sizeof(float);
285 const size_t input_transform_size = batch_size * input_channels * tile_elements * sizeof(float);
286 const size_t output_transform_size = batch_size * output_channels * tile_elements * sizeof(float);
287 const size_t memory_size = kernel_transform_size + input_transform_size + output_transform_size;
288
289 if (workspace_buffer == NULL) {
290 if (workspace_size == NULL) {
291 memory_block = allocate_memory(memory_size);
292 if (memory_block == NULL) {
293 return nnp_status_out_of_memory;
294 }
295 } else {
296 *workspace_size = memory_size;
297 return nnp_status_success;
298 }
299 } else {
300 if (*workspace_size < memory_size) {
301 return nnp_status_insufficient_buffer;
302 }
303 memory_block = workspace_buffer;
304 }
305
306 float* input_transform = memory_block;
307 float* output_transform = memory_block + input_transform_size;
308 float* kernel_transform = memory_block + input_transform_size + output_transform_size;
309
310 NNP_KERNEL_TRANSFORM_START(profile)
311 struct kernel_transform_context kernel_transform_context = {
312 .transform_function = kernel_transform_function,
313 .kernel = kernel,
314 .kernel_transform = kernel_transform,
315 .tuple_elements = tuple_elements,
316 .output_channels = output_channels,
317 .input_channels = input_channels,
318 .input_channels_block_max = input_channels_block_max,
319 .kernel_size = kernel_size,
320 };
321 pthreadpool_parallelize_2d_tile_2d(threadpool,
322 (pthreadpool_task_2d_tile_2d_t) compute_kernel_transform,
323 &kernel_transform_context,
324 input_channels, output_channels,
325 1, output_channels_subblock_max,
326 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
327 NNP_KERNEL_TRANSFORM_END(profile)
328
329 for (size_t y = 0; y < output_size.height; y += output_tile_size.height) {
330 const size_t input_y = min(doz(y, input_padding.top), input_size.height);
331 for (size_t x = 0; x < output_size.width; x += output_tile_size.width) {
332 const size_t input_x = min(doz(x, input_padding.left), input_size.width);
333
334 NNP_INPUT_TRANSFORM_START(profile)
335 struct input_transform_context input_transform_context = {
336 .transform_function = input_transform_function,
337 .input = input + input_y * input_size.width + input_x,
338 .input_transform = input_transform,
339 .tuple_elements = tuple_elements,
340 .batch_size = batch_size,
341 .input_channels = input_channels,
342 .input_channels_block_max = input_channels_block_max,
343 .input_size = input_size,
344 .row_offset = doz(input_padding.top, y),
345 .row_count = min(input_size.height - input_y,
346 tile_size.height - input_transform_context.row_offset),
347 .column_offset = doz(input_padding.left, x),
348 .column_count = min(input_size.width - input_x,
349 tile_size.width - input_transform_context.column_offset),
350 };
351 pthreadpool_parallelize_2d_tile_2d(threadpool,
352 (pthreadpool_task_2d_tile_2d_t) compute_input_transform,
353 &input_transform_context,
354 input_channels, batch_size,
355 1, batch_subblock_max,
356 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
357 NNP_INPUT_TRANSFORM_END(profile)
358
359 NNP_BLOCK_MULTIPLICATION_START(profile)
360 for (size_t tuple_index = 0; tuple_index < tuple_count; tuple_index += 1) {
361 for (size_t input_channels_block_start = 0; input_channels_block_start < input_channels; input_channels_block_start += input_channels_block_max) {
362 const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max);
363 for (size_t batch_block_start = 0; batch_block_start < batch_size; batch_block_start += batch_block_max) {
364 const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max);
365 struct matrix_multiplication_context matrix_multiplication_context = {
366 .tuple_elements = tuple_elements,
367 .batch_block_size = batch_block_size,
368 .input_channels_block_start = input_channels_block_start,
369 .input_channels_block_size = input_channels_block_size,
370 .batch_subblock_max = batch_subblock_max,
371 .output_channels_subblock_max = output_channels_subblock_max,
372 .input_transform = input_transform +
373 tuple_index * tuple_elements * batch_size * input_channels +
374 input_channels_block_start * batch_size * tuple_elements +
375 batch_block_start * input_channels_block_size * tuple_elements,
376 .kernel_transform = kernel_transform +
377 tuple_index * tuple_elements * output_channels * input_channels +
378 input_channels_block_start * output_channels * tuple_elements,
379 .output_transform = output_transform + tuple_index * tuple_elements * batch_size * output_channels +
380 batch_block_start * output_channels * tuple_elements,
381 };
382 if (fourier_transform) {
383 if (tuple_index < NNP_COMPLEX_TUPLE_INDEX) {
384 matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.s4cX_conjb_only_mr_x_nr;
385 matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.s4cX_conjb_upto_mr_x_nr;
386 } else {
387 matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.cX_conjb_only_mr_x_nr;
388 matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.cX_conjb_upto_mr_x_nr;
389 }
390 } else {
391 matrix_multiplication_context.fast_gemm = nnp_hwinfo.sxgemm.only_mr_x_nr;
392 matrix_multiplication_context.full_gemm = nnp_hwinfo.sxgemm.upto_mr_x_nr;
393 }
394 pthreadpool_parallelize_2d_tile_2d(threadpool,
395 (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication,
396 &matrix_multiplication_context,
397 output_channels, batch_block_size,
398 output_channels_block_max, batch_subblock_max,
399 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
400 }
401 }
402 }
403 NNP_BLOCK_MULTIPLICATION_END(profile)
404
405 NNP_OUTPUT_TRANSFORM_START(profile)
406 struct output_transform_context output_transform_context = {
407 .transform_function = output_transform_function,
408 .output = output + y * output_size.width + x,
409 .output_transform = output_transform,
410 .bias = bias,
411 .tuple_elements = tuple_elements,
412 .output_channels = output_channels,
413 .batch_size = batch_size,
414 .batch_block_max = batch_block_max,
415 .output_size = output_size,
416 .row_count = min(output_tile_size.height, output_size.height - y),
417 .column_count = min(output_tile_size.width, output_size.width - x),
418 };
419 pthreadpool_parallelize_2d_tile_2d(threadpool,
420 (pthreadpool_task_2d_tile_2d_t) compute_output_transform,
421 &output_transform_context,
422 batch_size, output_channels,
423 1, output_channels_subblock_max,
424 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
425 NNP_OUTPUT_TRANSFORM_END(profile)
426 }
427 }
428
429 if (memory_block != workspace_buffer) {
430 release_memory(memory_block, memory_size);
431 }
432 return nnp_status_success;
433}
434
435enum nnp_status nnp_convolution_output(
436 enum nnp_convolution_algorithm algorithm,
437 size_t batch_size,
438 size_t input_channels,
439 size_t output_channels,
440 struct nnp_size input_size,
441 struct nnp_padding input_padding,
442 struct nnp_size kernel_size,
443 const float* input,
444 const float* kernel,
445 const float* bias,
446 float* output,
447 void* workspace_buffer,
448 size_t* workspace_size,
449 enum nnp_activation activation,
450 const void* activation_parameters,
451 pthreadpool_t threadpool,
452 struct nnp_profile* profile)
453{
454 NNP_TOTAL_START(profile)
455
456 /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */
457 enum nnp_status status = validate_convolution_arguments(
458 batch_size, input_channels, output_channels,
459 input_size, input_padding, kernel_size, (struct nnp_size) { 1, 1 },
460 activation, activation_parameters);
461 if (status != nnp_status_success) {
462 goto cleanup;
463 }
464
465 if (activation_parameters != NULL) {
466 status = nnp_status_unsupported_activation_parameters;
467 goto cleanup;
468 }
469
470 const struct nnp_size output_size = {
471 .width = input_padding.left + input_size.width + input_padding.right - kernel_size.width + 1,
472 .height = input_padding.top + input_size.height + input_padding.bottom - kernel_size.height + 1
473 };
474
475 /* If requested, choose optimal convolution algorithm */
476 if (algorithm == nnp_convolution_algorithm_auto) {
477 if (max(kernel_size.width, kernel_size.height) > 8) {
478 algorithm = nnp_convolution_algorithm_ft16x16;
479 } else {
480 const size_t tile_count_8x8 =
481 divide_round_up(output_size.height, 8 - kernel_size.height + 1) *
482 divide_round_up(output_size.width, 8 - kernel_size.width + 1);
483 const size_t tile_count_16x16 =
484 divide_round_up(output_size.height, 16 - kernel_size.height + 1) *
485 divide_round_up(output_size.width, 16 - kernel_size.width + 1);
486 if (tile_count_8x8 <= 4 * tile_count_16x16) {
487 /* 8x8 tiles are more efficient */
488 if ((kernel_size.height == 3) && (kernel_size.width == 3)) {
489 algorithm = nnp_convolution_algorithm_wt8x8;
490 } else {
491 algorithm = nnp_convolution_algorithm_ft8x8;
492 }
493 } else {
494 algorithm = nnp_convolution_algorithm_ft16x16;
495 }
496 }
497 }
498
499 /* Choose tiling parameters and transform functions depending on convolution algorithm */
500 struct nnp_size tile_size;
501 bool fourier_transform;
502 nnp_transform_2d_with_offset input_transform_function;
503 nnp_transform_2d_with_offset kernel_transform_function;
504 nnp_transform_2d_with_bias output_transform_function;
505 switch (algorithm) {
506 case nnp_convolution_algorithm_ft8x8:
507 input_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream;
508 kernel_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream;
509 switch (activation) {
510 case nnp_activation_relu:
511 output_transform_function = nnp_hwinfo.transforms.ifft8x8_with_bias_with_relu;
512 break;
513 case nnp_activation_identity:
514 output_transform_function = nnp_hwinfo.transforms.ifft8x8_with_bias;
515 break;
516 default:
517 NNP_UNREACHABLE;
518 }
519 tile_size = (struct nnp_size) { .height = 8, .width = 8 };
520 fourier_transform = true;
521 break;
522 case nnp_convolution_algorithm_ft16x16:
523 input_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream;
524 kernel_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream;
525 switch (activation) {
526 case nnp_activation_relu:
527 output_transform_function = nnp_hwinfo.transforms.ifft16x16_with_bias_with_relu;
528 break;
529 case nnp_activation_identity:
530 output_transform_function = nnp_hwinfo.transforms.ifft16x16_with_bias;
531 break;
532 default:
533 NNP_UNREACHABLE;
534 }
535 tile_size = (struct nnp_size) { .height = 16, .width = 16 };
536 fourier_transform = true;
537 break;
538 case nnp_convolution_algorithm_wt8x8:
539 case nnp_convolution_algorithm_wt8x8_fp16:
540 if ((kernel_size.height != 3) || (kernel_size.width != 3)) {
541 status = nnp_status_unsupported_algorithm;
542 goto cleanup;
543 }
544 input_transform_function = nnp_hwinfo.transforms.iwt_f6x6_3x3_with_offset_and_stream;
545 kernel_transform_function = nnp_hwinfo.transforms.kwt_f6x6_3x3;
546 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_with_bias;
547 switch (activation) {
548 case nnp_activation_relu:
549 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_with_bias_with_relu;
550 break;
551 case nnp_activation_identity:
552 output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_with_bias;
553 break;
554 default:
555 NNP_UNREACHABLE;
556 }
557 tile_size = (struct nnp_size) { .height = 8, .width = 8 };
558 fourier_transform = false;
559 break;
560 case nnp_convolution_algorithm_implicit_gemm:
561 case nnp_convolution_algorithm_direct:
562 status = nnp_status_unsupported_algorithm;
563 goto cleanup;
564 case nnp_convolution_algorithm_auto:
565 NNP_UNREACHABLE;
566 default:
567 status = nnp_status_invalid_algorithm;
568 goto cleanup;
569 }
570
571 switch (algorithm) {
572 case nnp_convolution_algorithm_wt8x8:
573 case nnp_convolution_algorithm_wt8x8_fp16:
574 case nnp_convolution_algorithm_ft8x8:
575 case nnp_convolution_algorithm_ft16x16:
576 if (kernel_size.height > tile_size.height || kernel_size.width > tile_size.width) {
577 status = nnp_status_unsupported_algorithm;
578 goto cleanup;
579 }
580 status = compute_fast_convolution_output(
581 fourier_transform,
582 batch_size, input_channels, output_channels,
583 tile_size, input_size, input_padding, kernel_size, output_size,
584 input, kernel, bias, output, workspace_buffer, workspace_size,
585 input_transform_function, kernel_transform_function, output_transform_function,
586 threadpool, profile);
587 break;
588 case nnp_convolution_algorithm_implicit_gemm:
589 case nnp_convolution_algorithm_direct:
590 case nnp_convolution_algorithm_auto:
591 NNP_UNREACHABLE;
592 }
593
594cleanup:
595 NNP_TOTAL_END(profile)
596 return status;
597}
598