1 | #include <stdbool.h> |
2 | #include <stdint.h> |
3 | #include <stddef.h> |
4 | |
5 | #include <nnpack.h> |
6 | #include <nnpack/macros.h> |
7 | #include <nnpack/utils.h> |
8 | #include <nnpack/system.h> |
9 | |
10 | #include <nnpack/hwinfo.h> |
11 | #include <nnpack/validation.h> |
12 | |
13 | |
14 | struct NNP_CACHE_ALIGN kernel_transform_context { |
15 | nnp_transform_2d_with_offset transform_function; |
16 | const float* kernel; |
17 | float* kernel_transform; |
18 | |
19 | size_t tuple_elements; |
20 | size_t input_channels; |
21 | size_t output_channels; |
22 | size_t output_channels_block_max; |
23 | struct nnp_size kernel_size; |
24 | }; |
25 | |
26 | static void compute_kernel_transform( |
27 | const struct kernel_transform_context context[restrict static 1], |
28 | size_t output_channel, size_t input_channels_subblock_start, |
29 | size_t output_channel_range, size_t input_channels_subblock_size) |
30 | { |
31 | const size_t tuple_elements = context->tuple_elements; |
32 | const size_t input_channels = context->input_channels; |
33 | const size_t output_channels = context->output_channels; |
34 | const size_t output_channels_block_max = context->output_channels_block_max; |
35 | const struct nnp_size kernel_size = context->kernel_size; |
36 | |
37 | const float (*kernel)[input_channels][kernel_size.width * kernel_size.height] = |
38 | (const float(*)[input_channels][kernel_size.width * kernel_size.height]) context->kernel; |
39 | float* kernel_transform = context->kernel_transform; |
40 | nnp_transform_2d_with_offset transform_function = context->transform_function; |
41 | |
42 | const size_t output_channels_block_start = round_down(output_channel, output_channels_block_max); |
43 | const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max); |
44 | const size_t output_channels_block_offset = output_channel - output_channels_block_start; |
45 | |
46 | for (size_t input_channels_subblock_offset = 0; input_channels_subblock_offset < input_channels_subblock_size; input_channels_subblock_offset += 1) { |
47 | const size_t input_channel = input_channels_subblock_start + input_channels_subblock_offset; |
48 | transform_function( |
49 | kernel[output_channel][input_channel], |
50 | kernel_transform + |
51 | (output_channels_block_start * input_channels + input_channels_subblock_start * output_channels_block_size + output_channels_block_offset * input_channels_subblock_size + input_channels_subblock_offset) * tuple_elements, |
52 | kernel_size.width, |
53 | output_channels * input_channels * tuple_elements * sizeof(float), |
54 | kernel_size.height, kernel_size.width, 0, 0); |
55 | } |
56 | } |
57 | |
58 | struct NNP_CACHE_ALIGN grad_output_transform_context { |
59 | nnp_transform_2d_with_offset transform_function; |
60 | const float* grad_output; |
61 | float* grad_output_transform; |
62 | |
63 | size_t tuple_elements; |
64 | size_t batch_size; |
65 | size_t output_channels; |
66 | size_t output_channels_block_max; |
67 | struct nnp_size output_size; |
68 | size_t row_offset; |
69 | size_t row_count; |
70 | size_t column_offset; |
71 | size_t column_count; |
72 | }; |
73 | |
74 | static void compute_grad_output_transform( |
75 | const struct grad_output_transform_context context[restrict static 1], |
76 | size_t output_channel, size_t batch_subblock_start, |
77 | size_t output_channel_range, size_t batch_subblock_size) |
78 | { |
79 | const size_t tuple_elements = context->tuple_elements; |
80 | const size_t batch_size = context->batch_size; |
81 | const size_t output_channels = context->output_channels; |
82 | const size_t output_channels_block_max = context->output_channels_block_max; |
83 | const struct nnp_size output_size = context->output_size; |
84 | const size_t row_offset = context->row_offset; |
85 | const size_t row_count = context->row_count; |
86 | const size_t column_offset = context->column_offset; |
87 | const size_t column_count = context->column_count; |
88 | |
89 | const float (*grad_output)[output_channels][output_size.width * output_size.height] = |
90 | (const float(*)[output_channels][output_size.width * output_size.height]) context->grad_output; |
91 | float* grad_output_transform = context->grad_output_transform; |
92 | nnp_transform_2d_with_offset transform_function = context->transform_function; |
93 | |
94 | const size_t output_channels_block_start = round_down(output_channel, output_channels_block_max); |
95 | const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max); |
96 | const size_t output_channels_block_offset = output_channel - output_channels_block_start; |
97 | |
98 | for (size_t batch_subblock_offset = 0; batch_subblock_offset < batch_subblock_size; batch_subblock_offset += 1) { |
99 | const size_t sample = batch_subblock_start + batch_subblock_offset; |
100 | transform_function( |
101 | grad_output[sample][output_channel], |
102 | grad_output_transform + |
103 | (output_channels_block_start * batch_size + batch_subblock_start * output_channels_block_size + output_channels_block_offset * batch_subblock_size + batch_subblock_offset) * tuple_elements, |
104 | output_size.width, |
105 | batch_size * output_channels * tuple_elements * sizeof(float), |
106 | row_count, column_count, row_offset, column_offset); |
107 | } |
108 | } |
109 | |
110 | struct NNP_CACHE_ALIGN grad_input_transform_context { |
111 | nnp_transform_2d_with_offset transform_function; |
112 | float* grad_input; |
113 | const float* grad_input_transform; |
114 | |
115 | size_t tuple_elements; |
116 | size_t input_channels; |
117 | size_t batch_size; |
118 | size_t batch_block_max; |
119 | struct nnp_size input_size; |
120 | size_t row_offset; |
121 | size_t row_count; |
122 | size_t column_offset; |
123 | size_t column_count; |
124 | }; |
125 | |
126 | static void compute_grad_input_transform( |
127 | const struct grad_input_transform_context context[restrict static 1], |
128 | size_t sample, size_t input_channels_subblock_start, |
129 | size_t sample_range, size_t input_channels_subblock_size) |
130 | { |
131 | const size_t tuple_elements = context->tuple_elements; |
132 | const size_t batch_size = context->batch_size; |
133 | const size_t input_channels = context->input_channels; |
134 | const size_t batch_block_max = context->batch_block_max; |
135 | const struct nnp_size input_size = context->input_size; |
136 | const size_t row_offset = context->row_offset; |
137 | const size_t row_count = context->row_count; |
138 | const size_t column_offset = context->column_offset; |
139 | const size_t column_count = context->column_count; |
140 | |
141 | float (*grad_input)[input_channels][input_size.width * input_size.height] = |
142 | (float(*)[input_channels][input_size.width * input_size.height]) context->grad_input; |
143 | const float* grad_input_transform = context->grad_input_transform; |
144 | nnp_transform_2d_with_offset transform_function = context->transform_function; |
145 | |
146 | const size_t batch_block_start = round_down(sample, batch_block_max); |
147 | const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max); |
148 | const size_t batch_block_offset = sample - batch_block_start; |
149 | |
150 | for (size_t input_channels_subblock_offset = 0; input_channels_subblock_offset < input_channels_subblock_size; input_channels_subblock_offset += 1) { |
151 | const size_t input_channel = input_channels_subblock_start + input_channels_subblock_offset; |
152 | transform_function( |
153 | grad_input_transform + |
154 | (batch_block_start * input_channels + input_channels_subblock_start * batch_block_size + batch_block_offset * input_channels_subblock_size + input_channels_subblock_offset) * tuple_elements, |
155 | grad_input[sample][input_channel], |
156 | batch_size * input_channels * tuple_elements * sizeof(float), |
157 | input_size.width, |
158 | row_count, column_count, row_offset, column_offset); |
159 | } |
160 | } |
161 | |
162 | struct NNP_CACHE_ALIGN matrix_multiplication_context { |
163 | size_t tuple_elements; |
164 | size_t batch_size; |
165 | size_t input_channels; |
166 | size_t batch_block_start; |
167 | size_t batch_block_size; |
168 | size_t output_channels_block_start; |
169 | size_t output_channels_block_size; |
170 | size_t batch_subblock_max; |
171 | size_t input_channels_subblock_max; |
172 | |
173 | const float* grad_output_transform; |
174 | const float* kernel_transform; |
175 | float* grad_input_transform; |
176 | |
177 | nnp_fast_tuple_gemm_function fast_gemm; |
178 | nnp_full_tuple_gemm_function full_gemm; |
179 | }; |
180 | |
181 | static void compute_matrix_multiplication( |
182 | const struct matrix_multiplication_context context[restrict static 1], |
183 | size_t input_channels_block_start, size_t batch_subblock_start, |
184 | size_t input_channels_block_size, size_t batch_subblock_size) |
185 | { |
186 | const size_t tuple_elements = context->tuple_elements; |
187 | const size_t batch_size = context->batch_size; |
188 | const size_t input_channels = context->input_channels; |
189 | const size_t batch_block_start = context->batch_block_start; |
190 | const size_t batch_block_size = context->batch_block_size; |
191 | const size_t output_channels_block_start = context->output_channels_block_start; |
192 | const size_t output_channels_block_size = context->output_channels_block_size; |
193 | const size_t batch_subblock_max = context->batch_subblock_max; |
194 | const size_t input_channels_subblock_max = context->input_channels_subblock_max; |
195 | |
196 | const float* grad_output_transform = context->grad_output_transform + |
197 | (output_channels_block_start * batch_size + (batch_block_start + batch_subblock_start) * output_channels_block_size) * tuple_elements; |
198 | const float* kernel_transform = context->kernel_transform + |
199 | (output_channels_block_start * input_channels + input_channels_block_start * output_channels_block_size) * tuple_elements; |
200 | float* grad_input_transform = context->grad_input_transform + |
201 | (batch_block_start * input_channels + input_channels_block_start * batch_block_size) * tuple_elements; |
202 | |
203 | if (batch_subblock_size == batch_subblock_max) { |
204 | const nnp_fast_tuple_gemm_function fast_gemm = context->fast_gemm; |
205 | while (input_channels_block_size >= input_channels_subblock_max) { |
206 | input_channels_block_size -= input_channels_subblock_max; |
207 | |
208 | fast_gemm( |
209 | output_channels_block_size, output_channels_block_start, |
210 | grad_output_transform, |
211 | kernel_transform, |
212 | grad_input_transform + batch_subblock_start * input_channels_subblock_max * tuple_elements, |
213 | input_channels_subblock_max * tuple_elements); |
214 | |
215 | kernel_transform += input_channels_subblock_max * output_channels_block_size * tuple_elements; |
216 | grad_input_transform += input_channels_subblock_max * batch_block_size * tuple_elements; |
217 | } |
218 | } |
219 | |
220 | const nnp_full_tuple_gemm_function full_gemm = context->full_gemm; |
221 | while (input_channels_block_size != 0) { |
222 | const size_t input_channels_subblock_size = min(input_channels_block_size, input_channels_subblock_max); |
223 | input_channels_block_size -= input_channels_subblock_size; |
224 | |
225 | full_gemm( |
226 | batch_subblock_size, input_channels_subblock_size, |
227 | output_channels_block_size, output_channels_block_start, |
228 | grad_output_transform, |
229 | kernel_transform, |
230 | grad_input_transform + batch_subblock_start * input_channels_subblock_size * tuple_elements, |
231 | input_channels_subblock_size * tuple_elements); |
232 | |
233 | kernel_transform += input_channels_subblock_max * output_channels_block_size * tuple_elements; |
234 | grad_input_transform += input_channels_subblock_max * batch_block_size * tuple_elements; |
235 | } |
236 | } |
237 | |
238 | static enum nnp_status compute_fast_convolution_input_gradient( |
239 | bool fourier_transform, |
240 | size_t batch_size, |
241 | size_t input_channels, |
242 | size_t output_channels, |
243 | struct nnp_size tile_size, |
244 | struct nnp_size input_size, |
245 | struct nnp_padding input_padding, |
246 | struct nnp_size kernel_size, |
247 | struct nnp_size output_size, |
248 | const float* grad_output, |
249 | const float* kernel, |
250 | float* grad_input, |
251 | void* workspace_buffer, |
252 | size_t* workspace_size, |
253 | nnp_transform_2d_with_offset grad_output_transform_function, |
254 | nnp_transform_2d_with_offset kernel_transform_function, |
255 | nnp_transform_2d_with_offset grad_input_transform_function, |
256 | pthreadpool_t threadpool, |
257 | struct nnp_profile* profile) |
258 | { |
259 | void* memory_block = NULL; |
260 | const size_t simd_width = nnp_hwinfo.simd_width; |
261 | const size_t tuple_elements = (fourier_transform ? simd_width * 2 : simd_width); |
262 | const size_t tile_elements = tile_size.height * tile_size.width; |
263 | const size_t tuple_count = tile_elements / tuple_elements; |
264 | |
265 | const struct nnp_size grad_input_tile_size = { |
266 | .height = tile_size.height - kernel_size.height + 1, |
267 | .width = tile_size.width - kernel_size.width + 1 |
268 | }; |
269 | |
270 | /* Calculate cache blocking parameters */ |
271 | const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / (tuple_elements * sizeof(float)); |
272 | const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / (tuple_elements * sizeof(float)); |
273 | const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / (tuple_elements * sizeof(float)); |
274 | |
275 | const size_t batch_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.mr : nnp_hwinfo.sxgemm.mr); |
276 | const size_t input_channels_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.nr : nnp_hwinfo.sxgemm.nr); |
277 | |
278 | const size_t output_channels_block_max = |
279 | round_down(cache_elements_l1 / (batch_subblock_max + input_channels_subblock_max), 2); |
280 | const size_t batch_block_max = |
281 | round_down(cache_elements_l3 / output_channels_block_max, batch_subblock_max); |
282 | const size_t input_channels_block_max = |
283 | round_down(cache_elements_l2 / output_channels_block_max, input_channels_subblock_max); |
284 | |
285 | /* Calculate memory footprint and allocate memory */ |
286 | const size_t kernel_transform_size = output_channels * input_channels * tile_elements * sizeof(float); |
287 | const size_t grad_input_transform_size = batch_size * input_channels * tile_elements * sizeof(float); |
288 | const size_t grad_output_transform_size = batch_size * output_channels * tile_elements * sizeof(float); |
289 | const size_t memory_size = kernel_transform_size + grad_input_transform_size + grad_output_transform_size; |
290 | |
291 | if (workspace_buffer == NULL) { |
292 | if (workspace_size == NULL) { |
293 | memory_block = allocate_memory(memory_size); |
294 | if (memory_block == NULL) { |
295 | return nnp_status_out_of_memory; |
296 | } |
297 | } else { |
298 | *workspace_size = memory_size; |
299 | return nnp_status_success; |
300 | } |
301 | } else { |
302 | if (*workspace_size < memory_size) { |
303 | return nnp_status_insufficient_buffer; |
304 | } |
305 | memory_block = workspace_buffer; |
306 | } |
307 | |
308 | float* grad_output_transform = memory_block; |
309 | float* kernel_transform = memory_block + grad_output_transform_size; |
310 | float* grad_input_transform = memory_block + grad_output_transform_size + kernel_transform_size; |
311 | |
312 | NNP_KERNEL_TRANSFORM_START(profile) |
313 | struct kernel_transform_context kernel_transform_context = { |
314 | .transform_function = kernel_transform_function, |
315 | .kernel = kernel, |
316 | .kernel_transform = kernel_transform, |
317 | .tuple_elements = tuple_elements, |
318 | .input_channels = input_channels, |
319 | .output_channels = output_channels, |
320 | .output_channels_block_max = output_channels_block_max, |
321 | .kernel_size = kernel_size, |
322 | }; |
323 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
324 | (pthreadpool_task_2d_tile_2d_t) compute_kernel_transform, |
325 | &kernel_transform_context, |
326 | output_channels, input_channels, |
327 | 1, input_channels_subblock_max, |
328 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
329 | NNP_KERNEL_TRANSFORM_END(profile) |
330 | |
331 | for (size_t y = 0; y < input_size.height; y += grad_input_tile_size.height) { |
332 | const size_t grad_output_y = min(doz(y + input_padding.top, kernel_size.height - 1), output_size.height); |
333 | for (size_t x = 0; x < input_size.width; x += grad_input_tile_size.width) { |
334 | const size_t grad_output_x = min(doz(x + input_padding.left, kernel_size.width - 1), output_size.width); |
335 | |
336 | NNP_OUTPUT_TRANSFORM_START(profile) |
337 | struct grad_output_transform_context grad_output_transform_context = { |
338 | .transform_function = grad_output_transform_function, |
339 | .grad_output = grad_output + grad_output_y * output_size.width + grad_output_x, |
340 | .grad_output_transform = grad_output_transform, |
341 | .tuple_elements = tuple_elements, |
342 | .batch_size = batch_size, |
343 | .output_channels = output_channels, |
344 | .output_channels_block_max = output_channels_block_max, |
345 | .output_size = output_size, |
346 | .row_offset = doz(kernel_size.height - 1, y + input_padding.top), |
347 | .row_count = min(output_size.height - grad_output_y, |
348 | tile_size.height - grad_output_transform_context.row_offset), |
349 | .column_offset = doz(kernel_size.width - 1, x + input_padding.left), |
350 | .column_count = min(output_size.width - grad_output_x, |
351 | tile_size.width - grad_output_transform_context.column_offset), |
352 | }; |
353 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
354 | (pthreadpool_task_2d_tile_2d_t) compute_grad_output_transform, |
355 | &grad_output_transform_context, |
356 | output_channels, batch_size, |
357 | 1, batch_subblock_max, |
358 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
359 | NNP_OUTPUT_TRANSFORM_END(profile) |
360 | |
361 | NNP_BLOCK_MULTIPLICATION_START(profile) |
362 | for (size_t tuple_index = 0; tuple_index < tuple_count; tuple_index += 1) { |
363 | for (size_t output_channels_block_start = 0; output_channels_block_start < output_channels; output_channels_block_start += output_channels_block_max) { |
364 | const size_t output_channels_block_size = min(output_channels - output_channels_block_start, output_channels_block_max); |
365 | for (size_t batch_block_start = 0; batch_block_start < batch_size; batch_block_start += batch_block_max) { |
366 | const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max); |
367 | struct matrix_multiplication_context matrix_multiplication_context = { |
368 | .tuple_elements = tuple_elements, |
369 | .batch_size = batch_size, |
370 | .input_channels = input_channels, |
371 | .batch_block_start = batch_block_start, |
372 | .batch_block_size = batch_block_size, |
373 | .output_channels_block_start = output_channels_block_start, |
374 | .output_channels_block_size = output_channels_block_size, |
375 | .batch_subblock_max = batch_subblock_max, |
376 | .input_channels_subblock_max = input_channels_subblock_max, |
377 | .grad_output_transform = grad_output_transform + tuple_index * tuple_elements * batch_size * output_channels, |
378 | .kernel_transform = kernel_transform + tuple_index * tuple_elements * output_channels * input_channels, |
379 | .grad_input_transform = grad_input_transform + tuple_index * tuple_elements * batch_size * input_channels, |
380 | }; |
381 | if (fourier_transform) { |
382 | if (tuple_index < NNP_COMPLEX_TUPLE_INDEX) { |
383 | matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.s4cX_only_mr_x_nr; |
384 | matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.s4cX_upto_mr_x_nr; |
385 | } else { |
386 | matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.cX_only_mr_x_nr; |
387 | matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.cX_upto_mr_x_nr; |
388 | } |
389 | } else { |
390 | matrix_multiplication_context.fast_gemm = nnp_hwinfo.sxgemm.only_mr_x_nr; |
391 | matrix_multiplication_context.full_gemm = nnp_hwinfo.sxgemm.upto_mr_x_nr; |
392 | } |
393 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
394 | (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication, |
395 | &matrix_multiplication_context, |
396 | input_channels, batch_block_size, |
397 | input_channels_block_max, batch_subblock_max, |
398 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
399 | } |
400 | } |
401 | } |
402 | NNP_BLOCK_MULTIPLICATION_END(profile) |
403 | |
404 | NNP_INPUT_TRANSFORM_START(profile) |
405 | struct grad_input_transform_context grad_input_transform_context = { |
406 | .transform_function = grad_input_transform_function, |
407 | .grad_input = grad_input + y * input_size.width + x, |
408 | .grad_input_transform = grad_input_transform, |
409 | .tuple_elements = tuple_elements, |
410 | .input_channels = input_channels, |
411 | .batch_size = batch_size, |
412 | .batch_block_max = batch_block_max, |
413 | .input_size = input_size, |
414 | .row_offset = fourier_transform ? kernel_size.height - 1 : 0, |
415 | .row_count = min(input_size.height - y, grad_input_tile_size.height), |
416 | .column_offset = fourier_transform ? kernel_size.width - 1 : 0, |
417 | .column_count = min(input_size.width - x, grad_input_tile_size.width), |
418 | }; |
419 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
420 | (pthreadpool_task_2d_tile_2d_t) compute_grad_input_transform, |
421 | &grad_input_transform_context, |
422 | batch_size, input_channels, |
423 | 1, input_channels_subblock_max, |
424 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
425 | NNP_INPUT_TRANSFORM_END(profile) |
426 | } |
427 | } |
428 | |
429 | if (memory_block != workspace_buffer) { |
430 | release_memory(memory_block, memory_size); |
431 | } |
432 | return nnp_status_success; |
433 | } |
434 | |
435 | enum nnp_status nnp_convolution_input_gradient( |
436 | enum nnp_convolution_algorithm algorithm, |
437 | size_t batch_size, |
438 | size_t input_channels, |
439 | size_t output_channels, |
440 | struct nnp_size input_size, |
441 | struct nnp_padding input_padding, |
442 | struct nnp_size kernel_size, |
443 | const float* grad_output, |
444 | const float* kernel, |
445 | float* grad_input, |
446 | void* workspace_buffer, |
447 | size_t* workspace_size, |
448 | enum nnp_activation activation, |
449 | const void* activation_parameters, |
450 | pthreadpool_t threadpool, |
451 | struct nnp_profile* profile) |
452 | { |
453 | NNP_TOTAL_START(profile) |
454 | |
455 | /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */ |
456 | enum nnp_status status = validate_convolution_arguments( |
457 | batch_size, input_channels, output_channels, |
458 | input_size, input_padding, kernel_size, (struct nnp_size) { 1, 1 }, |
459 | activation, activation_parameters); |
460 | if (status != nnp_status_success) { |
461 | goto cleanup; |
462 | } |
463 | |
464 | if (activation != nnp_activation_identity) { |
465 | status = nnp_status_unsupported_activation; |
466 | goto cleanup; |
467 | } |
468 | |
469 | if (activation_parameters != NULL) { |
470 | status = nnp_status_unsupported_activation_parameters; |
471 | goto cleanup; |
472 | } |
473 | |
474 | /* If requested, choose optimal convolution algorithm */ |
475 | if (algorithm == nnp_convolution_algorithm_auto) { |
476 | if (max(kernel_size.width, kernel_size.height) > 8) { |
477 | algorithm = nnp_convolution_algorithm_ft16x16; |
478 | } else { |
479 | const size_t tile_count_8x8 = |
480 | divide_round_up(input_size.height, 8 - kernel_size.height + 1) * |
481 | divide_round_up(input_size.width, 8 - kernel_size.width + 1); |
482 | const size_t tile_count_16x16 = |
483 | divide_round_up(input_size.height, 16 - kernel_size.height + 1) * |
484 | divide_round_up(input_size.width, 16 - kernel_size.width + 1); |
485 | if (tile_count_8x8 <= 4 * tile_count_16x16) { |
486 | /* 8x8 tiles are more efficient */ |
487 | if ((kernel_size.height == 3) && (kernel_size.width == 3)) { |
488 | algorithm = nnp_convolution_algorithm_wt8x8; |
489 | } else { |
490 | algorithm = nnp_convolution_algorithm_ft8x8; |
491 | } |
492 | } else { |
493 | algorithm = nnp_convolution_algorithm_ft16x16; |
494 | } |
495 | } |
496 | } |
497 | |
498 | /* Choose tiling parameters and transform functions depending on convolution algorithm */ |
499 | struct nnp_size tile_size; |
500 | bool fourier_transform; |
501 | nnp_transform_2d_with_offset grad_output_transform_function; |
502 | nnp_transform_2d_with_offset kernel_transform_function; |
503 | nnp_transform_2d_with_offset grad_input_transform_function; |
504 | switch (algorithm) { |
505 | case nnp_convolution_algorithm_ft8x8: |
506 | grad_output_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream; |
507 | kernel_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream; |
508 | grad_input_transform_function = nnp_hwinfo.transforms.ifft8x8_with_offset; |
509 | tile_size = (struct nnp_size) { .height = 8, .width = 8 }; |
510 | fourier_transform = true; |
511 | break; |
512 | case nnp_convolution_algorithm_ft16x16: |
513 | grad_output_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream; |
514 | kernel_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream; |
515 | grad_input_transform_function = nnp_hwinfo.transforms.ifft16x16_with_offset; |
516 | tile_size = (struct nnp_size) { .height = 16, .width = 16 }; |
517 | fourier_transform = true; |
518 | break; |
519 | case nnp_convolution_algorithm_wt8x8: |
520 | case nnp_convolution_algorithm_wt8x8_fp16: |
521 | if ((kernel_size.height != 3) || (kernel_size.width != 3)) { |
522 | status = nnp_status_unsupported_algorithm; |
523 | goto cleanup; |
524 | } |
525 | grad_output_transform_function = nnp_hwinfo.transforms.iwt_f6x6_3x3_with_offset_and_stream; |
526 | kernel_transform_function = nnp_hwinfo.transforms.kwt_f6x6_3Rx3R; |
527 | grad_input_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3; |
528 | tile_size = (struct nnp_size) { .height = 8, .width = 8 }; |
529 | fourier_transform = false; |
530 | break; |
531 | case nnp_convolution_algorithm_implicit_gemm: |
532 | case nnp_convolution_algorithm_direct: |
533 | status = nnp_status_unsupported_algorithm; |
534 | goto cleanup; |
535 | case nnp_convolution_algorithm_auto: |
536 | NNP_UNREACHABLE; |
537 | default: |
538 | status = nnp_status_invalid_algorithm; |
539 | goto cleanup; |
540 | } |
541 | |
542 | const struct nnp_size output_size = { |
543 | .width = input_padding.left + input_size.width + input_padding.right - kernel_size.width + 1, |
544 | .height = input_padding.top + input_size.height + input_padding.bottom - kernel_size.height + 1 |
545 | }; |
546 | |
547 | switch (algorithm) { |
548 | case nnp_convolution_algorithm_wt8x8: |
549 | case nnp_convolution_algorithm_wt8x8_fp16: |
550 | case nnp_convolution_algorithm_ft8x8: |
551 | case nnp_convolution_algorithm_ft16x16: |
552 | if (kernel_size.height > tile_size.height || kernel_size.width > tile_size.width) { |
553 | status = nnp_status_unsupported_algorithm; |
554 | goto cleanup; |
555 | } |
556 | status = compute_fast_convolution_input_gradient( |
557 | fourier_transform, |
558 | batch_size, input_channels, output_channels, |
559 | tile_size, input_size, input_padding, kernel_size, output_size, |
560 | grad_output, kernel, grad_input, workspace_buffer, workspace_size, |
561 | grad_output_transform_function, kernel_transform_function, grad_input_transform_function, |
562 | threadpool, profile); |
563 | break; |
564 | case nnp_convolution_algorithm_implicit_gemm: |
565 | case nnp_convolution_algorithm_direct: |
566 | case nnp_convolution_algorithm_auto: |
567 | NNP_UNREACHABLE; |
568 | } |
569 | |
570 | cleanup: |
571 | NNP_TOTAL_END(profile) |
572 | return status; |
573 | } |
574 | |