1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12#include <string.h>
13
14#include <qnnpack.h>
15#include <qnnpack/operator.h>
16#include <qnnpack/log.h>
17#include <qnnpack/common.h>
18#include <qnnpack/math.h>
19#include <qnnpack/params.h>
20
21#ifdef _MSC_VER
22#include <malloc.h>
23#endif
24
25struct q8gemm_context {
26 size_t k;
27 size_t k_stride;
28 size_t n;
29 size_t n_stride;
30 const uint8_t* a;
31 size_t a_stride;
32 const uint8_t* packed_w;
33 uint8_t* c;
34 size_t c_stride;
35 union qnnp_conv_quantization_params quantization_params;
36 const q8gemm_ukernel_function ukernel;
37};
38
39static void compute_q8gemm(
40 const struct q8gemm_context context[RESTRICT_STATIC 1],
41 size_t group_index,
42 size_t pixel_index,
43 size_t mr_block_start,
44 size_t nr_block_start,
45 size_t group_range /* always 1 */,
46 size_t pixel_range,
47 size_t mr_block_size,
48 size_t nr_block_size)
49{
50 const size_t k = context->k;
51 const size_t k_stride = context->k_stride;
52 const size_t n = context->n;
53 const size_t n_stride = context->n_stride;
54 const uint8_t* restrict a = context->a;
55 const size_t a_stride = context->a_stride;
56 const void* restrict packed_w = context->packed_w;
57 uint8_t* restrict c = context->c;
58 const size_t c_stride = context->c_stride;
59
60 context->ukernel(
61 mr_block_size,
62 nr_block_size,
63 k,
64 a + (pixel_index + mr_block_start) * a_stride + group_index * k,
65 a_stride,
66 (const void*) ((uintptr_t) packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
67 c + (pixel_index + mr_block_start) * c_stride + nr_block_start + group_index * n,
68 c_stride,
69 &context->quantization_params);
70}
71
72struct q8sum_rows_context {
73 const uint8_t* a;
74 size_t groups;
75 size_t m;
76 size_t k;
77 size_t a_stride;
78 const int32_t multiplier;
79 int32_t* a_sum;
80 size_t a_sum_stride;
81 const q8sum_rows_ukernel_function ukernel;
82};
83
84static void compute_sum_rows(
85 const struct q8sum_rows_context context[RESTRICT_STATIC 1],
86 size_t group_index,
87 size_t batch_index,
88 size_t block_start,
89 size_t group_range /* always 1 */,
90 size_t batch_range /* always 1 */,
91 size_t block_size)
92{
93 const uint8_t* a = context->a;
94 const size_t groups = context->groups;
95 const size_t m = context->m;
96 const size_t k = context->k;
97 const size_t a_stride = context->a_stride;
98 const int32_t multiplier = context->multiplier;
99 int32_t* a_sum = context->a_sum;
100 const size_t a_sum_stride = context->a_sum_stride;
101
102 context->ukernel(
103 a + batch_index * m * a_stride + group_index * k + block_start * a_stride,
104 min(block_size, m - block_start),
105 k,
106 a_stride,
107 multiplier,
108 a_sum + batch_index * groups * a_sum_stride + group_index * a_sum_stride + block_start);
109}
110
111struct q8gemm_xzp_context {
112 size_t k;
113 size_t k_stride;
114 size_t n;
115 size_t n_stride;
116 const uint8_t* a;
117 size_t a_stride;
118 const void* packed_w;
119 uint8_t* c;
120 size_t c_stride;
121 const int32_t* a_sum;
122 size_t groups;
123 size_t batch_size;
124 size_t a_sum_stride;
125 union qnnp_q31_requantization_params requantization_params;
126 const q8gemm_xzp_ukernel_function ukernel;
127};
128
129static void compute_q8gemm_xzp(
130 const struct q8gemm_xzp_context context[RESTRICT_STATIC 1],
131 size_t group_index,
132 size_t pixel_index,
133 size_t mr_block_start,
134 size_t nr_block_start,
135 size_t group_range /* always 1 */,
136 size_t pixel_range,
137 size_t mr_block_size,
138 size_t nr_block_size)
139{
140 const size_t k = context->k;
141 const size_t k_stride = context->k_stride;
142 const size_t n = context->n;
143 const size_t n_stride = context->n_stride;
144 const uint8_t* restrict a = context->a;
145 const size_t a_stride = context->a_stride;
146 const void* restrict packed_w = context->packed_w;
147 uint8_t* restrict c = context->c;
148 const size_t c_stride = context->c_stride;
149 const int32_t* a_sum = context->a_sum;
150 const size_t groups = context->groups;
151 const size_t a_sum_stride = context->a_sum_stride;
152
153 context->ukernel(
154 mr_block_size,
155 nr_block_size,
156 k,
157 a + (pixel_index + mr_block_start) * a_stride + group_index * k,
158 a_stride,
159 a_sum + pixel_index * groups + group_index * a_sum_stride + mr_block_start,
160 (const void*) ((uintptr_t) packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
161 c + (pixel_index + mr_block_start) * c_stride + nr_block_start + group_index * n,
162 c_stride,
163 &context->requantization_params);
164}
165
166struct q8conv_context {
167 size_t bs;
168 size_t ks;
169 size_t kc;
170 size_t kc_stride;
171 size_t m;
172 size_t m_stride;
173 size_t n;
174 size_t n_stride;
175 const uint8_t** indirect_a;
176 const void* packed_w;
177 uint8_t* c;
178 size_t c_stride;
179 union qnnp_conv_quantization_params quantization_params;
180 const q8conv_ukernel_function ukernel;
181};
182
183static void compute_q8conv(
184 const struct q8conv_context context[RESTRICT_STATIC 1],
185 size_t group_index,
186 size_t image_index,
187 size_t mr_block_start,
188 size_t nr_block_start,
189 size_t group_range /* always 1 */,
190 size_t image_range /* always 1 */,
191 size_t mr_block_size,
192 size_t nr_block_size)
193{
194 const size_t bs = context->bs;
195 const size_t ks = context->ks;
196 const size_t kc = context->kc;
197 const size_t kc_stride = context->kc_stride;
198 const size_t m = context->m;
199 const size_t m_stride = context->m_stride;
200 const size_t n = context->n;
201 const size_t n_stride = context->n_stride;
202 const uint8_t** restrict indirect_a = context->indirect_a;
203 const void* restrict packed_w = context->packed_w;
204 uint8_t* restrict c = context->c;
205 const size_t c_stride = context->c_stride;
206
207 context->ukernel(
208 mr_block_size,
209 nr_block_size,
210 kc,
211 ks,
212 indirect_a + (mr_block_start + (image_index + group_index * bs) * m_stride) * ks,
213 (const void*) ((uintptr_t) packed_w + (nr_block_start + group_index * n_stride) * (kc_stride * sizeof(uint8_t) + sizeof(int32_t))),
214 c + (mr_block_start + image_index * m) * c_stride + group_index * n + nr_block_start,
215 c_stride,
216 &context->quantization_params);
217}
218
219struct q8dwconv_context {
220 size_t groups;
221 size_t group_stride;
222 const uint8_t** indirection_buffer;
223 size_t indirection_buffer_row_stride;
224 size_t indirection_buffer_col_stride;
225 const void* packed_weights;
226 uint8_t* output;
227 size_t output_height;
228 size_t output_width;
229 size_t output_row_stride;
230 size_t output_col_increment;
231 union qnnp_conv_quantization_params quantization_params;
232 union {
233 const q8dwconv_up_ukernel_function unipass_ukernel;
234 const q8dwconv_mp_ukernel_function multipass_ukernel;
235 };
236};
237
238static void compute_dwconv_unipass(
239 const struct q8dwconv_context context[RESTRICT_STATIC 1],
240 size_t image,
241 size_t output_y)
242{
243 const size_t output_height = context->output_height;
244
245 context->unipass_ukernel(
246 context->groups,
247 context->output_width,
248 context->indirection_buffer + (image * output_height + output_y) * context->indirection_buffer_row_stride,
249 context->packed_weights,
250 context->output + (image * output_height + output_y) * context->output_row_stride,
251 context->indirection_buffer_col_stride,
252 context->output_col_increment,
253 &context->quantization_params);
254}
255
256static void compute_dwconv_multiipass(
257 const struct q8dwconv_context context[RESTRICT_STATIC 1],
258 size_t image,
259 size_t output_y)
260{
261 const size_t output_height = context->output_height;
262 QNNP_ALIGN(16)
263 #ifdef _MSC_VER
264 int32_t* multipass_acc = _malloca(sizeof(int32_t) * context->group_stride);
265 #else
266 int32_t multipass_acc[context->group_stride];
267 #endif
268
269
270 context->multipass_ukernel(
271 context->groups,
272 context->output_width,
273 context->indirection_buffer + (image * output_height + output_y) * context->indirection_buffer_row_stride,
274 context->packed_weights,
275 multipass_acc,
276 context->output + (image * output_height + output_y) * context->output_row_stride,
277 context->indirection_buffer_col_stride,
278 context->output_col_increment,
279 &context->quantization_params);
280
281 #ifdef _MSC_VER
282 _freea(multipass_acc);
283 #endif
284}
285
286struct max_pooling_context {
287 const void** indirect_input;
288 size_t indirect_input_batch_stride;
289 size_t indirect_input_height_stride;
290 void* output;
291 size_t output_batch_stride;
292 size_t output_height_stride;
293 size_t output_width;
294 size_t pooling_size;
295 size_t channels;
296 size_t input_increment;
297 size_t output_increment;
298 union qnnp_u8_clamping_params params;
299 u8maxpool_ukernel_function ukernel;
300};
301
302static void compute_max_pooling(
303 const struct max_pooling_context context[RESTRICT_STATIC 1],
304 size_t batch_index,
305 size_t output_y)
306{
307 const void** indirect_input =
308 (const void**) ((uintptr_t) context->indirect_input +
309 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
310 void* output =
311 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
312
313 context->ukernel(
314 context->output_width, context->pooling_size, context->channels,
315 (const uint8_t**) indirect_input, output,
316 context->input_increment, context->output_increment,
317 &context->params);
318}
319
320struct average_pooling_context {
321 const void** indirect_input;
322 size_t indirect_input_batch_stride;
323 size_t indirect_input_height_stride;
324 void* output;
325 size_t output_batch_stride;
326 size_t output_height_stride;
327 size_t output_width;
328 size_t pooling_size;
329 size_t channels;
330 size_t packed_channels;
331 const void* zero;
332 size_t input_increment;
333 size_t output_increment;
334 union qnnp_avgpool_quantization_params quantization_params;
335 union {
336 q8avgpool_up_ukernel_function unipass_ukernel;
337 q8avgpool_mp_ukernel_function multipass_ukernel;
338 };
339};
340
341static void compute_average_pooling_unipass(
342 const struct average_pooling_context context[RESTRICT_STATIC 1],
343 size_t batch_index,
344 size_t output_y)
345{
346 const void** indirect_input =
347 (const void**) ((uintptr_t) context->indirect_input +
348 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
349 void* output =
350 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
351
352 context->unipass_ukernel(
353 context->output_width, context->pooling_size, context->channels,
354 (const uint8_t**) indirect_input, context->zero, output,
355 context->input_increment, context->output_increment,
356 &context->quantization_params);
357}
358
359static void compute_average_pooling_multipass(
360 const struct average_pooling_context context[RESTRICT_STATIC 1],
361 size_t batch_index,
362 size_t output_y)
363{
364 const void** indirect_input =
365 (const void**) ((uintptr_t) context->indirect_input +
366 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
367 void* output =
368 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
369 QNNP_ALIGN(16)
370 #ifdef _MSC_VER
371 int32_t* multipass_buffer = _malloca(sizeof(int32_t) * context->packed_channels);
372 #else
373 int32_t multipass_buffer[context->packed_channels];
374 #endif
375
376 context->multipass_ukernel(
377 context->output_width, context->pooling_size, context->channels,
378 (const uint8_t**) indirect_input, context->zero, multipass_buffer, output,
379 context->input_increment, context->output_increment,
380 &context->quantization_params);
381
382 #ifdef _MSC_VER
383 _freea(multipass_buffer);
384 #endif
385}
386
387struct global_average_pooling_context {
388 const void* input;
389 const void* zero;
390 size_t input_pixel_stride;
391 size_t input_batch_stride;
392 size_t input_elements;
393 size_t channels;
394 size_t packed_channels;
395 void* output;
396 size_t output_batch_stride;
397 union qnnp_avgpool_quantization_params quantization_params;
398 union {
399 q8gavgpool_up_ukernel_function unipass_ukernel;
400 q8gavgpool_mp_ukernel_function multipass_ukernel;
401 };
402};
403
404static void compute_global_average_pooling_unipass(
405 const struct global_average_pooling_context context[RESTRICT_STATIC 1],
406 size_t batch_index)
407{
408 const void* input =
409 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
410 void* output =
411 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
412
413 context->unipass_ukernel(
414 context->input_elements,
415 context->channels,
416 input,
417 context->input_pixel_stride,
418 context->zero,
419 output,
420 &context->quantization_params);
421}
422
423static void compute_global_average_pooling_multipass(
424 const struct global_average_pooling_context context[RESTRICT_STATIC 1],
425 size_t batch_index)
426{
427 const void* input =
428 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
429 void* output =
430 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
431 QNNP_ALIGN(16)
432 #ifdef _MSC_VER
433 int32_t* multipass_buffer = _malloca(sizeof(int32_t) * context->packed_channels);
434 #else
435 int32_t multipass_buffer[context->packed_channels];
436 #endif
437
438
439 context->multipass_ukernel(
440 context->input_elements,
441 context->channels,
442 input,
443 context->input_pixel_stride,
444 context->zero,
445 multipass_buffer,
446 output,
447 &context->quantization_params);
448
449 #ifdef _MSC_VER
450 _freea(multipass_buffer);
451 #endif
452}
453
454struct q8add_strided_context {
455 size_t n;
456 const uint8_t* a;
457 size_t a_stride;
458 const uint8_t* b;
459 size_t b_stride;
460 const uint8_t* y;
461 size_t y_stride;
462 union qnnp_add_quantization_params quantization_params;
463 q8vadd_ukernel_function ukernel;
464};
465
466static void compute_q8add_strided(
467 const struct q8add_strided_context context[RESTRICT_STATIC 1],
468 size_t batch_offset,
469 size_t batch_range /* always 1 */)
470{
471 assert(batch_range == 1);
472
473 const size_t n = context->n;
474 const size_t a_stride = context->a_stride;
475 const size_t b_stride = context->b_stride;
476 const size_t y_stride = context->y_stride;
477 const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_offset);
478 const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_offset);
479 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_offset);
480
481 context->ukernel(n, a, b, y, &context->quantization_params);
482}
483
484struct q8add_contiguous_context {
485 const uint8_t* a;
486 const uint8_t* b;
487 uint8_t* y;
488 union qnnp_add_quantization_params quantization_params;
489 q8vadd_ukernel_function ukernel;
490};
491
492static void compute_q8add_contiguous(
493 const struct q8add_contiguous_context context[RESTRICT_STATIC 1],
494 size_t offset,
495 size_t size)
496{
497 const void* a = (const void*) ((uintptr_t) context->a + offset);
498 const void* b = (const void*) ((uintptr_t) context->b + offset);
499 void* y = (void*) ((uintptr_t) context->y + offset);
500 context->ukernel(size, a, b, y, &context->quantization_params);
501}
502
503struct channel_shuffle_context {
504 const void* x;
505 size_t x_stride;
506 void* y;
507 size_t y_stride;
508 size_t n;
509 size_t m;
510 union {
511 xzipc_ukernel_function fixed_ukernel;
512 xzipv_ukernel_function variable_ukernel;
513 };
514};
515
516static void compute_channel_shuffle_fixed(
517 const struct channel_shuffle_context context[RESTRICT_STATIC 1],
518 size_t index)
519{
520 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
521 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
522
523 context->fixed_ukernel(context->n, x, y);
524}
525
526static void compute_channel_shuffle_variable(
527 const struct channel_shuffle_context context[RESTRICT_STATIC 1],
528 size_t index)
529{
530 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
531 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
532
533 context->variable_ukernel(context->n, context->m, x, y);
534}
535
536struct lut_strided_context {
537 size_t n;
538 const void* x;
539 size_t x_stride;
540 const void* t;
541 void* y;
542 size_t y_stride;
543 x8lut_ukernel_function ukernel;
544};
545
546static void compute_lut_strided(
547 const struct lut_strided_context context[RESTRICT_STATIC 1],
548 size_t batch_index)
549{
550 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
551 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
552
553 context->ukernel(context->n, x, context->t, y);
554}
555
556struct lut_contiguous_context {
557 const void* x;
558 size_t x_stride;
559 const void* t;
560 void* y;
561 size_t y_stride;
562 x8lut_ukernel_function ukernel;
563};
564
565static void compute_lut_contiguous(
566 const struct lut_contiguous_context context[RESTRICT_STATIC 1],
567 size_t offset,
568 size_t size)
569{
570 const void* x = (const void*) ((uintptr_t) context->x + offset);
571 void* y = (void*) ((uintptr_t) context->y + offset);
572
573 context->ukernel(size, x, context->t, y);
574}
575
576struct clamp_strided_context {
577 size_t n;
578 const void* x;
579 size_t x_stride;
580 void* y;
581 size_t y_stride;
582 u8clamp_ukernel_function ukernel;
583 union qnnp_u8_clamping_params params;
584};
585
586static void compute_clamp_strided(
587 const struct clamp_strided_context context[RESTRICT_STATIC 1],
588 size_t batch_index)
589{
590 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
591 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
592 context->ukernel(context->n, x, y, &context->params);
593}
594
595struct clamp_contiguous_context {
596 const void* x;
597 size_t x_stride;
598 void* y;
599 size_t y_stride;
600 u8clamp_ukernel_function ukernel;
601 union qnnp_u8_clamping_params params;
602};
603
604static void compute_clamp_contiguous(
605 const struct clamp_contiguous_context context[RESTRICT_STATIC 1],
606 size_t offset,
607 size_t size)
608{
609 const void* x = (const void*) ((uintptr_t) context->x + offset);
610 void* y = (void*) ((uintptr_t) context->y + offset);
611 context->ukernel(size, x, y, &context->params);
612}
613
614struct u8softargmax_context {
615 size_t n;
616 const uint8_t* x;
617 size_t x_stride;
618 const uint32_t* t;
619 uint8_t* y;
620 size_t y_stride;
621 u8rmax_ukernel_function rmax_ukernel;
622 u8lut32norm_ukernel_function lut_norm_ukernel;
623};
624
625static void compute_u8softargmax(
626 const struct u8softargmax_context context[RESTRICT_STATIC 1],
627 size_t batch_index)
628{
629 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
630 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
631 const size_t n = context->n;
632
633 const uint8_t x_max = context->rmax_ukernel(n, x);
634 const size_t adjustment = x_max ^ 255;
635 const uint32_t* t = (const uint32_t*) context->t + adjustment;
636 context->lut_norm_ukernel(n, x, t, y);
637}
638
639enum qnnp_status qnnp_run_operator(qnnp_operator_t op, pthreadpool_t threadpool)
640{
641 // For any ukernel type, there is no work to do if the batch size is 0.
642 if (op->batch_size == 0) {
643 return qnnp_status_success;
644 }
645
646 switch (op->ukernel_type) {
647 case qnnp_ukernel_type_dwconv:
648 {
649 const size_t batch_size = op->batch_size;
650 const size_t groups = op->groups;
651 const size_t kernel_height = op->kernel_height;
652 const size_t kernel_width = op->kernel_width;
653 const size_t kernel_size = kernel_height * kernel_width;
654 const size_t width_step = op->dilation_width == 1 ? op->stride_width : op->kernel_width;
655 const size_t output_height = op->output_height;
656 const size_t output_width = op->output_width;
657
658 switch (kernel_size) {
659 case 9:
660 {
661 struct q8dwconv_context context = {
662 .groups = groups,
663 .indirection_buffer = (const uint8_t**) op->indirection_buffer,
664 .indirection_buffer_row_stride = kernel_size + (output_width * width_step - 1) * kernel_height,
665 .indirection_buffer_col_stride = kernel_height * width_step * sizeof(void*),
666 .packed_weights = op->packed_weights,
667 .output = op->output,
668 .output_height = output_height,
669 .output_width = output_width,
670 .output_row_stride = output_width * op->output_pixel_stride,
671 .output_col_increment = (op->output_pixel_stride - groups) * sizeof(uint8_t),
672 .quantization_params = op->conv_quantization_params,
673 .unipass_ukernel = qnnp_params.q8dw9.updw,
674 };
675 pthreadpool_compute_2d(
676 threadpool,
677 (pthreadpool_function_2d_t) compute_dwconv_unipass,
678 &context,
679 batch_size, output_height);
680 break;
681 }
682 case 25:
683 {
684 struct q8dwconv_context context = {
685 .groups = groups,
686 .group_stride = op->group_stride,
687 .indirection_buffer = (const uint8_t**) op->indirection_buffer,
688 .indirection_buffer_row_stride = kernel_size + (output_width * width_step - 1) * kernel_height,
689 .indirection_buffer_col_stride = kernel_height * width_step * sizeof(void*),
690 .packed_weights = op->packed_weights,
691 .output = op->output,
692 .output_height = output_height,
693 .output_width = output_width,
694 .output_row_stride = output_width * op->output_pixel_stride,
695 .output_col_increment = (op->output_pixel_stride - groups) * sizeof(uint8_t),
696 .quantization_params = op->conv_quantization_params,
697 .multipass_ukernel = qnnp_params.q8dw25.mpdw,
698 };
699 pthreadpool_compute_2d(
700 threadpool,
701 (pthreadpool_function_2d_t) compute_dwconv_multiipass,
702 &context,
703 batch_size, output_height);
704 break;
705 }
706 default:
707 QNNP_UNREACHABLE;
708 }
709 break;
710 }
711 case qnnp_ukernel_type_xzp_gemm:
712 {
713 const size_t batch_size = op->batch_size;
714 const size_t groups = op->groups;
715 const size_t group_input_channels = op->group_input_channels;
716 const size_t group_output_channels = op->group_output_channels;
717 const uint32_t mr = qnnp_params.q8conv_xzp.mr;
718 const uint32_t nr = qnnp_params.q8conv_xzp.nr;
719 const uint32_t kr = qnnp_params.q8conv_xzp.kr;
720 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
721 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
722
723 /* compute input row sum */
724 const size_t input_size = op->input_height * op->input_width;
725 int32_t* a_sum = (int32_t*) op->a_sum;
726
727 struct q8sum_rows_context context = {
728 .a = op->input,
729 .groups = groups,
730 .m = input_size,
731 .k = group_input_channels,
732 .a_stride = op->input_pixel_stride,
733 .multiplier = (int32_t) -op->kernel_zero_point,
734 .a_sum = a_sum,
735 .a_sum_stride = input_size,
736 .ukernel = qnnp_params.q8sum_rows.sum_rows,
737 };
738 pthreadpool_compute_3d_tiled(
739 threadpool,
740 (pthreadpool_function_3d_tiled_t) compute_sum_rows,
741 &context,
742 groups, batch_size, input_size,
743 1, 1, qnnp_params.q8sum_rows.m);
744
745 struct q8gemm_xzp_context q8gemm_xzp_context = {
746 .k = group_input_channels,
747 .k_stride = k_stride,
748 .n = group_output_channels,
749 .n_stride = n_stride,
750 .a = op->input,
751 .a_stride = op->input_pixel_stride,
752 .packed_w = op->packed_weights,
753 .c = op->output,
754 .c_stride = op->output_pixel_stride,
755 .a_sum = a_sum,
756 .groups = op->groups,
757 .batch_size = batch_size,
758 .a_sum_stride = input_size,
759 .requantization_params = op->requantization_params,
760 .ukernel = qnnp_params.q8conv_xzp.gemm,
761 };
762 pthreadpool_compute_4d_tiled(
763 threadpool,
764 (pthreadpool_function_4d_tiled_t) compute_q8gemm_xzp,
765 &q8gemm_xzp_context,
766 groups, batch_size * input_size, input_size, group_output_channels,
767 1, input_size, mr, nr);
768 break;
769 }
770 case qnnp_ukernel_type_gemm:
771 {
772 const size_t batch_size = op->batch_size;
773 const size_t groups = op->groups;
774 const size_t group_input_channels = op->group_input_channels;
775 const size_t group_output_channels = op->group_output_channels;
776 const uint32_t mr = qnnp_params.q8conv.mr;
777 const uint32_t nr = qnnp_params.q8conv.nr;
778 const uint32_t kr = qnnp_params.q8conv.kr;
779 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
780 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
781
782 const size_t output_size = op->output_height * op->output_width;
783 struct q8gemm_context q8gemm_context = {
784 .k = group_input_channels,
785 .k_stride = k_stride,
786 .n = group_output_channels,
787 .n_stride = n_stride,
788 .a = op->input,
789 .a_stride = op->input_pixel_stride,
790 .packed_w = op->packed_weights,
791 .c = op->output,
792 .c_stride = op->output_pixel_stride,
793 .quantization_params = op->conv_quantization_params,
794 .ukernel = qnnp_params.q8conv.gemm,
795 };
796
797 pthreadpool_compute_4d_tiled(
798 threadpool,
799 (pthreadpool_function_4d_tiled_t) compute_q8gemm,
800 &q8gemm_context,
801 groups, batch_size * output_size, output_size, group_output_channels,
802 1, output_size, mr, nr);
803 break;
804 }
805 case qnnp_ukernel_type_conv:
806 {
807 const size_t batch_size = op->batch_size;
808 const size_t groups = op->groups;
809 const size_t group_input_channels = op->group_input_channels;
810 const size_t group_output_channels = op->group_output_channels;
811 const uint32_t mr = qnnp_params.q8conv.mr;
812 const uint32_t nr = qnnp_params.q8conv.nr;
813 const uint32_t kr = qnnp_params.q8conv.kr;
814 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
815 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
816
817 const size_t output_size = op->output_height * op->output_width;
818 const size_t kernel_size = op->kernel_height * op->kernel_width;
819 const size_t m_stride = round_up(output_size, mr);
820 struct q8conv_context q8conv_context = {
821 .bs = batch_size,
822 .ks = kernel_size,
823 .kc = group_input_channels,
824 .kc_stride = k_stride * kernel_size,
825 .m = output_size,
826 .m_stride = m_stride,
827 .n = group_output_channels,
828 .n_stride = n_stride,
829 .indirect_a = (const uint8_t**) op->indirection_buffer,
830 .packed_w = op->packed_weights,
831 .c = op->output,
832 .c_stride = op->output_pixel_stride,
833 .quantization_params = op->conv_quantization_params,
834 .ukernel = qnnp_params.q8conv.conv,
835 };
836
837 pthreadpool_compute_4d_tiled(
838 threadpool,
839 (pthreadpool_function_4d_tiled_t) compute_q8conv,
840 &q8conv_context,
841 groups, batch_size, output_size, group_output_channels,
842 1, 1, mr, nr);
843 break;
844 }
845 case qnnp_ukernel_type_average_pooling:
846 {
847 const uint32_t kr = qnnp_params.q8avgpool.kr;
848 const uint32_t mr = qnnp_params.q8avgpool.mr;
849 const uint32_t qr = qnnp_params.q8avgpool.qr;
850 const size_t channels = op->channels;
851 const size_t output_width = op->output_width;
852 const size_t output_height = op->output_height;
853 const size_t pooling_height = op->kernel_height;
854 const size_t pooling_width = op->kernel_width;
855 const size_t pooling_size = pooling_height * pooling_width;
856
857 const size_t width_step = min(op->stride_width, pooling_width);
858 const size_t indirect_input_height_stride = (pooling_size + (output_width * width_step - 1) * pooling_height) * sizeof(void*);
859 const size_t output_height_stride = output_width * op->output_pixel_stride;
860
861 size_t multipass_adjustment = 0;
862 if (channels >= kr && pooling_size > mr) {
863 multipass_adjustment = round_up(pooling_size - mr, qr) + mr - qr;
864 }
865 struct average_pooling_context context = {
866 .indirect_input = op->indirection_buffer,
867 .indirect_input_batch_stride = output_height * indirect_input_height_stride,
868 .indirect_input_height_stride = indirect_input_height_stride,
869 .output = op->output,
870 .output_batch_stride = output_height * output_height_stride,
871 .output_height_stride = output_height_stride,
872 .output_width = output_width,
873 .pooling_size = pooling_size,
874 .channels = channels,
875 .packed_channels = (channels + (kr - 1)) & -kr,
876 .zero = op->zero_pointer,
877 .input_increment = (pooling_height * width_step - multipass_adjustment) * sizeof(void*),
878 .output_increment = (op->output_pixel_stride - channels) * sizeof(uint8_t),
879 .quantization_params = op->avgpool_quantization_params,
880 };
881
882 pthreadpool_function_2d_t compute_function = NULL;
883 if (channels < kr) {
884 compute_function = (pthreadpool_function_2d_t) compute_average_pooling_unipass;
885 context.unipass_ukernel = qnnp_params.q8avgpool.ltkr;
886 } else {
887 if (pooling_size <= mr) {
888 compute_function = (pthreadpool_function_2d_t) compute_average_pooling_unipass;
889 context.unipass_ukernel = qnnp_params.q8avgpool.gekr_lemr;
890 } else {
891 compute_function = (pthreadpool_function_2d_t) compute_average_pooling_multipass;
892 context.multipass_ukernel = qnnp_params.q8avgpool.gekr_gtmr;
893 }
894 }
895
896 pthreadpool_compute_2d(threadpool, compute_function, &context, op->batch_size, output_height);
897 break;
898 }
899 case qnnp_ukernel_type_max_pooling:
900 {
901 const uint32_t kr = qnnp_params.u8maxpool.kr;
902 const uint32_t mr = qnnp_params.u8maxpool.mr;
903 const uint32_t qr = qnnp_params.u8maxpool.qr;
904 const size_t channels = op->channels;
905 const size_t output_width = op->output_width;
906 const size_t output_height = op->output_height;
907 const size_t pooling_height = op->kernel_height;
908 const size_t pooling_width = op->kernel_width;
909 const size_t pooling_size = pooling_height * pooling_width;
910
911 const size_t width_step = op->dilation_width > 1 ? pooling_width : min(op->stride_width, pooling_width);
912 const size_t indirect_input_height_stride = (pooling_size + (output_width * width_step - 1) * pooling_height) * sizeof(void*);
913 const size_t output_height_stride = output_width * op->output_pixel_stride;
914
915 size_t multipass_adjustment = pooling_size;
916 if (channels >= kr) {
917 multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
918 }
919 struct max_pooling_context context = {
920 .indirect_input = op->indirection_buffer,
921 .indirect_input_batch_stride = output_height * indirect_input_height_stride,
922 .indirect_input_height_stride = indirect_input_height_stride,
923 .output = op->output,
924 .output_batch_stride = output_height * output_height_stride,
925 .output_height_stride = output_height_stride,
926 .output_width = output_width,
927 .pooling_size = pooling_size,
928 .channels = channels,
929 .input_increment = (pooling_height * width_step - multipass_adjustment) * sizeof(void*),
930 .output_increment = (op->output_pixel_stride - channels) * sizeof(uint8_t),
931 .params = op->u8_clamping_params,
932 .ukernel = channels < kr ? qnnp_params.u8maxpool.ltkr : qnnp_params.u8maxpool.gekr,
933 };
934
935 pthreadpool_compute_2d(threadpool,
936 (pthreadpool_function_2d_t) compute_max_pooling, &context,
937 op->batch_size, output_height);
938 break;
939 };
940 case qnnp_ukernel_type_add:
941 {
942 const size_t batch_size = op->batch_size;
943 const size_t channels = op->channels;
944 const size_t a_stride = op->input_pixel_stride;
945 const size_t b_stride = op->input2_pixel_stride;
946 const size_t y_stride = op->output_pixel_stride;
947 if ((((a_stride ^ channels) | (b_stride ^ channels) | (y_stride ^ channels)) == 0) || batch_size == 1) {
948 const size_t block_size = 4096;
949 struct q8add_contiguous_context add_context = {
950 .a = op->input,
951 .b = op->input2,
952 .y = op->output,
953 .quantization_params = op->add_quantization_params,
954 .ukernel = qnnp_params.q8vadd,
955 };
956 pthreadpool_compute_1d_tiled(
957 threadpool,
958 (pthreadpool_function_1d_tiled_t) compute_q8add_contiguous,
959 &add_context,
960 batch_size * channels * sizeof(uint8_t), block_size);
961 } else {
962 struct q8add_strided_context add_context = {
963 .a = op->input,
964 .a_stride = a_stride * sizeof(uint8_t),
965 .b = op->input2,
966 .b_stride = b_stride * sizeof(uint8_t),
967 .y = op->output,
968 .y_stride = y_stride * sizeof(uint8_t),
969 .n = channels,
970 .quantization_params = op->add_quantization_params,
971 .ukernel = qnnp_params.q8vadd,
972 };
973 pthreadpool_compute_1d_tiled(
974 threadpool,
975 (pthreadpool_function_1d_tiled_t) compute_q8add_strided,
976 &add_context,
977 batch_size, 1);
978 }
979 break;
980 }
981 case qnnp_ukernel_type_global_average_pooling:
982 {
983 const uint32_t nr = qnnp_params.q8gavgpool.nr;
984 const uint32_t mr = qnnp_params.q8gavgpool.mr;
985 const size_t input_pixel_stride = op->input_pixel_stride * sizeof(uint8_t);
986 const size_t input_width = op->input_width;
987 const size_t channels = op->channels;
988 struct global_average_pooling_context context = {
989 .input = op->input,
990 .zero = op->zero_pointer,
991 .input_pixel_stride = input_pixel_stride,
992 .input_batch_stride = input_pixel_stride * input_width,
993 .input_elements = input_width,
994 .channels = channels,
995 .packed_channels = (channels + (nr - 1)) & -nr,
996 .output = op->output,
997 .output_batch_stride = op->output_pixel_stride * sizeof(uint8_t),
998 .quantization_params = op->avgpool_quantization_params,
999 };
1000 pthreadpool_function_1d_t compute_function = NULL;
1001 if (channels < nr) {
1002 compute_function = (pthreadpool_function_1d_t) compute_global_average_pooling_unipass;
1003 context.unipass_ukernel = qnnp_params.q8gavgpool.ltnr;
1004 } else {
1005 if (input_width <= mr) {
1006 compute_function = (pthreadpool_function_1d_t) compute_global_average_pooling_unipass;
1007 context.unipass_ukernel = qnnp_params.q8gavgpool.genr_lemr;
1008 } else {
1009 compute_function = (pthreadpool_function_1d_t) compute_global_average_pooling_multipass;
1010 context.multipass_ukernel = qnnp_params.q8gavgpool.genr_gtmr;
1011 }
1012 }
1013
1014 pthreadpool_compute_1d(threadpool, compute_function, &context, op->batch_size);
1015 break;
1016 }
1017 case qnnp_ukernel_type_lut:
1018 {
1019 const size_t batch_size = op->batch_size;
1020 const size_t channels = op->channels;
1021 const size_t x_stride = op->input_pixel_stride;
1022 const size_t y_stride = op->output_pixel_stride;
1023 if ((((x_stride ^ channels) | (y_stride ^ channels)) == 0) || batch_size == 1) {
1024 const size_t block_size = 1024;
1025 struct lut_contiguous_context context = {
1026 .x = op->input,
1027 .x_stride = x_stride * sizeof(uint8_t),
1028 .t = op->lookup_table,
1029 .y = op->output,
1030 .y_stride = y_stride * sizeof(uint8_t),
1031 .ukernel = qnnp_params.x8lut,
1032 };
1033 pthreadpool_compute_1d_tiled(
1034 threadpool,
1035 (pthreadpool_function_1d_tiled_t) compute_lut_contiguous, &context,
1036 batch_size * channels * sizeof(uint8_t), block_size);
1037 } else {
1038 struct lut_strided_context context = {
1039 .n = channels,
1040 .x = op->input,
1041 .x_stride = x_stride * sizeof(uint8_t),
1042 .t = op->lookup_table,
1043 .y = op->output,
1044 .y_stride = y_stride * sizeof(uint8_t),
1045 .ukernel = qnnp_params.x8lut,
1046 };
1047 pthreadpool_compute_1d(
1048 threadpool,
1049 (pthreadpool_function_1d_t) compute_lut_strided, &context,
1050 batch_size);
1051 }
1052 break;
1053 }
1054 case qnnp_ukernel_type_clamp:
1055 {
1056 const size_t batch_size = op->batch_size;
1057 const size_t channels = op->channels;
1058 const size_t x_stride = op->input_pixel_stride;
1059 const size_t y_stride = op->output_pixel_stride;
1060 if ((((x_stride ^ channels) | (y_stride ^ channels)) == 0) || batch_size == 1) {
1061 const size_t block_size = 4096;
1062 struct clamp_contiguous_context context = {
1063 .x = op->input,
1064 .x_stride = x_stride * sizeof(uint8_t),
1065 .y = op->output,
1066 .y_stride = y_stride * sizeof(uint8_t),
1067 .ukernel = qnnp_params.u8clamp,
1068 .params = op->u8_clamping_params,
1069 };
1070 pthreadpool_compute_1d_tiled(
1071 threadpool,
1072 (pthreadpool_function_1d_tiled_t) compute_clamp_contiguous, &context,
1073 batch_size * channels * sizeof(uint8_t), block_size);
1074 } else {
1075 struct clamp_strided_context context = {
1076 .n = channels,
1077 .x = op->input,
1078 .x_stride = x_stride * sizeof(uint8_t),
1079 .y = op->output,
1080 .y_stride = y_stride * sizeof(uint8_t),
1081 .ukernel = qnnp_params.u8clamp,
1082 .params = op->u8_clamping_params,
1083 };
1084 pthreadpool_compute_1d(
1085 threadpool,
1086 (pthreadpool_function_1d_t) compute_clamp_strided, &context,
1087 batch_size);
1088 }
1089 break;
1090 }
1091 case qnnp_ukernel_type_softargmax:
1092 {
1093 struct u8softargmax_context context = {
1094 .n = op->channels,
1095 .x = op->input,
1096 .x_stride = op->input_pixel_stride * sizeof(uint8_t),
1097 .t = op->lookup_table,
1098 .y = op->output,
1099 .y_stride = op->output_pixel_stride * sizeof(uint8_t),
1100 .rmax_ukernel = qnnp_params.u8rmax,
1101 .lut_norm_ukernel = qnnp_params.u8lut32norm,
1102 };
1103 pthreadpool_compute_1d(
1104 threadpool,
1105 (pthreadpool_function_1d_t) compute_u8softargmax, &context,
1106 op->batch_size);
1107 break;
1108 }
1109 case qnnp_ukernel_type_channel_shuffle:
1110 {
1111 const size_t groups = op->groups;
1112 struct channel_shuffle_context channel_shuffle_context = {
1113 .x = op->input,
1114 .x_stride = op->input_pixel_stride * sizeof(uint8_t),
1115 .y = op->output,
1116 .y_stride = op->output_pixel_stride * sizeof(uint8_t),
1117 .n = op->group_channels * sizeof(uint8_t),
1118 .m = groups,
1119 };
1120 pthreadpool_function_1d_t compute_function = NULL;
1121 switch (groups) {
1122 case 2:
1123 compute_function = (pthreadpool_function_1d_t) compute_channel_shuffle_fixed;
1124 channel_shuffle_context.fixed_ukernel = qnnp_params.x8zip.x2;
1125 break;
1126 case 3:
1127 compute_function = (pthreadpool_function_1d_t) compute_channel_shuffle_fixed;
1128 channel_shuffle_context.fixed_ukernel = qnnp_params.x8zip.x3;
1129 break;
1130 case 4:
1131 compute_function = (pthreadpool_function_1d_t) compute_channel_shuffle_fixed;
1132 channel_shuffle_context.fixed_ukernel = qnnp_params.x8zip.x4;
1133 break;
1134 default:
1135 compute_function = (pthreadpool_function_1d_t) compute_channel_shuffle_variable;
1136 channel_shuffle_context.variable_ukernel = qnnp_params.x8zip.xm;
1137 break;
1138 case 0:
1139 case 1:
1140 QNNP_UNREACHABLE;
1141 }
1142 pthreadpool_compute_1d(
1143 threadpool,
1144 compute_function,
1145 &channel_shuffle_context,
1146 op->batch_size);
1147 break;
1148 }
1149 default:
1150 QNNP_UNREACHABLE;
1151 }
1152 return qnnp_status_success;
1153}
1154