1 | #include <stdbool.h> |
2 | #include <stdint.h> |
3 | #include <stddef.h> |
4 | |
5 | #include <nnpack.h> |
6 | #include <nnpack/macros.h> |
7 | #include <nnpack/utils.h> |
8 | #include <nnpack/system.h> |
9 | |
10 | #include <nnpack/hwinfo.h> |
11 | #include <nnpack/validation.h> |
12 | |
13 | |
14 | struct NNP_CACHE_ALIGN kernel_transform_context { |
15 | nnp_transform_2d_with_offset transform_function; |
16 | const float* kernel; |
17 | float* kernel_transform; |
18 | |
19 | size_t tuple_elements; |
20 | size_t output_channels; |
21 | size_t input_channels; |
22 | size_t input_channels_block_max; |
23 | struct nnp_size kernel_size; |
24 | }; |
25 | |
26 | static void compute_kernel_transform( |
27 | const struct kernel_transform_context context[restrict static 1], |
28 | size_t input_channel, size_t output_channels_subblock_start, |
29 | size_t input_channel_range, size_t output_channels_subblock_size) |
30 | { |
31 | const size_t tuple_elements = context->tuple_elements; |
32 | const size_t output_channels = context->output_channels; |
33 | const size_t input_channels = context->input_channels; |
34 | const size_t input_channels_block_max = context->input_channels_block_max; |
35 | const struct nnp_size kernel_size = context->kernel_size; |
36 | |
37 | const float (*kernel)[input_channels][kernel_size.width * kernel_size.height] = |
38 | (const float(*)[input_channels][kernel_size.width * kernel_size.height]) context->kernel; |
39 | float* kernel_transform = context->kernel_transform; |
40 | nnp_transform_2d_with_offset transform_function = context->transform_function; |
41 | |
42 | const size_t input_channels_block_start = round_down(input_channel, input_channels_block_max); |
43 | const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max); |
44 | const size_t input_channels_block_offset = input_channel - input_channels_block_start; |
45 | |
46 | for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) { |
47 | const size_t output_channel = output_channels_subblock_start + output_channels_subblock_offset; |
48 | transform_function( |
49 | kernel[output_channel][input_channel], |
50 | kernel_transform + |
51 | (input_channels_block_start * output_channels + output_channels_subblock_start * input_channels_block_size + input_channels_block_offset * output_channels_subblock_size + output_channels_subblock_offset) * tuple_elements, |
52 | kernel_size.width, |
53 | output_channels * input_channels * tuple_elements * sizeof(float), |
54 | kernel_size.height, kernel_size.width, 0, 0); |
55 | } |
56 | } |
57 | |
58 | struct NNP_CACHE_ALIGN input_transform_context { |
59 | nnp_transform_2d_with_offset transform_function; |
60 | const float* input; |
61 | float* input_transform; |
62 | |
63 | size_t tuple_elements; |
64 | size_t batch_size; |
65 | size_t input_channels; |
66 | size_t input_channels_block_max; |
67 | struct nnp_size input_size; |
68 | size_t row_offset; |
69 | size_t row_count; |
70 | size_t column_offset; |
71 | size_t column_count; |
72 | }; |
73 | |
74 | static void compute_input_transform( |
75 | const struct input_transform_context context[restrict static 1], |
76 | size_t input_channel, size_t batch_subblock_start, |
77 | size_t input_channel_range, size_t batch_subblock_size) |
78 | { |
79 | const size_t tuple_elements = context->tuple_elements; |
80 | const size_t batch_size = context->batch_size; |
81 | const size_t input_channels = context->input_channels; |
82 | const size_t input_channels_block_max = context->input_channels_block_max; |
83 | const struct nnp_size input_size = context->input_size; |
84 | const size_t row_offset = context->row_offset; |
85 | const size_t row_count = context->row_count; |
86 | const size_t column_offset = context->column_offset; |
87 | const size_t column_count = context->column_count; |
88 | |
89 | const float (*input)[input_channels][input_size.width * input_size.height] = |
90 | (const float(*)[input_channels][input_size.width * input_size.height]) context->input; |
91 | float* input_transform = context->input_transform; |
92 | nnp_transform_2d_with_offset transform_function = context->transform_function; |
93 | |
94 | const size_t input_channels_block_start = round_down(input_channel, input_channels_block_max); |
95 | const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max); |
96 | const size_t input_channels_block_offset = input_channel - input_channels_block_start; |
97 | |
98 | for (size_t batch_subblock_offset = 0; batch_subblock_offset < batch_subblock_size; batch_subblock_offset += 1) { |
99 | const size_t sample = batch_subblock_start + batch_subblock_offset; |
100 | transform_function( |
101 | input[sample][input_channel], |
102 | input_transform + |
103 | (input_channels_block_start * batch_size + batch_subblock_start * input_channels_block_size + input_channels_block_offset * batch_subblock_size + batch_subblock_offset) * tuple_elements, |
104 | input_size.width, |
105 | batch_size * input_channels * tuple_elements * sizeof(float), |
106 | row_count, column_count, row_offset, column_offset); |
107 | } |
108 | } |
109 | |
110 | struct NNP_CACHE_ALIGN output_transform_context { |
111 | nnp_transform_2d_with_bias transform_function; |
112 | float* output; |
113 | const float* output_transform; |
114 | const float* bias; |
115 | |
116 | size_t tuple_elements; |
117 | size_t output_channels; |
118 | size_t batch_size; |
119 | size_t batch_block_max; |
120 | struct nnp_size output_size; |
121 | size_t row_offset; |
122 | size_t row_count; |
123 | size_t column_offset; |
124 | size_t column_count; |
125 | }; |
126 | |
127 | static void compute_output_transform( |
128 | const struct output_transform_context context[restrict static 1], |
129 | size_t sample, size_t output_channels_subblock_start, |
130 | size_t sample_range, size_t output_channels_subblock_size) |
131 | { |
132 | const size_t tuple_elements = context->tuple_elements; |
133 | const size_t batch_size = context->batch_size; |
134 | const size_t output_channels = context->output_channels; |
135 | const size_t batch_block_max = context->batch_block_max; |
136 | const struct nnp_size output_size = context->output_size; |
137 | const size_t row_offset = context->row_offset; |
138 | const size_t row_count = context->row_count; |
139 | const size_t column_offset = context->column_offset; |
140 | const size_t column_count = context->column_count; |
141 | |
142 | float (*output)[output_channels][output_size.width * output_size.height] = |
143 | (float(*)[output_channels][output_size.width * output_size.height]) context->output; |
144 | const float* output_transform = context->output_transform; |
145 | const float* bias = context->bias; |
146 | nnp_transform_2d_with_bias transform_function = context->transform_function; |
147 | |
148 | const size_t batch_block_start = round_down(sample, batch_block_max); |
149 | const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max); |
150 | const size_t batch_block_offset = sample - batch_block_start; |
151 | |
152 | for (size_t output_channels_subblock_offset = 0; output_channels_subblock_offset < output_channels_subblock_size; output_channels_subblock_offset += 1) { |
153 | const size_t output_channel = output_channels_subblock_start + output_channels_subblock_offset; |
154 | transform_function( |
155 | output_transform + |
156 | (batch_block_start * output_channels + output_channels_subblock_start * batch_block_size + batch_block_offset * output_channels_subblock_size + output_channels_subblock_offset) * tuple_elements, |
157 | output[sample][output_channel], |
158 | &bias[output_channel], |
159 | batch_size * output_channels * tuple_elements * sizeof(float), |
160 | output_size.width, |
161 | row_count, column_count); |
162 | } |
163 | } |
164 | |
165 | struct NNP_CACHE_ALIGN matrix_multiplication_context { |
166 | size_t tuple_elements; |
167 | size_t batch_block_size; |
168 | size_t input_channels_block_start; |
169 | size_t input_channels_block_size; |
170 | size_t batch_subblock_max; |
171 | size_t output_channels_subblock_max; |
172 | |
173 | const float* input_transform; |
174 | const float* kernel_transform; |
175 | float* output_transform; |
176 | |
177 | nnp_fast_tuple_gemm_function fast_gemm; |
178 | nnp_full_tuple_gemm_function full_gemm; |
179 | }; |
180 | |
181 | static void compute_matrix_multiplication( |
182 | const struct matrix_multiplication_context context[restrict static 1], |
183 | size_t output_channels_block_start, size_t batch_subblock_start, |
184 | size_t output_channels_block_size, size_t batch_subblock_size) |
185 | { |
186 | const size_t tuple_elements = context->tuple_elements; |
187 | const size_t batch_block_size = context->batch_block_size; |
188 | const size_t input_channels_block_start = context->input_channels_block_start; |
189 | const size_t input_channels_block_size = context->input_channels_block_size; |
190 | const size_t batch_subblock_max = context->batch_subblock_max; |
191 | const size_t output_channels_subblock_max = context->output_channels_subblock_max; |
192 | |
193 | const float* input_transform = context->input_transform + |
194 | (batch_subblock_start * input_channels_block_size * tuple_elements); |
195 | const float* kernel_transform = context->kernel_transform + |
196 | (output_channels_block_start * input_channels_block_size * tuple_elements); |
197 | float* output_transform = context->output_transform + |
198 | (output_channels_block_start * batch_block_size * tuple_elements); |
199 | |
200 | if (batch_subblock_size == batch_subblock_max) { |
201 | const nnp_fast_tuple_gemm_function fast_gemm = context->fast_gemm; |
202 | while (output_channels_block_size >= output_channels_subblock_max) { |
203 | output_channels_block_size -= output_channels_subblock_max; |
204 | |
205 | fast_gemm( |
206 | input_channels_block_size, input_channels_block_start, |
207 | input_transform, |
208 | kernel_transform, |
209 | output_transform + (batch_subblock_start * output_channels_subblock_max * tuple_elements), |
210 | output_channels_subblock_max * tuple_elements); |
211 | |
212 | kernel_transform += input_channels_block_size * output_channels_subblock_max * tuple_elements; |
213 | output_transform += batch_block_size * output_channels_subblock_max * tuple_elements; |
214 | } |
215 | } |
216 | |
217 | const nnp_full_tuple_gemm_function full_gemm = context->full_gemm; |
218 | while (output_channels_block_size != 0) { |
219 | const size_t output_channels_subblock_size = min(output_channels_block_size, output_channels_subblock_max); |
220 | output_channels_block_size -= output_channels_subblock_size; |
221 | |
222 | full_gemm( |
223 | batch_subblock_size, output_channels_subblock_size, |
224 | input_channels_block_size, input_channels_block_start, |
225 | input_transform, |
226 | kernel_transform, |
227 | output_transform + (batch_subblock_start * output_channels_subblock_size * tuple_elements), |
228 | output_channels_subblock_size * tuple_elements); |
229 | |
230 | kernel_transform += input_channels_block_size * output_channels_subblock_max * tuple_elements; |
231 | output_transform += batch_block_size * output_channels_subblock_max * tuple_elements; |
232 | } |
233 | } |
234 | |
235 | static enum nnp_status compute_fast_convolution_output( |
236 | bool fourier_transform, |
237 | size_t batch_size, |
238 | size_t input_channels, |
239 | size_t output_channels, |
240 | struct nnp_size tile_size, |
241 | struct nnp_size input_size, |
242 | struct nnp_padding input_padding, |
243 | struct nnp_size kernel_size, |
244 | struct nnp_size output_size, |
245 | const float* input, |
246 | const float* kernel, |
247 | const float* bias, |
248 | float* output, |
249 | void* workspace_buffer, |
250 | size_t* workspace_size, |
251 | const nnp_transform_2d_with_offset input_transform_function, |
252 | const nnp_transform_2d_with_offset kernel_transform_function, |
253 | const nnp_transform_2d_with_bias output_transform_function, |
254 | pthreadpool_t threadpool, |
255 | struct nnp_profile* profile) |
256 | { |
257 | void* memory_block = NULL; |
258 | const size_t simd_width = nnp_hwinfo.simd_width; |
259 | const size_t tuple_elements = (fourier_transform ? simd_width * 2 : simd_width); |
260 | const size_t tile_elements = tile_size.height * tile_size.width; |
261 | const size_t tuple_count = tile_elements / tuple_elements; |
262 | |
263 | const struct nnp_size output_tile_size = { |
264 | .height = tile_size.height - kernel_size.height + 1, |
265 | .width = tile_size.width - kernel_size.width + 1 |
266 | }; |
267 | |
268 | /* Calculate cache blocking parameters */ |
269 | const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / (tuple_elements * sizeof(float)); |
270 | const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / (tuple_elements * sizeof(float)); |
271 | const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / (tuple_elements * sizeof(float)); |
272 | |
273 | const size_t batch_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.mr : nnp_hwinfo.sxgemm.mr); |
274 | const size_t output_channels_subblock_max = (fourier_transform ? nnp_hwinfo.cxgemm.nr : nnp_hwinfo.sxgemm.nr); |
275 | |
276 | const size_t input_channels_block_max = |
277 | round_down(cache_elements_l1 / (batch_subblock_max + output_channels_subblock_max), 2); |
278 | const size_t batch_block_max = |
279 | round_down(cache_elements_l3 / input_channels_block_max, batch_subblock_max); |
280 | const size_t output_channels_block_max = |
281 | round_down(cache_elements_l2 / input_channels_block_max, output_channels_subblock_max); |
282 | |
283 | /* Calculate memory footprint and allocate memory */ |
284 | const size_t kernel_transform_size = output_channels * input_channels * tile_elements * sizeof(float); |
285 | const size_t input_transform_size = batch_size * input_channels * tile_elements * sizeof(float); |
286 | const size_t output_transform_size = batch_size * output_channels * tile_elements * sizeof(float); |
287 | const size_t memory_size = kernel_transform_size + input_transform_size + output_transform_size; |
288 | |
289 | if (workspace_buffer == NULL) { |
290 | if (workspace_size == NULL) { |
291 | memory_block = allocate_memory(memory_size); |
292 | if (memory_block == NULL) { |
293 | return nnp_status_out_of_memory; |
294 | } |
295 | } else { |
296 | *workspace_size = memory_size; |
297 | return nnp_status_success; |
298 | } |
299 | } else { |
300 | if (*workspace_size < memory_size) { |
301 | return nnp_status_insufficient_buffer; |
302 | } |
303 | memory_block = workspace_buffer; |
304 | } |
305 | |
306 | float* input_transform = memory_block; |
307 | float* output_transform = memory_block + input_transform_size; |
308 | float* kernel_transform = memory_block + input_transform_size + output_transform_size; |
309 | |
310 | NNP_KERNEL_TRANSFORM_START(profile) |
311 | struct kernel_transform_context kernel_transform_context = { |
312 | .transform_function = kernel_transform_function, |
313 | .kernel = kernel, |
314 | .kernel_transform = kernel_transform, |
315 | .tuple_elements = tuple_elements, |
316 | .output_channels = output_channels, |
317 | .input_channels = input_channels, |
318 | .input_channels_block_max = input_channels_block_max, |
319 | .kernel_size = kernel_size, |
320 | }; |
321 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
322 | (pthreadpool_task_2d_tile_2d_t) compute_kernel_transform, |
323 | &kernel_transform_context, |
324 | input_channels, output_channels, |
325 | 1, output_channels_subblock_max, |
326 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
327 | NNP_KERNEL_TRANSFORM_END(profile) |
328 | |
329 | for (size_t y = 0; y < output_size.height; y += output_tile_size.height) { |
330 | const size_t input_y = min(doz(y, input_padding.top), input_size.height); |
331 | for (size_t x = 0; x < output_size.width; x += output_tile_size.width) { |
332 | const size_t input_x = min(doz(x, input_padding.left), input_size.width); |
333 | |
334 | NNP_INPUT_TRANSFORM_START(profile) |
335 | struct input_transform_context input_transform_context = { |
336 | .transform_function = input_transform_function, |
337 | .input = input + input_y * input_size.width + input_x, |
338 | .input_transform = input_transform, |
339 | .tuple_elements = tuple_elements, |
340 | .batch_size = batch_size, |
341 | .input_channels = input_channels, |
342 | .input_channels_block_max = input_channels_block_max, |
343 | .input_size = input_size, |
344 | .row_offset = doz(input_padding.top, y), |
345 | .row_count = min(input_size.height - input_y, |
346 | tile_size.height - input_transform_context.row_offset), |
347 | .column_offset = doz(input_padding.left, x), |
348 | .column_count = min(input_size.width - input_x, |
349 | tile_size.width - input_transform_context.column_offset), |
350 | }; |
351 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
352 | (pthreadpool_task_2d_tile_2d_t) compute_input_transform, |
353 | &input_transform_context, |
354 | input_channels, batch_size, |
355 | 1, batch_subblock_max, |
356 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
357 | NNP_INPUT_TRANSFORM_END(profile) |
358 | |
359 | NNP_BLOCK_MULTIPLICATION_START(profile) |
360 | for (size_t tuple_index = 0; tuple_index < tuple_count; tuple_index += 1) { |
361 | for (size_t input_channels_block_start = 0; input_channels_block_start < input_channels; input_channels_block_start += input_channels_block_max) { |
362 | const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max); |
363 | for (size_t batch_block_start = 0; batch_block_start < batch_size; batch_block_start += batch_block_max) { |
364 | const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max); |
365 | struct matrix_multiplication_context matrix_multiplication_context = { |
366 | .tuple_elements = tuple_elements, |
367 | .batch_block_size = batch_block_size, |
368 | .input_channels_block_start = input_channels_block_start, |
369 | .input_channels_block_size = input_channels_block_size, |
370 | .batch_subblock_max = batch_subblock_max, |
371 | .output_channels_subblock_max = output_channels_subblock_max, |
372 | .input_transform = input_transform + |
373 | tuple_index * tuple_elements * batch_size * input_channels + |
374 | input_channels_block_start * batch_size * tuple_elements + |
375 | batch_block_start * input_channels_block_size * tuple_elements, |
376 | .kernel_transform = kernel_transform + |
377 | tuple_index * tuple_elements * output_channels * input_channels + |
378 | input_channels_block_start * output_channels * tuple_elements, |
379 | .output_transform = output_transform + tuple_index * tuple_elements * batch_size * output_channels + |
380 | batch_block_start * output_channels * tuple_elements, |
381 | }; |
382 | if (fourier_transform) { |
383 | if (tuple_index < NNP_COMPLEX_TUPLE_INDEX) { |
384 | matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.s4cX_conjb_only_mr_x_nr; |
385 | matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.s4cX_conjb_upto_mr_x_nr; |
386 | } else { |
387 | matrix_multiplication_context.fast_gemm = nnp_hwinfo.cxgemm.cX_conjb_only_mr_x_nr; |
388 | matrix_multiplication_context.full_gemm = nnp_hwinfo.cxgemm.cX_conjb_upto_mr_x_nr; |
389 | } |
390 | } else { |
391 | matrix_multiplication_context.fast_gemm = nnp_hwinfo.sxgemm.only_mr_x_nr; |
392 | matrix_multiplication_context.full_gemm = nnp_hwinfo.sxgemm.upto_mr_x_nr; |
393 | } |
394 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
395 | (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication, |
396 | &matrix_multiplication_context, |
397 | output_channels, batch_block_size, |
398 | output_channels_block_max, batch_subblock_max, |
399 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
400 | } |
401 | } |
402 | } |
403 | NNP_BLOCK_MULTIPLICATION_END(profile) |
404 | |
405 | NNP_OUTPUT_TRANSFORM_START(profile) |
406 | struct output_transform_context output_transform_context = { |
407 | .transform_function = output_transform_function, |
408 | .output = output + y * output_size.width + x, |
409 | .output_transform = output_transform, |
410 | .bias = bias, |
411 | .tuple_elements = tuple_elements, |
412 | .output_channels = output_channels, |
413 | .batch_size = batch_size, |
414 | .batch_block_max = batch_block_max, |
415 | .output_size = output_size, |
416 | .row_count = min(output_tile_size.height, output_size.height - y), |
417 | .column_count = min(output_tile_size.width, output_size.width - x), |
418 | }; |
419 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
420 | (pthreadpool_task_2d_tile_2d_t) compute_output_transform, |
421 | &output_transform_context, |
422 | batch_size, output_channels, |
423 | 1, output_channels_subblock_max, |
424 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
425 | NNP_OUTPUT_TRANSFORM_END(profile) |
426 | } |
427 | } |
428 | |
429 | if (memory_block != workspace_buffer) { |
430 | release_memory(memory_block, memory_size); |
431 | } |
432 | return nnp_status_success; |
433 | } |
434 | |
435 | enum nnp_status nnp_convolution_output( |
436 | enum nnp_convolution_algorithm algorithm, |
437 | size_t batch_size, |
438 | size_t input_channels, |
439 | size_t output_channels, |
440 | struct nnp_size input_size, |
441 | struct nnp_padding input_padding, |
442 | struct nnp_size kernel_size, |
443 | const float* input, |
444 | const float* kernel, |
445 | const float* bias, |
446 | float* output, |
447 | void* workspace_buffer, |
448 | size_t* workspace_size, |
449 | enum nnp_activation activation, |
450 | const void* activation_parameters, |
451 | pthreadpool_t threadpool, |
452 | struct nnp_profile* profile) |
453 | { |
454 | NNP_TOTAL_START(profile) |
455 | |
456 | /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */ |
457 | enum nnp_status status = validate_convolution_arguments( |
458 | batch_size, input_channels, output_channels, |
459 | input_size, input_padding, kernel_size, (struct nnp_size) { 1, 1 }, |
460 | activation, activation_parameters); |
461 | if (status != nnp_status_success) { |
462 | goto cleanup; |
463 | } |
464 | |
465 | if (activation_parameters != NULL) { |
466 | status = nnp_status_unsupported_activation_parameters; |
467 | goto cleanup; |
468 | } |
469 | |
470 | const struct nnp_size output_size = { |
471 | .width = input_padding.left + input_size.width + input_padding.right - kernel_size.width + 1, |
472 | .height = input_padding.top + input_size.height + input_padding.bottom - kernel_size.height + 1 |
473 | }; |
474 | |
475 | /* If requested, choose optimal convolution algorithm */ |
476 | if (algorithm == nnp_convolution_algorithm_auto) { |
477 | if (max(kernel_size.width, kernel_size.height) > 8) { |
478 | algorithm = nnp_convolution_algorithm_ft16x16; |
479 | } else { |
480 | const size_t tile_count_8x8 = |
481 | divide_round_up(output_size.height, 8 - kernel_size.height + 1) * |
482 | divide_round_up(output_size.width, 8 - kernel_size.width + 1); |
483 | const size_t tile_count_16x16 = |
484 | divide_round_up(output_size.height, 16 - kernel_size.height + 1) * |
485 | divide_round_up(output_size.width, 16 - kernel_size.width + 1); |
486 | if (tile_count_8x8 <= 4 * tile_count_16x16) { |
487 | /* 8x8 tiles are more efficient */ |
488 | if ((kernel_size.height == 3) && (kernel_size.width == 3)) { |
489 | algorithm = nnp_convolution_algorithm_wt8x8; |
490 | } else { |
491 | algorithm = nnp_convolution_algorithm_ft8x8; |
492 | } |
493 | } else { |
494 | algorithm = nnp_convolution_algorithm_ft16x16; |
495 | } |
496 | } |
497 | } |
498 | |
499 | /* Choose tiling parameters and transform functions depending on convolution algorithm */ |
500 | struct nnp_size tile_size; |
501 | bool fourier_transform; |
502 | nnp_transform_2d_with_offset input_transform_function; |
503 | nnp_transform_2d_with_offset kernel_transform_function; |
504 | nnp_transform_2d_with_bias output_transform_function; |
505 | switch (algorithm) { |
506 | case nnp_convolution_algorithm_ft8x8: |
507 | input_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream; |
508 | kernel_transform_function = nnp_hwinfo.transforms.fft8x8_with_offset_and_stream; |
509 | switch (activation) { |
510 | case nnp_activation_relu: |
511 | output_transform_function = nnp_hwinfo.transforms.ifft8x8_with_bias_with_relu; |
512 | break; |
513 | case nnp_activation_identity: |
514 | output_transform_function = nnp_hwinfo.transforms.ifft8x8_with_bias; |
515 | break; |
516 | default: |
517 | NNP_UNREACHABLE; |
518 | } |
519 | tile_size = (struct nnp_size) { .height = 8, .width = 8 }; |
520 | fourier_transform = true; |
521 | break; |
522 | case nnp_convolution_algorithm_ft16x16: |
523 | input_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream; |
524 | kernel_transform_function = nnp_hwinfo.transforms.fft16x16_with_offset_and_stream; |
525 | switch (activation) { |
526 | case nnp_activation_relu: |
527 | output_transform_function = nnp_hwinfo.transforms.ifft16x16_with_bias_with_relu; |
528 | break; |
529 | case nnp_activation_identity: |
530 | output_transform_function = nnp_hwinfo.transforms.ifft16x16_with_bias; |
531 | break; |
532 | default: |
533 | NNP_UNREACHABLE; |
534 | } |
535 | tile_size = (struct nnp_size) { .height = 16, .width = 16 }; |
536 | fourier_transform = true; |
537 | break; |
538 | case nnp_convolution_algorithm_wt8x8: |
539 | case nnp_convolution_algorithm_wt8x8_fp16: |
540 | if ((kernel_size.height != 3) || (kernel_size.width != 3)) { |
541 | status = nnp_status_unsupported_algorithm; |
542 | goto cleanup; |
543 | } |
544 | input_transform_function = nnp_hwinfo.transforms.iwt_f6x6_3x3_with_offset_and_stream; |
545 | kernel_transform_function = nnp_hwinfo.transforms.kwt_f6x6_3x3; |
546 | output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_with_bias; |
547 | switch (activation) { |
548 | case nnp_activation_relu: |
549 | output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_with_bias_with_relu; |
550 | break; |
551 | case nnp_activation_identity: |
552 | output_transform_function = nnp_hwinfo.transforms.owt_f6x6_3x3_with_bias; |
553 | break; |
554 | default: |
555 | NNP_UNREACHABLE; |
556 | } |
557 | tile_size = (struct nnp_size) { .height = 8, .width = 8 }; |
558 | fourier_transform = false; |
559 | break; |
560 | case nnp_convolution_algorithm_implicit_gemm: |
561 | case nnp_convolution_algorithm_direct: |
562 | status = nnp_status_unsupported_algorithm; |
563 | goto cleanup; |
564 | case nnp_convolution_algorithm_auto: |
565 | NNP_UNREACHABLE; |
566 | default: |
567 | status = nnp_status_invalid_algorithm; |
568 | goto cleanup; |
569 | } |
570 | |
571 | switch (algorithm) { |
572 | case nnp_convolution_algorithm_wt8x8: |
573 | case nnp_convolution_algorithm_wt8x8_fp16: |
574 | case nnp_convolution_algorithm_ft8x8: |
575 | case nnp_convolution_algorithm_ft16x16: |
576 | if (kernel_size.height > tile_size.height || kernel_size.width > tile_size.width) { |
577 | status = nnp_status_unsupported_algorithm; |
578 | goto cleanup; |
579 | } |
580 | status = compute_fast_convolution_output( |
581 | fourier_transform, |
582 | batch_size, input_channels, output_channels, |
583 | tile_size, input_size, input_padding, kernel_size, output_size, |
584 | input, kernel, bias, output, workspace_buffer, workspace_size, |
585 | input_transform_function, kernel_transform_function, output_transform_function, |
586 | threadpool, profile); |
587 | break; |
588 | case nnp_convolution_algorithm_implicit_gemm: |
589 | case nnp_convolution_algorithm_direct: |
590 | case nnp_convolution_algorithm_auto: |
591 | NNP_UNREACHABLE; |
592 | } |
593 | |
594 | cleanup: |
595 | NNP_TOTAL_END(profile) |
596 | return status; |
597 | } |
598 | |