1 | #include <stdbool.h> |
2 | #include <stdint.h> |
3 | #include <stddef.h> |
4 | |
5 | #include <nnpack.h> |
6 | #include <nnpack/macros.h> |
7 | #include <nnpack/utils.h> |
8 | #include <nnpack/system.h> |
9 | |
10 | #include <nnpack/hwinfo.h> |
11 | #include <nnpack/validation.h> |
12 | |
13 | |
14 | struct NNP_CACHE_ALIGN input_transform_context { |
15 | size_t tuple_elements; |
16 | size_t input_elements; |
17 | size_t batch_block_size; |
18 | size_t input_channels; |
19 | size_t input_stride; |
20 | uint32_t row_offset; |
21 | uint32_t column_offset; |
22 | uint32_t row_count; |
23 | uint32_t column_count; |
24 | const float* input; |
25 | float* input_transform; |
26 | nnp_transform_2d_with_offset transform_function; |
27 | }; |
28 | |
29 | static void compute_input_transform( |
30 | const struct input_transform_context context[restrict static 1], |
31 | size_t batch_block_offset, size_t input_channels_subblock_start, |
32 | size_t batch_block_offset_range, size_t input_channels_subblock_size) |
33 | { |
34 | const size_t tuple_elements = context->tuple_elements; |
35 | const size_t input_elements = context->input_elements; |
36 | const size_t batch_block_size = context->batch_block_size; |
37 | const size_t input_channels = context->input_channels; |
38 | const size_t input_stride = context->input_stride; |
39 | const uint32_t row_count = context->row_count; |
40 | const uint32_t column_count = context->column_count; |
41 | const uint32_t row_offset = context->row_offset; |
42 | const uint32_t column_offset = context->column_offset; |
43 | const float* input = context->input; |
44 | float* input_transform = context->input_transform; |
45 | const nnp_transform_2d_with_offset transform = context->transform_function; |
46 | |
47 | for (size_t input_channels_subblock_offset = 0; input_channels_subblock_offset < input_channels_subblock_size; input_channels_subblock_offset += 1) { |
48 | const size_t input_channel = input_channels_subblock_start + input_channels_subblock_offset; |
49 | transform( |
50 | input + |
51 | (batch_block_offset * input_channels + input_channel) * input_elements, |
52 | input_transform + |
53 | (input_channels_subblock_start * batch_block_size + batch_block_offset * input_channels_subblock_size + input_channels_subblock_offset) * tuple_elements, |
54 | input_stride, |
55 | batch_block_size * input_channels * tuple_elements * sizeof(float), |
56 | row_count, column_count, row_offset, column_offset); |
57 | } |
58 | } |
59 | |
60 | struct NNP_CACHE_ALIGN grad_output_transform_context { |
61 | size_t tuple_elements; |
62 | size_t output_elements; |
63 | size_t batch_block_size; |
64 | size_t output_channels; |
65 | size_t grad_output_stride; |
66 | uint32_t row_count; |
67 | uint32_t column_count; |
68 | const float* grad_output; |
69 | float* grad_output_transform; |
70 | nnp_transform_2d_with_offset transform_function; |
71 | }; |
72 | |
73 | static void compute_grad_output_transform( |
74 | const struct grad_output_transform_context context[restrict static 1], |
75 | size_t batch_block_offset, size_t output_channels_subblock_start, |
76 | size_t batch_block_offset_range, size_t output_channels_subblock_size) |
77 | { |
78 | const size_t tuple_elements = context->tuple_elements; |
79 | const size_t output_elements = context->output_elements; |
80 | const size_t batch_block_size = context->batch_block_size; |
81 | const size_t output_channels = context->output_channels; |
82 | const size_t grad_output_stride = context->grad_output_stride; |
83 | const uint32_t row_count = context->row_count; |
84 | const uint32_t column_count = context->column_count; |
85 | const float* grad_output = context->grad_output; |
86 | float* grad_output_transform = context->grad_output_transform; |
87 | const nnp_transform_2d_with_offset transform = context->transform_function; |
88 | |
89 | for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) { |
90 | const size_t output_channel = output_channels_subblock_start + output_channels_subblock_offset; |
91 | transform( |
92 | grad_output + |
93 | (batch_block_offset * output_channels + output_channel) * output_elements, |
94 | grad_output_transform + |
95 | (output_channels_subblock_start * batch_block_size + batch_block_offset * output_channels_subblock_size + output_channels_subblock_offset) * tuple_elements, |
96 | grad_output_stride, |
97 | batch_block_size * output_channels * tuple_elements * sizeof(float), |
98 | row_count, column_count, 0, 0); |
99 | } |
100 | } |
101 | |
102 | struct NNP_CACHE_ALIGN grad_kernel_transform_context { |
103 | size_t tuple_elements; |
104 | size_t input_channels; |
105 | size_t output_channels; |
106 | size_t output_channels_block_max; |
107 | struct nnp_size kernel_size; |
108 | const float* grad_kernel_transform; |
109 | float* grad_kernel; |
110 | nnp_transform_2d_with_offset transform_function; |
111 | }; |
112 | |
113 | static void compute_grad_kernel_transform( |
114 | const struct grad_kernel_transform_context context[restrict static 1], |
115 | size_t output_channel, size_t input_channels_subblock_start, |
116 | size_t output_channel_range, size_t input_channels_subblock_size) |
117 | { |
118 | const size_t tuple_elements = context->tuple_elements; |
119 | const size_t input_channels = context->input_channels; |
120 | const size_t output_channels = context->output_channels; |
121 | const size_t output_channels_block_max = context->output_channels_block_max; |
122 | const struct nnp_size kernel_size = context->kernel_size; |
123 | const float* grad_kernel_transform = context->grad_kernel_transform; |
124 | float* grad_kernel = context->grad_kernel; |
125 | const nnp_transform_2d_with_offset transform = context->transform_function; |
126 | |
127 | const size_t output_channels_block_start = round_down(output_channel, output_channels_block_max); |
128 | const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max); |
129 | const size_t output_channels_block_offset = output_channel - output_channels_block_start; |
130 | const size_t kernel_elements = kernel_size.height * kernel_size.width; |
131 | |
132 | for (size_t input_channels_subblock_offset = 0; input_channels_subblock_offset < input_channels_subblock_size; input_channels_subblock_offset += 1) { |
133 | const size_t input_channel = input_channels_subblock_start + input_channels_subblock_offset; |
134 | transform( |
135 | grad_kernel_transform + |
136 | (output_channels_block_start * input_channels + input_channels_subblock_start * output_channels_block_size + output_channels_block_offset * input_channels_subblock_size + input_channels_subblock_offset) * tuple_elements, |
137 | grad_kernel + |
138 | (output_channel * input_channels + input_channel) * kernel_elements, |
139 | output_channels * input_channels * tuple_elements * sizeof(float), |
140 | kernel_size.width, |
141 | kernel_size.height, kernel_size.width, 0, 0); |
142 | } |
143 | } |
144 | |
145 | struct NNP_CACHE_ALIGN matrix_multiplication_context { |
146 | size_t tuple_elements; |
147 | size_t batch_size; |
148 | size_t batch_block_size; |
149 | size_t batch_block_update; |
150 | size_t input_channels; |
151 | size_t input_channels_block_start; |
152 | size_t input_channels_block_size; |
153 | size_t input_channels_subblock_max; |
154 | size_t output_channels; |
155 | size_t output_channels_subblock_max; |
156 | const float* grad_output_transform; |
157 | const float* input_transform; |
158 | float* grad_kernel_transform; |
159 | |
160 | nnp_fast_tuple_gemm_function fast_gemm; |
161 | nnp_full_tuple_gemm_function full_gemm; |
162 | }; |
163 | |
164 | static void compute_matrix_multiplication( |
165 | const struct matrix_multiplication_context context[restrict static 1], |
166 | size_t output_channels_block_start, size_t input_channels_subblock_start, |
167 | size_t output_channels_block_size, size_t input_channels_subblock_size) |
168 | { |
169 | const size_t tuple_elements = context->tuple_elements; |
170 | const size_t batch_size = context->batch_size; |
171 | const size_t batch_block_size = context->batch_block_size; |
172 | const size_t batch_block_update = context->batch_block_update; |
173 | const size_t input_channels = context->input_channels; |
174 | const size_t input_channels_block_start = context->input_channels_block_start; |
175 | const size_t input_channels_block_size = context->input_channels_block_size; |
176 | const size_t input_channels_subblock_max = context->input_channels_subblock_max; |
177 | const size_t output_channels = context->output_channels; |
178 | const size_t output_channels_subblock_max = context->output_channels_subblock_max; |
179 | |
180 | const float* grad_output_transform = context->grad_output_transform + |
181 | output_channels_block_start * batch_block_size * tuple_elements; |
182 | const float* input_transform = context->input_transform + |
183 | (input_channels_block_start + input_channels_subblock_start) * batch_block_size * tuple_elements; |
184 | float* grad_kernel_transform = context->grad_kernel_transform + |
185 | (output_channels_block_start * input_channels + (input_channels_block_start + input_channels_subblock_start) * output_channels_block_size) * tuple_elements; |
186 | |
187 | if (input_channels_subblock_size == input_channels_subblock_max) { |
188 | const nnp_fast_tuple_gemm_function fast_gemm = context->fast_gemm; |
189 | while (output_channels_block_size >= output_channels_subblock_max) { |
190 | output_channels_block_size -= output_channels_subblock_max; |
191 | |
192 | fast_gemm( |
193 | batch_block_size, batch_block_update, |
194 | input_transform, |
195 | grad_output_transform, |
196 | grad_kernel_transform, |
197 | input_channels_subblock_size * tuple_elements); |
198 | |
199 | grad_output_transform += output_channels_subblock_max * batch_block_size * tuple_elements; |
200 | grad_kernel_transform += output_channels_subblock_max * input_channels_subblock_size * tuple_elements; |
201 | } |
202 | } |
203 | |
204 | const nnp_full_tuple_gemm_function full_gemm = context->full_gemm; |
205 | while (output_channels_block_size != 0) { |
206 | const size_t output_channels_subblock_size = min(output_channels_block_size, output_channels_subblock_max); |
207 | output_channels_block_size -= output_channels_subblock_size; |
208 | |
209 | full_gemm( |
210 | input_channels_subblock_size, output_channels_subblock_size, |
211 | batch_block_size, batch_block_update, |
212 | input_transform, |
213 | grad_output_transform, |
214 | grad_kernel_transform, |
215 | input_channels_subblock_size * tuple_elements); |
216 | |
217 | grad_output_transform += output_channels_subblock_max * batch_block_size * tuple_elements; |
218 | grad_kernel_transform += output_channels_subblock_max * input_channels_subblock_size * tuple_elements; |
219 | } |
220 | } |
221 | |
222 | static enum nnp_status compute_fast_convolution_kernel_gradient( |
223 | size_t batch_size, |
224 | size_t input_channels, |
225 | size_t output_channels, |
226 | struct nnp_size tile_size, |
227 | struct nnp_size input_size, |
228 | struct nnp_padding input_padding, |
229 | struct nnp_size kernel_size, |
230 | struct nnp_size output_size, |
231 | const float* input, |
232 | const float* grad_output, |
233 | float* grad_kernel, |
234 | void* workspace_buffer, |
235 | size_t* workspace_size, |
236 | nnp_transform_2d_with_offset input_transform_function, |
237 | nnp_transform_2d_with_offset grad_output_transform_function, |
238 | nnp_transform_2d_with_offset grad_kernel_transform_function, |
239 | pthreadpool_t threadpool, |
240 | struct nnp_profile* profile) |
241 | { |
242 | void* memory_block = NULL; |
243 | const size_t simd_width = nnp_hwinfo.simd_width; |
244 | const size_t tuple_elements = simd_width * 2; |
245 | const size_t tile_elements = tile_size.height * tile_size.width; |
246 | const size_t tuple_count = tile_elements / tuple_elements; |
247 | |
248 | const struct nnp_size output_tile = { |
249 | .height = tile_size.height - kernel_size.height + 1, |
250 | .width = tile_size.width - kernel_size.width + 1 |
251 | }; |
252 | |
253 | /* Calculate cache blocking parameters */ |
254 | const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / (tuple_elements * sizeof(float)); |
255 | const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / (tuple_elements * sizeof(float)); |
256 | const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / (tuple_elements * sizeof(float)); |
257 | |
258 | const size_t input_channels_subblock_max = nnp_hwinfo.cxgemm.mr; |
259 | const size_t output_channels_subblock_max = nnp_hwinfo.cxgemm.nr; |
260 | |
261 | const size_t batch_block_max = |
262 | round_down(cache_elements_l1 / (input_channels_subblock_max + output_channels_subblock_max), 2); |
263 | const size_t input_channels_block_max = |
264 | round_down(cache_elements_l3 / batch_block_max, input_channels_subblock_max); |
265 | const size_t output_channels_block_max = |
266 | round_down(cache_elements_l2 / batch_block_max, output_channels_subblock_max); |
267 | |
268 | /* Calculate memory footprint and allocate memory */ |
269 | const size_t input_transform_size = min(batch_size, batch_block_max) * input_channels * tile_elements * sizeof(float); |
270 | const size_t grad_output_transform_size = min(batch_size, batch_block_max) * output_channels * tile_elements * sizeof(float); |
271 | const size_t grad_kernel_transform_size = output_channels * input_channels * tile_elements * sizeof(float); |
272 | const size_t memory_size = input_transform_size + grad_output_transform_size + grad_kernel_transform_size; |
273 | |
274 | if (workspace_buffer == NULL) { |
275 | if (workspace_size == NULL) { |
276 | memory_block = allocate_memory(memory_size); |
277 | if (memory_block == NULL) { |
278 | return nnp_status_out_of_memory; |
279 | } |
280 | } else { |
281 | *workspace_size = memory_size; |
282 | return nnp_status_success; |
283 | } |
284 | } else { |
285 | if (*workspace_size < memory_size) { |
286 | return nnp_status_insufficient_buffer; |
287 | } |
288 | memory_block = workspace_buffer; |
289 | } |
290 | |
291 | float* input_transform = memory_block; |
292 | float* grad_output_transform = memory_block + input_transform_size; |
293 | float* grad_kernel_transform = memory_block + input_transform_size + grad_output_transform_size; |
294 | |
295 | for (size_t y = 0; y < output_size.height; y += output_tile.height) { |
296 | const size_t input_y = min(doz(y, input_padding.top), input_size.height); |
297 | for (size_t x = 0; x < output_size.width; x += output_tile.width) { |
298 | const size_t input_x = min(doz(x, input_padding.left), input_size.width); |
299 | |
300 | for (size_t batch_block_start = 0; batch_block_start < batch_size; batch_block_start += batch_block_max) { |
301 | const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max); |
302 | |
303 | /* Input transform */ |
304 | NNP_INPUT_TRANSFORM_START(profile) |
305 | struct input_transform_context input_transform_context = { |
306 | .tuple_elements = tuple_elements, |
307 | .input_elements = input_size.height * input_size.width, |
308 | .batch_block_size = batch_block_size, |
309 | .input_channels = input_channels, |
310 | .input_stride = input_size.width, |
311 | .row_offset = doz(input_padding.top, y), |
312 | .column_offset = doz(input_padding.left, x), |
313 | .row_count = min(input_size.height - input_y, |
314 | tile_size.height - input_transform_context.row_offset), |
315 | .column_count = min(input_size.width - input_x, |
316 | tile_size.width - input_transform_context.column_offset), |
317 | .input = input + (batch_block_start * input_channels * input_size.height + input_y) * input_size.width + input_x, |
318 | .input_transform = input_transform, |
319 | .transform_function = input_transform_function, |
320 | }; |
321 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
322 | (pthreadpool_task_2d_tile_2d_t) compute_input_transform, |
323 | &input_transform_context, |
324 | batch_block_size, input_channels, |
325 | 1, input_channels_subblock_max, |
326 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
327 | NNP_INPUT_TRANSFORM_END(profile) |
328 | |
329 | /* Grad output transform */ |
330 | NNP_OUTPUT_TRANSFORM_START(profile) |
331 | struct grad_output_transform_context grad_output_transform_context = { |
332 | .tuple_elements = tuple_elements, |
333 | .output_elements = output_size.height * output_size.width, |
334 | .batch_block_size = batch_block_size, |
335 | .output_channels = output_channels, |
336 | .grad_output_stride = output_size.width, |
337 | .row_count = min(output_tile.height, output_size.height - y), |
338 | .column_count = min(output_tile.width, output_size.width - x), |
339 | .grad_output = grad_output + (batch_block_start * output_channels * output_size.height + y) * output_size.width + x, |
340 | .grad_output_transform = grad_output_transform, |
341 | .transform_function = grad_output_transform_function, |
342 | }; |
343 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
344 | (pthreadpool_task_2d_tile_2d_t) compute_grad_output_transform, |
345 | &grad_output_transform_context, |
346 | batch_block_size, output_channels, |
347 | 1, output_channels_subblock_max, |
348 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
349 | NNP_OUTPUT_TRANSFORM_END(profile) |
350 | |
351 | NNP_BLOCK_MULTIPLICATION_START(profile) |
352 | for (size_t tuple_index = 0; tuple_index < tuple_count; tuple_index += 1) { |
353 | for (size_t input_channels_block_start = 0; input_channels_block_start < input_channels; input_channels_block_start += input_channels_block_max) { |
354 | const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max); |
355 | |
356 | struct matrix_multiplication_context matrix_multiplication_context = { |
357 | .tuple_elements = tuple_elements, |
358 | .batch_size = batch_size, |
359 | .batch_block_size = batch_block_size, |
360 | .batch_block_update = batch_block_start | x | y, |
361 | .input_channels = input_channels, |
362 | .input_channels_block_start = input_channels_block_start, |
363 | .input_channels_block_size = input_channels_block_size, |
364 | .input_channels_subblock_max = input_channels_subblock_max, |
365 | .output_channels = output_channels, |
366 | .output_channels_subblock_max = output_channels_subblock_max, |
367 | .grad_output_transform = grad_output_transform + |
368 | tuple_index * tuple_elements * batch_block_size * output_channels, |
369 | .input_transform = input_transform + |
370 | tuple_index * tuple_elements * batch_block_size * input_channels, |
371 | .grad_kernel_transform = grad_kernel_transform + |
372 | tuple_index * tuple_elements * output_channels * input_channels, |
373 | }; |
374 | if (tuple_index < NNP_COMPLEX_TUPLE_INDEX) { |
375 | matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.s4cX_conjb_transc_only_mr_x_nr; |
376 | matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.s4cX_conjb_transc_upto_mr_x_nr; |
377 | } else { |
378 | matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.cX_conjb_transc_only_mr_x_nr; |
379 | matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.cX_conjb_transc_upto_mr_x_nr; |
380 | } |
381 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
382 | (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication, |
383 | &matrix_multiplication_context, |
384 | output_channels, input_channels_block_size, |
385 | output_channels_block_max, input_channels_subblock_max, |
386 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
387 | } |
388 | } |
389 | NNP_BLOCK_MULTIPLICATION_END(profile) |
390 | |
391 | |
392 | } |
393 | } |
394 | } |
395 | /* Grad kernel transform */ |
396 | NNP_KERNEL_TRANSFORM_START(profile) |
397 | struct grad_kernel_transform_context grad_kernel_transform_context = { |
398 | .tuple_elements = tuple_elements, |
399 | .input_channels = input_channels, |
400 | .output_channels = output_channels, |
401 | .output_channels_block_max = output_channels_block_max, |
402 | .kernel_size = kernel_size, |
403 | .grad_kernel = grad_kernel, |
404 | .grad_kernel_transform = grad_kernel_transform, |
405 | .transform_function = grad_kernel_transform_function, |
406 | }; |
407 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
408 | (pthreadpool_task_2d_tile_2d_t) compute_grad_kernel_transform, |
409 | &grad_kernel_transform_context, |
410 | output_channels, input_channels, |
411 | 1, input_channels_subblock_max, |
412 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
413 | NNP_KERNEL_TRANSFORM_END(profile) |
414 | |
415 | if (memory_block != workspace_buffer) { |
416 | release_memory(memory_block, memory_size); |
417 | } |
418 | return nnp_status_success; |
419 | } |
420 | |
421 | enum nnp_status nnp_convolution_kernel_gradient( |
422 | enum nnp_convolution_algorithm algorithm, |
423 | size_t batch_size, |
424 | size_t input_channels, |
425 | size_t output_channels, |
426 | struct nnp_size input_size, |
427 | struct nnp_padding input_padding, |
428 | struct nnp_size kernel_size, |
429 | const float* input, |
430 | const float* grad_output, |
431 | float* grad_kernel, |
432 | void* workspace_buffer, |
433 | size_t* workspace_size, |
434 | enum nnp_activation activation, |
435 | const void* activation_parameters, |
436 | pthreadpool_t threadpool, |
437 | struct nnp_profile* profile) |
438 | { |
439 | NNP_TOTAL_START(profile) |
440 | |
441 | /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */ |
442 | enum nnp_status status = validate_convolution_arguments( |
443 | batch_size, input_channels, output_channels, |
444 | input_size, input_padding, kernel_size, (struct nnp_size) { 1, 1 }, |
445 | activation, activation_parameters); |
446 | if (status != nnp_status_success) { |
447 | goto cleanup; |
448 | } |
449 | |
450 | if (activation != nnp_activation_identity) { |
451 | status = nnp_status_unsupported_activation; |
452 | goto cleanup; |
453 | } |
454 | |
455 | if (activation_parameters != NULL) { |
456 | status = nnp_status_unsupported_activation_parameters; |
457 | goto cleanup; |
458 | } |
459 | |
460 | const struct nnp_size output_size = { |
461 | .width = input_padding.left + input_size.width + input_padding.right - kernel_size.width + 1, |
462 | .height = input_padding.top + input_size.height + input_padding.bottom - kernel_size.height + 1 |
463 | }; |
464 | |
465 | /* If requested, choose optimal convolution algorithm */ |
466 | if (algorithm == nnp_convolution_algorithm_auto) { |
467 | if (max(kernel_size.width, kernel_size.height) > 8) { |
468 | algorithm = nnp_convolution_algorithm_ft16x16; |
469 | } else { |
470 | const size_t tile_count_8x8 = |
471 | divide_round_up(output_size.height, 8 - kernel_size.height + 1) * |
472 | divide_round_up(output_size.width, 8 - kernel_size.width + 1); |
473 | const size_t tile_count_16x16 = |
474 | divide_round_up(output_size.height, 16 - kernel_size.height + 1) * |
475 | divide_round_up(output_size.width, 16 - kernel_size.width + 1); |
476 | if (tile_count_8x8 <= 4 * tile_count_16x16) { |
477 | /* 8x8 tiles are more efficient */ |
478 | algorithm = nnp_convolution_algorithm_ft8x8; |
479 | } else { |
480 | algorithm = nnp_convolution_algorithm_ft16x16; |
481 | } |
482 | } |
483 | } |
484 | |
485 | /* Choose tiling parameters and transform functions depending on convolution algorithm */ |
486 | struct nnp_size tile_size; |
487 | nnp_transform_2d_with_offset input_transform_function; |
488 | nnp_transform_2d_with_offset grad_output_transform_function; |
489 | nnp_transform_2d_with_offset grad_kernel_transform_function; |
490 | switch (algorithm) { |
491 | case nnp_convolution_algorithm_ft8x8: |
492 | input_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream; |
493 | grad_output_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream; |
494 | grad_kernel_transform_function = nnp_hwinfo.transforms.ifft8x8_with_offset; |
495 | tile_size = (struct nnp_size) { .height = 8, .width = 8 }; |
496 | break; |
497 | case nnp_convolution_algorithm_ft16x16: |
498 | input_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream; |
499 | grad_output_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream; |
500 | grad_kernel_transform_function = nnp_hwinfo.transforms.ifft16x16_with_offset; |
501 | tile_size = (struct nnp_size) { .height = 16, .width = 16 }; |
502 | break; |
503 | case nnp_convolution_algorithm_wt8x8: |
504 | case nnp_convolution_algorithm_wt8x8_fp16: |
505 | /* |
506 | * Winograd transform is not supported for this operation: |
507 | * it needs F(5x5, 4x4) transform and presently we implement only F(3x3, 6x6) |
508 | */ |
509 | status = nnp_status_unsupported_algorithm; |
510 | goto cleanup; |
511 | case nnp_convolution_algorithm_implicit_gemm: |
512 | case nnp_convolution_algorithm_direct: |
513 | status = nnp_status_unsupported_algorithm; |
514 | goto cleanup; |
515 | case nnp_convolution_algorithm_auto: |
516 | NNP_UNREACHABLE; |
517 | default: |
518 | status = nnp_status_invalid_algorithm; |
519 | goto cleanup; |
520 | } |
521 | |
522 | switch (algorithm) { |
523 | case nnp_convolution_algorithm_ft8x8: |
524 | case nnp_convolution_algorithm_ft16x16: |
525 | if (kernel_size.height > tile_size.height || kernel_size.width > tile_size.width) { |
526 | status = nnp_status_unsupported_algorithm; |
527 | goto cleanup; |
528 | } |
529 | status = compute_fast_convolution_kernel_gradient( |
530 | batch_size, input_channels, output_channels, |
531 | tile_size, input_size, input_padding, kernel_size, output_size, |
532 | input, grad_output, grad_kernel, workspace_buffer, workspace_size, |
533 | input_transform_function, grad_output_transform_function, grad_kernel_transform_function, |
534 | threadpool, profile); |
535 | break; |
536 | case nnp_convolution_algorithm_wt8x8: |
537 | case nnp_convolution_algorithm_wt8x8_fp16: |
538 | case nnp_convolution_algorithm_implicit_gemm: |
539 | case nnp_convolution_algorithm_direct: |
540 | case nnp_convolution_algorithm_auto: |
541 | NNP_UNREACHABLE; |
542 | } |
543 | |
544 | cleanup: |
545 | NNP_TOTAL_END(profile) |
546 | return status; |
547 | } |
548 | |