1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // Copyright 2019 Google LLC |
5 | // |
6 | // This source code is licensed under the BSD-style license found in the |
7 | // LICENSE file in the root directory of this source tree. |
8 | |
9 | #include <assert.h> |
10 | #include <stddef.h> |
11 | #include <stdint.h> |
12 | #include <string.h> |
13 | |
14 | #include <xnnpack.h> |
15 | #include <xnnpack/allocator.h> |
16 | #include <xnnpack/operator.h> |
17 | #include <xnnpack/log.h> |
18 | #include <xnnpack/common.h> |
19 | #include <xnnpack/math.h> |
20 | #include <xnnpack/microkernel-type.h> |
21 | #include <xnnpack/params.h> |
22 | #include <xnnpack/compute.h> |
23 | |
24 | |
25 | void xnn_compute_transposec_2d( |
26 | const struct transpose_context* context, |
27 | size_t i, |
28 | size_t j, |
29 | size_t tile_i, |
30 | size_t tile_j) |
31 | { |
32 | const size_t ld_input = context->input_stride[1]; |
33 | const size_t ld_output = context->output_stride[0]; |
34 | context->const_size_ukernel( |
35 | (const void*) ((uintptr_t) context->x + i * context->input_stride[0] + j * context->input_stride[1]), |
36 | (void*) ((uintptr_t) context->y + j * context->output_stride[1] + i * context->output_stride[0]), |
37 | ld_input, |
38 | ld_output, |
39 | tile_i, |
40 | tile_j, |
41 | &context->params); |
42 | } |
43 | |
44 | void xnn_compute_transposec_3d( |
45 | const struct transpose_context* context, |
46 | size_t i, |
47 | size_t j, |
48 | size_t k, |
49 | size_t tile_j, |
50 | size_t tile_k) |
51 | { |
52 | const size_t ld_input = context->input_stride[2]; |
53 | const size_t ld_output = context->output_stride[1]; |
54 | const void* x = (const void*) ((uintptr_t) context->x + |
55 | i * context->input_stride[0] + j * context->input_stride[1] + k * context->input_stride[2]); |
56 | void* y = (void*) ((uintptr_t) context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
57 | k * context->output_stride[2]); |
58 | |
59 | context->const_size_ukernel( |
60 | x, |
61 | y, |
62 | ld_input, |
63 | ld_output, |
64 | tile_j, |
65 | tile_k, |
66 | &context->params); |
67 | } |
68 | |
69 | void xnn_compute_transposec_4d( |
70 | const struct transpose_context* context, |
71 | size_t i, |
72 | size_t j, |
73 | size_t k, |
74 | size_t l, |
75 | size_t tile_k, |
76 | size_t tile_l) |
77 | { |
78 | const size_t ld_input = context->input_stride[3]; |
79 | const size_t ld_output = context->output_stride[2]; |
80 | const void* x = (const void*) ((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
81 | k * context->input_stride[2] + l * context->input_stride[3]); |
82 | void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
83 | k * context->output_stride[2] + l * context->output_stride[3]); |
84 | |
85 | context->const_size_ukernel( |
86 | x, |
87 | y, |
88 | ld_input, |
89 | ld_output, |
90 | tile_k, |
91 | tile_l, |
92 | &context->params); |
93 | } |
94 | |
95 | void xnn_compute_transposec_5d( |
96 | const struct transpose_context* context, |
97 | size_t i, |
98 | size_t j, |
99 | size_t k, |
100 | size_t l, |
101 | size_t m, |
102 | size_t tile_l, |
103 | size_t tile_m) |
104 | { |
105 | const size_t ld_input = context->input_stride[4]; |
106 | const size_t ld_output = context->output_stride[3]; |
107 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
108 | k * context->input_stride[2] + l * context->input_stride[3] + m * context->input_stride[4]); |
109 | void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
110 | k * context->output_stride[2] + l * context->output_stride[3] + m * context->output_stride[4]); |
111 | |
112 | context->const_size_ukernel( |
113 | x, |
114 | y, |
115 | ld_input, |
116 | ld_output, |
117 | tile_l, |
118 | tile_m, |
119 | &context->params); |
120 | } |
121 | |
122 | void xnn_compute_transposec_6d( |
123 | const struct transpose_context* context, |
124 | size_t i, |
125 | size_t j, |
126 | size_t k, |
127 | size_t l, |
128 | size_t m, |
129 | size_t n, |
130 | size_t tile_m, |
131 | size_t tile_n) |
132 | { |
133 | const size_t ld_input = context->input_stride[5]; |
134 | const size_t ld_output = context->output_stride[4]; |
135 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
136 | k * context->input_stride[2] + l * context->input_stride[3] + |
137 | m * context->input_stride[4] + n * context->input_stride[5]); |
138 | void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
139 | k * context->output_stride[2] + l * context->output_stride[3] + m * context->output_stride[4] + |
140 | n * context->output_stride[5]); |
141 | |
142 | context->const_size_ukernel( |
143 | x, |
144 | y, |
145 | ld_input, |
146 | ld_output, |
147 | tile_m, |
148 | tile_n, |
149 | &context->params); |
150 | } |
151 | |
152 | void xnn_compute_transposev_2d( |
153 | const struct transpose_context* context, |
154 | size_t i, |
155 | size_t j, |
156 | size_t tile_i, |
157 | size_t tile_j) |
158 | { |
159 | const size_t element_size = context->output_stride[1]; |
160 | const size_t ld_input = context->input_stride[1]; |
161 | const size_t ld_output = context->output_stride[0]; |
162 | const void* x = (const void*) ((uintptr_t) context->x + |
163 | i * context->input_stride[0] + j * context->input_stride[1]); |
164 | void* y = (void*) ((uintptr_t) context->y + context->output_stride[1] * j + i * context->output_stride[0]); |
165 | |
166 | context->variable_size_ukernel( |
167 | x, |
168 | y, |
169 | ld_input, |
170 | ld_output, |
171 | context->input_stride[0], |
172 | context->output_stride[1], |
173 | element_size, |
174 | tile_i, |
175 | tile_j); |
176 | } |
177 | |
178 | void xnn_compute_transposev_3d( |
179 | const struct transpose_context* context, |
180 | size_t i, |
181 | size_t j, |
182 | size_t k, |
183 | size_t tile_j, |
184 | size_t tile_k) |
185 | { |
186 | const size_t element_size = context->output_stride[2]; |
187 | const size_t ld_input = context->input_stride[2]; |
188 | const size_t ld_output = context->output_stride[1]; |
189 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
190 | k * context->input_stride[2]); |
191 | void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] + |
192 | k * context->output_stride[2]); |
193 | |
194 | context->variable_size_ukernel( |
195 | x, |
196 | y, |
197 | ld_input, |
198 | ld_output, |
199 | context->input_stride[1], |
200 | context->output_stride[2], |
201 | element_size, |
202 | tile_j, |
203 | tile_k); |
204 | } |
205 | |
206 | void xnn_compute_transposev_4d( |
207 | const struct transpose_context* context, |
208 | size_t i, |
209 | size_t j, |
210 | size_t k, |
211 | size_t l, |
212 | size_t tile_k, |
213 | size_t tile_l) |
214 | { |
215 | const size_t element_size = context->output_stride[3]; |
216 | const size_t ld_input = context->input_stride[3]; |
217 | const size_t ld_output = context->output_stride[2]; |
218 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
219 | k * context->input_stride[2] + l * context->input_stride[3]); |
220 | void* y = (void*)((uintptr_t)context->y + context->output_stride[3] * l + i * context->output_stride[0] + |
221 | j * context->output_stride[1] + k * context->output_stride[2]); |
222 | |
223 | context->variable_size_ukernel( |
224 | x, |
225 | y, |
226 | ld_input, |
227 | ld_output, |
228 | context->input_stride[2], |
229 | context->output_stride[3], |
230 | element_size, |
231 | tile_k, |
232 | tile_l); |
233 | } |
234 | |
235 | void xnn_compute_transposev_5d( |
236 | const struct transpose_context* context, |
237 | size_t i, |
238 | size_t j, |
239 | size_t k, |
240 | size_t l, |
241 | size_t m, |
242 | size_t tile_l, |
243 | size_t tile_m) |
244 | { |
245 | const size_t element_size = context->output_stride[4]; |
246 | const size_t ld_input = context->input_stride[4]; |
247 | const size_t ld_output = context->output_stride[3]; |
248 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
249 | k * context->input_stride[2] + l * context->input_stride[3] + m * context->input_stride[4]); |
250 | void* y = (void*)((uintptr_t)context->y + context->output_stride[4] * m + i * context->output_stride[0] + |
251 | j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3]); |
252 | |
253 | context->variable_size_ukernel( |
254 | x, |
255 | y, |
256 | ld_input, |
257 | ld_output, |
258 | context->input_stride[3], |
259 | context->output_stride[4], |
260 | element_size, |
261 | tile_l, |
262 | tile_m); |
263 | } |
264 | |
265 | void xnn_compute_transposev_6d( |
266 | const struct transpose_context* context, |
267 | size_t i, |
268 | size_t j, |
269 | size_t k, |
270 | size_t l, |
271 | size_t m, |
272 | size_t n, |
273 | size_t tile_m, |
274 | size_t tile_n) |
275 | { |
276 | const size_t element_size = context->output_stride[5]; |
277 | const size_t ld_input = context->input_stride[5]; |
278 | const size_t ld_output = context->output_stride[4]; |
279 | const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] + |
280 | k * context->input_stride[2] + l * context->input_stride[3] + |
281 | m * context->input_stride[4] + n * context->input_stride[5]); |
282 | void* y = (void*)((uintptr_t)context->y + context->output_stride[5] * n + i * context->output_stride[0] + |
283 | j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3] + |
284 | m * context->output_stride[4]); |
285 | |
286 | context->variable_size_ukernel( |
287 | x, |
288 | y, |
289 | ld_input, |
290 | ld_output, |
291 | context->input_stride[4], |
292 | context->output_stride[5], |
293 | element_size, |
294 | tile_m, |
295 | tile_n); |
296 | } |
297 | |
298 | void xnn_compute_grouped_gemm( |
299 | const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
300 | size_t group_index, |
301 | size_t mr_block_start, |
302 | size_t nr_block_start, |
303 | size_t mr_block_size, |
304 | size_t nr_block_size) |
305 | { |
306 | const size_t k_scaled = context->k_scaled; |
307 | const size_t a_stride = context->a_stride; |
308 | const size_t cm_stride = context->cm_stride; |
309 | |
310 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
311 | mr_block_size, |
312 | nr_block_size, |
313 | k_scaled, |
314 | (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled), |
315 | a_stride, |
316 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride), |
317 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride), |
318 | cm_stride, |
319 | context->cn_stride, |
320 | &context->params); |
321 | } |
322 | |
323 | void xnn_compute_gemm( |
324 | const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
325 | size_t mr_block_start, |
326 | size_t nr_block_start, |
327 | size_t mr_block_size, |
328 | size_t nr_block_size) |
329 | { |
330 | const size_t a_stride = context->a_stride; |
331 | const size_t cm_stride = context->cm_stride; |
332 | |
333 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
334 | mr_block_size, |
335 | nr_block_size, |
336 | context->k_scaled, |
337 | (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), |
338 | a_stride, |
339 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
340 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
341 | cm_stride, |
342 | context->cn_stride, |
343 | context->fused_params); |
344 | } |
345 | |
346 | void xnn_compute_spmm( |
347 | const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], |
348 | size_t batch_index, |
349 | size_t mr_block_start, |
350 | size_t mr_block_size) |
351 | { |
352 | context->ukernel( |
353 | mr_block_size, |
354 | context->n, |
355 | (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start), |
356 | context->nonzero_weights, |
357 | context->input_increments, |
358 | context->output_channel_nonzeros, |
359 | (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start), |
360 | context->scaled_m, |
361 | &context->params); |
362 | } |
363 | |
364 | void xnn_compute_grouped_batch_igemm( |
365 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
366 | size_t batch_index, |
367 | size_t group_index, |
368 | size_t mr_block_start, |
369 | size_t nr_block_start, |
370 | size_t mr_block_size, |
371 | size_t nr_block_size) |
372 | { |
373 | const size_t ks = context->ks; |
374 | const size_t cm_stride = context->cm_stride; |
375 | |
376 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
377 | mr_block_size, |
378 | nr_block_size, |
379 | context->kc, |
380 | context->ks_scaled, |
381 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
382 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), |
383 | (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
384 | cm_stride, |
385 | context->cn_stride, |
386 | context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, |
387 | context->zero, |
388 | &context->params); |
389 | } |
390 | |
391 | void xnn_compute_grouped_igemm( |
392 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
393 | size_t group_index, |
394 | size_t mr_block_start, |
395 | size_t nr_block_start, |
396 | size_t mr_block_size, |
397 | size_t nr_block_size) |
398 | { |
399 | const size_t ks = context->ks; |
400 | const size_t cm_stride = context->cm_stride; |
401 | |
402 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
403 | mr_block_size, |
404 | nr_block_size, |
405 | context->kc, |
406 | context->ks_scaled, |
407 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
408 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), |
409 | (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
410 | cm_stride, |
411 | context->cn_stride, |
412 | context->a_offset + group_index * context->ga_stride, |
413 | context->zero, |
414 | &context->params); |
415 | } |
416 | |
417 | void xnn_compute_batch_igemm( |
418 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
419 | size_t batch_index, |
420 | size_t mr_block_start, |
421 | size_t nr_block_start, |
422 | size_t mr_block_size, |
423 | size_t nr_block_size) |
424 | { |
425 | const size_t ks = context->ks; |
426 | const size_t cm_stride = context->cm_stride; |
427 | |
428 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
429 | mr_block_size, |
430 | nr_block_size, |
431 | context->kc, |
432 | context->ks_scaled, |
433 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
434 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
435 | (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
436 | cm_stride, |
437 | context->cn_stride, |
438 | context->a_offset + batch_index * context->ba_stride, |
439 | context->zero, |
440 | &context->params); |
441 | } |
442 | |
443 | void xnn_compute_igemm( |
444 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
445 | size_t mr_block_start, |
446 | size_t nr_block_start, |
447 | size_t mr_block_size, |
448 | size_t nr_block_size) |
449 | { |
450 | const size_t ks = context->ks; |
451 | const size_t cm_stride = context->cm_stride; |
452 | |
453 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
454 | mr_block_size, |
455 | nr_block_size, |
456 | context->kc, |
457 | context->ks_scaled, |
458 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
459 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
460 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
461 | cm_stride, |
462 | context->cn_stride, |
463 | context->a_offset, |
464 | context->zero, |
465 | &context->params); |
466 | } |
467 | |
468 | void xnn_compute_grouped_subgemm2d( |
469 | const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
470 | size_t batch_index, |
471 | size_t group_index, |
472 | size_t subkernel_index, |
473 | size_t slice_y, |
474 | size_t slice_x_start, |
475 | size_t nc_block_start, |
476 | size_t slice_x_max, |
477 | size_t nc_block_size) |
478 | { |
479 | const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; |
480 | |
481 | if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { |
482 | return; |
483 | } |
484 | |
485 | const size_t slice_width = subconvolution_params->slice_width; |
486 | if XNN_UNLIKELY(slice_x_start >= slice_width) { |
487 | return; |
488 | } |
489 | const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
490 | |
491 | const size_t ax_stride = context->ax_stride; |
492 | const size_t cx_stride = context->cx_stride; |
493 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
494 | slice_x_size, |
495 | nc_block_size, |
496 | context->kc, |
497 | (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride), |
498 | ax_stride, |
499 | (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride), |
500 | (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), |
501 | cx_stride, |
502 | context->cn_stride, |
503 | &context->params); |
504 | } |
505 | |
506 | void xnn_compute_subgemm2d( |
507 | const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
508 | size_t batch_index, |
509 | size_t subkernel_index, |
510 | size_t slice_y, |
511 | size_t slice_x_start, |
512 | size_t nc_block_start, |
513 | size_t slice_x_max, |
514 | size_t nc_block_size) |
515 | { |
516 | const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; |
517 | |
518 | if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { |
519 | return; |
520 | } |
521 | |
522 | const size_t slice_width = subconvolution_params->slice_width; |
523 | if XNN_UNLIKELY(slice_x_start >= slice_width) { |
524 | return; |
525 | } |
526 | const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
527 | |
528 | const size_t ax_stride = context->ax_stride; |
529 | const size_t cx_stride = context->cx_stride; |
530 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
531 | slice_x_size, |
532 | nc_block_size, |
533 | context->kc, |
534 | (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride), |
535 | ax_stride, |
536 | (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride), |
537 | (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), |
538 | cx_stride, |
539 | context->cn_stride, |
540 | &context->params); |
541 | } |
542 | |
543 | void xnn_compute_grouped_subconv2d( |
544 | const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], |
545 | size_t batch_index, |
546 | size_t group_index, |
547 | size_t subkernel_index, |
548 | size_t slice_y, |
549 | size_t slice_x_start, |
550 | size_t nc_block_start, |
551 | size_t slice_x_max, |
552 | size_t nc_block_size) |
553 | { |
554 | const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; |
555 | |
556 | if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { |
557 | return; |
558 | } |
559 | |
560 | const size_t slice_width = subconvolution_params->slice_width; |
561 | if XNN_UNLIKELY(slice_x_start >= slice_width) { |
562 | return; |
563 | } |
564 | const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
565 | |
566 | const size_t cx_stride = context->cx_stride; |
567 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
568 | slice_x_size, |
569 | nc_block_size, |
570 | context->kc, |
571 | subconvolution_params->scaled_kernel_size, |
572 | (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride), |
573 | (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride), |
574 | (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), |
575 | cx_stride, |
576 | context->cn_stride, |
577 | context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, |
578 | context->zero, |
579 | &context->params); |
580 | } |
581 | |
582 | void xnn_compute_subconv2d( |
583 | const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], |
584 | size_t batch_index, |
585 | size_t subkernel_index, |
586 | size_t slice_y, |
587 | size_t slice_x_start, |
588 | size_t nc_block_start, |
589 | size_t slice_x_max, |
590 | size_t nc_block_size) |
591 | { |
592 | const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; |
593 | |
594 | if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { |
595 | return; |
596 | } |
597 | |
598 | const size_t slice_width = subconvolution_params->slice_width; |
599 | if XNN_UNLIKELY(slice_x_start >= slice_width) { |
600 | return; |
601 | } |
602 | const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
603 | |
604 | const size_t cx_stride = context->cx_stride; |
605 | context->ukernel.function[XNN_UARCH_DEFAULT]( |
606 | slice_x_size, |
607 | nc_block_size, |
608 | context->kc, |
609 | subconvolution_params->scaled_kernel_size, |
610 | (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride), |
611 | (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride), |
612 | (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), |
613 | cx_stride, |
614 | context->cn_stride, |
615 | context->a_offset + batch_index * context->ba_stride, |
616 | context->zero, |
617 | &context->params); |
618 | } |
619 | |
620 | void xnn_compute_conv2d_hwc2chw( |
621 | const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], |
622 | size_t batch_index, |
623 | size_t output_y_start, |
624 | size_t output_y_slice) |
625 | { |
626 | context->hwc2chw_ukernel( |
627 | context->input_height, |
628 | context->input_width, |
629 | output_y_start, |
630 | output_y_start + output_y_slice, |
631 | (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride), |
632 | context->zero, |
633 | context->packed_weights, |
634 | (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride), |
635 | context->input_padding_top, |
636 | context->output_channels, |
637 | context->output_height_stride, |
638 | context->output_channel_stride, |
639 | &context->params); |
640 | } |
641 | |
642 | void xnn_compute_dwconv_unipass( |
643 | const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], |
644 | size_t batch_index, |
645 | size_t output_y) |
646 | { |
647 | const void** indirect_input = |
648 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
649 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
650 | void* output = (void*) ((uintptr_t) context->output + |
651 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
652 | |
653 | context->unipass_ukernel( |
654 | context->groups, context->output_width, |
655 | indirect_input, context->packed_weights, output, |
656 | context->indirect_input_width_stride, context->output_increment, |
657 | input_offset, context->zero, |
658 | &context->params); |
659 | } |
660 | |
661 | void xnn_compute_dwconv2d_chw( |
662 | const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], |
663 | size_t batch_index, |
664 | size_t channel) |
665 | { |
666 | context->chw_ukernel( |
667 | context->input_height, |
668 | context->input_width, |
669 | (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride), |
670 | (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride), |
671 | context->zero, |
672 | (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride), |
673 | context->input_padding_top, |
674 | &context->params); |
675 | } |
676 | |
677 | void xnn_compute_argmax_pooling_unipass( |
678 | const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
679 | size_t batch_index, |
680 | size_t output_y) |
681 | { |
682 | const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + |
683 | output_y * context->indirect_input_height_stride); |
684 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
685 | void* output = (void*) ((uintptr_t) context->output + |
686 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
687 | uint32_t* index = (uint32_t*) ((uintptr_t) context->index + |
688 | batch_index * context->index_batch_stride + output_y * context->index_height_stride); |
689 | |
690 | context->unipass_ukernel( |
691 | context->output_width, context->pooling_size, context->channels, |
692 | indirect_input, input_offset, output, index, |
693 | context->input_increment, context->output_increment); |
694 | } |
695 | |
696 | void xnn_compute_argmax_pooling_multipass( |
697 | const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
698 | size_t batch_index, |
699 | size_t output_y) |
700 | { |
701 | const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + |
702 | output_y * context->indirect_input_height_stride); |
703 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
704 | void* output = (void*) ((uintptr_t) context->output + |
705 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
706 | uint32_t* index = (uint32_t*) ((uintptr_t) context->index + |
707 | batch_index * context->index_batch_stride + output_y * context->index_height_stride); |
708 | |
709 | void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES); |
710 | void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES); |
711 | |
712 | context->multipass_ukernel( |
713 | context->output_width, context->pooling_size, context->channels, |
714 | indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index, |
715 | context->input_increment, context->output_increment); |
716 | } |
717 | |
718 | void xnn_compute_max_pooling( |
719 | const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
720 | size_t batch_index, |
721 | size_t output_y) |
722 | { |
723 | const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + |
724 | output_y * context->indirect_input_height_stride); |
725 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
726 | void* output = (void*) ((uintptr_t) context->output + |
727 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
728 | |
729 | context->ukernel( |
730 | context->output_width, context->pooling_size, context->channels, |
731 | indirect_input, input_offset, output, |
732 | context->input_increment, context->output_increment, |
733 | &context->params); |
734 | } |
735 | |
736 | void xnn_compute_unpooling( |
737 | const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
738 | size_t input_y, |
739 | size_t input_x) |
740 | { |
741 | const void* input = (const void*) ((uintptr_t) context->input + |
742 | input_y * context->input_height_stride + input_x * context->input_width_stride); |
743 | const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index + |
744 | input_y * context->index_height_stride + input_x * context->index_width_stride); |
745 | void** indirect_output = |
746 | (void**) ((uintptr_t) context->indirect_output + |
747 | input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride); |
748 | |
749 | context->ukernel( |
750 | context->pooling_size, |
751 | context->channels, |
752 | context->fill_value, |
753 | input, index, indirect_output); |
754 | } |
755 | |
756 | void xnn_compute_average_pooling_unipass( |
757 | const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
758 | size_t batch_index, |
759 | size_t output_y) |
760 | { |
761 | const void** indirect_input = |
762 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
763 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
764 | void* output = (void*) ((uintptr_t) context->output + |
765 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
766 | |
767 | context->unipass_ukernel( |
768 | context->output_width, context->pooling_size, context->channels, |
769 | indirect_input, input_offset, context->zero, output, |
770 | context->input_increment, context->output_increment, |
771 | &context->params); |
772 | } |
773 | |
774 | void xnn_compute_average_pooling_multipass( |
775 | const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
776 | size_t batch_index, |
777 | size_t output_y) |
778 | { |
779 | const void** indirect_input = |
780 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
781 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
782 | void* output = (void*) ((uintptr_t) context->output + |
783 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
784 | |
785 | void* multipass_buffer = |
786 | XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); |
787 | |
788 | context->multipass_ukernel( |
789 | context->output_width, context->pooling_size, context->channels, |
790 | indirect_input, input_offset, context->zero, multipass_buffer, output, |
791 | context->input_increment, context->output_increment, |
792 | &context->params); |
793 | } |
794 | |
795 | void xnn_compute_pixelwise_average_pooling_unipass( |
796 | const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
797 | size_t batch_index, |
798 | size_t output_y) |
799 | { |
800 | const void** indirect_input = |
801 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
802 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
803 | const void* pixelwise_buffer = |
804 | (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride); |
805 | void* output = (void*) ((uintptr_t) context->output + |
806 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
807 | |
808 | context->unipass_ukernel( |
809 | context->output_width, context->pooling_size, context->channels, |
810 | indirect_input, input_offset, context->zero, pixelwise_buffer, output, |
811 | context->input_increment, context->output_increment, |
812 | &context->params); |
813 | } |
814 | |
815 | void xnn_compute_pixelwise_average_pooling_multipass( |
816 | const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], |
817 | size_t batch_index, |
818 | size_t output_y) |
819 | { |
820 | const void** indirect_input = |
821 | (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); |
822 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; |
823 | const void* pixelwise_buffer = |
824 | (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride); |
825 | void* output = (void*) ((uintptr_t) context->output + |
826 | batch_index * context->output_batch_stride + output_y * context->output_height_stride); |
827 | |
828 | void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); |
829 | |
830 | context->multipass_ukernel( |
831 | context->output_width, context->pooling_size, context->channels, |
832 | indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output, |
833 | context->input_increment, context->output_increment, |
834 | &context->params); |
835 | } |
836 | |
837 | void xnn_compute_global_average_pooling_nwc_unipass( |
838 | const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], |
839 | size_t batch_index) |
840 | { |
841 | const void* input = |
842 | (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride); |
843 | void* output = |
844 | (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride); |
845 | |
846 | context->unipass_ukernel( |
847 | context->input_elements, |
848 | context->channels, |
849 | input, |
850 | context->input_pixel_stride, |
851 | context->zero, |
852 | output, |
853 | &context->params); |
854 | } |
855 | |
856 | void xnn_compute_global_average_pooling_nwc_multipass( |
857 | const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], |
858 | size_t batch_index) |
859 | { |
860 | const void* input = |
861 | (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride); |
862 | void* output = |
863 | (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride); |
864 | |
865 | void* multipass_buffer = |
866 | XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); |
867 | |
868 | context->multipass_ukernel( |
869 | context->input_elements, |
870 | context->channels, |
871 | input, |
872 | context->input_pixel_stride, |
873 | context->zero, |
874 | multipass_buffer, |
875 | output, |
876 | &context->params); |
877 | } |
878 | |
879 | void xnn_compute_global_average_pooling_ncw( |
880 | const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)], |
881 | size_t batch_index, |
882 | size_t channels_start, |
883 | size_t channels_slice) |
884 | { |
885 | const void* input = (const void*) ((uintptr_t) context->input + |
886 | channels_start * context->input_channel_stride + batch_index * context->input_batch_stride); |
887 | void* output = (void*) ((uintptr_t) context->output + |
888 | channels_start * context->output_channel_stride + batch_index * context->output_batch_stride); |
889 | |
890 | context->ukernel( |
891 | context->input_elements, |
892 | channels_slice, |
893 | input, |
894 | output, |
895 | &context->params); |
896 | } |
897 | |
898 | void xnn_compute_resize_bilinear( |
899 | const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], |
900 | size_t batch_index, |
901 | size_t pixel_start, |
902 | size_t pixel_range) |
903 | { |
904 | void* output = |
905 | (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride); |
906 | |
907 | context->ukernel( |
908 | pixel_range, |
909 | context->scaled_channels, |
910 | context->indirect_input + pixel_start * 4, |
911 | context->input_offset + batch_index * context->input_batch_stride, |
912 | (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)), |
913 | output, |
914 | context->output_pixel_stride - context->scaled_channels); |
915 | } |
916 | |
917 | void xnn_compute_resize_bilinear_chw( |
918 | const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)], |
919 | size_t batch_index, |
920 | size_t channel_start, |
921 | size_t channel_range) |
922 | { |
923 | void* output = |
924 | (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride); |
925 | const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride; |
926 | |
927 | context->ukernel( |
928 | context->output_pixels, |
929 | channel_range, |
930 | context->indirect_input, |
931 | input_offset, |
932 | context->packed_weights, |
933 | output, |
934 | context->input_channel_stride); |
935 | } |
936 | |
937 | void xnn_compute_prelu( |
938 | const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)], |
939 | size_t batch_start, |
940 | size_t batch_range) |
941 | { |
942 | const size_t x_stride = context->x_stride; |
943 | const size_t y_stride = context->y_stride; |
944 | const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start); |
945 | void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start); |
946 | |
947 | context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride); |
948 | } |
949 | |
950 | void xnn_compute_pad_5d( |
951 | const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], |
952 | size_t i, size_t j, size_t k, size_t l, size_t m) |
953 | { |
954 | const void* input = (const void*) ((uintptr_t) context->input + |
955 | i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]); |
956 | void* output = (void*) ((uintptr_t) context->output + |
957 | i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]); |
958 | |
959 | const size_t i_padding = context->pre_paddings[5]; |
960 | const size_t j_padding = context->pre_paddings[4]; |
961 | const size_t k_padding = context->pre_paddings[3]; |
962 | const size_t l_padding = context->pre_paddings[2]; |
963 | const size_t m_padding = context->pre_paddings[1]; |
964 | |
965 | const size_t i_size = context->input_size[5]; |
966 | const size_t j_size = context->input_size[4]; |
967 | const size_t k_size = context->input_size[3]; |
968 | const size_t l_size = context->input_size[2]; |
969 | const size_t m_size = context->input_size[1]; |
970 | |
971 | if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size && |
972 | l - l_padding < l_size && m - m_padding < m_size) |
973 | { |
974 | context->pad_ukernel( |
975 | 1 /* rows */, |
976 | context->input_size[0], context->pre_paddings[0], context->post_paddings[0], |
977 | input, 0 /* input stride */, output, 0 /* output stride */, |
978 | context->padding_value); |
979 | } else { |
980 | context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, context->padding_value); |
981 | } |
982 | } |
983 | |
984 | void xnn_compute_slice_1d( |
985 | const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], |
986 | size_t i) |
987 | { |
988 | const void* input = (const void*) ((uintptr_t) context->input + i * context->input_stride[0]); |
989 | void* output = (void*) ((uintptr_t) context->output + i * context->output_stride[0]); |
990 | |
991 | context->ukernel(context->contiguous_size, input, output, NULL); |
992 | } |
993 | |
994 | void xnn_compute_slice_2d( |
995 | const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], |
996 | size_t i, size_t j) |
997 | { |
998 | const void* input = |
999 | (const void*) ((uintptr_t) context->input + |
1000 | i * context->input_stride[1] + |
1001 | j * context->input_stride[0]); |
1002 | void* output = |
1003 | (void*) ((uintptr_t) context->output + i * context->output_stride[1] + j * context->output_stride[0]); |
1004 | |
1005 | context->ukernel(context->contiguous_size, input, output, NULL); |
1006 | } |
1007 | |
1008 | void xnn_compute_slice_3d( |
1009 | const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], |
1010 | size_t i, size_t j, size_t k) |
1011 | { |
1012 | const void* input = |
1013 | (const void*) ((uintptr_t) context->input + |
1014 | i * context->input_stride[2] + |
1015 | j * context->input_stride[1] + |
1016 | k * context->input_stride[0]); |
1017 | void* output = |
1018 | (void*) ((uintptr_t) context->output + i * context->output_stride[2] + |
1019 | j * context->output_stride[1] + k * context->output_stride[0]); |
1020 | |
1021 | context->ukernel(context->contiguous_size, input, output, NULL); |
1022 | } |
1023 | |
1024 | void xnn_compute_slice_4d( |
1025 | const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], |
1026 | size_t i, size_t j, size_t k, size_t l) |
1027 | { |
1028 | const void* input = |
1029 | (const void*) ((uintptr_t) context->input + |
1030 | i * context->input_stride[3] + |
1031 | j * context->input_stride[2] + |
1032 | k * context->input_stride[1] + |
1033 | l * context->input_stride[0]); |
1034 | void* output = |
1035 | (void*) ((uintptr_t) context->output + i * context->output_stride[3] + |
1036 | j * context->output_stride[2] + k * context->output_stride[1] + l * context->output_stride[0]); |
1037 | |
1038 | context->ukernel(context->contiguous_size, input, output, NULL); |
1039 | } |
1040 | |
1041 | void xnn_compute_slice_5d( |
1042 | const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], |
1043 | size_t i, size_t j, size_t k, size_t l, size_t m) |
1044 | { |
1045 | const void* input = |
1046 | (const void* ) ((uintptr_t) context->input + |
1047 | i * context->input_stride[4] + |
1048 | j * context->input_stride[3] + |
1049 | k * context->input_stride[2] + |
1050 | l * context->input_stride[1] + |
1051 | m * context->input_stride[0]); |
1052 | void* output = |
1053 | (void*) ((uintptr_t) context->output + i * context->output_stride[4] + |
1054 | j * context->output_stride[3] + k * context->output_stride[2] + |
1055 | l * context->output_stride[1] + m * context->output_stride[0]); |
1056 | |
1057 | context->ukernel(context->contiguous_size, input, output, NULL); |
1058 | } |
1059 | |
1060 | void xnn_compute_elementwise_binary_1d( |
1061 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
1062 | size_t i) |
1063 | { |
1064 | const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[4]); |
1065 | const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[4]); |
1066 | void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[4]); |
1067 | context->ukernel(context->elements, a, b, y, &context->params); |
1068 | } |
1069 | |
1070 | void xnn_compute_elementwise_binary_2d( |
1071 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
1072 | size_t i, size_t j) |
1073 | { |
1074 | const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[3] + j * context->a_stride[4]); |
1075 | const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[3] + j * context->b_stride[4]); |
1076 | void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[3] + j * context->y_stride[4]); |
1077 | context->ukernel(context->elements, a, b, y, &context->params); |
1078 | } |
1079 | |
1080 | void xnn_compute_elementwise_binary_3d( |
1081 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
1082 | size_t i, size_t j, size_t k) |
1083 | { |
1084 | const void* a = (const void*) ((uintptr_t) context->a + |
1085 | i * context->a_stride[2] + j * context->a_stride[3] + k * context->a_stride[4]); |
1086 | const void* b = (const void*) ((uintptr_t) context->b + |
1087 | i * context->b_stride[2] + j * context->b_stride[3] + k * context->b_stride[4]); |
1088 | void* y = (void*) ((uintptr_t) context->y + |
1089 | i * context->y_stride[2] + j * context->y_stride[3] + k * context->y_stride[4]); |
1090 | context->ukernel(context->elements, a, b, y, &context->params); |
1091 | } |
1092 | |
1093 | void xnn_compute_elementwise_binary_4d( |
1094 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
1095 | size_t i, size_t j, size_t k, size_t l) |
1096 | { |
1097 | const void* a = (const void*) ((uintptr_t) context->a + |
1098 | i * context->a_stride[1] + j * context->a_stride[2] + k * context->a_stride[3] + l * context->a_stride[4]); |
1099 | const void* b = (const void*) ((uintptr_t) context->b + |
1100 | i * context->b_stride[1] + j * context->b_stride[2] + k * context->b_stride[3] + l * context->b_stride[4]); |
1101 | void* y = (void*) ((uintptr_t) context->y + |
1102 | i * context->y_stride[1] + j * context->y_stride[2] + k * context->y_stride[3] + l * context->y_stride[4]); |
1103 | context->ukernel(context->elements, a, b, y, &context->params); |
1104 | } |
1105 | |
1106 | void xnn_compute_elementwise_binary_5d( |
1107 | const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], |
1108 | size_t i, size_t j, size_t k, size_t l, size_t m) |
1109 | { |
1110 | const void* a = (const void*) ((uintptr_t) context->a + |
1111 | i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]); |
1112 | const void* b = (const void*) ((uintptr_t) context->b + |
1113 | i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]); |
1114 | void* y = (void*) ((uintptr_t) context->y + |
1115 | i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]); |
1116 | context->ukernel(context->elements, a, b, y, &context->params); |
1117 | } |
1118 | |
1119 | void xnn_compute_channel_shuffle_fixed( |
1120 | const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], |
1121 | size_t index) |
1122 | { |
1123 | const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); |
1124 | void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); |
1125 | |
1126 | context->fixed_ukernel(context->n, x, y); |
1127 | } |
1128 | |
1129 | void xnn_compute_channel_shuffle_variable( |
1130 | const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], |
1131 | size_t index) |
1132 | { |
1133 | const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); |
1134 | void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); |
1135 | |
1136 | context->variable_ukernel(context->n, context->m, x, y); |
1137 | } |
1138 | |
1139 | void xnn_compute_lut_strided( |
1140 | const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], |
1141 | size_t batch_index) |
1142 | { |
1143 | const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index); |
1144 | void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index); |
1145 | |
1146 | context->ukernel(context->n, x, y, context->t); |
1147 | } |
1148 | |
1149 | void xnn_compute_lut_contiguous( |
1150 | const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], |
1151 | size_t offset, |
1152 | size_t size) |
1153 | { |
1154 | const void* x = (const void*) ((uintptr_t) context->x + offset); |
1155 | void* y = (void*) ((uintptr_t) context->y + offset); |
1156 | |
1157 | context->ukernel(size, x, y, context->t); |
1158 | } |
1159 | |
1160 | void xnn_compute_univector_strided( |
1161 | const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)], |
1162 | size_t batch_index, |
1163 | size_t batch_range) |
1164 | { |
1165 | const size_t x_stride = context->x_stride; |
1166 | const size_t y_stride = context->y_stride; |
1167 | |
1168 | const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_index); |
1169 | void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index); |
1170 | do { |
1171 | context->ukernel(context->n, x, y, &context->params); |
1172 | x = (const void*) ((uintptr_t) x + x_stride); |
1173 | y = (void*) ((uintptr_t) y + y_stride); |
1174 | } while (--batch_range != 0); |
1175 | } |
1176 | |
1177 | void xnn_compute_univector_contiguous( |
1178 | const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], |
1179 | size_t offset, |
1180 | size_t size) |
1181 | { |
1182 | const uint32_t log2_xsize = context->log2_xsize; |
1183 | const uint32_t log2_ysize = context->log2_ysize; |
1184 | const void* x = (const void*) ((uintptr_t) context->x + offset); |
1185 | void* y = (void*) ((uintptr_t) context->y + ((offset >> log2_xsize) << log2_ysize)); |
1186 | context->ukernel(size, x, y, &context->params); |
1187 | } |
1188 | |
1189 | void xnn_compute_u8_softmax( |
1190 | const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], |
1191 | size_t batch_index) |
1192 | { |
1193 | const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index); |
1194 | uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index); |
1195 | const size_t n = context->n; |
1196 | |
1197 | uint8_t x_max = 0; |
1198 | context->rmax_ukernel(n, x, &x_max); |
1199 | const size_t adjustment = x_max ^ 255; |
1200 | const uint32_t* t = (const uint32_t*) context->t + adjustment; |
1201 | context->lut_norm_ukernel(n, x, t, y); |
1202 | } |
1203 | |
1204 | void xnn_compute_floating_point_softmax( |
1205 | const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], |
1206 | size_t batch_index) |
1207 | { |
1208 | const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index); |
1209 | void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index); |
1210 | const size_t n = context->n; |
1211 | |
1212 | // First pass: reduce-max |
1213 | union { |
1214 | float as_float; |
1215 | uint16_t as_half; |
1216 | } x_max; |
1217 | context->rmax_ukernel(n, x, &x_max); |
1218 | |
1219 | // Second pass: reduce-add & store exp(x-x_max) |
1220 | union { |
1221 | float as_float; |
1222 | uint16_t as_half; |
1223 | } y_sum; |
1224 | context->raddstoreexpminusmax_ukernel(n, x, &x_max, y, &y_sum, &context->expminus_params); |
1225 | |
1226 | // Third pass: scale y |
1227 | union { |
1228 | float as_float; |
1229 | uint16_t as_half; |
1230 | } y_scale; |
1231 | context->compute_reciprocal(&y_sum, &y_scale); |
1232 | context->vmulc_ukernel(n, y, &y_scale, y, &context->minmax_params); |
1233 | } |
1234 | |
1235 | void xnn_compute_vmulcaddc( |
1236 | const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], |
1237 | size_t batch_start, |
1238 | size_t batch_size) |
1239 | { |
1240 | const size_t x_stride = context->x_stride; |
1241 | const size_t y_stride = context->y_stride; |
1242 | |
1243 | const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start); |
1244 | void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start); |
1245 | |
1246 | context->ukernel( |
1247 | batch_size, |
1248 | context->n, |
1249 | x, x_stride, |
1250 | context->w, |
1251 | y, y_stride, |
1252 | &context->params); |
1253 | } |
1254 | |
1255 | #if XNN_MAX_UARCH_TYPES > 1 |
1256 | void xnn_compute_hmp_grouped_gemm( |
1257 | const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1258 | uint32_t uarch_index, |
1259 | size_t group_index, |
1260 | size_t mr_block_start, |
1261 | size_t nr_block_start, |
1262 | size_t mr_block_size, |
1263 | size_t nr_block_size) |
1264 | { |
1265 | const size_t k_scaled = context->k_scaled; |
1266 | const size_t a_stride = context->a_stride; |
1267 | const size_t cm_stride = context->cm_stride; |
1268 | |
1269 | context->ukernel.function[uarch_index]( |
1270 | mr_block_size, |
1271 | nr_block_size, |
1272 | k_scaled, |
1273 | (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled), |
1274 | a_stride, |
1275 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride), |
1276 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride), |
1277 | cm_stride, |
1278 | context->cn_stride, |
1279 | &context->params); |
1280 | } |
1281 | |
1282 | void xnn_compute_hmp_gemm( |
1283 | const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1284 | uint32_t uarch_index, |
1285 | size_t mr_block_start, |
1286 | size_t nr_block_start, |
1287 | size_t mr_block_size, |
1288 | size_t nr_block_size) |
1289 | { |
1290 | const size_t a_stride = context->a_stride; |
1291 | const size_t cm_stride = context->cm_stride; |
1292 | |
1293 | context->ukernel.function[uarch_index]( |
1294 | mr_block_size, |
1295 | nr_block_size, |
1296 | context->k_scaled, |
1297 | (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), |
1298 | a_stride, |
1299 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
1300 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1301 | cm_stride, |
1302 | context->cn_stride, |
1303 | context->fused_params); |
1304 | } |
1305 | |
1306 | void xnn_compute_hmp_grouped_batch_igemm( |
1307 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1308 | uint32_t uarch_index, |
1309 | size_t batch_index, |
1310 | size_t group_index, |
1311 | size_t mr_block_start, |
1312 | size_t nr_block_start, |
1313 | size_t mr_block_size, |
1314 | size_t nr_block_size) |
1315 | { |
1316 | const size_t ks = context->ks; |
1317 | const size_t cm_stride = context->cm_stride; |
1318 | |
1319 | context->ukernel.function[uarch_index]( |
1320 | mr_block_size, |
1321 | nr_block_size, |
1322 | context->kc, |
1323 | context->ks_scaled, |
1324 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
1325 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), |
1326 | (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1327 | cm_stride, |
1328 | context->cn_stride, |
1329 | context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, |
1330 | context->zero, |
1331 | &context->params); |
1332 | } |
1333 | |
1334 | void xnn_compute_hmp_grouped_igemm( |
1335 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1336 | uint32_t uarch_index, |
1337 | size_t group_index, |
1338 | size_t mr_block_start, |
1339 | size_t nr_block_start, |
1340 | size_t mr_block_size, |
1341 | size_t nr_block_size) |
1342 | { |
1343 | const size_t ks = context->ks; |
1344 | const size_t cm_stride = context->cm_stride; |
1345 | |
1346 | context->ukernel.function[uarch_index]( |
1347 | mr_block_size, |
1348 | nr_block_size, |
1349 | context->kc, |
1350 | context->ks_scaled, |
1351 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
1352 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), |
1353 | (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1354 | cm_stride, |
1355 | context->cn_stride, |
1356 | context->a_offset + group_index * context->ga_stride, |
1357 | context->zero, |
1358 | &context->params); |
1359 | } |
1360 | |
1361 | void xnn_compute_batch_hmp_igemm( |
1362 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1363 | uint32_t uarch_index, |
1364 | size_t batch_index, |
1365 | size_t mr_block_start, |
1366 | size_t nr_block_start, |
1367 | size_t mr_block_size, |
1368 | size_t nr_block_size) |
1369 | { |
1370 | const size_t ks = context->ks; |
1371 | const size_t cm_stride = context->cm_stride; |
1372 | |
1373 | context->ukernel.function[uarch_index]( |
1374 | mr_block_size, |
1375 | nr_block_size, |
1376 | context->kc, |
1377 | context->ks_scaled, |
1378 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
1379 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
1380 | (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1381 | cm_stride, |
1382 | context->cn_stride, |
1383 | context->a_offset + batch_index * context->ba_stride, |
1384 | context->zero, |
1385 | &context->params); |
1386 | } |
1387 | |
1388 | void xnn_compute_hmp_igemm( |
1389 | const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], |
1390 | uint32_t uarch_index, |
1391 | size_t mr_block_start, |
1392 | size_t nr_block_start, |
1393 | size_t mr_block_size, |
1394 | size_t nr_block_size) |
1395 | { |
1396 | const size_t ks = context->ks; |
1397 | const size_t cm_stride = context->cm_stride; |
1398 | |
1399 | context->ukernel.function[uarch_index]( |
1400 | mr_block_size, |
1401 | nr_block_size, |
1402 | context->kc, |
1403 | context->ks_scaled, |
1404 | (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), |
1405 | (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), |
1406 | (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), |
1407 | cm_stride, |
1408 | context->cn_stride, |
1409 | context->a_offset, |
1410 | context->zero, |
1411 | &context->params); |
1412 | } |
1413 | #endif // XNN_MAX_UARCH_TYPES > 1 |
1414 | |
1415 | |
1416 | enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool) |
1417 | { |
1418 | return xnn_run_operator_with_index(op, 0, 0, threadpool); |
1419 | } |
1420 | |
1421 | enum xnn_status xnn_run_operator_with_index( |
1422 | xnn_operator_t op, |
1423 | size_t opdata_index, |
1424 | size_t operator_object_index, |
1425 | pthreadpool_t threadpool) |
1426 | { |
1427 | switch (op->state) { |
1428 | case xnn_run_state_invalid: |
1429 | xnn_log_error("failed to run operator: operator was not successfully setup" ); |
1430 | return xnn_status_invalid_state; |
1431 | case xnn_run_state_ready: |
1432 | xnn_log_debug("running operator %zu:%zu (%s %s)" , opdata_index, |
1433 | operator_object_index, |
1434 | xnn_operator_type_to_string(op->type), |
1435 | xnn_microkernel_type_to_string(op->ukernel.type)); |
1436 | break; |
1437 | case xnn_run_state_skip: |
1438 | xnn_log_debug("skip running operator %zu:%zu (%s %s)" , opdata_index, |
1439 | operator_object_index, |
1440 | xnn_operator_type_to_string(op->type), |
1441 | xnn_microkernel_type_to_string(op->ukernel.type)); |
1442 | return xnn_status_success; |
1443 | } |
1444 | |
1445 | uint32_t flags = PTHREADPOOL_FLAG_DISABLE_DENORMALS; |
1446 | if (op->flags & XNN_FLAG_YIELD_WORKERS) { |
1447 | flags |= PTHREADPOOL_FLAG_YIELD_WORKERS; |
1448 | } |
1449 | switch (op->compute.type) { |
1450 | case xnn_parallelization_type_invalid: |
1451 | break; |
1452 | case xnn_parallelization_type_1d: |
1453 | assert(op->compute.range[0] != 0); |
1454 | pthreadpool_parallelize_1d( |
1455 | threadpool, |
1456 | op->compute.task_1d, |
1457 | &op->context, |
1458 | op->compute.range[0], |
1459 | flags); |
1460 | break; |
1461 | case xnn_parallelization_type_1d_tile_1d: |
1462 | assert(op->compute.range[0] != 0); |
1463 | assert(op->compute.tile[0] != 0); |
1464 | pthreadpool_parallelize_1d_tile_1d( |
1465 | threadpool, |
1466 | op->compute.task_1d_tile_1d, |
1467 | &op->context, |
1468 | op->compute.range[0], |
1469 | op->compute.tile[0], |
1470 | flags); |
1471 | break; |
1472 | case xnn_parallelization_type_2d: |
1473 | assert(op->compute.range[0] != 0); |
1474 | assert(op->compute.range[1] != 0); |
1475 | pthreadpool_parallelize_2d( |
1476 | threadpool, |
1477 | op->compute.task_2d, |
1478 | &op->context, |
1479 | op->compute.range[0], op->compute.range[1], |
1480 | flags); |
1481 | break; |
1482 | case xnn_parallelization_type_2d_tile_1d: |
1483 | assert(op->compute.range[0] != 0); |
1484 | assert(op->compute.range[1] != 0); |
1485 | assert(op->compute.tile[0] != 0); |
1486 | pthreadpool_parallelize_2d_tile_1d( |
1487 | threadpool, |
1488 | op->compute.task_2d_tile_1d, |
1489 | &op->context, |
1490 | op->compute.range[0], op->compute.range[1], |
1491 | op->compute.tile[0], |
1492 | flags); |
1493 | break; |
1494 | case xnn_parallelization_type_2d_tile_2d: |
1495 | assert(op->compute.range[0] != 0); |
1496 | assert(op->compute.range[1] != 0); |
1497 | assert(op->compute.tile[0] != 0); |
1498 | assert(op->compute.tile[1] != 0); |
1499 | pthreadpool_parallelize_2d_tile_2d( |
1500 | threadpool, |
1501 | op->compute.task_2d_tile_2d, |
1502 | &op->context, |
1503 | op->compute.range[0], op->compute.range[1], |
1504 | op->compute.tile[0], op->compute.tile[1], |
1505 | flags); |
1506 | break; |
1507 | case xnn_parallelization_type_3d: |
1508 | assert(op->compute.range[0] != 0); |
1509 | assert(op->compute.range[1] != 0); |
1510 | assert(op->compute.range[2] != 0); |
1511 | pthreadpool_parallelize_3d( |
1512 | threadpool, |
1513 | op->compute.task_3d, |
1514 | &op->context, |
1515 | op->compute.range[0], op->compute.range[1], op->compute.range[2], |
1516 | flags); |
1517 | break; |
1518 | case xnn_parallelization_type_3d_tile_2d: |
1519 | assert(op->compute.range[0] != 0); |
1520 | assert(op->compute.range[1] != 0); |
1521 | assert(op->compute.range[2] != 0); |
1522 | assert(op->compute.tile[0] != 0); |
1523 | assert(op->compute.tile[1] != 0); |
1524 | pthreadpool_parallelize_3d_tile_2d( |
1525 | threadpool, |
1526 | op->compute.task_3d_tile_2d, |
1527 | &op->context, |
1528 | op->compute.range[0], op->compute.range[1], op->compute.range[2], |
1529 | op->compute.tile[0], op->compute.tile[1], |
1530 | flags); |
1531 | break; |
1532 | case xnn_parallelization_type_4d: |
1533 | assert(op->compute.range[0] != 0); |
1534 | assert(op->compute.range[1] != 0); |
1535 | assert(op->compute.range[2] != 0); |
1536 | assert(op->compute.range[3] != 0); |
1537 | pthreadpool_parallelize_4d( |
1538 | threadpool, |
1539 | op->compute.task_4d, |
1540 | &op->context, |
1541 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], |
1542 | flags); |
1543 | break; |
1544 | case xnn_parallelization_type_4d_tile_2d: |
1545 | assert(op->compute.range[0] != 0); |
1546 | assert(op->compute.range[1] != 0); |
1547 | assert(op->compute.range[2] != 0); |
1548 | assert(op->compute.range[3] != 0); |
1549 | assert(op->compute.tile[0] != 0); |
1550 | assert(op->compute.tile[1] != 0); |
1551 | pthreadpool_parallelize_4d_tile_2d( |
1552 | threadpool, |
1553 | op->compute.task_4d_tile_2d, |
1554 | &op->context, |
1555 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], |
1556 | op->compute.tile[0], op->compute.tile[1], |
1557 | flags); |
1558 | break; |
1559 | case xnn_parallelization_type_5d: |
1560 | assert(op->compute.range[0] != 0); |
1561 | assert(op->compute.range[1] != 0); |
1562 | assert(op->compute.range[2] != 0); |
1563 | assert(op->compute.range[3] != 0); |
1564 | assert(op->compute.range[4] != 0); |
1565 | pthreadpool_parallelize_5d( |
1566 | threadpool, |
1567 | op->compute.task_5d, |
1568 | &op->context, |
1569 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], |
1570 | flags); |
1571 | break; |
1572 | case xnn_parallelization_type_5d_tile_2d: |
1573 | assert(op->compute.range[0] != 0); |
1574 | assert(op->compute.range[1] != 0); |
1575 | assert(op->compute.range[2] != 0); |
1576 | assert(op->compute.range[3] != 0); |
1577 | assert(op->compute.range[4] != 0); |
1578 | assert(op->compute.tile[0] != 0); |
1579 | assert(op->compute.tile[1] != 0); |
1580 | pthreadpool_parallelize_5d_tile_2d( |
1581 | threadpool, |
1582 | op->compute.task_5d_tile_2d, |
1583 | &op->context, |
1584 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], |
1585 | op->compute.tile[0], op->compute.tile[1], |
1586 | flags); |
1587 | break; |
1588 | case xnn_parallelization_type_6d_tile_2d: |
1589 | assert(op->compute.range[0] != 0); |
1590 | assert(op->compute.range[1] != 0); |
1591 | assert(op->compute.range[2] != 0); |
1592 | assert(op->compute.range[3] != 0); |
1593 | assert(op->compute.range[4] != 0); |
1594 | assert(op->compute.range[5] != 0); |
1595 | assert(op->compute.tile[0] != 0); |
1596 | assert(op->compute.tile[1] != 0); |
1597 | pthreadpool_parallelize_6d_tile_2d( |
1598 | threadpool, |
1599 | op->compute.task_6d_tile_2d, |
1600 | &op->context, |
1601 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5], |
1602 | op->compute.tile[0], op->compute.tile[1], |
1603 | flags); |
1604 | break; |
1605 | #if XNN_MAX_UARCH_TYPES > 1 |
1606 | case xnn_parallelization_type_2d_tile_2d_with_uarch: |
1607 | assert(op->compute.range[0] != 0); |
1608 | assert(op->compute.range[1] != 0); |
1609 | assert(op->compute.tile[0] != 0); |
1610 | assert(op->compute.tile[1] != 0); |
1611 | pthreadpool_parallelize_2d_tile_2d_with_uarch( |
1612 | threadpool, |
1613 | op->compute.task_2d_tile_2d_with_id, |
1614 | &op->context, |
1615 | 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
1616 | op->compute.range[0], op->compute.range[1], |
1617 | op->compute.tile[0], op->compute.tile[1], |
1618 | flags); |
1619 | break; |
1620 | case xnn_parallelization_type_3d_tile_2d_with_uarch: |
1621 | assert(op->compute.range[0] != 0); |
1622 | assert(op->compute.range[1] != 0); |
1623 | assert(op->compute.range[2] != 0); |
1624 | assert(op->compute.tile[0] != 0); |
1625 | assert(op->compute.tile[1] != 0); |
1626 | pthreadpool_parallelize_3d_tile_2d_with_uarch( |
1627 | threadpool, |
1628 | op->compute.task_3d_tile_2d_with_id, |
1629 | &op->context, |
1630 | 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
1631 | op->compute.range[0], op->compute.range[1], op->compute.range[2], |
1632 | op->compute.tile[0], op->compute.tile[1], |
1633 | flags); |
1634 | break; |
1635 | case xnn_parallelization_type_4d_tile_2d_with_uarch: |
1636 | assert(op->compute.range[0] != 0); |
1637 | assert(op->compute.range[1] != 0); |
1638 | assert(op->compute.range[2] != 0); |
1639 | assert(op->compute.range[3] != 0); |
1640 | assert(op->compute.tile[0] != 0); |
1641 | assert(op->compute.tile[1] != 0); |
1642 | pthreadpool_parallelize_4d_tile_2d_with_uarch( |
1643 | threadpool, |
1644 | op->compute.task_4d_tile_2d_with_id, |
1645 | &op->context, |
1646 | 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
1647 | op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], |
1648 | op->compute.tile[0], op->compute.tile[1], |
1649 | flags); |
1650 | break; |
1651 | #endif // XNN_MAX_UARCH_TYPES > 1 |
1652 | default: |
1653 | XNN_UNREACHABLE; |
1654 | } |
1655 | return xnn_status_success; |
1656 | } |
1657 | |