1#include <stdbool.h>
2#include <stdint.h>
3#include <stddef.h>
4
5#include <nnpack.h>
6#include <nnpack/macros.h>
7#include <nnpack/utils.h>
8#include <nnpack/system.h>
9
10#include <nnpack/hwinfo.h>
11#include <nnpack/validation.h>
12
13
14struct NNP_CACHE_ALIGN kernel_transform_context {
15 nnp_transform_2d_with_offset transform_function;
16 const float* kernel;
17 float* kernel_transform;
18
19 size_t tuple_elements;
20 size_t input_channels;
21 size_t output_channels;
22 size_t output_channels_block_max;
23 struct nnp_size kernel_size;
24};
25
26static void compute_kernel_transform(
27 const struct kernel_transform_context context[restrict static 1],
28 size_t output_channel, size_t input_channels_subblock_start,
29 size_t output_channel_range, size_t input_channels_subblock_size)
30{
31 const size_t tuple_elements = context->tuple_elements;
32 const size_t input_channels = context->input_channels;
33 const size_t output_channels = context->output_channels;
34 const size_t output_channels_block_max = context->output_channels_block_max;
35 const struct nnp_size kernel_size = context->kernel_size;
36
37 const float (*kernel)[input_channels][kernel_size.width * kernel_size.height] =
38 (const float(*)[input_channels][kernel_size.width * kernel_size.height]) context->kernel;
39 float* kernel_transform = context->kernel_transform;
40 nnp_transform_2d_with_offset transform_function = context->transform_function;
41
42 const size_t output_channels_block_start = round_down(output_channel, output_channels_block_max);
43 const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max);
44 const size_t output_channels_block_offset = output_channel - output_channels_block_start;
45
46 for (size_t input_channels_subblock_offset = 0; input_channels_subblock_offset < input_channels_subblock_size; input_channels_subblock_offset += 1) {
47 const size_t input_channel = input_channels_subblock_start + input_channels_subblock_offset;
48 transform_function(
49 kernel[output_channel][input_channel],
50 kernel_transform +
51 (output_channels_block_start * input_channels + input_channels_subblock_start * output_channels_block_size + output_channels_block_offset * input_channels_subblock_size + input_channels_subblock_offset) * tuple_elements,
52 kernel_size.width,
53 output_channels * input_channels * tuple_elements * sizeof(float),
54 kernel_size.height, kernel_size.width, 0, 0);
55 }
56}
57
58struct NNP_CACHE_ALIGN grad_output_transform_context {
59 nnp_transform_2d_with_offset transform_function;
60 const float* grad_output;
61 float* grad_output_transform;
62
63 size_t tuple_elements;
64 size_t batch_size;
65 size_t output_channels;
66 size_t output_channels_block_max;
67 struct nnp_size output_size;
68 size_t row_offset;
69 size_t row_count;
70 size_t column_offset;
71 size_t column_count;
72};
73
74static void compute_grad_output_transform(
75 const struct grad_output_transform_context context[restrict static 1],
76 size_t output_channel, size_t batch_subblock_start,
77 size_t output_channel_range, size_t batch_subblock_size)
78{
79 const size_t tuple_elements = context->tuple_elements;
80 const size_t batch_size = context->batch_size;
81 const size_t output_channels = context->output_channels;
82 const size_t output_channels_block_max = context->output_channels_block_max;
83 const struct nnp_size output_size = context->output_size;
84 const size_t row_offset = context->row_offset;
85 const size_t row_count = context->row_count;
86 const size_t column_offset = context->column_offset;
87 const size_t column_count = context->column_count;
88
89 const float (*grad_output)[output_channels][output_size.width * output_size.height] =
90 (const float(*)[output_channels][output_size.width * output_size.height]) context->grad_output;
91 float* grad_output_transform = context->grad_output_transform;
92 nnp_transform_2d_with_offset transform_function = context->transform_function;
93
94 const size_t output_channels_block_start = round_down(output_channel, output_channels_block_max);
95 const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max);
96 const size_t output_channels_block_offset = output_channel - output_channels_block_start;
97
98 for (size_t batch_subblock_offset = 0; batch_subblock_offset < batch_subblock_size; batch_subblock_offset += 1) {
99 const size_t sample = batch_subblock_start + batch_subblock_offset;
100 transform_function(
101 grad_output[sample][output_channel],
102 grad_output_transform +
103 (output_channels_block_start * batch_size + batch_subblock_start * output_channels_block_size + output_channels_block_offset * batch_subblock_size + batch_subblock_offset) * tuple_elements,
104 output_size.width,
105 batch_size * output_channels * tuple_elements * sizeof(float),
106 row_count, column_count, row_offset, column_offset);
107 }
108}
109
110struct NNP_CACHE_ALIGN grad_input_transform_context {
111 nnp_transform_2d_with_offset transform_function;
112 float* grad_input;
113 const float* grad_input_transform;
114
115 size_t tuple_elements;
116 size_t input_channels;
117 size_t batch_size;
118 size_t batch_block_max;
119 struct nnp_size input_size;
120 size_t row_offset;
121 size_t row_count;
122 size_t column_offset;
123 size_t column_count;
124};
125
126static void compute_grad_input_transform(
127 const struct grad_input_transform_context context[restrict static 1],
128 size_t sample, size_t input_channels_subblock_start,
129 size_t sample_range, size_t input_channels_subblock_size)
130{
131 const size_t tuple_elements = context->tuple_elements;
132 const size_t batch_size = context->batch_size;
133 const size_t input_channels = context->input_channels;
134 const size_t batch_block_max = context->batch_block_max;
135 const struct nnp_size input_size = context->input_size;
136 const size_t row_offset = context->row_offset;
137 const size_t row_count = context->row_count;
138 const size_t column_offset = context->column_offset;
139 const size_t column_count = context->column_count;
140
141 float (*grad_input)[input_channels][input_size.width * input_size.height] =
142 (float(*)[input_channels][input_size.width * input_size.height]) context->grad_input;
143 const float* grad_input_transform = context->grad_input_transform;
144 nnp_transform_2d_with_offset transform_function = context->transform_function;
145
146 const size_t batch_block_start = round_down(sample, batch_block_max);
147 const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max);
148 const size_t batch_block_offset = sample - batch_block_start;
149
150 for (size_t input_channels_subblock_offset = 0; input_channels_subblock_offset < input_channels_subblock_size; input_channels_subblock_offset += 1) {
151 const size_t input_channel = input_channels_subblock_start + input_channels_subblock_offset;
152 transform_function(
153 grad_input_transform +
154 (batch_block_start * input_channels + input_channels_subblock_start * batch_block_size + batch_block_offset * input_channels_subblock_size + input_channels_subblock_offset) * tuple_elements,
155 grad_input[sample][input_channel],
156 batch_size * input_channels * tuple_elements * sizeof(float),
157 input_size.width,
158 row_count, column_count, row_offset, column_offset);
159 }
160}
161
162struct NNP_CACHE_ALIGN matrix_multiplication_context {
163 size_t tuple_elements;
164 size_t batch_size;
165 size_t input_channels;
166 size_t batch_block_start;
167 size_t batch_block_size;
168 size_t output_channels_block_start;
169 size_t output_channels_block_size;
170 size_t batch_subblock_max;
171 size_t input_channels_subblock_max;
172
173 const float* grad_output_transform;
174 const float* kernel_transform;
175 float* grad_input_transform;
176
177 nnp_fast_tuple_gemm_function fast_gemm;
178 nnp_full_tuple_gemm_function full_gemm;
179};
180
181static void compute_matrix_multiplication(
182 const struct matrix_multiplication_context context[restrict static 1],
183 size_t input_channels_block_start, size_t batch_subblock_start,
184 size_t input_channels_block_size, size_t batch_subblock_size)
185{
186 const size_t tuple_elements = context->tuple_elements;
187 const size_t batch_size = context->batch_size;
188 const size_t input_channels = context->input_channels;
189 const size_t batch_block_start = context->batch_block_start;
190 const size_t batch_block_size = context->batch_block_size;
191 const size_t output_channels_block_start = context->output_channels_block_start;
192 const size_t output_channels_block_size = context->output_channels_block_size;
193 const size_t batch_subblock_max = context->batch_subblock_max;
194 const size_t input_channels_subblock_max = context->input_channels_subblock_max;
195
196 const float* grad_output_transform = context->grad_output_transform +
197 (output_channels_block_start * batch_size + (batch_block_start + batch_subblock_start) * output_channels_block_size) * tuple_elements;
198 const float* kernel_transform = context->kernel_transform +
199 (output_channels_block_start * input_channels + input_channels_block_start * output_channels_block_size) * tuple_elements;
200 float* grad_input_transform = context->grad_input_transform +
201 (batch_block_start * input_channels + input_channels_block_start * batch_block_size) * tuple_elements;
202
203 if (batch_subblock_size == batch_subblock_max) {
204 const nnp_fast_tuple_gemm_function fast_gemm = context->fast_gemm;
205 while (input_channels_block_size >= input_channels_subblock_max) {
206 input_channels_block_size -= input_channels_subblock_max;
207
208 fast_gemm(
209 output_channels_block_size, output_channels_block_start,
210 grad_output_transform,
211 kernel_transform,
212 grad_input_transform + batch_subblock_start * input_channels_subblock_max * tuple_elements,
213 input_channels_subblock_max * tuple_elements);
214
215 kernel_transform += input_channels_subblock_max * output_channels_block_size * tuple_elements;
216 grad_input_transform += input_channels_subblock_max * batch_block_size * tuple_elements;
217 }
218 }
219
220 const nnp_full_tuple_gemm_function full_gemm = context->full_gemm;
221 while (input_channels_block_size != 0) {
222 const size_t input_channels_subblock_size = min(input_channels_block_size, input_channels_subblock_max);
223 input_channels_block_size -= input_channels_subblock_size;
224
225 full_gemm(
226 batch_subblock_size, input_channels_subblock_size,
227 output_channels_block_size, output_channels_block_start,
228 grad_output_transform,
229 kernel_transform,
230 grad_input_transform + batch_subblock_start * input_channels_subblock_size * tuple_elements,
231 input_channels_subblock_size * tuple_elements);
232
233 kernel_transform += input_channels_subblock_max * output_channels_block_size * tuple_elements;
234 grad_input_transform += input_channels_subblock_max * batch_block_size * tuple_elements;
235 }
236}
237
238static enum nnp_status compute_fast_convolution_input_gradient(
239 bool fourier_transform,
240 size_t batch_size,
241 size_t input_channels,
242 size_t output_channels,
243 struct nnp_size tile_size,
244 struct nnp_size input_size,
245 struct nnp_padding input_padding,
246 struct nnp_size kernel_size,
247 struct nnp_size output_size,
248 const float* grad_output,
249 const float* kernel,
250 float* grad_input,
251 void* workspace_buffer,
252 size_t* workspace_size,
253 nnp_transform_2d_with_offset grad_output_transform_function,
254 nnp_transform_2d_with_offset kernel_transform_function,
255 nnp_transform_2d_with_offset grad_input_transform_function,
256 pthreadpool_t threadpool,
257 struct nnp_profile* profile)
258{
259 void* memory_block = NULL;
260 const size_t simd_width = nnp_hwinfo.simd_width;
261 const size_t tuple_elements = (fourier_transform ? simd_width * 2 : simd_width);
262 const size_t tile_elements = tile_size.height * tile_size.width;
263 const size_t tuple_count = tile_elements / tuple_elements;
264
265 const struct nnp_size grad_input_tile_size = {
266 .height = tile_size.height - kernel_size.height + 1,
267 .width = tile_size.width - kernel_size.width + 1
268 };
269
270 /* Calculate cache blocking parameters */
271 const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / (tuple_elements * sizeof(float));
272 const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / (tuple_elements * sizeof(float));
273 const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / (tuple_elements * sizeof(float));
274
275 const size_t batch_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.mr : nnp_hwinfo.sxgemm.mr);
276 const size_t input_channels_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.nr : nnp_hwinfo.sxgemm.nr);
277
278 const size_t output_channels_block_max =
279 round_down(cache_elements_l1 / (batch_subblock_max + input_channels_subblock_max), 2);
280 const size_t batch_block_max =
281 round_down(cache_elements_l3 / output_channels_block_max, batch_subblock_max);
282 const size_t input_channels_block_max =
283 round_down(cache_elements_l2 / output_channels_block_max, input_channels_subblock_max);
284
285 /* Calculate memory footprint and allocate memory */
286 const size_t kernel_transform_size = output_channels * input_channels * tile_elements * sizeof(float);
287 const size_t grad_input_transform_size = batch_size * input_channels * tile_elements * sizeof(float);
288 const size_t grad_output_transform_size = batch_size * output_channels * tile_elements * sizeof(float);
289 const size_t memory_size = kernel_transform_size + grad_input_transform_size + grad_output_transform_size;
290
291 if (workspace_buffer == NULL) {
292 if (workspace_size == NULL) {
293 memory_block = allocate_memory(memory_size);
294 if (memory_block == NULL) {
295 return nnp_status_out_of_memory;
296 }
297 } else {
298 *workspace_size = memory_size;
299 return nnp_status_success;
300 }
301 } else {
302 if (*workspace_size < memory_size) {
303 return nnp_status_insufficient_buffer;
304 }
305 memory_block = workspace_buffer;
306 }
307
308 float* grad_output_transform = memory_block;
309 float* kernel_transform = memory_block + grad_output_transform_size;
310 float* grad_input_transform = memory_block + grad_output_transform_size + kernel_transform_size;
311
312 NNP_KERNEL_TRANSFORM_START(profile)
313 struct kernel_transform_context kernel_transform_context = {
314 .transform_function = kernel_transform_function,
315 .kernel = kernel,
316 .kernel_transform = kernel_transform,
317 .tuple_elements = tuple_elements,
318 .input_channels = input_channels,
319 .output_channels = output_channels,
320 .output_channels_block_max = output_channels_block_max,
321 .kernel_size = kernel_size,
322 };
323 pthreadpool_parallelize_2d_tile_2d(threadpool,
324 (pthreadpool_task_2d_tile_2d_t) compute_kernel_transform,
325 &kernel_transform_context,
326 output_channels, input_channels,
327 1, input_channels_subblock_max,
328 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
329 NNP_KERNEL_TRANSFORM_END(profile)
330
331 for (size_t y = 0; y < input_size.height; y += grad_input_tile_size.height) {
332 const size_t grad_output_y = min(doz(y + input_padding.top, kernel_size.height - 1), output_size.height);
333 for (size_t x = 0; x < input_size.width; x += grad_input_tile_size.width) {
334 const size_t grad_output_x = min(doz(x + input_padding.left, kernel_size.width - 1), output_size.width);
335
336 NNP_OUTPUT_TRANSFORM_START(profile)
337 struct grad_output_transform_context grad_output_transform_context = {
338 .transform_function = grad_output_transform_function,
339 .grad_output = grad_output + grad_output_y * output_size.width + grad_output_x,
340 .grad_output_transform = grad_output_transform,
341 .tuple_elements = tuple_elements,
342 .batch_size = batch_size,
343 .output_channels = output_channels,
344 .output_channels_block_max = output_channels_block_max,
345 .output_size = output_size,
346 .row_offset = doz(kernel_size.height - 1, y + input_padding.top),
347 .row_count = min(output_size.height - grad_output_y,
348 tile_size.height - grad_output_transform_context.row_offset),
349 .column_offset = doz(kernel_size.width - 1, x + input_padding.left),
350 .column_count = min(output_size.width - grad_output_x,
351 tile_size.width - grad_output_transform_context.column_offset),
352 };
353 pthreadpool_parallelize_2d_tile_2d(threadpool,
354 (pthreadpool_task_2d_tile_2d_t) compute_grad_output_transform,
355 &grad_output_transform_context,
356 output_channels, batch_size,
357 1, batch_subblock_max,
358 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
359 NNP_OUTPUT_TRANSFORM_END(profile)
360
361 NNP_BLOCK_MULTIPLICATION_START(profile)
362 for (size_t tuple_index = 0; tuple_index < tuple_count; tuple_index += 1) {
363 for (size_t output_channels_block_start = 0; output_channels_block_start < output_channels; output_channels_block_start += output_channels_block_max) {
364 const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max);
365 for (size_t batch_block_start = 0; batch_block_start < batch_size; batch_block_start += batch_block_max) {
366 const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max);
367 struct matrix_multiplication_context matrix_multiplication_context = {
368 .tuple_elements = tuple_elements,
369 .batch_size = batch_size,
370 .input_channels = input_channels,
371 .batch_block_start = batch_block_start,
372 .batch_block_size = batch_block_size,
373 .output_channels_block_start = output_channels_block_start,
374 .output_channels_block_size = output_channels_block_size,
375 .batch_subblock_max = batch_subblock_max,
376 .input_channels_subblock_max = input_channels_subblock_max,
377 .grad_output_transform = grad_output_transform + tuple_index * tuple_elements * batch_size * output_channels,
378 .kernel_transform = kernel_transform + tuple_index * tuple_elements * output_channels * input_channels,
379 .grad_input_transform = grad_input_transform + tuple_index * tuple_elements * batch_size * input_channels,
380 };
381 if (fourier_transform) {
382 if (tuple_index < NNP_COMPLEX_TUPLE_INDEX) {
383 matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.s4cX_only_mr_x_nr;
384 matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.s4cX_upto_mr_x_nr;
385 } else {
386 matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.cX_only_mr_x_nr;
387 matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.cX_upto_mr_x_nr;
388 }
389 } else {
390 matrix_multiplication_context.fast_gemm = nnp_hwinfo.sxgemm.only_mr_x_nr;
391 matrix_multiplication_context.full_gemm = nnp_hwinfo.sxgemm.upto_mr_x_nr;
392 }
393 pthreadpool_parallelize_2d_tile_2d(threadpool,
394 (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication,
395 &matrix_multiplication_context,
396 input_channels, batch_block_size,
397 input_channels_block_max, batch_subblock_max,
398 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
399 }
400 }
401 }
402 NNP_BLOCK_MULTIPLICATION_END(profile)
403
404 NNP_INPUT_TRANSFORM_START(profile)
405 struct grad_input_transform_context grad_input_transform_context = {
406 .transform_function = grad_input_transform_function,
407 .grad_input = grad_input + y * input_size.width + x,
408 .grad_input_transform = grad_input_transform,
409 .tuple_elements = tuple_elements,
410 .input_channels = input_channels,
411 .batch_size = batch_size,
412 .batch_block_max = batch_block_max,
413 .input_size = input_size,
414 .row_offset = fourier_transform ? kernel_size.height - 1 : 0,
415 .row_count = min(input_size.height - y, grad_input_tile_size.height),
416 .column_offset = fourier_transform ? kernel_size.width - 1 : 0,
417 .column_count = min(input_size.width - x, grad_input_tile_size.width),
418 };
419 pthreadpool_parallelize_2d_tile_2d(threadpool,
420 (pthreadpool_task_2d_tile_2d_t) compute_grad_input_transform,
421 &grad_input_transform_context,
422 batch_size, input_channels,
423 1, input_channels_subblock_max,
424 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
425 NNP_INPUT_TRANSFORM_END(profile)
426 }
427 }
428
429 if (memory_block != workspace_buffer) {
430 release_memory(memory_block, memory_size);
431 }
432 return nnp_status_success;
433}
434
435enum nnp_status nnp_convolution_input_gradient(
436 enum nnp_convolution_algorithm algorithm,
437 size_t batch_size,
438 size_t input_channels,
439 size_t output_channels,
440 struct nnp_size input_size,
441 struct nnp_padding input_padding,
442 struct nnp_size kernel_size,
443 const float* grad_output,
444 const float* kernel,
445 float* grad_input,
446 void* workspace_buffer,
447 size_t* workspace_size,
448 enum nnp_activation activation,
449 const void* activation_parameters,
450 pthreadpool_t threadpool,
451 struct nnp_profile* profile)
452{
453 NNP_TOTAL_START(profile)
454
455 /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */
456 enum nnp_status status = validate_convolution_arguments(
457 batch_size, input_channels, output_channels,
458 input_size, input_padding, kernel_size, (struct nnp_size) { 1, 1 },
459 activation, activation_parameters);
460 if (status != nnp_status_success) {
461 goto cleanup;
462 }
463
464 if (activation != nnp_activation_identity) {
465 status = nnp_status_unsupported_activation;
466 goto cleanup;
467 }
468
469 if (activation_parameters != NULL) {
470 status = nnp_status_unsupported_activation_parameters;
471 goto cleanup;
472 }
473
474 /* If requested, choose optimal convolution algorithm */
475 if (algorithm == nnp_convolution_algorithm_auto) {
476 if (max(kernel_size.width, kernel_size.height) > 8) {
477 algorithm = nnp_convolution_algorithm_ft16x16;
478 } else {
479 const size_t tile_count_8x8 =
480 divide_round_up(input_size.height, 8 - kernel_size.height + 1) *
481 divide_round_up(input_size.width, 8 - kernel_size.width + 1);
482 const size_t tile_count_16x16 =
483 divide_round_up(input_size.height, 16 - kernel_size.height + 1) *
484 divide_round_up(input_size.width, 16 - kernel_size.width + 1);
485 if (tile_count_8x8 <= 4 * tile_count_16x16) {
486 /* 8x8 tiles are more efficient */
487 if ((kernel_size.height == 3) && (kernel_size.width == 3)) {
488 algorithm = nnp_convolution_algorithm_wt8x8;
489 } else {
490 algorithm = nnp_convolution_algorithm_ft8x8;
491 }
492 } else {
493 algorithm = nnp_convolution_algorithm_ft16x16;
494 }
495 }
496 }
497
498 /* Choose tiling parameters and transform functions depending on convolution algorithm */
499 struct nnp_size tile_size;
500 bool fourier_transform;
501 nnp_transform_2d_with_offset grad_output_transform_function;
502 nnp_transform_2d_with_offset kernel_transform_function;
503 nnp_transform_2d_with_offset grad_input_transform_function;
504 switch (algorithm) {
505 case nnp_convolution_algorithm_ft8x8:
506 grad_output_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream;
507 kernel_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream;
508 grad_input_transform_function = nnp_hwinfo.transforms.ifft8x8_with_offset;
509 tile_size = (struct nnp_size) { .height = 8, .width = 8 };
510 fourier_transform = true;
511 break;
512 case nnp_convolution_algorithm_ft16x16:
513 grad_output_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream;
514 kernel_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream;
515 grad_input_transform_function = nnp_hwinfo.transforms.ifft16x16_with_offset;
516 tile_size = (struct nnp_size) { .height = 16, .width = 16 };
517 fourier_transform = true;
518 break;
519 case nnp_convolution_algorithm_wt8x8:
520 case nnp_convolution_algorithm_wt8x8_fp16:
521 if ((kernel_size.height != 3) || (kernel_size.width != 3)) {
522 status = nnp_status_unsupported_algorithm;
523 goto cleanup;
524 }
525 grad_output_transform_function = nnp_hwinfo.transforms.iwt_f6x6_3x3_with_offset_and_stream;
526 kernel_transform_function = nnp_hwinfo.transforms.kwt_f6x6_3Rx3R;
527 grad_input_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3;
528 tile_size = (struct nnp_size) { .height = 8, .width = 8 };
529 fourier_transform = false;
530 break;
531 case nnp_convolution_algorithm_implicit_gemm:
532 case nnp_convolution_algorithm_direct:
533 status = nnp_status_unsupported_algorithm;
534 goto cleanup;
535 case nnp_convolution_algorithm_auto:
536 NNP_UNREACHABLE;
537 default:
538 status = nnp_status_invalid_algorithm;
539 goto cleanup;
540 }
541
542 const struct nnp_size output_size = {
543 .width = input_padding.left + input_size.width + input_padding.right - kernel_size.width + 1,
544 .height = input_padding.top + input_size.height + input_padding.bottom - kernel_size.height + 1
545 };
546
547 switch (algorithm) {
548 case nnp_convolution_algorithm_wt8x8:
549 case nnp_convolution_algorithm_wt8x8_fp16:
550 case nnp_convolution_algorithm_ft8x8:
551 case nnp_convolution_algorithm_ft16x16:
552 if (kernel_size.height > tile_size.height || kernel_size.width > tile_size.width) {
553 status = nnp_status_unsupported_algorithm;
554 goto cleanup;
555 }
556 status = compute_fast_convolution_input_gradient(
557 fourier_transform,
558 batch_size, input_channels, output_channels,
559 tile_size, input_size, input_padding, kernel_size, output_size,
560 grad_output, kernel, grad_input, workspace_buffer, workspace_size,
561 grad_output_transform_function, kernel_transform_function, grad_input_transform_function,
562 threadpool, profile);
563 break;
564 case nnp_convolution_algorithm_implicit_gemm:
565 case nnp_convolution_algorithm_direct:
566 case nnp_convolution_algorithm_auto:
567 NNP_UNREACHABLE;
568 }
569
570cleanup:
571 NNP_TOTAL_END(profile)
572 return status;
573}
574