1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // Copyright 2019 Google LLC |
5 | // |
6 | // This source code is licensed under the BSD-style license found in the |
7 | // LICENSE file in the root directory of this source tree. |
8 | |
9 | #include <assert.h> |
10 | #include <stddef.h> |
11 | #include <stdint.h> |
12 | #include <string.h> |
13 | |
14 | #include <xnnpack.h> |
15 | #include <xnnpack/allocator.h> |
16 | #include <xnnpack/operator.h> |
17 | #include <xnnpack/log.h> |
18 | #include <xnnpack/common.h> |
19 | #include <xnnpack/math.h> |
20 | #include <xnnpack/params.h> |
21 | #include <xnnpack/compute.h> |
22 | |
23 | |
24 | void xnn_compute_transposec_2d( |
25 | const struct transpose_context* context, |
26 | size_t i, |
27 | size_t j, |
28 | size_t tile_i, |
29 | size_t tile_j) |
30 | { |
31 | const size_t log2_element_size = context->log2_element_size; |
32 | |
33 | context->const_size_ukernel( |
34 | (const void*) ((uintptr_t) context->x + (i << log2_element_size) + j * context->input_stride[1]), |
35 | (void*) ((uintptr_t) context->y + (j << log2_element_size) + i * context->output_stride[0]), |
36 | context->input_stride[1], |
37 | context->output_stride[0], |
38 | tile_i, |
39 | tile_j); |
40 | } |
41 | |
42 | void xnn_compute_transposec_3d( |
43 | const struct transpose_context* context, |
44 | size_t i, |
45 | size_t j, |
46 | size_t k, |
47 | size_t tile_j, |
48 | size_t tile_k) |
49 | { |
50 | const size_t log2_element_size = context->log2_element_size; |
51 | const size_t ld_input = context->input_stride[2]; |
52 | const size_t ld_output = context->output_stride[1]; |
53 | const void* x = (const void*) ((uintptr_t) context->x + |
54 | (i * context->input_stride[0] + j * context->input_stride[1]) + k * ld_input); |
55 | void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
56 | (k << log2_element_size)); |
57 | |
58 | context->const_size_ukernel( |
59 | x, |
60 | y, |
61 | ld_input, |
62 | ld_output, |
63 | tile_j, |
64 | tile_k); |
65 | } |
66 | |
67 | void xnn_compute_transposec_4d( |
68 | const struct transpose_context* context, |
69 | size_t i, |
70 | size_t j, |
71 | size_t k, |
72 | size_t l, |
73 | size_t tile_k, |
74 | size_t tile_l) |
75 | { |
76 | const size_t log2_element_size = context->log2_element_size; |
77 | const size_t ld_input = context->input_stride[3]; |
78 | const size_t ld_output = context->output_stride[2]; |
79 | const void* x = (const void*) ((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
80 | k * context->input_stride[2] + l * ld_input); |
81 | void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
82 | k * context->output_stride[2] + (l << log2_element_size)); |
83 | |
84 | context->const_size_ukernel( |
85 | x, |
86 | y, |
87 | ld_input, |
88 | ld_output, |
89 | tile_k, |
90 | tile_l); |
91 | } |
92 | |
93 | void xnn_compute_transposec_5d( |
94 | const struct transpose_context* context, |
95 | size_t i, |
96 | size_t j, |
97 | size_t k, |
98 | size_t l, |
99 | size_t m, |
100 | size_t tile_l, |
101 | size_t tile_m) |
102 | { |
103 | const size_t log2_element_size = context->log2_element_size; |
104 | const size_t ld_input = context->input_stride[4]; |
105 | const size_t ld_output = context->output_stride[3]; |
106 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
107 | k * context->input_stride[2] + l * context->input_stride[3] + m * ld_input); |
108 | void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
109 | k * context->output_stride[2] + l * context->output_stride[3] + (m << log2_element_size)); |
110 | |
111 | context->const_size_ukernel( |
112 | x, |
113 | y, |
114 | ld_input, |
115 | ld_output, |
116 | tile_l, |
117 | tile_m); |
118 | } |
119 | |
120 | void xnn_compute_transposec_6d( |
121 | const struct transpose_context* context, |
122 | size_t i, |
123 | size_t j, |
124 | size_t k, |
125 | size_t l, |
126 | size_t m, |
127 | size_t n, |
128 | size_t tile_m, |
129 | size_t tile_n) |
130 | { |
131 | const size_t log2_element_size = context->log2_element_size; |
132 | const size_t ld_input = context->input_stride[5]; |
133 | const size_t ld_output = context->output_stride[4]; |
134 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
135 | k * context->input_stride[2] + l * context->input_stride[3] + |
136 | m * context->input_stride[4] + n * ld_input); |
137 | void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
138 | k * context->output_stride[2] + l * context->output_stride[3] + m * context->output_stride[4] + |
139 | (n << log2_element_size)); |
140 | |
141 | context->const_size_ukernel( |
142 | x, |
143 | y, |
144 | ld_input, |
145 | ld_output, |
146 | tile_m, |
147 | tile_n); |
148 | } |
149 | |
150 | void xnn_compute_transposev_2d( |
151 | const struct transpose_context* context, |
152 | size_t i, |
153 | size_t j, |
154 | size_t tile_i, |
155 | size_t tile_j) |
156 | { |
157 | const size_t element_size = context->element_size; |
158 | const size_t ld_input = context->input_stride[1]; |
159 | const size_t ld_output = context->output_stride[0]; |
160 | const void* x = (const void*) ((uintptr_t) context->x + |
161 | i * context->input_stride[0] + j * ld_input); |
162 | void* y = (void*) ((uintptr_t) context->y + context->output_stride[1] * j + i * context->output_stride[0]); |
163 | |
164 | context->variable_size_ukernel( |
165 | x, |
166 | y, |
167 | ld_input, |
168 | ld_output, |
169 | context->input_stride[0], |
170 | context->output_stride[1], |
171 | element_size, |
172 | tile_i, |
173 | tile_j); |
174 | } |
175 | |
176 | void xnn_compute_transposev_3d( |
177 | const struct transpose_context* context, |
178 | size_t i, |
179 | size_t j, |
180 | size_t k, |
181 | size_t tile_j, |
182 | size_t tile_k) |
183 | { |
184 | const size_t element_size = context->element_size; |
185 | const size_t ld_input = context->input_stride[2]; |
186 | const size_t ld_output = context->output_stride[1]; |
187 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
188 | k * ld_input); |
189 | void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
190 | k * context->output_stride[2]); |
191 | |
192 | context->variable_size_ukernel( |
193 | x, |
194 | y, |
195 | ld_input, |
196 | ld_output, |
197 | context->input_stride[1], |
198 | context->output_stride[2], |
199 | element_size, |
200 | tile_j, |
201 | tile_k); |
202 | } |
203 | |
204 | void xnn_compute_transposev_4d( |
205 | const struct transpose_context* context, |
206 | size_t i, |
207 | size_t j, |
208 | size_t k, |
209 | size_t l, |
210 | size_t tile_k, |
211 | size_t tile_l) |
212 | { |
213 | const size_t element_size = context->element_size; |
214 | const size_t ld_input = context->input_stride[3]; |
215 | const size_t ld_output = context->output_stride[2]; |
216 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
217 | k * context->input_stride[2] + l * ld_input); |
218 | void* y = (void*)((uintptr_t)context->y + context->output_stride[3] * l + i * context->output_stride[0] + |
219 | j * context->output_stride[1] + k * context->output_stride[2]); |
220 | |
221 | context->variable_size_ukernel( |
222 | x, |
223 | y, |
224 | ld_input, |
225 | ld_output, |
226 | context->input_stride[2], |
227 | context->output_stride[3], |
228 | element_size, |
229 | tile_k, |
230 | tile_l); |
231 | } |
232 | |
233 | void xnn_compute_transposev_5d( |
234 | const struct transpose_context* context, |
235 | size_t i, |
236 | size_t j, |
237 | size_t k, |
238 | size_t l, |
239 | size_t m, |
240 | size_t tile_l, |
241 | size_t tile_m) |
242 | { |
243 | const size_t element_size = context->element_size; |
244 | const size_t ld_input = context->input_stride[4]; |
245 | const size_t ld_output = context->output_stride[3]; |
246 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
247 | k * context->input_stride[2] + l * context->input_stride[3] + m * ld_input); |
248 | void* y = (void*)((uintptr_t)context->y + context->output_stride[4] * m + i * context->output_stride[0] + |
249 | j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3]); |
250 | |
251 | context->variable_size_ukernel( |
252 | x, |
253 | y, |
254 | ld_input, |
255 | ld_output, |
256 | context->input_stride[3], |
257 | context->output_stride[4], |
258 | element_size, |
259 | tile_l, |
260 | tile_m); |
261 | } |
262 | |
263 | void xnn_compute_transposev_6d( |
264 | const struct transpose_context* context, |
265 | size_t i, |
266 | size_t j, |
267 | size_t k, |
268 | size_t l, |
269 | size_t m, |
270 | size_t n, |
271 | size_t tile_m, |
272 | size_t tile_n) |
273 | { |
274 | const size_t element_size = context->element_size; |
275 | const size_t ld_input = context->input_stride[5]; |
276 | const size_t ld_output = context->output_stride[4]; |
277 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
278 | k * context->input_stride[2] + l * context->input_stride[3] + |
279 | m * context->input_stride[4] + n * ld_input); |
280 | void* y = (void*)((uintptr_t)context->y + context->output_stride[5] * n + i * context->output_stride[0] + |
281 | j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3] + |
282 | m * context->output_stride[4]); |
283 | |
284 | context->variable_size_ukernel( |
285 | x, |
286 | y, |
287 | ld_input, |
288 | ld_output, |
289 | context->input_stride[4], |
290 | context->output_stride[5], |
291 | element_size, |
292 | tile_m, |
293 | tile_n); |
294 | } |
295 | |
296 | void xnn_compute_grouped_gemm( |
297 | const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
298 | size_t group_index, |
299 | size_t mr_block_start, |
300 | size_t nr_block_start, |
301 | size_t mr_block_size, |
302 | size_t nr_block_size) |
303 | { |
304 | const size_t k_scaled = context->k_scaled; |
305 | const size_t a_stride = context->a_stride; |
306 | const size_t cm_stride = context->cm_stride; |
307 | |
308 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
309 | mr_block_size, |
310 | nr_block_size, |
311 | k_scaled, |
312 | (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled), |
313 | a_stride, |
314 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride), |
315 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride), |
316 | cm_stride, |
317 | context->cn_stride, |
318 | &context->params); |
319 | } |
320 | |
321 | void xnn_compute_gemm( |
322 | const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
323 | size_t mr_block_start, |
324 | size_t nr_block_start, |
325 | size_t mr_block_size, |
326 | size_t nr_block_size) |
327 | { |
328 | const size_t a_stride = context->a_stride; |
329 | const size_t cm_stride = context->cm_stride; |
330 | |
331 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
332 | mr_block_size, |
333 | nr_block_size, |
334 | context->k_scaled, |
335 | (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), |
336 | a_stride, |
337 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
338 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
339 | cm_stride, |
340 | context->cn_stride, |
341 | context->fused_params); |
342 | } |
343 | |
344 | void xnn_compute_spmm( |
345 | const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], |
346 | size_t batch_index, |
347 | size_t mr_block_start, |
348 | size_t mr_block_size) |
349 | { |
350 | context->ukernel( |
351 | mr_block_size, |
352 | context->n, |
353 | (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start), |
354 | context->nonzero_weights, |
355 | context->input_increments, |
356 | context->output_channel_nonzeros, |
357 | (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start), |
358 | context->scaled_m, |
359 | &context->params); |
360 | } |
361 | |
362 | void xnn_compute_grouped_batch_igemm( |
363 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
364 | size_t batch_index, |
365 | size_t group_index, |
366 | size_t mr_block_start, |
367 | size_t nr_block_start, |
368 | size_t mr_block_size, |
369 | size_t nr_block_size) |
370 | { |
371 | const size_t ks = context->ks; |
372 | const size_t cm_stride = context->cm_stride; |
373 | |
374 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
375 | mr_block_size, |
376 | nr_block_size, |
377 | context->kc, |
378 | context->ks_scaled, |
379 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
380 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), |
381 | (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
382 | cm_stride, |
383 | context->cn_stride, |
384 | context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, |
385 | context->zero, |
386 | &context->params); |
387 | } |
388 | |
389 | void xnn_compute_grouped_igemm( |
390 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
391 | size_t group_index, |
392 | size_t mr_block_start, |
393 | size_t nr_block_start, |
394 | size_t mr_block_size, |
395 | size_t nr_block_size) |
396 | { |
397 | const size_t ks = context->ks; |
398 | const size_t cm_stride = context->cm_stride; |
399 | |
400 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
401 | mr_block_size, |
402 | nr_block_size, |
403 | context->kc, |
404 | context->ks_scaled, |
405 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
406 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), |
407 | (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
408 | cm_stride, |
409 | context->cn_stride, |
410 | context->a_offset + group_index * context->ga_stride, |
411 | context->zero, |
412 | &context->params); |
413 | } |
414 | |
415 | void xnn_compute_batch_igemm( |
416 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
417 | size_t batch_index, |
418 | size_t mr_block_start, |
419 | size_t nr_block_start, |
420 | size_t mr_block_size, |
421 | size_t nr_block_size) |
422 | { |
423 | const size_t ks = context->ks; |
424 | const size_t cm_stride = context->cm_stride; |
425 | |
426 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
427 | mr_block_size, |
428 | nr_block_size, |
429 | context->kc, |
430 | context->ks_scaled, |
431 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
432 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
433 | (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
434 | cm_stride, |
435 | context->cn_stride, |
436 | context->a_offset + batch_index * context->ba_stride, |
437 | context->zero, |
438 | &context->params); |
439 | } |
440 | |
441 | void xnn_compute_igemm( |
442 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
443 | size_t mr_block_start, |
444 | size_t nr_block_start, |
445 | size_t mr_block_size, |
446 | size_t nr_block_size) |
447 | { |
448 | const size_t ks = context->ks; |
449 | const size_t cm_stride = context->cm_stride; |
450 | |
451 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
452 | mr_block_size, |
453 | nr_block_size, |
454 | context->kc, |
455 | context->ks_scaled, |
456 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
457 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
458 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
459 | cm_stride, |
460 | context->cn_stride, |
461 | context->a_offset, |
462 | context->zero, |
463 | &context->params); |
464 | } |
465 | |
466 | void xnn_compute_grouped_subgemm2d( |
467 | const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
468 | size_t batch_index, |
469 | size_t group_index, |
470 | size_t subkernel_index, |
471 | size_t slice_y, |
472 | size_t slice_x_start, |
473 | size_t nc_block_start, |
474 | size_t slice_x_max, |
475 | size_t nc_block_size) |
476 | { |
477 | const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; |
478 | |
479 | if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { |
480 | return; |
481 | } |
482 | |
483 | const size_t slice_width = subconvolution_params->slice_width; |
484 | if XNN_UNLIKELY(slice_x_start >= slice_width) { |
485 | return; |
486 | } |
487 | const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
488 | |
489 | const size_t ax_stride = context->ax_stride; |
490 | const size_t cx_stride = context->cx_stride; |
491 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
492 | slice_x_size, |
493 | nc_block_size, |
494 | context->kc, |
495 | (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride), |
496 | ax_stride, |
497 | (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride), |
498 | (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), |
499 | cx_stride, |
500 | context->cn_stride, |
501 | &context->params); |
502 | } |
503 | |
504 | void xnn_compute_subgemm2d( |
505 | const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
506 | size_t batch_index, |
507 | size_t subkernel_index, |
508 | size_t slice_y, |
509 | size_t slice_x_start, |
510 | size_t nc_block_start, |
511 | size_t slice_x_max, |
512 | size_t nc_block_size) |
513 | { |
514 | const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; |
515 | |
516 | if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { |
517 | return; |
518 | } |
519 | |
520 | const size_t slice_width = subconvolution_params->slice_width; |
521 | if XNN_UNLIKELY(slice_x_start >= slice_width) { |
522 | return; |
523 | } |
524 | const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
525 | |
526 | const size_t ax_stride = context->ax_stride; |
527 | const size_t cx_stride = context->cx_stride; |
528 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
529 | slice_x_size, |
530 | nc_block_size, |
531 | context->kc, |
532 | (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride), |
533 | ax_stride, |
534 | (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride), |
535 | (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), |
536 | cx_stride, |
537 | context->cn_stride, |
538 | &context->params); |
539 | } |
540 | |
541 | void xnn_compute_grouped_subconv2d( |
542 | const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], |
543 | size_t batch_index, |
544 | size_t group_index, |
545 | size_t subkernel_index, |
546 | size_t slice_y, |
547 | size_t slice_x_start, |
548 | size_t nc_block_start, |
549 | size_t slice_x_max, |
550 | size_t nc_block_size) |
551 | { |
552 | const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; |
553 | |
554 | if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { |
555 | return; |
556 | } |
557 | |
558 | const size_t slice_width = subconvolution_params->slice_width; |
559 | if XNN_UNLIKELY(slice_x_start >= slice_width) { |
560 | return; |
561 | } |
562 | const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
563 | |
564 | const size_t cx_stride = context->cx_stride; |
565 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
566 | slice_x_size, |
567 | nc_block_size, |
568 | context->kc, |
569 | subconvolution_params->scaled_kernel_size, |
570 | (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride), |
571 | (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride), |
572 | (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), |
573 | cx_stride, |
574 | context->cn_stride, |
575 | context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, |
576 | context->zero, |
577 | &context->params); |
578 | } |
579 | |
580 | void xnn_compute_subconv2d( |
581 | const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], |
582 | size_t batch_index, |
583 | size_t subkernel_index, |
584 | size_t slice_y, |
585 | size_t slice_x_start, |
586 | size_t nc_block_start, |
587 | size_t slice_x_max, |
588 | size_t nc_block_size) |
589 | { |
590 | const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; |
591 | |
592 | if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { |
593 | return; |
594 | } |
595 | |
596 | const size_t slice_width = subconvolution_params->slice_width; |
597 | if XNN_UNLIKELY(slice_x_start >= slice_width) { |
598 | return; |
599 | } |
600 | const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
601 | |
602 | const size_t cx_stride = context->cx_stride; |
603 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
604 | slice_x_size, |
605 | nc_block_size, |
606 | context->kc, |
607 | subconvolution_params->scaled_kernel_size, |
608 | (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride), |
609 | (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride), |
610 | (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), |
611 | cx_stride, |
612 | context->cn_stride, |
613 | context->a_offset + batch_index * context->ba_stride, |
614 | context->zero, |
615 | &context->params); |
616 | } |
617 | |
618 | void xnn_compute_conv2d_hwc2chw( |
619 | const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], |
620 | size_t batch_index, |
621 | size_t output_y_start, |
622 | size_t output_y_slice) |
623 | { |
624 | context->hwc2chw_ukernel( |
625 | context->input_height, |
626 | context->input_width, |
627 | output_y_start, |
628 | output_y_start + output_y_slice, |
629 | (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride), |
630 | context->zero, |
631 | context->packed_weights, |
632 | (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride), |
633 | context->input_padding_top, |
634 | context->output_channels, |
635 | context->output_height_stride, |
636 | context->output_channel_stride, |
637 | &context->params); |
638 | } |
639 | |
640 | void xnn_compute_dwconv_unipass( |
641 | const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], |
642 | size_t batch_index, |
643 | size_t output_y) |
644 | { |
645 | const void** indirect_input = |
646 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
647 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
648 | void* output = (void*) ((uintptr_t) context->output + |
649 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
650 | |
651 | context->unipass_ukernel( |
652 | context->groups, context->output_width, |
653 | indirect_input, context->packed_weights, output, |
654 | context->indirect_input_width_stride, context->output_increment, |
655 | input_offset, context->zero, |
656 | &context->params); |
657 | } |
658 | |
659 | void xnn_compute_dwconv2d_chw( |
660 | const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], |
661 | size_t batch_index, |
662 | size_t channel) |
663 | { |
664 | context->chw_ukernel( |
665 | context->input_height, |
666 | context->input_width, |
667 | (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride), |
668 | (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride), |
669 | context->zero, |
670 | (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride), |
671 | context->input_padding_top, |
672 | &context->params); |
673 | } |
674 | |
675 | void xnn_compute_argmax_pooling_unipass( |
676 | const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
677 | size_t batch_index, |
678 | size_t output_y) |
679 | { |
680 | const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + |
681 | output_y * context->indirect_input_height_stride); |
682 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
683 | void* output = (void*) ((uintptr_t) context->output + |
684 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
685 | uint32_t* index = (uint32_t*) ((uintptr_t) context->index + |
686 | batch_index * context->index_batch_stride + output_y * context->index_height_stride); |
687 | |
688 | context->unipass_ukernel( |
689 | context->output_width, context->pooling_size, context->channels, |
690 | indirect_input, input_offset, output, index, |
691 | context->input_increment, context->output_increment); |
692 | } |
693 | |
694 | void xnn_compute_argmax_pooling_multipass( |
695 | const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
696 | size_t batch_index, |
697 | size_t output_y) |
698 | { |
699 | const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + |
700 | output_y * context->indirect_input_height_stride); |
701 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
702 | void* output = (void*) ((uintptr_t) context->output + |
703 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
704 | uint32_t* index = (uint32_t*) ((uintptr_t) context->index + |
705 | batch_index * context->index_batch_stride + output_y * context->index_height_stride); |
706 | |
707 | void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES); |
708 | void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES); |
709 | |
710 | context->multipass_ukernel( |
711 | context->output_width, context->pooling_size, context->channels, |
712 | indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index, |
713 | context->input_increment, context->output_increment); |
714 | } |
715 | |
716 | void xnn_compute_max_pooling( |
717 | const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
718 | size_t batch_index, |
719 | size_t output_y) |
720 | { |
721 | const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + |
722 | output_y * context->indirect_input_height_stride); |
723 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
724 | void* output = (void*) ((uintptr_t) context->output + |
725 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
726 | |
727 | context->ukernel( |
728 | context->output_width, context->pooling_size, context->channels, |
729 | indirect_input, input_offset, output, |
730 | context->input_increment, context->output_increment, |
731 | &context->params); |
732 | } |
733 | |
734 | void xnn_compute_unpooling( |
735 | const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
736 | size_t input_y, |
737 | size_t input_x) |
738 | { |
739 | const void* input = (const void*) ((uintptr_t) context->input + |
740 | input_y * context->input_height_stride + input_x * context->input_width_stride); |
741 | const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index + |
742 | input_y * context->index_height_stride + input_x * context->index_width_stride); |
743 | void** indirect_output = |
744 | (void**) ((uintptr_t) context->indirect_output + |
745 | input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride); |
746 | |
747 | context->ukernel( |
748 | context->pooling_size, |
749 | context->channels, |
750 | context->fill_value, |
751 | input, index, indirect_output); |
752 | } |
753 | |
754 | void xnn_compute_average_pooling_unipass( |
755 | const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
756 | size_t batch_index, |
757 | size_t output_y) |
758 | { |
759 | const void** indirect_input = |
760 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
761 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
762 | void* output = (void*) ((uintptr_t) context->output + |
763 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
764 | |
765 | context->unipass_ukernel( |
766 | context->output_width, context->pooling_size, context->channels, |
767 | indirect_input, input_offset, context->zero, output, |
768 | context->input_increment, context->output_increment, |
769 | &context->params); |
770 | } |
771 | |
772 | void xnn_compute_average_pooling_multipass( |
773 | const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
774 | size_t batch_index, |
775 | size_t output_y) |
776 | { |
777 | const void** indirect_input = |
778 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
779 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
780 | void* output = (void*) ((uintptr_t) context->output + |
781 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
782 | |
783 | void* multipass_buffer = |
784 | XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); |
785 | |
786 | context->multipass_ukernel( |
787 | context->output_width, context->pooling_size, context->channels, |
788 | indirect_input, input_offset, context->zero, multipass_buffer, output, |
789 | context->input_increment, context->output_increment, |
790 | &context->params); |
791 | } |
792 | |
793 | void xnn_compute_pixelwise_average_pooling_unipass( |
794 | const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
795 | size_t batch_index, |
796 | size_t output_y) |
797 | { |
798 | const void** indirect_input = |
799 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
800 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
801 | const void* pixelwise_buffer = |
802 | (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride); |
803 | void* output = (void*) ((uintptr_t) context->output + |
804 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
805 | |
806 | context->unipass_ukernel( |
807 | context->output_width, context->pooling_size, context->channels, |
808 | indirect_input, input_offset, context->zero, pixelwise_buffer, output, |
809 | context->input_increment, context->output_increment, |
810 | &context->params); |
811 | } |
812 | |
813 | void xnn_compute_pixelwise_average_pooling_multipass( |
814 | const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
815 | size_t batch_index, |
816 | size_t output_y) |
817 | { |
818 | const void** indirect_input = |
819 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
820 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
821 | const void* pixelwise_buffer = |
822 | (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride); |
823 | void* output = (void*) ((uintptr_t) context->output + |
824 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
825 | |
826 | void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); |
827 | |
828 | context->multipass_ukernel( |
829 | context->output_width, context->pooling_size, context->channels, |
830 | indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output, |
831 | context->input_increment, context->output_increment, |
832 | &context->params); |
833 | } |
834 | |
835 | void xnn_compute_global_average_pooling_nwc_unipass( |
836 | const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], |
837 | size_t batch_index) |
838 | { |
839 | const void* input = |
840 | (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride); |
841 | void* output = |
842 | (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride); |
843 | |
844 | context->unipass_ukernel( |
845 | context->input_elements, |
846 | context->channels, |
847 | input, |
848 | context->input_pixel_stride, |
849 | context->zero, |
850 | output, |
851 | &context->params); |
852 | } |
853 | |
854 | void xnn_compute_global_average_pooling_nwc_multipass( |
855 | const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], |
856 | size_t batch_index) |
857 | { |
858 | const void* input = |
859 | (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride); |
860 | void* output = |
861 | (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride); |
862 | |
863 | void* multipass_buffer = |
864 | XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); |
865 | |
866 | context->multipass_ukernel( |
867 | context->input_elements, |
868 | context->channels, |
869 | input, |
870 | context->input_pixel_stride, |
871 | context->zero, |
872 | multipass_buffer, |
873 | output, |
874 | &context->params); |
875 | } |
876 | |
877 | void xnn_compute_global_average_pooling_ncw( |
878 | const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)], |
879 | size_t batch_index, |
880 | size_t channels_start, |
881 | size_t channels_slice) |
882 | { |
883 | const void* input = (const void*) ((uintptr_t) context->input + |
884 | channels_start * context->input_channel_stride + batch_index * context->input_batch_stride); |
885 | void* output = (void*) ((uintptr_t) context->output + |
886 | channels_start * context->output_channel_stride + batch_index * context->output_batch_stride); |
887 | |
888 | context->ukernel( |
889 | context->input_elements, |
890 | channels_slice, |
891 | input, |
892 | output, |
893 | &context->params); |
894 | } |
895 | |
896 | void xnn_compute_resize_bilinear( |
897 | const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], |
898 | size_t batch_index, |
899 | size_t pixel_start, |
900 | size_t pixel_range) |
901 | { |
902 | void* output = |
903 | (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride); |
904 | |
905 | context->ukernel( |
906 | pixel_range, |
907 | context->scaled_channels, |
908 | context->indirect_input + pixel_start * 4, |
909 | context->input_offset + batch_index * context->input_batch_stride, |
910 | (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)), |
911 | output, |
912 | context->output_pixel_stride - context->scaled_channels); |
913 | } |
914 | |
915 | void xnn_compute_resize_bilinear_chw( |
916 | const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)], |
917 | size_t batch_index, |
918 | size_t channel_start, |
919 | size_t channel_range) |
920 | { |
921 | void* output = |
922 | (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride); |
923 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride; |
924 | |
925 | context->ukernel( |
926 | context->output_pixels, |
927 | channel_range, |
928 | context->indirect_input, |
929 | input_offset, |
930 | context->packed_weights, |
931 | output, |
932 | context->input_channel_stride); |
933 | } |
934 | |
935 | void xnn_compute_prelu( |
936 | const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)], |
937 | size_t batch_start, |
938 | size_t batch_range) |
939 | { |
940 | const size_t x_stride = context->x_stride; |
941 | const size_t y_stride = context->y_stride; |
942 | const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start); |
943 | void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start); |
944 | |
945 | context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride); |
946 | } |
947 | |
948 | void xnn_compute_pad_5d( |
949 | const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], |
950 | size_t i, size_t j, size_t k, size_t l, size_t m) |
951 | { |
952 | const void* input = (const void*) ((uintptr_t) context->input + |
953 | i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]); |
954 | void* output = (void*) ((uintptr_t) context->output + |
955 | i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]); |
956 | |
957 | const size_t i_padding = context->pre_paddings[5]; |
958 | const size_t j_padding = context->pre_paddings[4]; |
959 | const size_t k_padding = context->pre_paddings[3]; |
960 | const size_t l_padding = context->pre_paddings[2]; |
961 | const size_t m_padding = context->pre_paddings[1]; |
962 | |
963 | const size_t i_size = context->input_size[5]; |
964 | const size_t j_size = context->input_size[4]; |
965 | const size_t k_size = context->input_size[3]; |
966 | const size_t l_size = context->input_size[2]; |
967 | const size_t m_size = context->input_size[1]; |
968 | |
969 | if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size && |
970 | l - l_padding < l_size && m - m_padding < m_size) |
971 | { |
972 | context->pad_ukernel( |
973 | 1 /* rows */, |
974 | context->input_size[0], context->pre_paddings[0], context->post_paddings[0], |
975 | input, 0 /* input stride */, output, 0 /* output stride */, |
976 | context->padding_value); |
977 | } else { |
978 | context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, context->padding_value); |
979 | } |
980 | } |
981 | |
982 | void xnn_compute_elementwise_binary_1d( |
983 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
984 | size_t i) |
985 | { |
986 | const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[4]); |
987 | const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[4]); |
988 | void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[4]); |
989 | context->ukernel(context->elements, a, b, y, &context->params); |
990 | } |
991 | |
992 | void xnn_compute_elementwise_binary_2d( |
993 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
994 | size_t i, size_t j) |
995 | { |
996 | const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[3] + j * context->a_stride[4]); |
997 | const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[3] + j * context->b_stride[4]); |
998 | void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[3] + j * context->y_stride[4]); |
999 | context->ukernel(context->elements, a, b, y, &context->params); |
1000 | } |
1001 | |
1002 | void xnn_compute_elementwise_binary_3d( |
1003 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
1004 | size_t i, size_t j, size_t k) |
1005 | { |
1006 | const void* a = (const void*) ((uintptr_t) context->a + |
1007 | i * context->a_stride[2] + j * context->a_stride[3] + k * context->a_stride[4]); |
1008 | const void* b = (const void*) ((uintptr_t) context->b + |
1009 | i * context->b_stride[2] + j * context->b_stride[3] + k * context->b_stride[4]); |
1010 | void* y = (void*) ((uintptr_t) context->y + |
1011 | i * context->y_stride[2] + j * context->y_stride[3] + k * context->y_stride[4]); |
1012 | context->ukernel(context->elements, a, b, y, &context->params); |
1013 | } |
1014 | |
1015 | void xnn_compute_elementwise_binary_4d( |
1016 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
1017 | size_t i, size_t j, size_t k, size_t l) |
1018 | { |
1019 | const void* a = (const void*) ((uintptr_t) context->a + |
1020 | i * context->a_stride[1] + j * context->a_stride[2] + k * context->a_stride[3] + l * context->a_stride[4]); |
1021 | const void* b = (const void*) ((uintptr_t) context->b + |
1022 | i * context->b_stride[1] + j * context->b_stride[2] + k * context->b_stride[3] + l * context->b_stride[4]); |
1023 | void* y = (void*) ((uintptr_t) context->y + |
1024 | i * context->y_stride[1] + j * context->y_stride[2] + k * context->y_stride[3] + l * context->y_stride[4]); |
1025 | context->ukernel(context->elements, a, b, y, &context->params); |
1026 | } |
1027 | |
1028 | void xnn_compute_elementwise_binary_5d( |
1029 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
1030 | size_t i, size_t j, size_t k, size_t l, size_t m) |
1031 | { |
1032 | const void* a = (const void*) ((uintptr_t) context->a + |
1033 | i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]); |
1034 | const void* b = (const void*) ((uintptr_t) context->b + |
1035 | i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]); |
1036 | void* y = (void*) ((uintptr_t) context->y + |
1037 | i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]); |
1038 | context->ukernel(context->elements, a, b, y, &context->params); |
1039 | } |
1040 | |
1041 | void xnn_compute_channel_shuffle_fixed( |
1042 | const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], |
1043 | size_t index) |
1044 | { |
1045 | const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); |
1046 | void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); |
1047 | |
1048 | context->fixed_ukernel(context->n, x, y); |
1049 | } |
1050 | |
1051 | void xnn_compute_channel_shuffle_variable( |
1052 | const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], |
1053 | size_t index) |
1054 | { |
1055 | const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); |
1056 | void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); |
1057 | |
1058 | context->variable_ukernel(context->n, context->m, x, y); |
1059 | } |
1060 | |
1061 | void xnn_compute_lut_strided( |
1062 | const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], |
1063 | size_t batch_index) |
1064 | { |
1065 | const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index); |
1066 | void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index); |
1067 | |
1068 | context->ukernel(context->n, x, y, context->t); |
1069 | } |
1070 | |
1071 | void xnn_compute_lut_contiguous( |
1072 | const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], |
1073 | size_t offset, |
1074 | size_t size) |
1075 | { |
1076 | const void* x = (const void*) ((uintptr_t) context->x + offset); |
1077 | void* y = (void*) ((uintptr_t) context->y + offset); |
1078 | |
1079 | context->ukernel(size, x, y, context->t); |
1080 | } |
1081 | |
1082 | void xnn_compute_univector_strided( |
1083 | const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)], |
1084 | size_t batch_index, |
1085 | size_t batch_range) |
1086 | { |
1087 | const size_t x_stride = context->x_stride; |
1088 | const size_t y_stride = context->y_stride; |
1089 | |
1090 | const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_index); |
1091 | void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index); |
1092 | do { |
1093 | context->ukernel(context->n, x, y, &context->params); |
1094 | x = (const void*) ((uintptr_t) x + x_stride); |
1095 | y = (void*) ((uintptr_t) y + y_stride); |
1096 | } while (--batch_range != 0); |
1097 | } |
1098 | |
1099 | void xnn_compute_univector_contiguous( |
1100 | const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], |
1101 | size_t offset, |
1102 | size_t size) |
1103 | { |
1104 | const uint32_t log2_xsize = context->log2_xsize; |
1105 | const uint32_t log2_ysize = context->log2_ysize; |
1106 | const void* x = (const void*) ((uintptr_t) context->x + offset); |
1107 | void* y = (void*) ((uintptr_t) context->y + ((offset >> log2_xsize) << log2_ysize)); |
1108 | context->ukernel(size, x, y, &context->params); |
1109 | } |
1110 | |
1111 | void xnn_compute_u8_softmax( |
1112 | const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], |
1113 | size_t batch_index) |
1114 | { |
1115 | const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index); |
1116 | uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index); |
1117 | const size_t n = context->n; |
1118 | |
1119 | uint8_t x_max = 0; |
1120 | context->rmax_ukernel(n, x, &x_max); |
1121 | const size_t adjustment = x_max ^ 255; |
1122 | const uint32_t* t = (const uint32_t*) context->t + adjustment; |
1123 | context->lut_norm_ukernel(n, x, t, y); |
1124 | } |
1125 | |
1126 | void xnn_compute_floating_point_softmax( |
1127 | const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], |
1128 | size_t batch_index) |
1129 | { |
1130 | const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index); |
1131 | void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index); |
1132 | const size_t n = context->n; |
1133 | |
1134 | // First pass: reduce-max |
1135 | union { |
1136 | float as_float; |
1137 | uint16_t as_half; |
1138 | } x_max; |
1139 | context->rmax_ukernel(n, x, &x_max); |
1140 | |
1141 | // Second pass: reduce-add & store exp(x-x_max) |
1142 | union { |
1143 | float as_float; |
1144 | uint16_t as_half; |
1145 | } y_sum; |
1146 | context->raddstoreexpminusmax_ukernel(n, x, &x_max, y, &y_sum, &context->expminus_params); |
1147 | |
1148 | // Third pass: scale y |
1149 | union { |
1150 | float as_float; |
1151 | uint16_t as_half; |
1152 | } y_scale; |
1153 | context->compute_reciprocal(&y_sum, &y_scale); |
1154 | context->vmulc_ukernel(n, y, &y_scale, y, &context->minmax_params); |
1155 | } |
1156 | |
1157 | void xnn_compute_vmulcaddc( |
1158 | const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], |
1159 | size_t batch_start, |
1160 | size_t batch_size) |
1161 | { |
1162 | const size_t x_stride = context->x_stride; |
1163 | const size_t y_stride = context->y_stride; |
1164 | |
1165 | const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start); |
1166 | void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start); |
1167 | |
1168 | context->ukernel( |
1169 | batch_size, |
1170 | context->n, |
1171 | x, x_stride, |
1172 | context->w, |
1173 | y, y_stride, |
1174 | &context->params); |
1175 | } |
1176 | |
1177 | #if XNN_MAX_UARCH_TYPES > 1 |
1178 | void xnn_compute_hmp_grouped_gemm( |
1179 | const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1180 | uint32_t uarch_index, |
1181 | size_t group_index, |
1182 | size_t mr_block_start, |
1183 | size_t nr_block_start, |
1184 | size_t mr_block_size, |
1185 | size_t nr_block_size) |
1186 | { |
1187 | const size_t k_scaled = context->k_scaled; |
1188 | const size_t a_stride = context->a_stride; |
1189 | const size_t cm_stride = context->cm_stride; |
1190 | |
1191 | context->ukernel.function[uarch_index]( |
1192 | mr_block_size, |
1193 | nr_block_size, |
1194 | k_scaled, |
1195 | (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled), |
1196 | a_stride, |
1197 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride), |
1198 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride), |
1199 | cm_stride, |
1200 | context->cn_stride, |
1201 | &context->params); |
1202 | } |
1203 | |
1204 | void xnn_compute_hmp_gemm( |
1205 | const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1206 | uint32_t uarch_index, |
1207 | size_t mr_block_start, |
1208 | size_t nr_block_start, |
1209 | size_t mr_block_size, |
1210 | size_t nr_block_size) |
1211 | { |
1212 | const size_t a_stride = context->a_stride; |
1213 | const size_t cm_stride = context->cm_stride; |
1214 | |
1215 | context->ukernel.function[uarch_index]( |
1216 | mr_block_size, |
1217 | nr_block_size, |
1218 | context->k_scaled, |
1219 | (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), |
1220 | a_stride, |
1221 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
1222 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1223 | cm_stride, |
1224 | context->cn_stride, |
1225 | context->fused_params); |
1226 | } |
1227 | |
1228 | void xnn_compute_hmp_grouped_batch_igemm( |
1229 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1230 | uint32_t uarch_index, |
1231 | size_t batch_index, |
1232 | size_t group_index, |
1233 | size_t mr_block_start, |
1234 | size_t nr_block_start, |
1235 | size_t mr_block_size, |
1236 | size_t nr_block_size) |
1237 | { |
1238 | const size_t ks = context->ks; |
1239 | const size_t cm_stride = context->cm_stride; |
1240 | |
1241 | context->ukernel.function[uarch_index]( |
1242 | mr_block_size, |
1243 | nr_block_size, |
1244 | context->kc, |
1245 | context->ks_scaled, |
1246 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
1247 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), |
1248 | (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1249 | cm_stride, |
1250 | context->cn_stride, |
1251 | context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, |
1252 | context->zero, |
1253 | &context->params); |
1254 | } |
1255 | |
1256 | void xnn_compute_hmp_grouped_igemm( |
1257 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1258 | uint32_t uarch_index, |
1259 | size_t group_index, |
1260 | size_t mr_block_start, |
1261 | size_t nr_block_start, |
1262 | size_t mr_block_size, |
1263 | size_t nr_block_size) |
1264 | { |
1265 | const size_t ks = context->ks; |
1266 | const size_t cm_stride = context->cm_stride; |
1267 | |
1268 | context->ukernel.function[uarch_index]( |
1269 | mr_block_size, |
1270 | nr_block_size, |
1271 | context->kc, |
1272 | context->ks_scaled, |
1273 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
1274 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), |
1275 | (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1276 | cm_stride, |
1277 | context->cn_stride, |
1278 | context->a_offset + group_index * context->ga_stride, |
1279 | context->zero, |
1280 | &context->params); |
1281 | } |
1282 | |
1283 | void xnn_compute_batch_hmp_igemm( |
1284 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1285 | uint32_t uarch_index, |
1286 | size_t batch_index, |
1287 | size_t mr_block_start, |
1288 | size_t nr_block_start, |
1289 | size_t mr_block_size, |
1290 | size_t nr_block_size) |
1291 | { |
1292 | const size_t ks = context->ks; |
1293 | const size_t cm_stride = context->cm_stride; |
1294 | |
1295 | context->ukernel.function[uarch_index]( |
1296 | mr_block_size, |
1297 | nr_block_size, |
1298 | context->kc, |
1299 | context->ks_scaled, |
1300 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
1301 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
1302 | (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1303 | cm_stride, |
1304 | context->cn_stride, |
1305 | context->a_offset + batch_index * context->ba_stride, |
1306 | context->zero, |
1307 | &context->params); |
1308 | } |
1309 | |
1310 | void xnn_compute_hmp_igemm( |
1311 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1312 | uint32_t uarch_index, |
1313 | size_t mr_block_start, |
1314 | size_t nr_block_start, |
1315 | size_t mr_block_size, |
1316 | size_t nr_block_size) |
1317 | { |
1318 | const size_t ks = context->ks; |
1319 | const size_t cm_stride = context->cm_stride; |
1320 | |
1321 | context->ukernel.function[uarch_index]( |
1322 | mr_block_size, |
1323 | nr_block_size, |
1324 | context->kc, |
1325 | context->ks_scaled, |
1326 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
1327 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
1328 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1329 | cm_stride, |
1330 | context->cn_stride, |
1331 | context->a_offset, |
1332 | context->zero, |
1333 | &context->params); |
1334 | } |
1335 | #endif // XNN_MAX_UARCH_TYPES > 1 |
1336 | |
1337 | enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool) |
1338 | { |
1339 | if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { |
1340 | xnn_log_error("failed to run operator: XNNPACK is not initialized" ); |
1341 | return xnn_status_uninitialized; |
1342 | } |
1343 | switch (op->state) { |
1344 | case xnn_run_state_invalid: |
1345 | xnn_log_error("failed to run operator: operator was not successfully setup" ); |
1346 | return xnn_status_invalid_state; |
1347 | case xnn_run_state_ready: |
1348 | break; |
1349 | case xnn_run_state_skip: |
1350 | return xnn_status_success; |
1351 | } |
1352 | |
1353 | uint32_t flags = PTHREADPOOL_FLAG_DISABLE_DENORMALS; |
1354 | if (op->flags & XNN_FLAG_YIELD_WORKERS) { |
1355 | flags |= PTHREADPOOL_FLAG_YIELD_WORKERS; |
1356 | } |
1357 | switch (op->compute.type) { |
1358 | case xnn_parallelization_type_invalid: |
1359 | break; |
1360 | case xnn_parallelization_type_1d: |
1361 | assert(op->compute.range[0] != 0); |
1362 | pthreadpool_parallelize_1d( |
1363 | threadpool, |
1364 | op->compute.task_1d, |
1365 | &op->context, |
1366 | op->compute.range[0], |
1367 | flags); |
1368 | break; |
1369 | case xnn_parallelization_type_1d_tile_1d: |
1370 | assert(op->compute.range[0] != 0); |
1371 | assert(op->compute.tile[0] != 0); |
1372 | pthreadpool_parallelize_1d_tile_1d( |
1373 | threadpool, |
1374 | op->compute.task_1d_tile_1d, |
1375 | &op->context, |
1376 | op->compute.range[0], |
1377 | op->compute.tile[0], |
1378 | flags); |
1379 | break; |
1380 | case xnn_parallelization_type_2d: |
1381 | assert(op->compute.range[0] != 0); |
1382 | assert(op->compute.range[1] != 0); |
1383 | pthreadpool_parallelize_2d( |
1384 | threadpool, |
1385 | op->compute.task_2d, |
1386 | &op->context, |
1387 | op->compute.range[0], op->compute.range[1], |
1388 | flags); |
1389 | break; |
1390 | case xnn_parallelization_type_2d_tile_1d: |
1391 | assert(op->compute.range[0] != 0); |
1392 | assert(op->compute.range[1] != 0); |
1393 | assert(op->compute.tile[0] != 0); |
1394 | pthreadpool_parallelize_2d_tile_1d( |
1395 | threadpool, |
1396 | op->compute.task_2d_tile_1d, |
1397 | &op->context, |
1398 | op->compute.range[0], op->compute.range[1], |
1399 | op->compute.tile[0], |
1400 | flags); |
1401 | break; |
1402 | case xnn_parallelization_type_2d_tile_2d: |
1403 | assert(op->compute.range[0] != 0); |
1404 | assert(op->compute.range[1] != 0); |
1405 | assert(op->compute.tile[0] != 0); |
1406 | assert(op->compute.tile[1] != 0); |
1407 | pthreadpool_parallelize_2d_tile_2d( |
1408 | threadpool, |
1409 | op->compute.task_2d_tile_2d, |
1410 | &op->context, |
1411 | op->compute.range[0], op->compute.range[1], |
1412 | op->compute.tile[0], op->compute.tile[1], |
1413 | flags); |
1414 | break; |
1415 | case xnn_parallelization_type_3d: |
1416 | assert(op->compute.range[0] != 0); |
1417 | assert(op->compute.range[1] != 0); |
1418 | assert(op->compute.range[2] != 0); |
1419 | pthreadpool_parallelize_3d( |
1420 | threadpool, |
1421 | op->compute.task_3d, |
1422 | &op->context, |
1423 | op->compute.range[0], op->compute.range[1], op->compute.range[2], |
1424 | flags); |
1425 | break; |
1426 | case xnn_parallelization_type_3d_tile_2d: |
1427 | assert(op->compute.range[0] != 0); |
1428 | assert(op->compute.range[1] != 0); |
1429 | assert(op->compute.range[2] != 0); |
1430 | assert(op->compute.tile[0] != 0); |
1431 | assert(op->compute.tile[1] != 0); |
1432 | pthreadpool_parallelize_3d_tile_2d( |
1433 | threadpool, |
1434 | op->compute.task_3d_tile_2d, |
1435 | &op->context, |
1436 | op->compute.range[0], op->compute.range[1], op->compute.range[2], |
1437 | op->compute.tile[0], op->compute.tile[1], |
1438 | flags); |
1439 | break; |
1440 | case xnn_parallelization_type_4d: |
1441 | assert(op->compute.range[0] != 0); |
1442 | assert(op->compute.range[1] != 0); |
1443 | assert(op->compute.range[2] != 0); |
1444 | assert(op->compute.range[3] != 0); |
1445 | pthreadpool_parallelize_4d( |
1446 | threadpool, |
1447 | op->compute.task_4d, |
1448 | &op->context, |
1449 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], |
1450 | flags); |
1451 | break; |
1452 | case xnn_parallelization_type_4d_tile_2d: |
1453 | assert(op->compute.range[0] != 0); |
1454 | assert(op->compute.range[1] != 0); |
1455 | assert(op->compute.range[2] != 0); |
1456 | assert(op->compute.range[3] != 0); |
1457 | assert(op->compute.tile[0] != 0); |
1458 | assert(op->compute.tile[1] != 0); |
1459 | pthreadpool_parallelize_4d_tile_2d( |
1460 | threadpool, |
1461 | op->compute.task_4d_tile_2d, |
1462 | &op->context, |
1463 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], |
1464 | op->compute.tile[0], op->compute.tile[1], |
1465 | flags); |
1466 | break; |
1467 | case xnn_parallelization_type_5d: |
1468 | assert(op->compute.range[0] != 0); |
1469 | assert(op->compute.range[1] != 0); |
1470 | assert(op->compute.range[2] != 0); |
1471 | assert(op->compute.range[3] != 0); |
1472 | assert(op->compute.range[4] != 0); |
1473 | pthreadpool_parallelize_5d( |
1474 | threadpool, |
1475 | op->compute.task_5d, |
1476 | &op->context, |
1477 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], |
1478 | flags); |
1479 | break; |
1480 | case xnn_parallelization_type_5d_tile_2d: |
1481 | assert(op->compute.range[0] != 0); |
1482 | assert(op->compute.range[1] != 0); |
1483 | assert(op->compute.range[2] != 0); |
1484 | assert(op->compute.range[3] != 0); |
1485 | assert(op->compute.range[4] != 0); |
1486 | assert(op->compute.tile[0] != 0); |
1487 | assert(op->compute.tile[1] != 0); |
1488 | pthreadpool_parallelize_5d_tile_2d( |
1489 | threadpool, |
1490 | op->compute.task_5d_tile_2d, |
1491 | &op->context, |
1492 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], |
1493 | op->compute.tile[0], op->compute.tile[1], |
1494 | flags); |
1495 | break; |
1496 | case xnn_parallelization_type_6d_tile_2d: |
1497 | assert(op->compute.range[0] != 0); |
1498 | assert(op->compute.range[1] != 0); |
1499 | assert(op->compute.range[2] != 0); |
1500 | assert(op->compute.range[3] != 0); |
1501 | assert(op->compute.range[4] != 0); |
1502 | assert(op->compute.range[5] != 0); |
1503 | assert(op->compute.tile[0] != 0); |
1504 | assert(op->compute.tile[1] != 0); |
1505 | pthreadpool_parallelize_6d_tile_2d( |
1506 | threadpool, |
1507 | op->compute.task_6d_tile_2d, |
1508 | &op->context, |
1509 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5], |
1510 | op->compute.tile[0], op->compute.tile[1], |
1511 | flags); |
1512 | break; |
1513 | #if XNN_MAX_UARCH_TYPES > 1 |
1514 | case xnn_parallelization_type_2d_tile_2d_with_uarch: |
1515 | assert(op->compute.range[0] != 0); |
1516 | assert(op->compute.range[1] != 0); |
1517 | assert(op->compute.tile[0] != 0); |
1518 | assert(op->compute.tile[1] != 0); |
1519 | pthreadpool_parallelize_2d_tile_2d_with_uarch( |
1520 | threadpool, |
1521 | op->compute.task_2d_tile_2d_with_id, |
1522 | &op->context, |
1523 | 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
1524 | op->compute.range[0], op->compute.range[1], |
1525 | op->compute.tile[0], op->compute.tile[1], |
1526 | flags); |
1527 | break; |
1528 | case xnn_parallelization_type_3d_tile_2d_with_uarch: |
1529 | assert(op->compute.range[0] != 0); |
1530 | assert(op->compute.range[1] != 0); |
1531 | assert(op->compute.range[2] != 0); |
1532 | assert(op->compute.tile[0] != 0); |
1533 | assert(op->compute.tile[1] != 0); |
1534 | pthreadpool_parallelize_3d_tile_2d_with_uarch( |
1535 | threadpool, |
1536 | op->compute.task_3d_tile_2d_with_id, |
1537 | &op->context, |
1538 | 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
1539 | op->compute.range[0], op->compute.range[1], op->compute.range[2], |
1540 | op->compute.tile[0], op->compute.tile[1], |
1541 | flags); |
1542 | break; |
1543 | case xnn_parallelization_type_4d_tile_2d_with_uarch: |
1544 | assert(op->compute.range[0] != 0); |
1545 | assert(op->compute.range[1] != 0); |
1546 | assert(op->compute.range[2] != 0); |
1547 | assert(op->compute.range[3] != 0); |
1548 | assert(op->compute.tile[0] != 0); |
1549 | assert(op->compute.tile[1] != 0); |
1550 | pthreadpool_parallelize_4d_tile_2d_with_uarch( |
1551 | threadpool, |
1552 | op->compute.task_4d_tile_2d_with_id, |
1553 | &op->context, |
1554 | 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
1555 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], |
1556 | op->compute.tile[0], op->compute.tile[1], |
1557 | flags); |
1558 | break; |
1559 | #endif // XNN_MAX_UARCH_TYPES > 1 |
1560 | default: |
1561 | XNN_UNREACHABLE; |
1562 | } |
1563 | return xnn_status_success; |
1564 | } |
1565 | |