1#include <stdbool.h>
2#include <stdint.h>
3#include <stddef.h>
4
5#include <nnpack.h>
6#include <nnpack/macros.h>
7#include <nnpack/utils.h>
8#include <nnpack/system.h>
9
10#include <nnpack/hwinfo.h>
11#include <nnpack/validation.h>
12
13
14struct NNP_CACHE_ALIGN input_transform_context {
15 size_t tuple_elements;
16 size_t input_elements;
17 size_t batch_block_size;
18 size_t input_channels;
19 size_t input_stride;
20 uint32_t row_offset;
21 uint32_t column_offset;
22 uint32_t row_count;
23 uint32_t column_count;
24 const float* input;
25 float* input_transform;
26 nnp_transform_2d_with_offset transform_function;
27};
28
29static void compute_input_transform(
30 const struct input_transform_context context[restrict static 1],
31 size_t batch_block_offset, size_t input_channels_subblock_start,
32 size_t batch_block_offset_range, size_t input_channels_subblock_size)
33{
34 const size_t tuple_elements = context->tuple_elements;
35 const size_t input_elements = context->input_elements;
36 const size_t batch_block_size = context->batch_block_size;
37 const size_t input_channels = context->input_channels;
38 const size_t input_stride = context->input_stride;
39 const uint32_t row_count = context->row_count;
40 const uint32_t column_count = context->column_count;
41 const uint32_t row_offset = context->row_offset;
42 const uint32_t column_offset = context->column_offset;
43 const float* input = context->input;
44 float* input_transform = context->input_transform;
45 const nnp_transform_2d_with_offset transform = context->transform_function;
46
47 for (size_t input_channels_subblock_offset = 0; input_channels_subblock_offset < input_channels_subblock_size; input_channels_subblock_offset += 1) {
48 const size_t input_channel = input_channels_subblock_start + input_channels_subblock_offset;
49 transform(
50 input +
51 (batch_block_offset * input_channels + input_channel) * input_elements,
52 input_transform +
53 (input_channels_subblock_start * batch_block_size + batch_block_offset * input_channels_subblock_size + input_channels_subblock_offset) * tuple_elements,
54 input_stride,
55 batch_block_size * input_channels * tuple_elements * sizeof(float),
56 row_count, column_count, row_offset, column_offset);
57 }
58}
59
60struct NNP_CACHE_ALIGN grad_output_transform_context {
61 size_t tuple_elements;
62 size_t output_elements;
63 size_t batch_block_size;
64 size_t output_channels;
65 size_t grad_output_stride;
66 uint32_t row_count;
67 uint32_t column_count;
68 const float* grad_output;
69 float* grad_output_transform;
70 nnp_transform_2d_with_offset transform_function;
71};
72
73static void compute_grad_output_transform(
74 const struct grad_output_transform_context context[restrict static 1],
75 size_t batch_block_offset, size_t output_channels_subblock_start,
76 size_t batch_block_offset_range, size_t output_channels_subblock_size)
77{
78 const size_t tuple_elements = context->tuple_elements;
79 const size_t output_elements = context->output_elements;
80 const size_t batch_block_size = context->batch_block_size;
81 const size_t output_channels = context->output_channels;
82 const size_t grad_output_stride = context->grad_output_stride;
83 const uint32_t row_count = context->row_count;
84 const uint32_t column_count = context->column_count;
85 const float* grad_output = context->grad_output;
86 float* grad_output_transform = context->grad_output_transform;
87 const nnp_transform_2d_with_offset transform = context->transform_function;
88
89 for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) {
90 const size_t output_channel = output_channels_subblock_start + output_channels_subblock_offset;
91 transform(
92 grad_output +
93 (batch_block_offset * output_channels + output_channel) * output_elements,
94 grad_output_transform +
95 (output_channels_subblock_start * batch_block_size + batch_block_offset * output_channels_subblock_size + output_channels_subblock_offset) * tuple_elements,
96 grad_output_stride,
97 batch_block_size * output_channels * tuple_elements * sizeof(float),
98 row_count, column_count, 0, 0);
99 }
100}
101
102struct NNP_CACHE_ALIGN grad_kernel_transform_context {
103 size_t tuple_elements;
104 size_t input_channels;
105 size_t output_channels;
106 size_t output_channels_block_max;
107 struct nnp_size kernel_size;
108 const float* grad_kernel_transform;
109 float* grad_kernel;
110 nnp_transform_2d_with_offset transform_function;
111};
112
113static void compute_grad_kernel_transform(
114 const struct grad_kernel_transform_context context[restrict static 1],
115 size_t output_channel, size_t input_channels_subblock_start,
116 size_t output_channel_range, size_t input_channels_subblock_size)
117{
118 const size_t tuple_elements = context->tuple_elements;
119 const size_t input_channels = context->input_channels;
120 const size_t output_channels = context->output_channels;
121 const size_t output_channels_block_max = context->output_channels_block_max;
122 const struct nnp_size kernel_size = context->kernel_size;
123 const float* grad_kernel_transform = context->grad_kernel_transform;
124 float* grad_kernel = context->grad_kernel;
125 const nnp_transform_2d_with_offset transform = context->transform_function;
126
127 const size_t output_channels_block_start = round_down(output_channel, output_channels_block_max);
128 const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max);
129 const size_t output_channels_block_offset = output_channel - output_channels_block_start;
130 const size_t kernel_elements = kernel_size.height * kernel_size.width;
131
132 for (size_t input_channels_subblock_offset = 0; input_channels_subblock_offset < input_channels_subblock_size; input_channels_subblock_offset += 1) {
133 const size_t input_channel = input_channels_subblock_start + input_channels_subblock_offset;
134 transform(
135 grad_kernel_transform +
136 (output_channels_block_start * input_channels + input_channels_subblock_start * output_channels_block_size + output_channels_block_offset * input_channels_subblock_size + input_channels_subblock_offset) * tuple_elements,
137 grad_kernel +
138 (output_channel * input_channels + input_channel) * kernel_elements,
139 output_channels * input_channels * tuple_elements * sizeof(float),
140 kernel_size.width,
141 kernel_size.height, kernel_size.width, 0, 0);
142 }
143}
144
145struct NNP_CACHE_ALIGN matrix_multiplication_context {
146 size_t tuple_elements;
147 size_t batch_size;
148 size_t batch_block_size;
149 size_t batch_block_update;
150 size_t input_channels;
151 size_t input_channels_block_start;
152 size_t input_channels_block_size;
153 size_t input_channels_subblock_max;
154 size_t output_channels;
155 size_t output_channels_subblock_max;
156 const float* grad_output_transform;
157 const float* input_transform;
158 float* grad_kernel_transform;
159
160 nnp_fast_tuple_gemm_function fast_gemm;
161 nnp_full_tuple_gemm_function full_gemm;
162};
163
164static void compute_matrix_multiplication(
165 const struct matrix_multiplication_context context[restrict static 1],
166 size_t output_channels_block_start, size_t input_channels_subblock_start,
167 size_t output_channels_block_size, size_t input_channels_subblock_size)
168{
169 const size_t tuple_elements = context->tuple_elements;
170 const size_t batch_size = context->batch_size;
171 const size_t batch_block_size = context->batch_block_size;
172 const size_t batch_block_update = context->batch_block_update;
173 const size_t input_channels = context->input_channels;
174 const size_t input_channels_block_start = context->input_channels_block_start;
175 const size_t input_channels_block_size = context->input_channels_block_size;
176 const size_t input_channels_subblock_max = context->input_channels_subblock_max;
177 const size_t output_channels = context->output_channels;
178 const size_t output_channels_subblock_max = context->output_channels_subblock_max;
179
180 const float* grad_output_transform = context->grad_output_transform +
181 output_channels_block_start * batch_block_size * tuple_elements;
182 const float* input_transform = context->input_transform +
183 (input_channels_block_start + input_channels_subblock_start) * batch_block_size * tuple_elements;
184 float* grad_kernel_transform = context->grad_kernel_transform +
185 (output_channels_block_start * input_channels + (input_channels_block_start + input_channels_subblock_start) * output_channels_block_size) * tuple_elements;
186
187 if (input_channels_subblock_size == input_channels_subblock_max) {
188 const nnp_fast_tuple_gemm_function fast_gemm = context->fast_gemm;
189 while (output_channels_block_size >= output_channels_subblock_max) {
190 output_channels_block_size -= output_channels_subblock_max;
191
192 fast_gemm(
193 batch_block_size, batch_block_update,
194 input_transform,
195 grad_output_transform,
196 grad_kernel_transform,
197 input_channels_subblock_size * tuple_elements);
198
199 grad_output_transform += output_channels_subblock_max * batch_block_size * tuple_elements;
200 grad_kernel_transform += output_channels_subblock_max * input_channels_subblock_size * tuple_elements;
201 }
202 }
203
204 const nnp_full_tuple_gemm_function full_gemm = context->full_gemm;
205 while (output_channels_block_size != 0) {
206 const size_t output_channels_subblock_size = min(output_channels_block_size, output_channels_subblock_max);
207 output_channels_block_size -= output_channels_subblock_size;
208
209 full_gemm(
210 input_channels_subblock_size, output_channels_subblock_size,
211 batch_block_size, batch_block_update,
212 input_transform,
213 grad_output_transform,
214 grad_kernel_transform,
215 input_channels_subblock_size * tuple_elements);
216
217 grad_output_transform += output_channels_subblock_max * batch_block_size * tuple_elements;
218 grad_kernel_transform += output_channels_subblock_max * input_channels_subblock_size * tuple_elements;
219 }
220}
221
222static enum nnp_status compute_fast_convolution_kernel_gradient(
223 size_t batch_size,
224 size_t input_channels,
225 size_t output_channels,
226 struct nnp_size tile_size,
227 struct nnp_size input_size,
228 struct nnp_padding input_padding,
229 struct nnp_size kernel_size,
230 struct nnp_size output_size,
231 const float* input,
232 const float* grad_output,
233 float* grad_kernel,
234 void* workspace_buffer,
235 size_t* workspace_size,
236 nnp_transform_2d_with_offset input_transform_function,
237 nnp_transform_2d_with_offset grad_output_transform_function,
238 nnp_transform_2d_with_offset grad_kernel_transform_function,
239 pthreadpool_t threadpool,
240 struct nnp_profile* profile)
241{
242 void* memory_block = NULL;
243 const size_t simd_width = nnp_hwinfo.simd_width;
244 const size_t tuple_elements = simd_width * 2;
245 const size_t tile_elements = tile_size.height * tile_size.width;
246 const size_t tuple_count = tile_elements / tuple_elements;
247
248 const struct nnp_size output_tile = {
249 .height = tile_size.height - kernel_size.height + 1,
250 .width = tile_size.width - kernel_size.width + 1
251 };
252
253 /* Calculate cache blocking parameters */
254 const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / (tuple_elements * sizeof(float));
255 const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / (tuple_elements * sizeof(float));
256 const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / (tuple_elements * sizeof(float));
257
258 const size_t input_channels_subblock_max = nnp_hwinfo.cxgemm.mr;
259 const size_t output_channels_subblock_max = nnp_hwinfo.cxgemm.nr;
260
261 const size_t batch_block_max =
262 round_down(cache_elements_l1 / (input_channels_subblock_max + output_channels_subblock_max), 2);
263 const size_t input_channels_block_max =
264 round_down(cache_elements_l3 / batch_block_max, input_channels_subblock_max);
265 const size_t output_channels_block_max =
266 round_down(cache_elements_l2 / batch_block_max, output_channels_subblock_max);
267
268 /* Calculate memory footprint and allocate memory */
269 const size_t input_transform_size = min(batch_size, batch_block_max) * input_channels * tile_elements * sizeof(float);
270 const size_t grad_output_transform_size = min(batch_size, batch_block_max) * output_channels * tile_elements * sizeof(float);
271 const size_t grad_kernel_transform_size = output_channels * input_channels * tile_elements * sizeof(float);
272 const size_t memory_size = input_transform_size + grad_output_transform_size + grad_kernel_transform_size;
273
274 if (workspace_buffer == NULL) {
275 if (workspace_size == NULL) {
276 memory_block = allocate_memory(memory_size);
277 if (memory_block == NULL) {
278 return nnp_status_out_of_memory;
279 }
280 } else {
281 *workspace_size = memory_size;
282 return nnp_status_success;
283 }
284 } else {
285 if (*workspace_size < memory_size) {
286 return nnp_status_insufficient_buffer;
287 }
288 memory_block = workspace_buffer;
289 }
290
291 float* input_transform = memory_block;
292 float* grad_output_transform = memory_block + input_transform_size;
293 float* grad_kernel_transform = memory_block + input_transform_size + grad_output_transform_size;
294
295 for (size_t y = 0; y < output_size.height; y += output_tile.height) {
296 const size_t input_y = min(doz(y, input_padding.top), input_size.height);
297 for (size_t x = 0; x < output_size.width; x += output_tile.width) {
298 const size_t input_x = min(doz(x, input_padding.left), input_size.width);
299
300 for (size_t batch_block_start = 0; batch_block_start < batch_size; batch_block_start += batch_block_max) {
301 const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max);
302
303 /* Input transform */
304 NNP_INPUT_TRANSFORM_START(profile)
305 struct input_transform_context input_transform_context = {
306 .tuple_elements = tuple_elements,
307 .input_elements = input_size.height * input_size.width,
308 .batch_block_size = batch_block_size,
309 .input_channels = input_channels,
310 .input_stride = input_size.width,
311 .row_offset = doz(input_padding.top, y),
312 .column_offset = doz(input_padding.left, x),
313 .row_count = min(input_size.height - input_y,
314 tile_size.height - input_transform_context.row_offset),
315 .column_count = min(input_size.width - input_x,
316 tile_size.width - input_transform_context.column_offset),
317 .input = input + (batch_block_start * input_channels * input_size.height + input_y) * input_size.width + input_x,
318 .input_transform = input_transform,
319 .transform_function = input_transform_function,
320 };
321 pthreadpool_parallelize_2d_tile_2d(threadpool,
322 (pthreadpool_task_2d_tile_2d_t) compute_input_transform,
323 &input_transform_context,
324 batch_block_size, input_channels,
325 1, input_channels_subblock_max,
326 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
327 NNP_INPUT_TRANSFORM_END(profile)
328
329 /* Grad output transform */
330 NNP_OUTPUT_TRANSFORM_START(profile)
331 struct grad_output_transform_context grad_output_transform_context = {
332 .tuple_elements = tuple_elements,
333 .output_elements = output_size.height * output_size.width,
334 .batch_block_size = batch_block_size,
335 .output_channels = output_channels,
336 .grad_output_stride = output_size.width,
337 .row_count = min(output_tile.height, output_size.height - y),
338 .column_count = min(output_tile.width, output_size.width - x),
339 .grad_output = grad_output + (batch_block_start * output_channels * output_size.height + y) * output_size.width + x,
340 .grad_output_transform = grad_output_transform,
341 .transform_function = grad_output_transform_function,
342 };
343 pthreadpool_parallelize_2d_tile_2d(threadpool,
344 (pthreadpool_task_2d_tile_2d_t) compute_grad_output_transform,
345 &grad_output_transform_context,
346 batch_block_size, output_channels,
347 1, output_channels_subblock_max,
348 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
349 NNP_OUTPUT_TRANSFORM_END(profile)
350
351 NNP_BLOCK_MULTIPLICATION_START(profile)
352 for (size_t tuple_index = 0; tuple_index < tuple_count; tuple_index += 1) {
353 for (size_t input_channels_block_start = 0; input_channels_block_start < input_channels; input_channels_block_start += input_channels_block_max) {
354 const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max);
355
356 struct matrix_multiplication_context matrix_multiplication_context = {
357 .tuple_elements = tuple_elements,
358 .batch_size = batch_size,
359 .batch_block_size = batch_block_size,
360 .batch_block_update = batch_block_start | x | y,
361 .input_channels = input_channels,
362 .input_channels_block_start = input_channels_block_start,
363 .input_channels_block_size = input_channels_block_size,
364 .input_channels_subblock_max = input_channels_subblock_max,
365 .output_channels = output_channels,
366 .output_channels_subblock_max = output_channels_subblock_max,
367 .grad_output_transform = grad_output_transform +
368 tuple_index * tuple_elements * batch_block_size * output_channels,
369 .input_transform = input_transform +
370 tuple_index * tuple_elements * batch_block_size * input_channels,
371 .grad_kernel_transform = grad_kernel_transform +
372 tuple_index * tuple_elements * output_channels * input_channels,
373 };
374 if (tuple_index < NNP_COMPLEX_TUPLE_INDEX) {
375 matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.s4cX_conjb_transc_only_mr_x_nr;
376 matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.s4cX_conjb_transc_upto_mr_x_nr;
377 } else {
378 matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.cX_conjb_transc_only_mr_x_nr;
379 matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.cX_conjb_transc_upto_mr_x_nr;
380 }
381 pthreadpool_parallelize_2d_tile_2d(threadpool,
382 (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication,
383 &matrix_multiplication_context,
384 output_channels, input_channels_block_size,
385 output_channels_block_max, input_channels_subblock_max,
386 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
387 }
388 }
389 NNP_BLOCK_MULTIPLICATION_END(profile)
390
391
392 }
393 }
394 }
395 /* Grad kernel transform */
396 NNP_KERNEL_TRANSFORM_START(profile)
397 struct grad_kernel_transform_context grad_kernel_transform_context = {
398 .tuple_elements = tuple_elements,
399 .input_channels = input_channels,
400 .output_channels = output_channels,
401 .output_channels_block_max = output_channels_block_max,
402 .kernel_size = kernel_size,
403 .grad_kernel = grad_kernel,
404 .grad_kernel_transform = grad_kernel_transform,
405 .transform_function = grad_kernel_transform_function,
406 };
407 pthreadpool_parallelize_2d_tile_2d(threadpool,
408 (pthreadpool_task_2d_tile_2d_t) compute_grad_kernel_transform,
409 &grad_kernel_transform_context,
410 output_channels, input_channels,
411 1, input_channels_subblock_max,
412 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
413 NNP_KERNEL_TRANSFORM_END(profile)
414
415 if (memory_block != workspace_buffer) {
416 release_memory(memory_block, memory_size);
417 }
418 return nnp_status_success;
419}
420
421enum nnp_status nnp_convolution_kernel_gradient(
422 enum nnp_convolution_algorithm algorithm,
423 size_t batch_size,
424 size_t input_channels,
425 size_t output_channels,
426 struct nnp_size input_size,
427 struct nnp_padding input_padding,
428 struct nnp_size kernel_size,
429 const float* input,
430 const float* grad_output,
431 float* grad_kernel,
432 void* workspace_buffer,
433 size_t* workspace_size,
434 enum nnp_activation activation,
435 const void* activation_parameters,
436 pthreadpool_t threadpool,
437 struct nnp_profile* profile)
438{
439 NNP_TOTAL_START(profile)
440
441 /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */
442 enum nnp_status status = validate_convolution_arguments(
443 batch_size, input_channels, output_channels,
444 input_size, input_padding, kernel_size, (struct nnp_size) { 1, 1 },
445 activation, activation_parameters);
446 if (status != nnp_status_success) {
447 goto cleanup;
448 }
449
450 if (activation != nnp_activation_identity) {
451 status = nnp_status_unsupported_activation;
452 goto cleanup;
453 }
454
455 if (activation_parameters != NULL) {
456 status = nnp_status_unsupported_activation_parameters;
457 goto cleanup;
458 }
459
460 const struct nnp_size output_size = {
461 .width = input_padding.left + input_size.width + input_padding.right - kernel_size.width + 1,
462 .height = input_padding.top + input_size.height + input_padding.bottom - kernel_size.height + 1
463 };
464
465 /* If requested, choose optimal convolution algorithm */
466 if (algorithm == nnp_convolution_algorithm_auto) {
467 if (max(kernel_size.width, kernel_size.height) > 8) {
468 algorithm = nnp_convolution_algorithm_ft16x16;
469 } else {
470 const size_t tile_count_8x8 =
471 divide_round_up(output_size.height, 8 - kernel_size.height + 1) *
472 divide_round_up(output_size.width, 8 - kernel_size.width + 1);
473 const size_t tile_count_16x16 =
474 divide_round_up(output_size.height, 16 - kernel_size.height + 1) *
475 divide_round_up(output_size.width, 16 - kernel_size.width + 1);
476 if (tile_count_8x8 <= 4 * tile_count_16x16) {
477 /* 8x8 tiles are more efficient */
478 algorithm = nnp_convolution_algorithm_ft8x8;
479 } else {
480 algorithm = nnp_convolution_algorithm_ft16x16;
481 }
482 }
483 }
484
485 /* Choose tiling parameters and transform functions depending on convolution algorithm */
486 struct nnp_size tile_size;
487 nnp_transform_2d_with_offset input_transform_function;
488 nnp_transform_2d_with_offset grad_output_transform_function;
489 nnp_transform_2d_with_offset grad_kernel_transform_function;
490 switch (algorithm) {
491 case nnp_convolution_algorithm_ft8x8:
492 input_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream;
493 grad_output_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream;
494 grad_kernel_transform_function = nnp_hwinfo.transforms.ifft8x8_with_offset;
495 tile_size = (struct nnp_size) { .height = 8, .width = 8 };
496 break;
497 case nnp_convolution_algorithm_ft16x16:
498 input_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream;
499 grad_output_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream;
500 grad_kernel_transform_function = nnp_hwinfo.transforms.ifft16x16_with_offset;
501 tile_size = (struct nnp_size) { .height = 16, .width = 16 };
502 break;
503 case nnp_convolution_algorithm_wt8x8:
504 case nnp_convolution_algorithm_wt8x8_fp16:
505 /*
506 * Winograd transform is not supported for this operation:
507 * it needs F(5x5, 4x4) transform and presently we implement only F(3x3, 6x6)
508 */
509 status = nnp_status_unsupported_algorithm;
510 goto cleanup;
511 case nnp_convolution_algorithm_implicit_gemm:
512 case nnp_convolution_algorithm_direct:
513 status = nnp_status_unsupported_algorithm;
514 goto cleanup;
515 case nnp_convolution_algorithm_auto:
516 NNP_UNREACHABLE;
517 default:
518 status = nnp_status_invalid_algorithm;
519 goto cleanup;
520 }
521
522 switch (algorithm) {
523 case nnp_convolution_algorithm_ft8x8:
524 case nnp_convolution_algorithm_ft16x16:
525 if (kernel_size.height > tile_size.height || kernel_size.width > tile_size.width) {
526 status = nnp_status_unsupported_algorithm;
527 goto cleanup;
528 }
529 status = compute_fast_convolution_kernel_gradient(
530 batch_size, input_channels, output_channels,
531 tile_size, input_size, input_padding, kernel_size, output_size,
532 input, grad_output, grad_kernel, workspace_buffer, workspace_size,
533 input_transform_function, grad_output_transform_function, grad_kernel_transform_function,
534 threadpool, profile);
535 break;
536 case nnp_convolution_algorithm_wt8x8:
537 case nnp_convolution_algorithm_wt8x8_fp16:
538 case nnp_convolution_algorithm_implicit_gemm:
539 case nnp_convolution_algorithm_direct:
540 case nnp_convolution_algorithm_auto:
541 NNP_UNREACHABLE;
542 }
543
544cleanup:
545 NNP_TOTAL_END(profile)
546 return status;
547}
548