1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12#include <string.h>
13
14#include <xnnpack.h>
15#include <xnnpack/allocator.h>
16#include <xnnpack/operator.h>
17#include <xnnpack/log.h>
18#include <xnnpack/common.h>
19#include <xnnpack/math.h>
20#include <xnnpack/microkernel-type.h>
21#include <xnnpack/params.h>
22#include <xnnpack/compute.h>
23
24
25void xnn_compute_transposec_2d(
26 const struct transpose_context* context,
27 size_t i,
28 size_t j,
29 size_t tile_i,
30 size_t tile_j)
31{
32 const size_t ld_input = context->input_stride[1];
33 const size_t ld_output = context->output_stride[0];
34 context->const_size_ukernel(
35 (const void*) ((uintptr_t) context->x + i * context->input_stride[0] + j * context->input_stride[1]),
36 (void*) ((uintptr_t) context->y + j * context->output_stride[1] + i * context->output_stride[0]),
37 ld_input,
38 ld_output,
39 tile_i,
40 tile_j,
41 &context->params);
42}
43
44void xnn_compute_transposec_3d(
45 const struct transpose_context* context,
46 size_t i,
47 size_t j,
48 size_t k,
49 size_t tile_j,
50 size_t tile_k)
51{
52 const size_t ld_input = context->input_stride[2];
53 const size_t ld_output = context->output_stride[1];
54 const void* x = (const void*) ((uintptr_t) context->x +
55 i * context->input_stride[0] + j * context->input_stride[1] + k * context->input_stride[2]);
56 void* y = (void*) ((uintptr_t) context->y + i * context->output_stride[0] + j * context->output_stride[1] +
57 k * context->output_stride[2]);
58
59 context->const_size_ukernel(
60 x,
61 y,
62 ld_input,
63 ld_output,
64 tile_j,
65 tile_k,
66 &context->params);
67}
68
69void xnn_compute_transposec_4d(
70 const struct transpose_context* context,
71 size_t i,
72 size_t j,
73 size_t k,
74 size_t l,
75 size_t tile_k,
76 size_t tile_l)
77{
78 const size_t ld_input = context->input_stride[3];
79 const size_t ld_output = context->output_stride[2];
80 const void* x = (const void*) ((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
81 k * context->input_stride[2] + l * context->input_stride[3]);
82 void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
83 k * context->output_stride[2] + l * context->output_stride[3]);
84
85 context->const_size_ukernel(
86 x,
87 y,
88 ld_input,
89 ld_output,
90 tile_k,
91 tile_l,
92 &context->params);
93}
94
95void xnn_compute_transposec_5d(
96 const struct transpose_context* context,
97 size_t i,
98 size_t j,
99 size_t k,
100 size_t l,
101 size_t m,
102 size_t tile_l,
103 size_t tile_m)
104{
105 const size_t ld_input = context->input_stride[4];
106 const size_t ld_output = context->output_stride[3];
107 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
108 k * context->input_stride[2] + l * context->input_stride[3] + m * context->input_stride[4]);
109 void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
110 k * context->output_stride[2] + l * context->output_stride[3] + m * context->output_stride[4]);
111
112 context->const_size_ukernel(
113 x,
114 y,
115 ld_input,
116 ld_output,
117 tile_l,
118 tile_m,
119 &context->params);
120}
121
122void xnn_compute_transposec_6d(
123 const struct transpose_context* context,
124 size_t i,
125 size_t j,
126 size_t k,
127 size_t l,
128 size_t m,
129 size_t n,
130 size_t tile_m,
131 size_t tile_n)
132{
133 const size_t ld_input = context->input_stride[5];
134 const size_t ld_output = context->output_stride[4];
135 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
136 k * context->input_stride[2] + l * context->input_stride[3] +
137 m * context->input_stride[4] + n * context->input_stride[5]);
138 void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
139 k * context->output_stride[2] + l * context->output_stride[3] + m * context->output_stride[4] +
140 n * context->output_stride[5]);
141
142 context->const_size_ukernel(
143 x,
144 y,
145 ld_input,
146 ld_output,
147 tile_m,
148 tile_n,
149 &context->params);
150}
151
152void xnn_compute_transposev_2d(
153 const struct transpose_context* context,
154 size_t i,
155 size_t j,
156 size_t tile_i,
157 size_t tile_j)
158{
159 const size_t element_size = context->output_stride[1];
160 const size_t ld_input = context->input_stride[1];
161 const size_t ld_output = context->output_stride[0];
162 const void* x = (const void*) ((uintptr_t) context->x +
163 i * context->input_stride[0] + j * context->input_stride[1]);
164 void* y = (void*) ((uintptr_t) context->y + context->output_stride[1] * j + i * context->output_stride[0]);
165
166 context->variable_size_ukernel(
167 x,
168 y,
169 ld_input,
170 ld_output,
171 context->input_stride[0],
172 context->output_stride[1],
173 element_size,
174 tile_i,
175 tile_j);
176}
177
178void xnn_compute_transposev_3d(
179 const struct transpose_context* context,
180 size_t i,
181 size_t j,
182 size_t k,
183 size_t tile_j,
184 size_t tile_k)
185{
186 const size_t element_size = context->output_stride[2];
187 const size_t ld_input = context->input_stride[2];
188 const size_t ld_output = context->output_stride[1];
189 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
190 k * context->input_stride[2]);
191 void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
192 k * context->output_stride[2]);
193
194 context->variable_size_ukernel(
195 x,
196 y,
197 ld_input,
198 ld_output,
199 context->input_stride[1],
200 context->output_stride[2],
201 element_size,
202 tile_j,
203 tile_k);
204}
205
206void xnn_compute_transposev_4d(
207 const struct transpose_context* context,
208 size_t i,
209 size_t j,
210 size_t k,
211 size_t l,
212 size_t tile_k,
213 size_t tile_l)
214{
215 const size_t element_size = context->output_stride[3];
216 const size_t ld_input = context->input_stride[3];
217 const size_t ld_output = context->output_stride[2];
218 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
219 k * context->input_stride[2] + l * context->input_stride[3]);
220 void* y = (void*)((uintptr_t)context->y + context->output_stride[3] * l + i * context->output_stride[0] +
221 j * context->output_stride[1] + k * context->output_stride[2]);
222
223 context->variable_size_ukernel(
224 x,
225 y,
226 ld_input,
227 ld_output,
228 context->input_stride[2],
229 context->output_stride[3],
230 element_size,
231 tile_k,
232 tile_l);
233}
234
235void xnn_compute_transposev_5d(
236 const struct transpose_context* context,
237 size_t i,
238 size_t j,
239 size_t k,
240 size_t l,
241 size_t m,
242 size_t tile_l,
243 size_t tile_m)
244{
245 const size_t element_size = context->output_stride[4];
246 const size_t ld_input = context->input_stride[4];
247 const size_t ld_output = context->output_stride[3];
248 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
249 k * context->input_stride[2] + l * context->input_stride[3] + m * context->input_stride[4]);
250 void* y = (void*)((uintptr_t)context->y + context->output_stride[4] * m + i * context->output_stride[0] +
251 j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3]);
252
253 context->variable_size_ukernel(
254 x,
255 y,
256 ld_input,
257 ld_output,
258 context->input_stride[3],
259 context->output_stride[4],
260 element_size,
261 tile_l,
262 tile_m);
263}
264
265void xnn_compute_transposev_6d(
266 const struct transpose_context* context,
267 size_t i,
268 size_t j,
269 size_t k,
270 size_t l,
271 size_t m,
272 size_t n,
273 size_t tile_m,
274 size_t tile_n)
275{
276 const size_t element_size = context->output_stride[5];
277 const size_t ld_input = context->input_stride[5];
278 const size_t ld_output = context->output_stride[4];
279 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
280 k * context->input_stride[2] + l * context->input_stride[3] +
281 m * context->input_stride[4] + n * context->input_stride[5]);
282 void* y = (void*)((uintptr_t)context->y + context->output_stride[5] * n + i * context->output_stride[0] +
283 j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3] +
284 m * context->output_stride[4]);
285
286 context->variable_size_ukernel(
287 x,
288 y,
289 ld_input,
290 ld_output,
291 context->input_stride[4],
292 context->output_stride[5],
293 element_size,
294 tile_m,
295 tile_n);
296}
297
298void xnn_compute_grouped_gemm(
299 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
300 size_t group_index,
301 size_t mr_block_start,
302 size_t nr_block_start,
303 size_t mr_block_size,
304 size_t nr_block_size)
305{
306 const size_t k_scaled = context->k_scaled;
307 const size_t a_stride = context->a_stride;
308 const size_t cm_stride = context->cm_stride;
309
310 context->ukernel.function[XNN_UARCH_DEFAULT](
311 mr_block_size,
312 nr_block_size,
313 k_scaled,
314 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
315 a_stride,
316 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
317 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
318 cm_stride,
319 context->cn_stride,
320 &context->params);
321}
322
323void xnn_compute_gemm(
324 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
325 size_t mr_block_start,
326 size_t nr_block_start,
327 size_t mr_block_size,
328 size_t nr_block_size)
329{
330 const size_t a_stride = context->a_stride;
331 const size_t cm_stride = context->cm_stride;
332
333 context->ukernel.function[XNN_UARCH_DEFAULT](
334 mr_block_size,
335 nr_block_size,
336 context->k_scaled,
337 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
338 a_stride,
339 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
340 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
341 cm_stride,
342 context->cn_stride,
343 context->fused_params);
344}
345
346void xnn_compute_spmm(
347 const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
348 size_t batch_index,
349 size_t mr_block_start,
350 size_t mr_block_size)
351{
352 context->ukernel(
353 mr_block_size,
354 context->n,
355 (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start),
356 context->nonzero_weights,
357 context->input_increments,
358 context->output_channel_nonzeros,
359 (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start),
360 context->scaled_m,
361 &context->params);
362}
363
364void xnn_compute_grouped_batch_igemm(
365 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
366 size_t batch_index,
367 size_t group_index,
368 size_t mr_block_start,
369 size_t nr_block_start,
370 size_t mr_block_size,
371 size_t nr_block_size)
372{
373 const size_t ks = context->ks;
374 const size_t cm_stride = context->cm_stride;
375
376 context->ukernel.function[XNN_UARCH_DEFAULT](
377 mr_block_size,
378 nr_block_size,
379 context->kc,
380 context->ks_scaled,
381 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
382 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
383 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
384 cm_stride,
385 context->cn_stride,
386 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
387 context->zero,
388 &context->params);
389}
390
391void xnn_compute_grouped_igemm(
392 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
393 size_t group_index,
394 size_t mr_block_start,
395 size_t nr_block_start,
396 size_t mr_block_size,
397 size_t nr_block_size)
398{
399 const size_t ks = context->ks;
400 const size_t cm_stride = context->cm_stride;
401
402 context->ukernel.function[XNN_UARCH_DEFAULT](
403 mr_block_size,
404 nr_block_size,
405 context->kc,
406 context->ks_scaled,
407 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
408 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
409 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
410 cm_stride,
411 context->cn_stride,
412 context->a_offset + group_index * context->ga_stride,
413 context->zero,
414 &context->params);
415}
416
417void xnn_compute_batch_igemm(
418 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
419 size_t batch_index,
420 size_t mr_block_start,
421 size_t nr_block_start,
422 size_t mr_block_size,
423 size_t nr_block_size)
424{
425 const size_t ks = context->ks;
426 const size_t cm_stride = context->cm_stride;
427
428 context->ukernel.function[XNN_UARCH_DEFAULT](
429 mr_block_size,
430 nr_block_size,
431 context->kc,
432 context->ks_scaled,
433 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
434 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
435 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
436 cm_stride,
437 context->cn_stride,
438 context->a_offset + batch_index * context->ba_stride,
439 context->zero,
440 &context->params);
441}
442
443void xnn_compute_igemm(
444 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
445 size_t mr_block_start,
446 size_t nr_block_start,
447 size_t mr_block_size,
448 size_t nr_block_size)
449{
450 const size_t ks = context->ks;
451 const size_t cm_stride = context->cm_stride;
452
453 context->ukernel.function[XNN_UARCH_DEFAULT](
454 mr_block_size,
455 nr_block_size,
456 context->kc,
457 context->ks_scaled,
458 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
459 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
460 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
461 cm_stride,
462 context->cn_stride,
463 context->a_offset,
464 context->zero,
465 &context->params);
466}
467
468void xnn_compute_grouped_subgemm2d(
469 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
470 size_t batch_index,
471 size_t group_index,
472 size_t subkernel_index,
473 size_t slice_y,
474 size_t slice_x_start,
475 size_t nc_block_start,
476 size_t slice_x_max,
477 size_t nc_block_size)
478{
479 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
480
481 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
482 return;
483 }
484
485 const size_t slice_width = subconvolution_params->slice_width;
486 if XNN_UNLIKELY(slice_x_start >= slice_width) {
487 return;
488 }
489 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
490
491 const size_t ax_stride = context->ax_stride;
492 const size_t cx_stride = context->cx_stride;
493 context->ukernel.function[XNN_UARCH_DEFAULT](
494 slice_x_size,
495 nc_block_size,
496 context->kc,
497 (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
498 ax_stride,
499 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
500 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
501 cx_stride,
502 context->cn_stride,
503 &context->params);
504}
505
506void xnn_compute_subgemm2d(
507 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
508 size_t batch_index,
509 size_t subkernel_index,
510 size_t slice_y,
511 size_t slice_x_start,
512 size_t nc_block_start,
513 size_t slice_x_max,
514 size_t nc_block_size)
515{
516 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
517
518 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
519 return;
520 }
521
522 const size_t slice_width = subconvolution_params->slice_width;
523 if XNN_UNLIKELY(slice_x_start >= slice_width) {
524 return;
525 }
526 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
527
528 const size_t ax_stride = context->ax_stride;
529 const size_t cx_stride = context->cx_stride;
530 context->ukernel.function[XNN_UARCH_DEFAULT](
531 slice_x_size,
532 nc_block_size,
533 context->kc,
534 (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
535 ax_stride,
536 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
537 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
538 cx_stride,
539 context->cn_stride,
540 &context->params);
541}
542
543void xnn_compute_grouped_subconv2d(
544 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
545 size_t batch_index,
546 size_t group_index,
547 size_t subkernel_index,
548 size_t slice_y,
549 size_t slice_x_start,
550 size_t nc_block_start,
551 size_t slice_x_max,
552 size_t nc_block_size)
553{
554 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
555
556 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
557 return;
558 }
559
560 const size_t slice_width = subconvolution_params->slice_width;
561 if XNN_UNLIKELY(slice_x_start >= slice_width) {
562 return;
563 }
564 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
565
566 const size_t cx_stride = context->cx_stride;
567 context->ukernel.function[XNN_UARCH_DEFAULT](
568 slice_x_size,
569 nc_block_size,
570 context->kc,
571 subconvolution_params->scaled_kernel_size,
572 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
573 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
574 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
575 cx_stride,
576 context->cn_stride,
577 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
578 context->zero,
579 &context->params);
580}
581
582void xnn_compute_subconv2d(
583 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
584 size_t batch_index,
585 size_t subkernel_index,
586 size_t slice_y,
587 size_t slice_x_start,
588 size_t nc_block_start,
589 size_t slice_x_max,
590 size_t nc_block_size)
591{
592 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
593
594 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
595 return;
596 }
597
598 const size_t slice_width = subconvolution_params->slice_width;
599 if XNN_UNLIKELY(slice_x_start >= slice_width) {
600 return;
601 }
602 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
603
604 const size_t cx_stride = context->cx_stride;
605 context->ukernel.function[XNN_UARCH_DEFAULT](
606 slice_x_size,
607 nc_block_size,
608 context->kc,
609 subconvolution_params->scaled_kernel_size,
610 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
611 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
612 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
613 cx_stride,
614 context->cn_stride,
615 context->a_offset + batch_index * context->ba_stride,
616 context->zero,
617 &context->params);
618}
619
620void xnn_compute_conv2d_hwc2chw(
621 const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
622 size_t batch_index,
623 size_t output_y_start,
624 size_t output_y_slice)
625{
626 context->hwc2chw_ukernel(
627 context->input_height,
628 context->input_width,
629 output_y_start,
630 output_y_start + output_y_slice,
631 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
632 context->zero,
633 context->packed_weights,
634 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
635 context->input_padding_top,
636 context->output_channels,
637 context->output_height_stride,
638 context->output_channel_stride,
639 &context->params);
640}
641
642void xnn_compute_dwconv_unipass(
643 const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
644 size_t batch_index,
645 size_t output_y)
646{
647 const void** indirect_input =
648 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
649 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
650 void* output = (void*) ((uintptr_t) context->output +
651 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
652
653 context->unipass_ukernel(
654 context->groups, context->output_width,
655 indirect_input, context->packed_weights, output,
656 context->indirect_input_width_stride, context->output_increment,
657 input_offset, context->zero,
658 &context->params);
659}
660
661void xnn_compute_dwconv2d_chw(
662 const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
663 size_t batch_index,
664 size_t channel)
665{
666 context->chw_ukernel(
667 context->input_height,
668 context->input_width,
669 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
670 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
671 context->zero,
672 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
673 context->input_padding_top,
674 &context->params);
675}
676
677void xnn_compute_argmax_pooling_unipass(
678 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
679 size_t batch_index,
680 size_t output_y)
681{
682 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
683 output_y * context->indirect_input_height_stride);
684 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
685 void* output = (void*) ((uintptr_t) context->output +
686 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
687 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
688 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
689
690 context->unipass_ukernel(
691 context->output_width, context->pooling_size, context->channels,
692 indirect_input, input_offset, output, index,
693 context->input_increment, context->output_increment);
694}
695
696void xnn_compute_argmax_pooling_multipass(
697 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
698 size_t batch_index,
699 size_t output_y)
700{
701 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
702 output_y * context->indirect_input_height_stride);
703 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
704 void* output = (void*) ((uintptr_t) context->output +
705 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
706 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
707 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
708
709 void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES);
710 void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES);
711
712 context->multipass_ukernel(
713 context->output_width, context->pooling_size, context->channels,
714 indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
715 context->input_increment, context->output_increment);
716}
717
718void xnn_compute_max_pooling(
719 const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
720 size_t batch_index,
721 size_t output_y)
722{
723 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
724 output_y * context->indirect_input_height_stride);
725 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
726 void* output = (void*) ((uintptr_t) context->output +
727 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
728
729 context->ukernel(
730 context->output_width, context->pooling_size, context->channels,
731 indirect_input, input_offset, output,
732 context->input_increment, context->output_increment,
733 &context->params);
734}
735
736void xnn_compute_unpooling(
737 const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
738 size_t input_y,
739 size_t input_x)
740{
741 const void* input = (const void*) ((uintptr_t) context->input +
742 input_y * context->input_height_stride + input_x * context->input_width_stride);
743 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
744 input_y * context->index_height_stride + input_x * context->index_width_stride);
745 void** indirect_output =
746 (void**) ((uintptr_t) context->indirect_output +
747 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
748
749 context->ukernel(
750 context->pooling_size,
751 context->channels,
752 context->fill_value,
753 input, index, indirect_output);
754}
755
756void xnn_compute_average_pooling_unipass(
757 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
758 size_t batch_index,
759 size_t output_y)
760{
761 const void** indirect_input =
762 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
763 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
764 void* output = (void*) ((uintptr_t) context->output +
765 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
766
767 context->unipass_ukernel(
768 context->output_width, context->pooling_size, context->channels,
769 indirect_input, input_offset, context->zero, output,
770 context->input_increment, context->output_increment,
771 &context->params);
772}
773
774void xnn_compute_average_pooling_multipass(
775 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
776 size_t batch_index,
777 size_t output_y)
778{
779 const void** indirect_input =
780 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
781 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
782 void* output = (void*) ((uintptr_t) context->output +
783 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
784
785 void* multipass_buffer =
786 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
787
788 context->multipass_ukernel(
789 context->output_width, context->pooling_size, context->channels,
790 indirect_input, input_offset, context->zero, multipass_buffer, output,
791 context->input_increment, context->output_increment,
792 &context->params);
793}
794
795void xnn_compute_pixelwise_average_pooling_unipass(
796 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
797 size_t batch_index,
798 size_t output_y)
799{
800 const void** indirect_input =
801 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
802 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
803 const void* pixelwise_buffer =
804 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
805 void* output = (void*) ((uintptr_t) context->output +
806 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
807
808 context->unipass_ukernel(
809 context->output_width, context->pooling_size, context->channels,
810 indirect_input, input_offset, context->zero, pixelwise_buffer, output,
811 context->input_increment, context->output_increment,
812 &context->params);
813}
814
815void xnn_compute_pixelwise_average_pooling_multipass(
816 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
817 size_t batch_index,
818 size_t output_y)
819{
820 const void** indirect_input =
821 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
822 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
823 const void* pixelwise_buffer =
824 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
825 void* output = (void*) ((uintptr_t) context->output +
826 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
827
828 void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
829
830 context->multipass_ukernel(
831 context->output_width, context->pooling_size, context->channels,
832 indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output,
833 context->input_increment, context->output_increment,
834 &context->params);
835}
836
837void xnn_compute_global_average_pooling_nwc_unipass(
838 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
839 size_t batch_index)
840{
841 const void* input =
842 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
843 void* output =
844 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
845
846 context->unipass_ukernel(
847 context->input_elements,
848 context->channels,
849 input,
850 context->input_pixel_stride,
851 context->zero,
852 output,
853 &context->params);
854}
855
856void xnn_compute_global_average_pooling_nwc_multipass(
857 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
858 size_t batch_index)
859{
860 const void* input =
861 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
862 void* output =
863 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
864
865 void* multipass_buffer =
866 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
867
868 context->multipass_ukernel(
869 context->input_elements,
870 context->channels,
871 input,
872 context->input_pixel_stride,
873 context->zero,
874 multipass_buffer,
875 output,
876 &context->params);
877}
878
879void xnn_compute_global_average_pooling_ncw(
880 const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
881 size_t batch_index,
882 size_t channels_start,
883 size_t channels_slice)
884{
885 const void* input = (const void*) ((uintptr_t) context->input +
886 channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
887 void* output = (void*) ((uintptr_t) context->output +
888 channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
889
890 context->ukernel(
891 context->input_elements,
892 channels_slice,
893 input,
894 output,
895 &context->params);
896}
897
898void xnn_compute_resize_bilinear(
899 const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
900 size_t batch_index,
901 size_t pixel_start,
902 size_t pixel_range)
903{
904 void* output =
905 (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
906
907 context->ukernel(
908 pixel_range,
909 context->scaled_channels,
910 context->indirect_input + pixel_start * 4,
911 context->input_offset + batch_index * context->input_batch_stride,
912 (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)),
913 output,
914 context->output_pixel_stride - context->scaled_channels);
915}
916
917void xnn_compute_resize_bilinear_chw(
918 const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
919 size_t batch_index,
920 size_t channel_start,
921 size_t channel_range)
922{
923 void* output =
924 (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride);
925 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride;
926
927 context->ukernel(
928 context->output_pixels,
929 channel_range,
930 context->indirect_input,
931 input_offset,
932 context->packed_weights,
933 output,
934 context->input_channel_stride);
935}
936
937void xnn_compute_prelu(
938 const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
939 size_t batch_start,
940 size_t batch_range)
941{
942 const size_t x_stride = context->x_stride;
943 const size_t y_stride = context->y_stride;
944 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
945 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
946
947 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride);
948}
949
950void xnn_compute_pad_5d(
951 const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
952 size_t i, size_t j, size_t k, size_t l, size_t m)
953{
954 const void* input = (const void*) ((uintptr_t) context->input +
955 i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]);
956 void* output = (void*) ((uintptr_t) context->output +
957 i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]);
958
959 const size_t i_padding = context->pre_paddings[5];
960 const size_t j_padding = context->pre_paddings[4];
961 const size_t k_padding = context->pre_paddings[3];
962 const size_t l_padding = context->pre_paddings[2];
963 const size_t m_padding = context->pre_paddings[1];
964
965 const size_t i_size = context->input_size[5];
966 const size_t j_size = context->input_size[4];
967 const size_t k_size = context->input_size[3];
968 const size_t l_size = context->input_size[2];
969 const size_t m_size = context->input_size[1];
970
971 if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size &&
972 l - l_padding < l_size && m - m_padding < m_size)
973 {
974 context->pad_ukernel(
975 1 /* rows */,
976 context->input_size[0], context->pre_paddings[0], context->post_paddings[0],
977 input, 0 /* input stride */, output, 0 /* output stride */,
978 context->padding_value);
979 } else {
980 context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, context->padding_value);
981 }
982}
983
984void xnn_compute_slice_1d(
985 const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
986 size_t i)
987{
988 const void* input = (const void*) ((uintptr_t) context->input + i * context->input_stride[0]);
989 void* output = (void*) ((uintptr_t) context->output + i * context->output_stride[0]);
990
991 context->ukernel(context->contiguous_size, input, output, NULL);
992}
993
994void xnn_compute_slice_2d(
995 const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
996 size_t i, size_t j)
997{
998 const void* input =
999 (const void*) ((uintptr_t) context->input +
1000 i * context->input_stride[1] +
1001 j * context->input_stride[0]);
1002 void* output =
1003 (void*) ((uintptr_t) context->output + i * context->output_stride[1] + j * context->output_stride[0]);
1004
1005 context->ukernel(context->contiguous_size, input, output, NULL);
1006}
1007
1008void xnn_compute_slice_3d(
1009 const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
1010 size_t i, size_t j, size_t k)
1011{
1012 const void* input =
1013 (const void*) ((uintptr_t) context->input +
1014 i * context->input_stride[2] +
1015 j * context->input_stride[1] +
1016 k * context->input_stride[0]);
1017 void* output =
1018 (void*) ((uintptr_t) context->output + i * context->output_stride[2] +
1019 j * context->output_stride[1] + k * context->output_stride[0]);
1020
1021 context->ukernel(context->contiguous_size, input, output, NULL);
1022}
1023
1024void xnn_compute_slice_4d(
1025 const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
1026 size_t i, size_t j, size_t k, size_t l)
1027{
1028 const void* input =
1029 (const void*) ((uintptr_t) context->input +
1030 i * context->input_stride[3] +
1031 j * context->input_stride[2] +
1032 k * context->input_stride[1] +
1033 l * context->input_stride[0]);
1034 void* output =
1035 (void*) ((uintptr_t) context->output + i * context->output_stride[3] +
1036 j * context->output_stride[2] + k * context->output_stride[1] + l * context->output_stride[0]);
1037
1038 context->ukernel(context->contiguous_size, input, output, NULL);
1039}
1040
1041void xnn_compute_slice_5d(
1042 const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
1043 size_t i, size_t j, size_t k, size_t l, size_t m)
1044{
1045 const void* input =
1046 (const void* ) ((uintptr_t) context->input +
1047 i * context->input_stride[4] +
1048 j * context->input_stride[3] +
1049 k * context->input_stride[2] +
1050 l * context->input_stride[1] +
1051 m * context->input_stride[0]);
1052 void* output =
1053 (void*) ((uintptr_t) context->output + i * context->output_stride[4] +
1054 j * context->output_stride[3] + k * context->output_stride[2] +
1055 l * context->output_stride[1] + m * context->output_stride[0]);
1056
1057 context->ukernel(context->contiguous_size, input, output, NULL);
1058}
1059
1060void xnn_compute_elementwise_binary_1d(
1061 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1062 size_t i)
1063{
1064 const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[4]);
1065 const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[4]);
1066 void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[4]);
1067 context->ukernel(context->elements, a, b, y, &context->params);
1068}
1069
1070void xnn_compute_elementwise_binary_2d(
1071 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1072 size_t i, size_t j)
1073{
1074 const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[3] + j * context->a_stride[4]);
1075 const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[3] + j * context->b_stride[4]);
1076 void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[3] + j * context->y_stride[4]);
1077 context->ukernel(context->elements, a, b, y, &context->params);
1078}
1079
1080void xnn_compute_elementwise_binary_3d(
1081 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1082 size_t i, size_t j, size_t k)
1083{
1084 const void* a = (const void*) ((uintptr_t) context->a +
1085 i * context->a_stride[2] + j * context->a_stride[3] + k * context->a_stride[4]);
1086 const void* b = (const void*) ((uintptr_t) context->b +
1087 i * context->b_stride[2] + j * context->b_stride[3] + k * context->b_stride[4]);
1088 void* y = (void*) ((uintptr_t) context->y +
1089 i * context->y_stride[2] + j * context->y_stride[3] + k * context->y_stride[4]);
1090 context->ukernel(context->elements, a, b, y, &context->params);
1091}
1092
1093void xnn_compute_elementwise_binary_4d(
1094 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1095 size_t i, size_t j, size_t k, size_t l)
1096{
1097 const void* a = (const void*) ((uintptr_t) context->a +
1098 i * context->a_stride[1] + j * context->a_stride[2] + k * context->a_stride[3] + l * context->a_stride[4]);
1099 const void* b = (const void*) ((uintptr_t) context->b +
1100 i * context->b_stride[1] + j * context->b_stride[2] + k * context->b_stride[3] + l * context->b_stride[4]);
1101 void* y = (void*) ((uintptr_t) context->y +
1102 i * context->y_stride[1] + j * context->y_stride[2] + k * context->y_stride[3] + l * context->y_stride[4]);
1103 context->ukernel(context->elements, a, b, y, &context->params);
1104}
1105
1106void xnn_compute_elementwise_binary_5d(
1107 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1108 size_t i, size_t j, size_t k, size_t l, size_t m)
1109{
1110 const void* a = (const void*) ((uintptr_t) context->a +
1111 i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
1112 const void* b = (const void*) ((uintptr_t) context->b +
1113 i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
1114 void* y = (void*) ((uintptr_t) context->y +
1115 i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
1116 context->ukernel(context->elements, a, b, y, &context->params);
1117}
1118
1119void xnn_compute_channel_shuffle_fixed(
1120 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
1121 size_t index)
1122{
1123 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
1124 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
1125
1126 context->fixed_ukernel(context->n, x, y);
1127}
1128
1129void xnn_compute_channel_shuffle_variable(
1130 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
1131 size_t index)
1132{
1133 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
1134 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
1135
1136 context->variable_ukernel(context->n, context->m, x, y);
1137}
1138
1139void xnn_compute_lut_strided(
1140 const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
1141 size_t batch_index)
1142{
1143 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
1144 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
1145
1146 context->ukernel(context->n, x, y, context->t);
1147}
1148
1149void xnn_compute_lut_contiguous(
1150 const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1151 size_t offset,
1152 size_t size)
1153{
1154 const void* x = (const void*) ((uintptr_t) context->x + offset);
1155 void* y = (void*) ((uintptr_t) context->y + offset);
1156
1157 context->ukernel(size, x, y, context->t);
1158}
1159
1160void xnn_compute_univector_strided(
1161 const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
1162 size_t batch_index,
1163 size_t batch_range)
1164{
1165 const size_t x_stride = context->x_stride;
1166 const size_t y_stride = context->y_stride;
1167
1168 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_index);
1169 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
1170 do {
1171 context->ukernel(context->n, x, y, &context->params);
1172 x = (const void*) ((uintptr_t) x + x_stride);
1173 y = (void*) ((uintptr_t) y + y_stride);
1174 } while (--batch_range != 0);
1175}
1176
1177void xnn_compute_univector_contiguous(
1178 const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1179 size_t offset,
1180 size_t size)
1181{
1182 const uint32_t log2_xsize = context->log2_xsize;
1183 const uint32_t log2_ysize = context->log2_ysize;
1184 const void* x = (const void*) ((uintptr_t) context->x + offset);
1185 void* y = (void*) ((uintptr_t) context->y + ((offset >> log2_xsize) << log2_ysize));
1186 context->ukernel(size, x, y, &context->params);
1187}
1188
1189void xnn_compute_u8_softmax(
1190 const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1191 size_t batch_index)
1192{
1193 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
1194 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
1195 const size_t n = context->n;
1196
1197 uint8_t x_max = 0;
1198 context->rmax_ukernel(n, x, &x_max);
1199 const size_t adjustment = x_max ^ 255;
1200 const uint32_t* t = (const uint32_t*) context->t + adjustment;
1201 context->lut_norm_ukernel(n, x, t, y);
1202}
1203
1204void xnn_compute_floating_point_softmax(
1205 const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1206 size_t batch_index)
1207{
1208 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
1209 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
1210 const size_t n = context->n;
1211
1212 // First pass: reduce-max
1213 union {
1214 float as_float;
1215 uint16_t as_half;
1216 } x_max;
1217 context->rmax_ukernel(n, x, &x_max);
1218
1219 // Second pass: reduce-add & store exp(x-x_max)
1220 union {
1221 float as_float;
1222 uint16_t as_half;
1223 } y_sum;
1224 context->raddstoreexpminusmax_ukernel(n, x, &x_max, y, &y_sum, &context->expminus_params);
1225
1226 // Third pass: scale y
1227 union {
1228 float as_float;
1229 uint16_t as_half;
1230 } y_scale;
1231 context->compute_reciprocal(&y_sum, &y_scale);
1232 context->vmulc_ukernel(n, y, &y_scale, y, &context->minmax_params);
1233}
1234
1235void xnn_compute_vmulcaddc(
1236 const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
1237 size_t batch_start,
1238 size_t batch_size)
1239{
1240 const size_t x_stride = context->x_stride;
1241 const size_t y_stride = context->y_stride;
1242
1243 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
1244 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
1245
1246 context->ukernel(
1247 batch_size,
1248 context->n,
1249 x, x_stride,
1250 context->w,
1251 y, y_stride,
1252 &context->params);
1253}
1254
1255#if XNN_MAX_UARCH_TYPES > 1
1256 void xnn_compute_hmp_grouped_gemm(
1257 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1258 uint32_t uarch_index,
1259 size_t group_index,
1260 size_t mr_block_start,
1261 size_t nr_block_start,
1262 size_t mr_block_size,
1263 size_t nr_block_size)
1264 {
1265 const size_t k_scaled = context->k_scaled;
1266 const size_t a_stride = context->a_stride;
1267 const size_t cm_stride = context->cm_stride;
1268
1269 context->ukernel.function[uarch_index](
1270 mr_block_size,
1271 nr_block_size,
1272 k_scaled,
1273 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
1274 a_stride,
1275 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
1276 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
1277 cm_stride,
1278 context->cn_stride,
1279 &context->params);
1280 }
1281
1282 void xnn_compute_hmp_gemm(
1283 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1284 uint32_t uarch_index,
1285 size_t mr_block_start,
1286 size_t nr_block_start,
1287 size_t mr_block_size,
1288 size_t nr_block_size)
1289 {
1290 const size_t a_stride = context->a_stride;
1291 const size_t cm_stride = context->cm_stride;
1292
1293 context->ukernel.function[uarch_index](
1294 mr_block_size,
1295 nr_block_size,
1296 context->k_scaled,
1297 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
1298 a_stride,
1299 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1300 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1301 cm_stride,
1302 context->cn_stride,
1303 context->fused_params);
1304 }
1305
1306 void xnn_compute_hmp_grouped_batch_igemm(
1307 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1308 uint32_t uarch_index,
1309 size_t batch_index,
1310 size_t group_index,
1311 size_t mr_block_start,
1312 size_t nr_block_start,
1313 size_t mr_block_size,
1314 size_t nr_block_size)
1315 {
1316 const size_t ks = context->ks;
1317 const size_t cm_stride = context->cm_stride;
1318
1319 context->ukernel.function[uarch_index](
1320 mr_block_size,
1321 nr_block_size,
1322 context->kc,
1323 context->ks_scaled,
1324 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1325 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
1326 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1327 cm_stride,
1328 context->cn_stride,
1329 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
1330 context->zero,
1331 &context->params);
1332 }
1333
1334 void xnn_compute_hmp_grouped_igemm(
1335 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1336 uint32_t uarch_index,
1337 size_t group_index,
1338 size_t mr_block_start,
1339 size_t nr_block_start,
1340 size_t mr_block_size,
1341 size_t nr_block_size)
1342 {
1343 const size_t ks = context->ks;
1344 const size_t cm_stride = context->cm_stride;
1345
1346 context->ukernel.function[uarch_index](
1347 mr_block_size,
1348 nr_block_size,
1349 context->kc,
1350 context->ks_scaled,
1351 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1352 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
1353 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1354 cm_stride,
1355 context->cn_stride,
1356 context->a_offset + group_index * context->ga_stride,
1357 context->zero,
1358 &context->params);
1359 }
1360
1361 void xnn_compute_batch_hmp_igemm(
1362 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1363 uint32_t uarch_index,
1364 size_t batch_index,
1365 size_t mr_block_start,
1366 size_t nr_block_start,
1367 size_t mr_block_size,
1368 size_t nr_block_size)
1369 {
1370 const size_t ks = context->ks;
1371 const size_t cm_stride = context->cm_stride;
1372
1373 context->ukernel.function[uarch_index](
1374 mr_block_size,
1375 nr_block_size,
1376 context->kc,
1377 context->ks_scaled,
1378 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1379 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1380 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1381 cm_stride,
1382 context->cn_stride,
1383 context->a_offset + batch_index * context->ba_stride,
1384 context->zero,
1385 &context->params);
1386 }
1387
1388 void xnn_compute_hmp_igemm(
1389 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1390 uint32_t uarch_index,
1391 size_t mr_block_start,
1392 size_t nr_block_start,
1393 size_t mr_block_size,
1394 size_t nr_block_size)
1395 {
1396 const size_t ks = context->ks;
1397 const size_t cm_stride = context->cm_stride;
1398
1399 context->ukernel.function[uarch_index](
1400 mr_block_size,
1401 nr_block_size,
1402 context->kc,
1403 context->ks_scaled,
1404 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1405 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1406 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1407 cm_stride,
1408 context->cn_stride,
1409 context->a_offset,
1410 context->zero,
1411 &context->params);
1412 }
1413#endif // XNN_MAX_UARCH_TYPES > 1
1414
1415
1416enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
1417{
1418 return xnn_run_operator_with_index(op, 0, 0, threadpool);
1419}
1420
1421enum xnn_status xnn_run_operator_with_index(
1422 xnn_operator_t op,
1423 size_t opdata_index,
1424 size_t operator_object_index,
1425 pthreadpool_t threadpool)
1426{
1427 switch (op->state) {
1428 case xnn_run_state_invalid:
1429 xnn_log_error("failed to run operator: operator was not successfully setup");
1430 return xnn_status_invalid_state;
1431 case xnn_run_state_ready:
1432 xnn_log_debug("running operator %zu:%zu (%s %s)", opdata_index,
1433 operator_object_index,
1434 xnn_operator_type_to_string(op->type),
1435 xnn_microkernel_type_to_string(op->ukernel.type));
1436 break;
1437 case xnn_run_state_skip:
1438 xnn_log_debug("skip running operator %zu:%zu (%s %s)", opdata_index,
1439 operator_object_index,
1440 xnn_operator_type_to_string(op->type),
1441 xnn_microkernel_type_to_string(op->ukernel.type));
1442 return xnn_status_success;
1443 }
1444
1445 uint32_t flags = PTHREADPOOL_FLAG_DISABLE_DENORMALS;
1446 if (op->flags & XNN_FLAG_YIELD_WORKERS) {
1447 flags |= PTHREADPOOL_FLAG_YIELD_WORKERS;
1448 }
1449 switch (op->compute.type) {
1450 case xnn_parallelization_type_invalid:
1451 break;
1452 case xnn_parallelization_type_1d:
1453 assert(op->compute.range[0] != 0);
1454 pthreadpool_parallelize_1d(
1455 threadpool,
1456 op->compute.task_1d,
1457 &op->context,
1458 op->compute.range[0],
1459 flags);
1460 break;
1461 case xnn_parallelization_type_1d_tile_1d:
1462 assert(op->compute.range[0] != 0);
1463 assert(op->compute.tile[0] != 0);
1464 pthreadpool_parallelize_1d_tile_1d(
1465 threadpool,
1466 op->compute.task_1d_tile_1d,
1467 &op->context,
1468 op->compute.range[0],
1469 op->compute.tile[0],
1470 flags);
1471 break;
1472 case xnn_parallelization_type_2d:
1473 assert(op->compute.range[0] != 0);
1474 assert(op->compute.range[1] != 0);
1475 pthreadpool_parallelize_2d(
1476 threadpool,
1477 op->compute.task_2d,
1478 &op->context,
1479 op->compute.range[0], op->compute.range[1],
1480 flags);
1481 break;
1482 case xnn_parallelization_type_2d_tile_1d:
1483 assert(op->compute.range[0] != 0);
1484 assert(op->compute.range[1] != 0);
1485 assert(op->compute.tile[0] != 0);
1486 pthreadpool_parallelize_2d_tile_1d(
1487 threadpool,
1488 op->compute.task_2d_tile_1d,
1489 &op->context,
1490 op->compute.range[0], op->compute.range[1],
1491 op->compute.tile[0],
1492 flags);
1493 break;
1494 case xnn_parallelization_type_2d_tile_2d:
1495 assert(op->compute.range[0] != 0);
1496 assert(op->compute.range[1] != 0);
1497 assert(op->compute.tile[0] != 0);
1498 assert(op->compute.tile[1] != 0);
1499 pthreadpool_parallelize_2d_tile_2d(
1500 threadpool,
1501 op->compute.task_2d_tile_2d,
1502 &op->context,
1503 op->compute.range[0], op->compute.range[1],
1504 op->compute.tile[0], op->compute.tile[1],
1505 flags);
1506 break;
1507 case xnn_parallelization_type_3d:
1508 assert(op->compute.range[0] != 0);
1509 assert(op->compute.range[1] != 0);
1510 assert(op->compute.range[2] != 0);
1511 pthreadpool_parallelize_3d(
1512 threadpool,
1513 op->compute.task_3d,
1514 &op->context,
1515 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1516 flags);
1517 break;
1518 case xnn_parallelization_type_3d_tile_2d:
1519 assert(op->compute.range[0] != 0);
1520 assert(op->compute.range[1] != 0);
1521 assert(op->compute.range[2] != 0);
1522 assert(op->compute.tile[0] != 0);
1523 assert(op->compute.tile[1] != 0);
1524 pthreadpool_parallelize_3d_tile_2d(
1525 threadpool,
1526 op->compute.task_3d_tile_2d,
1527 &op->context,
1528 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1529 op->compute.tile[0], op->compute.tile[1],
1530 flags);
1531 break;
1532 case xnn_parallelization_type_4d:
1533 assert(op->compute.range[0] != 0);
1534 assert(op->compute.range[1] != 0);
1535 assert(op->compute.range[2] != 0);
1536 assert(op->compute.range[3] != 0);
1537 pthreadpool_parallelize_4d(
1538 threadpool,
1539 op->compute.task_4d,
1540 &op->context,
1541 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1542 flags);
1543 break;
1544 case xnn_parallelization_type_4d_tile_2d:
1545 assert(op->compute.range[0] != 0);
1546 assert(op->compute.range[1] != 0);
1547 assert(op->compute.range[2] != 0);
1548 assert(op->compute.range[3] != 0);
1549 assert(op->compute.tile[0] != 0);
1550 assert(op->compute.tile[1] != 0);
1551 pthreadpool_parallelize_4d_tile_2d(
1552 threadpool,
1553 op->compute.task_4d_tile_2d,
1554 &op->context,
1555 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1556 op->compute.tile[0], op->compute.tile[1],
1557 flags);
1558 break;
1559 case xnn_parallelization_type_5d:
1560 assert(op->compute.range[0] != 0);
1561 assert(op->compute.range[1] != 0);
1562 assert(op->compute.range[2] != 0);
1563 assert(op->compute.range[3] != 0);
1564 assert(op->compute.range[4] != 0);
1565 pthreadpool_parallelize_5d(
1566 threadpool,
1567 op->compute.task_5d,
1568 &op->context,
1569 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1570 flags);
1571 break;
1572 case xnn_parallelization_type_5d_tile_2d:
1573 assert(op->compute.range[0] != 0);
1574 assert(op->compute.range[1] != 0);
1575 assert(op->compute.range[2] != 0);
1576 assert(op->compute.range[3] != 0);
1577 assert(op->compute.range[4] != 0);
1578 assert(op->compute.tile[0] != 0);
1579 assert(op->compute.tile[1] != 0);
1580 pthreadpool_parallelize_5d_tile_2d(
1581 threadpool,
1582 op->compute.task_5d_tile_2d,
1583 &op->context,
1584 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1585 op->compute.tile[0], op->compute.tile[1],
1586 flags);
1587 break;
1588 case xnn_parallelization_type_6d_tile_2d:
1589 assert(op->compute.range[0] != 0);
1590 assert(op->compute.range[1] != 0);
1591 assert(op->compute.range[2] != 0);
1592 assert(op->compute.range[3] != 0);
1593 assert(op->compute.range[4] != 0);
1594 assert(op->compute.range[5] != 0);
1595 assert(op->compute.tile[0] != 0);
1596 assert(op->compute.tile[1] != 0);
1597 pthreadpool_parallelize_6d_tile_2d(
1598 threadpool,
1599 op->compute.task_6d_tile_2d,
1600 &op->context,
1601 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
1602 op->compute.tile[0], op->compute.tile[1],
1603 flags);
1604 break;
1605#if XNN_MAX_UARCH_TYPES > 1
1606 case xnn_parallelization_type_2d_tile_2d_with_uarch:
1607 assert(op->compute.range[0] != 0);
1608 assert(op->compute.range[1] != 0);
1609 assert(op->compute.tile[0] != 0);
1610 assert(op->compute.tile[1] != 0);
1611 pthreadpool_parallelize_2d_tile_2d_with_uarch(
1612 threadpool,
1613 op->compute.task_2d_tile_2d_with_id,
1614 &op->context,
1615 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1616 op->compute.range[0], op->compute.range[1],
1617 op->compute.tile[0], op->compute.tile[1],
1618 flags);
1619 break;
1620 case xnn_parallelization_type_3d_tile_2d_with_uarch:
1621 assert(op->compute.range[0] != 0);
1622 assert(op->compute.range[1] != 0);
1623 assert(op->compute.range[2] != 0);
1624 assert(op->compute.tile[0] != 0);
1625 assert(op->compute.tile[1] != 0);
1626 pthreadpool_parallelize_3d_tile_2d_with_uarch(
1627 threadpool,
1628 op->compute.task_3d_tile_2d_with_id,
1629 &op->context,
1630 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1631 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1632 op->compute.tile[0], op->compute.tile[1],
1633 flags);
1634 break;
1635 case xnn_parallelization_type_4d_tile_2d_with_uarch:
1636 assert(op->compute.range[0] != 0);
1637 assert(op->compute.range[1] != 0);
1638 assert(op->compute.range[2] != 0);
1639 assert(op->compute.range[3] != 0);
1640 assert(op->compute.tile[0] != 0);
1641 assert(op->compute.tile[1] != 0);
1642 pthreadpool_parallelize_4d_tile_2d_with_uarch(
1643 threadpool,
1644 op->compute.task_4d_tile_2d_with_id,
1645 &op->context,
1646 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1647 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1648 op->compute.tile[0], op->compute.tile[1],
1649 flags);
1650 break;
1651#endif // XNN_MAX_UARCH_TYPES > 1
1652 default:
1653 XNN_UNREACHABLE;
1654 }
1655 return xnn_status_success;
1656}
1657