1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <math.h>
7#include <stddef.h>
8#include <stdint.h>
9#include <stdlib.h>
10
11#include <fp16.h>
12
13#include <xnnpack.h>
14#include <xnnpack/allocator.h>
15#include <xnnpack/log.h>
16#include <xnnpack/math.h>
17#include <xnnpack/node-type.h>
18#include <xnnpack/params.h>
19#include <xnnpack/subgraph.h>
20
21
22#ifndef XNN_ENABLE_SPARSE
23 #error "XNN_ENABLE_SPARSE not defined"
24#endif
25
26enum xnn_status xnn_create_subgraph(
27 uint32_t external_value_ids,
28 uint32_t flags,
29 xnn_subgraph_t* subgraph_out)
30{
31 struct xnn_subgraph* subgraph = NULL;
32 enum xnn_status status = xnn_status_uninitialized;
33
34 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
35 xnn_log_error("failed to create subgraph: XNNPACK is not initialized");
36 goto error;
37 }
38
39 status = xnn_status_out_of_memory;
40
41 subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph));
42 if (subgraph == NULL) {
43 xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph));
44 goto error;
45 }
46
47 subgraph->external_value_ids = external_value_ids;
48
49 subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value));
50 if (subgraph->values == NULL) {
51 xnn_log_error("failed to allocate %zu bytes for subgraph values",
52 (size_t) external_value_ids * sizeof(struct xnn_value));
53 goto error;
54 }
55 for (size_t i = 0; i < external_value_ids; i++) {
56 subgraph->values[i].id = i;
57 }
58 subgraph->num_values = external_value_ids;
59 subgraph->num_reserved_values = external_value_ids;
60
61 *subgraph_out = subgraph;
62 return xnn_status_success;
63
64error:
65 xnn_delete_subgraph(subgraph);
66 return status;
67}
68
69
70struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)
71{
72 struct xnn_value* values = subgraph->values;
73 const size_t size = subgraph->num_values;
74 const size_t capacity = subgraph->num_reserved_values;
75 if (capacity < size + 1) {
76 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
77 assert(new_capacity >= size + 1);
78 values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value));
79 if (values == NULL) {
80 xnn_log_error("failed to allocate %zu bytes for subgraph values",
81 capacity * sizeof(struct xnn_value));
82 return values;
83 }
84
85 memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value));
86 subgraph->num_reserved_values = new_capacity;
87 subgraph->values = values;
88 }
89 subgraph->num_values = size + 1;
90 struct xnn_value* new_value = values + size;
91 new_value->id = size;
92 return new_value;
93}
94
95void xnn_node_clear(struct xnn_node* node) {
96 assert(node != NULL);
97 memset(node, 0, sizeof(struct xnn_node));
98}
99
100void xnn_value_clear(struct xnn_value* value) {
101 assert(value != NULL);
102 memset(value, 0, sizeof(struct xnn_value));
103}
104
105void xnn_value_copy(
106 struct xnn_value* dst_value,
107 const struct xnn_value* src_value)
108{
109 // Note: Value ID stays unchanged
110
111 dst_value->type = src_value->type;
112 dst_value->datatype = src_value->datatype;
113 dst_value->quantization = src_value->quantization;
114 dst_value->shape = src_value->shape;
115 dst_value->flags = src_value->flags;
116 dst_value->data = src_value->data;
117 dst_value->producer = src_value->producer;
118 dst_value->first_consumer = src_value->first_consumer;
119}
120
121struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph)
122{
123 struct xnn_node* nodes = subgraph->nodes;
124 const size_t size = subgraph->num_nodes;
125 const size_t capacity = subgraph->num_reserved_nodes;
126
127 if (capacity < size + 1) {
128 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
129 assert(new_capacity >= size + 1);
130 nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
131 if (nodes == NULL) {
132 xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
133 capacity * sizeof(struct xnn_node));
134 return nodes;
135 }
136
137 memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
138 subgraph->num_reserved_nodes = new_capacity;
139 subgraph->nodes = nodes;
140 }
141 subgraph->num_nodes = size + 1;
142 struct xnn_node* new_node = nodes + size;
143 new_node->id = size;
144 return new_node;
145}
146
147void xnn_subgraph_add_nodes(xnn_subgraph_t subgraph, size_t num_nodes)
148{
149 struct xnn_node* nodes = subgraph->nodes;
150 const size_t size = subgraph->num_nodes;
151 const size_t capacity = subgraph->num_reserved_nodes;
152
153 if (capacity < size + num_nodes) {
154 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + max(num_nodes, 64));
155 assert(new_capacity >= size + num_nodes);
156 nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
157 if (nodes == NULL) {
158 xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
159 capacity * sizeof(struct xnn_node));
160 return;
161 }
162
163 memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
164 subgraph->num_reserved_nodes = new_capacity;
165 subgraph->nodes = nodes;
166 }
167 subgraph->num_nodes = size + num_nodes;
168 struct xnn_node* new_nodes = nodes + size;
169 for (size_t i = 0; i < num_nodes; i++) {
170 new_nodes[i].id = size + i;
171 }
172}
173
174void xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph)
175{
176 // Initialize producer/consumer fields to safe defaults.
177 for (uint32_t i = 0; i < subgraph->num_values; i++) {
178 struct xnn_value* value = &subgraph->values[i];
179 value->producer = XNN_INVALID_NODE_ID;
180 value->first_consumer = XNN_INVALID_NODE_ID;
181 value->num_consumers = 0;
182 }
183
184 // Analyse Nodes' inputs and output and update Values' producer/consumer fields
185 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
186 struct xnn_node* node = &subgraph->nodes[n];
187
188 for (uint32_t i = 0; i < node->num_inputs; i++) {
189 const uint32_t input_id = node->inputs[i];
190 assert(input_id < subgraph->num_values);
191
192 if (subgraph->values[input_id].num_consumers++ == 0) {
193 assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID);
194 subgraph->values[input_id].first_consumer = n;
195 }
196 }
197
198 for (uint32_t o = 0; o < node->num_outputs; o++) {
199 const uint32_t output_id = node->outputs[o];
200 assert(output_id < subgraph->num_values);
201
202 // Persistent values can be produced by multiple nodes, e.g. copy nodes writing to the same persistent value.
203 assert(xnn_value_is_persistent(&subgraph->values[output_id]) ||
204 subgraph->values[output_id].producer == XNN_INVALID_NODE_ID);
205 subgraph->values[output_id].producer = n;
206 }
207 }
208
209 // Count extra consumer for Values which are external outputs.
210 // Remove unreferenced values.
211 for (uint32_t i = 0; i < subgraph->num_values; i++) {
212 struct xnn_value* value = &subgraph->values[i];
213 if (xnn_value_is_external_output(value)) {
214 value->num_consumers += 1;
215 }
216 }
217}
218
219#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW 1
220#define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
221#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
222#define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
223
224uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
225 if (node->compute_type != xnn_compute_type_fp16 && node->compute_type != xnn_compute_type_fp32) {
226 if (node->type != xnn_node_type_invalid) {
227 xnn_log_info(
228 "Node %s compute type %d is incompatible with sparse inference",
229 xnn_node_type_to_string(node->type), node->compute_type);
230 }
231 return 0;
232 }
233
234 switch (node->type) {
235 case xnn_node_type_convolution_2d:
236 // Supported cases:
237 // - 1x1 convolution (no stride, no dilation, no padding, no groups)
238 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
239 if (node->params.convolution_2d.groups != 1) {
240 xnn_log_info("Node %s groups (%" PRIu32 ") "
241 "is incompatible with sparse inference",
242 xnn_node_type_to_string(node->type),
243 node->params.convolution_2d.groups);
244 return 0;
245 }
246 if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
247 xnn_log_info("Node %s dilation (height=%" PRIu32 ", width=%" PRIu32 ") "
248 "is incompatible with sparse inference",
249 xnn_node_type_to_string(node->type),
250 node->params.convolution_2d.dilation_height,
251 node->params.convolution_2d.dilation_width);
252 return 0;
253 }
254 if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
255 if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
256 node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0) {
257 xnn_log_info("Node %s (1x1 kernel) padding (top=%" PRIu32 ", right=%" PRIu32", bottom=%" PRIu32 ", left=%" PRIu32") "
258 "is incompatible with sparse inference",
259 xnn_node_type_to_string(node->type),
260 node->params.convolution_2d.input_padding_top,
261 node->params.convolution_2d.input_padding_right,
262 node->params.convolution_2d.input_padding_bottom,
263 node->params.convolution_2d.input_padding_left);
264 return 0;
265 }
266 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
267 xnn_log_info("Node %s (1x1 kernel) subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
268 "is incompatible with sparse inference",
269 xnn_node_type_to_string(node->type),
270 node->params.convolution_2d.subsampling_height,
271 node->params.convolution_2d.subsampling_width);
272 return 0;
273 }
274 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
275 } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
276 if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
277 node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1) {
278 xnn_log_info("Node %s (3x3 kernel) padding (top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
279 "is incompatible with sparse inference",
280 xnn_node_type_to_string(node->type),
281 node->params.convolution_2d.input_padding_top,
282 node->params.convolution_2d.input_padding_right,
283 node->params.convolution_2d.input_padding_bottom,
284 node->params.convolution_2d.input_padding_left);
285 return 0;
286 }
287 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
288 xnn_log_info("Node %s (3x3 kernel) subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
289 "is incompatible with sparse inference",
290 xnn_node_type_to_string(node->type),
291 node->params.convolution_2d.subsampling_height,
292 node->params.convolution_2d.subsampling_width);
293 return 0;
294 }
295 if (node->params.convolution_2d.group_input_channels != 3) {
296 xnn_log_info("Node %s (3x3 kernel) input channels (%zu) "
297 "is incompatible with sparse inference",
298 xnn_node_type_to_string(node->type),
299 node->params.convolution_2d.group_input_channels);
300 return 0;
301 }
302 return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
303 }
304 return 0;
305 case xnn_node_type_depthwise_convolution_2d:
306 // Supported cases:
307 // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
308 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
309 // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
310 // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
311 if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
312 xnn_log_info("Node %s dilation (height=%" PRIu32 ", width=%" PRIu32 ") "
313 "is incompatible with sparse inference",
314 xnn_node_type_to_string(node->type),
315 node->params.convolution_2d.dilation_height,
316 node->params.convolution_2d.dilation_width);
317 return 0;
318 }
319 if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
320 xnn_log_info("Node %s flags (%" PRIu32 ") has padding incompatible with sparse inference",
321 xnn_node_type_to_string(node->type),
322 node->flags);
323 return 0;
324 }
325 if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
326 xnn_log_info("Node %s depth_multiplier (%" PRIu32 ") is incompatible with sparse inference",
327 xnn_node_type_to_string(node->type),
328 node->params.depthwise_convolution_2d.depth_multiplier);
329 return 0;
330 }
331 if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
332 xnn_log_info("Node %s subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
333 "is incompatible with sparse inference",
334 xnn_node_type_to_string(node->type),
335 node->params.depthwise_convolution_2d.subsampling_height,
336 node->params.depthwise_convolution_2d.subsampling_width);
337 return 0;
338 }
339 switch (node->params.depthwise_convolution_2d.subsampling_height) {
340 case 1:
341 case 2:
342 break;
343 default:
344 xnn_log_info("Node %s subsampling_height (%" PRIu32 ") "
345 "is incompatible with sparse inference",
346 xnn_node_type_to_string(node->type),
347 node->params.depthwise_convolution_2d.subsampling_height);
348 return 0;
349 }
350 if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
351 xnn_log_info("Node %s kernel (height=%" PRIu32 ", width=%" PRIu32 ") "
352 "is incompatible with sparse inference",
353 xnn_node_type_to_string(node->type),
354 node->params.depthwise_convolution_2d.kernel_height,
355 node->params.depthwise_convolution_2d.kernel_width);
356 return 0;
357 }
358 switch (node->params.depthwise_convolution_2d.kernel_height) {
359 case 3:
360 if (node->params.depthwise_convolution_2d.input_padding_top == 1 &&
361 node->params.depthwise_convolution_2d.input_padding_right == 1 &&
362 node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
363 node->params.depthwise_convolution_2d.input_padding_left == 1) {
364 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
365 } else {
366 xnn_log_info("Node %s (3x3 kernel) padding "
367 "(top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
368 "is incompatible with sparse inference",
369 xnn_node_type_to_string(node->type),
370 node->params.depthwise_convolution_2d.input_padding_top,
371 node->params.depthwise_convolution_2d.input_padding_right,
372 node->params.depthwise_convolution_2d.input_padding_bottom,
373 node->params.depthwise_convolution_2d.input_padding_left);
374 return 0;
375 }
376 case 5:
377 if (node->params.depthwise_convolution_2d.input_padding_top == 2 &&
378 node->params.depthwise_convolution_2d.input_padding_right == 2 &&
379 node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
380 node->params.depthwise_convolution_2d.input_padding_left == 2) {
381 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
382 } else {
383 xnn_log_info("Node %s (5x5 kernel) padding "
384 "(top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
385 "is incompatible with sparse inference",
386 xnn_node_type_to_string(node->type),
387 node->params.depthwise_convolution_2d.input_padding_top,
388 node->params.depthwise_convolution_2d.input_padding_right,
389 node->params.depthwise_convolution_2d.input_padding_bottom,
390 node->params.depthwise_convolution_2d.input_padding_left);
391 return 0;
392 }
393 default:
394 return 0;
395 }
396 case xnn_node_type_depth_to_space:
397 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
398 case xnn_node_type_global_average_pooling_2d:
399 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
400 case xnn_node_type_add2:
401 case xnn_node_type_multiply2:
402 assert(node->num_inputs == 2);
403 assert(node->num_outputs == 1);
404 if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
405 subgraph->values[node->inputs[1]].shape.num_dims != 4)
406 {
407 xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
408 xnn_node_type_to_string(node->type));
409 return 0;
410 }
411
412 if (subgraph->values[node->inputs[0]].data != NULL) {
413 // Check that the first input is representable as either a scalar, or a vector
414 size_t num_nonunit_dims = 0;
415 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
416 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
417 num_nonunit_dims += 1;
418 }
419 }
420 if (num_nonunit_dims > 1) {
421 return 0;
422 }
423 }
424
425 if (subgraph->values[node->inputs[1]].data != NULL) {
426 // Check that the second input is representable as either a scalar, or a vector
427 size_t num_nonunit_dims = 0;
428 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
429 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
430 num_nonunit_dims += 1;
431 }
432 }
433 if (num_nonunit_dims > 1) {
434 return 0;
435 }
436 }
437
438 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
439 case xnn_node_type_static_resize_bilinear_2d:
440 if (subgraph->values[node->inputs[0]].shape.dim[1] > 1 &&
441 subgraph->values[node->inputs[0]].shape.dim[2] > 1) {
442 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
443 } else {
444 xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
445 xnn_node_type_to_string(node->type));
446 return 0;
447 }
448 case xnn_node_type_abs:
449 case xnn_node_type_bankers_rounding:
450 case xnn_node_type_ceiling:
451 case xnn_node_type_clamp:
452 case xnn_node_type_elu:
453 case xnn_node_type_floor:
454 case xnn_node_type_hardswish:
455 case xnn_node_type_leaky_relu:
456 case xnn_node_type_negate:
457 case xnn_node_type_sigmoid:
458 case xnn_node_type_square:
459 assert(node->num_inputs == 1);
460 assert(node->num_outputs == 1);
461 if (subgraph->values[node->inputs[0]].shape.num_dims == 4) {
462 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
463 } else {
464 xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
465 xnn_node_type_to_string(node->type));
466 return 0;
467 }
468 default:
469 return false;
470 }
471}
472
473void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
474{
475 // Convert parts of the subgraph to NCHW for sparse inference
476 // Step 1: detect NCHW-compatible Nodes
477 // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
478 // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
479 // Step 4: switch Values' layout to NCHW
480 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
481 struct xnn_node* node = &subgraph->nodes[n];
482 node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
483 xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
484 n, xnn_node_type_to_string(node->type),
485 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
486 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
487 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
488 }
489
490 // Run Shiloach-Vishkin connected components algorithm i.e. find all
491 // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders
492 // to all the producer nodes
493 bool update = false;
494 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
495 struct xnn_node* node = &subgraph->nodes[n];
496 node->cluster_leader = n;
497 if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
498 for (uint32_t i = 0; i < node->num_inputs; i++) {
499 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
500 if (value->data != NULL) {
501 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
502 // during the initial NCHW compatibility check for the Node.
503 continue;
504 }
505 if (xnn_value_is_external(value)) {
506 // External value, invalid cluster
507 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
508 continue;
509 }
510 const uint32_t producer_id = value->producer;
511 assert(producer_id != XNN_INVALID_NODE_ID);
512 assert(producer_id < n);
513 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
514 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
515 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
516 {
517 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
518 if (producer_node->cluster_leader != node->cluster_leader) {
519 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
520 update = true;
521 }
522 } else {
523 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
524 }
525 }
526 }
527 }
528 // No NCHW2NHWC compatible nodes have been found thus the graph rewriting
529 // practically cannot happen.
530 if (!update) {
531 return;
532 }
533 // Propagate the cluster leader to other nodes in the graph untill all the
534 // nodes in the cluster is not updated
535 while (update) {
536 update = false;
537 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
538 struct xnn_node* node = &subgraph->nodes[n];
539 if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
540 continue;
541 }
542
543 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
544 continue;
545 }
546
547 for (uint32_t i = 0; i < node->num_inputs; i++) {
548 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
549 if (value->data != NULL) {
550 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
551 // during the initial NCHW compatibility check for the Node.
552 continue;
553 }
554 if (xnn_value_is_external(value)) {
555 // External value, invalid cluster
556 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
557 continue;
558 }
559 const uint32_t producer_id = value->producer;
560 assert(producer_id != XNN_INVALID_NODE_ID);
561 assert(producer_id < n);
562 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
563 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
564 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
565 {
566 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
567 if (producer_node->cluster_leader != node->cluster_leader) {
568 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
569 update = true;
570 }
571 } else {
572 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
573 }
574 }
575 }
576 }
577 // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
578 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
579 struct xnn_node* node = &subgraph->nodes[n];
580 subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
581 }
582 // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
583 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
584 struct xnn_node* node = &subgraph->nodes[n];
585 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
586 continue;
587 }
588
589 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
590 continue;
591 }
592
593 for (uint32_t i = 0; i < node->num_inputs; i++) {
594 struct xnn_value* value = &subgraph->values[node->inputs[i]];
595 if (value->data != NULL) {
596 // Static data, skip this input value because it doesn't have a producer Node.
597 continue;
598 }
599 assert(!xnn_value_is_external(value));
600 value->num_nchw_compatible_consumers += 1;
601 }
602 }
603 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
604 struct xnn_node* node = &subgraph->nodes[n];
605 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
606 continue;
607 }
608
609 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
610 continue;
611 }
612
613 for (uint32_t i = 0; i < node->num_inputs; i++) {
614 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
615 if (value->data != NULL) {
616 // Static data, skip this input value because it doesn't have a producer Node.
617 continue;
618 }
619 assert(!xnn_value_is_external(value));
620 assert(value->num_nchw_compatible_consumers > 0);
621 if (value->num_nchw_compatible_consumers != value->num_consumers) {
622 subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
623 }
624 }
625 }
626 // Evaluate if it is profitable to run the model as sparse:
627 // - Compute the number of parameters and zeroes in 1x1 Convolution weights
628 // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0)
629 // or with less than 2/3rd of zeroes in 1x1 Convolution filters
630 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
631 struct xnn_node* node = &subgraph->nodes[n];
632 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
633 continue;
634 }
635
636 if (node->type == xnn_node_type_convolution_2d &&
637 max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1)
638 {
639 assert(node->num_inputs >= 2);
640
641 const struct xnn_value* filter = &subgraph->values[node->inputs[1]];
642 assert(filter->data != NULL);
643 assert(filter->shape.num_dims == 4);
644
645 const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3];
646 subgraph->nodes[node->cluster_leader].num_params += num_params;
647
648 const float* data = (const float*) filter->data;
649 size_t num_zeroes = 0;
650 for (size_t i = 0; i < num_params; i++) {
651 num_zeroes += (size_t) (data[i] == 0.0f);
652 }
653 xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params);
654 subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes;
655 }
656 }
657 bool use_nchw_layout = false;
658 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
659 struct xnn_node* node = &subgraph->nodes[n];
660 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
661 continue;
662 }
663
664 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
665 continue;
666 }
667
668 if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) {
669 xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights",
670 n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params);
671 continue;
672 }
673
674 for (uint32_t i = 0; i < node->num_inputs; i++) {
675 struct xnn_value* value = &subgraph->values[node->inputs[i]];
676 if (value->data != NULL) {
677 // Static data, skip this input value because it doesn't have a producer Node.
678 continue;
679 }
680 assert(!xnn_value_is_external(value));
681 assert(value->num_nchw_compatible_consumers > 0);
682 assert(value->num_nchw_compatible_consumers == value->num_consumers);
683 if (value->layout != xnn_layout_type_nchw) {
684 value->layout = xnn_layout_type_nchw;
685 xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
686 use_nchw_layout = true;
687 }
688 }
689 }
690 if (use_nchw_layout) {
691 xnn_log_info("XNNPACK has switched to sparse inference mode!");
692 }
693}
694
695bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
696{
697 xnn_log_info("Analyzing subgraph for FP16 compatibility");
698
699 // Convert tensors and operators in the subgraph to FP16
700 // 1. Check that all operators in the subgraph are supported in FP16.
701 // 2. Indicate values that must be converted to FP16.
702 // 3. Replace FP32 Values with FP16 Values as Nodes' inputs/outputs.
703 // 4. Insert FP32->FP16 Convert Nodes for external FP32 inputs and FP16->FP32 Convert Nodes for external outputs.
704
705 // Check that all operators in the subgraph are supported in FP16, bail out on any unsupported one.
706 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
707 struct xnn_node* node = &subgraph->nodes[n];
708 if (node->type == xnn_node_type_invalid) {
709 // Node was fused away, skip.
710 continue;
711 }
712
713 if (node->compute_type != xnn_compute_type_fp32) {
714 xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not FP32", n, xnn_node_type_to_string(node->type));
715 return false;
716 }
717 switch (node->type) {
718 case xnn_node_type_abs:
719 case xnn_node_type_add2:
720 case xnn_node_type_divide:
721 case xnn_node_type_maximum2:
722 case xnn_node_type_minimum2:
723 case xnn_node_type_multiply2:
724 case xnn_node_type_concatenate2:
725 case xnn_node_type_concatenate3:
726 case xnn_node_type_concatenate4:
727 case xnn_node_type_squared_difference:
728 case xnn_node_type_subtract:
729 for (uint32_t i = 0; i < node->num_inputs; i++) {
730 if (subgraph->values[node->inputs[i]].data != NULL) {
731 xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) has static input %" PRIu32,
732 n, xnn_node_type_to_string(node->type), i);
733 return false;
734 }
735 }
736 break;
737 case xnn_node_type_average_pooling_2d:
738 case xnn_node_type_bankers_rounding:
739 case xnn_node_type_ceiling:
740 case xnn_node_type_clamp:
741 case xnn_node_type_copy:
742 case xnn_node_type_convolution_2d:
743 case xnn_node_type_deconvolution_2d:
744 case xnn_node_type_depthwise_convolution_2d:
745 case xnn_node_type_depth_to_space:
746 case xnn_node_type_elu:
747 case xnn_node_type_even_split2:
748 case xnn_node_type_even_split3:
749 case xnn_node_type_even_split4:
750 case xnn_node_type_floor:
751 case xnn_node_type_fully_connected:
752 case xnn_node_type_global_average_pooling_2d:
753 case xnn_node_type_hardswish:
754 case xnn_node_type_leaky_relu:
755 case xnn_node_type_max_pooling_2d:
756 case xnn_node_type_negate:
757 case xnn_node_type_prelu:
758 case xnn_node_type_sigmoid:
759 case xnn_node_type_softmax:
760 case xnn_node_type_static_constant_pad:
761 case xnn_node_type_static_reshape:
762 case xnn_node_type_static_resize_bilinear_2d:
763 case xnn_node_type_static_transpose:
764 case xnn_node_type_square:
765 case xnn_node_type_square_root:
766 break;
767 default:
768 xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference",
769 n, xnn_node_type_to_string(node->type));
770 return false;
771 }
772 }
773
774 // Annotate Values to be converted to FP16 as FP16-compatible.
775 // Note that static weights in [Depthwise] Convolution, Fully Connected, and PReLU Nodes remain FP32,
776 // they will be converted to FP16 during weight repacking when the operator is created.
777 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
778 struct xnn_node* node = &subgraph->nodes[n];
779 switch (node->type) {
780 case xnn_node_type_convolution_2d:
781 case xnn_node_type_deconvolution_2d:
782 case xnn_node_type_depthwise_convolution_2d:
783 case xnn_node_type_fully_connected:
784 case xnn_node_type_prelu:
785 subgraph->values[node->inputs[0]].fp16_compatible = true;
786 subgraph->values[node->outputs[0]].fp16_compatible = true;
787 break;
788 default:
789 for (uint32_t i = 0; i < node->num_inputs; i++) {
790 subgraph->values[node->inputs[i]].fp16_compatible = true;
791 }
792 for (uint32_t o = 0; o < node->num_outputs; o++) {
793 subgraph->values[node->outputs[o]].fp16_compatible = true;
794 }
795 break;
796 }
797 }
798
799 // Replace FP32 Values in Nodes' inputs/outputs with FP16 Values.
800 // FP32 Values that are not external inputs or outputs are converted to FP16 in-place,
801 // for external inputs and outputs we create same-shaped FP16 Values and use those instead.
802 const uint32_t num_original_values = subgraph->num_values;
803 xnn_subgraph_analyze_consumers_and_producers(subgraph);
804 for (uint32_t n = 0; n < num_original_values; n++) {
805 struct xnn_value* value = &subgraph->values[n];
806 value->fp16_id = XNN_INVALID_VALUE_ID;
807 value->fp32_id = XNN_INVALID_VALUE_ID;
808 if (value->fp16_compatible) {
809 assert(value->data == NULL);
810 assert(value->datatype == xnn_datatype_fp32);
811 if (xnn_value_is_external(value)) {
812 struct xnn_value* fp16_value = xnn_subgraph_new_internal_value(subgraph);
813
814 // Recompute value due to potential reallocation in xnn_subgraph_new_internal_value
815 value = &subgraph->values[n];
816 xnn_value_copy(fp16_value, value);
817 fp16_value->datatype = xnn_datatype_fp16;
818
819 fp16_value->producer = value->producer;
820 fp16_value->num_consumers = value->num_consumers;
821 fp16_value->first_consumer = value->first_consumer;
822 value->producer = XNN_INVALID_NODE_ID;
823 value->num_consumers = 0;
824 value->first_consumer = XNN_INVALID_NODE_ID;
825
826 // Clear external input/output flags
827 fp16_value->flags = 0;
828 xnn_log_debug("FP16 rewrite: created FP16 tensor #%" PRIu32 " for FP32 tensor #%" PRIu32, fp16_value->id, n);
829
830 value->fp16_id = fp16_value->id;
831 fp16_value->fp32_id = n;
832 } else {
833 xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n);
834 value->datatype = xnn_datatype_fp16;
835 }
836 }
837 }
838 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
839 struct xnn_node* node = &subgraph->nodes[n];
840 if (node->type == xnn_node_type_invalid) {
841 // Node was fused away, skip.
842 continue;
843 }
844
845 assert(node->compute_type == xnn_compute_type_fp32);
846 node->compute_type = xnn_compute_type_fp16;
847 if (node->type == xnn_node_type_static_constant_pad) {
848 node->params.static_pad.padding_value =
849 fp16_ieee_from_fp32_value(uint32_as_float(node->params.static_pad.padding_value));
850 }
851 for (uint32_t i = 0; i < node->num_inputs; i++) {
852 const uint32_t fp16_id = subgraph->values[node->inputs[i]].fp16_id;
853 if (fp16_id != XNN_INVALID_VALUE_ID) {
854 assert(subgraph->values[fp16_id].fp32_id == node->inputs[i]);
855 node->inputs[i] = fp16_id;
856 }
857 }
858 for (uint32_t o = 0; o < node->num_outputs; o++) {
859 const uint32_t fp16_id = subgraph->values[node->outputs[o]].fp16_id;
860 if (fp16_id != XNN_INVALID_VALUE_ID) {
861 assert(subgraph->values[fp16_id].fp32_id == node->outputs[o]);
862 node->outputs[o] = fp16_id;
863 }
864 }
865 }
866
867 // Count the number of external inputs and outputs which require Convert nodes
868 uint32_t num_external_inputs = 0;
869 uint32_t num_external_outputs = 0;
870 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
871 const struct xnn_node* node = &subgraph->nodes[n];
872 for (uint32_t i = 0; i < node->num_inputs; i++) {
873 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
874 if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n) {
875 assert(value->data == NULL);
876 assert(value->datatype == xnn_datatype_fp16);
877 assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
878 // This value isn't always an external input, it could be an external output of the current subgraph (due to
879 // partition), and be simultaneously consumed by the current node.
880 if (xnn_value_is_external_input(&subgraph->values[value->fp32_id])) {
881 num_external_inputs += 1;
882 }
883 }
884 }
885 for (uint32_t o = 0; o < node->num_outputs; o++) {
886 const struct xnn_value* value = &subgraph->values[node->outputs[o]];
887 if (value->fp32_id != XNN_INVALID_VALUE_ID) {
888 assert(value->datatype == xnn_datatype_fp16);
889 assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
890 assert(xnn_value_is_external_output(&subgraph->values[value->fp32_id]));
891 num_external_outputs += 1;
892 }
893 }
894 }
895 xnn_log_debug("Discovered %"PRIu32" external inputs and %"PRIu32" external outputs",
896 num_external_inputs, num_external_outputs);
897
898 const uint32_t num_original_nodes = subgraph->num_nodes;
899 xnn_subgraph_add_nodes(subgraph, num_external_inputs + num_external_outputs);
900 struct xnn_node* output_node = subgraph->nodes + subgraph->num_nodes - 1;
901 for (uint32_t n = num_original_nodes; n != 0; n--) {
902 const struct xnn_node* node = &subgraph->nodes[n - 1];
903 // Insert Convert nodes for outputs
904 for (uint32_t o = 0; o < node->num_outputs; o++) {
905 const struct xnn_value* value = &subgraph->values[node->outputs[o]];
906 if (value->fp32_id != XNN_INVALID_VALUE_ID) {
907 xnn_log_debug("Inserted FP16->FP32 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
908 value->id, value->fp32_id);
909 const uint32_t output_node_id = output_node->id;
910 assert(output_node >= subgraph->nodes);
911 xnn_node_clear(output_node);
912 output_node->id = output_node_id;
913 xnn_init_convert_node(output_node, xnn_compute_type_fp16_to_fp32, value->id, value->fp32_id, 0 /* flags */);
914 output_node -= 1;
915 }
916 }
917 // Move the Node to the new location
918 if (output_node != node) {
919 const uint32_t output_node_id = output_node->id;
920 assert(output_node >= subgraph->nodes);
921 memcpy(output_node, node, sizeof(struct xnn_node));
922 output_node->id = output_node_id;
923 output_node -= 1;
924 }
925 // Insert Convert nodes for inputs
926 for (uint32_t i = 0; i < node->num_inputs; i++) {
927 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
928 if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n - 1) {
929 // Only insert convert nodes if the value actually is an external input. This value could be an external output,
930 // if that's the case, we have already inserted a convert node in loop above for outputs.
931 if (xnn_value_is_external_input(&subgraph->values[value->fp32_id])) {
932 xnn_log_debug("Inserted FP32->FP16 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
933 value->fp32_id, value->id);
934 const uint32_t output_node_id = output_node->id;
935 assert(output_node >= subgraph->nodes);
936 xnn_node_clear(output_node);
937 output_node->id = output_node_id;
938 xnn_init_convert_node(output_node, xnn_compute_type_fp32_to_fp16, value->fp32_id, value->id, 0 /* flags */);
939 output_node -= 1;
940 }
941 }
942 }
943 }
944
945 return true;
946}
947
948static void xnn_node_replace_output(struct xnn_node* node, uint32_t old_output_id, uint32_t new_output_id)
949{
950 for (size_t i = 0; i < node->num_outputs; i++) {
951 if (node->outputs[i] == old_output_id) {
952 node->outputs[i] = new_output_id;
953 }
954 }
955}
956
957enum xnn_status xnn_subgraph_fusion(
958 xnn_subgraph_t subgraph)
959{
960 // Fuse Nodes where possible
961 for (uint32_t i = 0; i < subgraph->num_values; i++) {
962 struct xnn_value* value = &subgraph->values[i];
963 if (value->num_consumers == 1) {
964 const uint32_t producer_id = value->producer;
965 if (producer_id == XNN_INVALID_NODE_ID) {
966 continue;
967 }
968 assert(producer_id < subgraph->num_nodes);
969
970 const uint32_t consumer_id = value->first_consumer;
971 if (consumer_id == XNN_INVALID_NODE_ID) {
972 continue;
973 }
974 assert(consumer_id < subgraph->num_nodes);
975
976 struct xnn_node* producer = &subgraph->nodes[producer_id];
977 assert(producer->type != xnn_node_type_invalid);
978 struct xnn_node* consumer = &subgraph->nodes[consumer_id];
979 assert(consumer->type != xnn_node_type_invalid);
980
981 // Try to fuse Clamp Node upstream into producer Node
982 if (consumer->type == xnn_node_type_clamp) {
983 switch (producer->type) {
984 case xnn_node_type_add2:
985 case xnn_node_type_average_pooling_2d:
986 case xnn_node_type_clamp:
987 case xnn_node_type_convolution_2d:
988 case xnn_node_type_divide:
989 case xnn_node_type_deconvolution_2d:
990 case xnn_node_type_depthwise_convolution_2d:
991 case xnn_node_type_fully_connected:
992 case xnn_node_type_multiply2:
993 case xnn_node_type_max_pooling_2d:
994 case xnn_node_type_subtract:
995 xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id);
996 assert(producer->num_outputs == 1);
997 assert(consumer->num_inputs == 1);
998 assert(consumer->num_outputs == 1);
999
1000 const uint32_t fused_output_id = consumer->outputs[0];
1001 assert(fused_output_id < subgraph->num_values);
1002 subgraph->values[fused_output_id].producer = producer_id;
1003 producer->outputs[0] = fused_output_id;
1004
1005 producer->activation.output_min =
1006 math_max_f32(producer->activation.output_min, consumer->activation.output_min);
1007 producer->activation.output_max =
1008 math_min_f32(producer->activation.output_max, consumer->activation.output_max);
1009
1010 xnn_node_clear(consumer);
1011 xnn_value_clear(value);
1012 break;
1013 default:
1014 break;
1015 }
1016 }
1017 // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node
1018 if (producer->type == xnn_node_type_static_constant_pad) {
1019 assert(producer->num_inputs == 1);
1020 assert(producer->num_outputs == 1);
1021 const bool is_spatial_2d_padding = value->shape.num_dims == 4 &&
1022 (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] |
1023 producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0;
1024 const enum xnn_datatype padding_datatype = subgraph->values[producer->outputs[0]].datatype;
1025 const uint32_t padding_value = producer->params.static_pad.padding_value;
1026 const bool is_zero_padding =
1027 (padding_datatype == xnn_datatype_fp32 && padding_value == 0) ||
1028 ((padding_datatype == xnn_datatype_qint8 || padding_datatype == xnn_datatype_quint8) &&
1029 padding_value == (uint32_t) (uint8_t) subgraph->values[producer->outputs[0]].quantization.zero_point);
1030 switch (consumer->type) {
1031 case xnn_node_type_convolution_2d:
1032 if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
1033 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32,
1034 consumer_id, producer_id);
1035 assert(consumer->num_inputs >= 1);
1036 assert(consumer->inputs[0] == producer->outputs[0]);
1037
1038 consumer->params.convolution_2d.input_padding_top += producer->params.static_pad.pre_paddings[1];
1039 consumer->params.convolution_2d.input_padding_right += producer->params.static_pad.post_paddings[2];
1040 consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1];
1041 consumer->params.convolution_2d.input_padding_left += producer->params.static_pad.pre_paddings[2];
1042
1043 consumer->inputs[0] = producer->inputs[0];
1044
1045 const uint32_t fused_input_id = producer->inputs[0];
1046 assert(fused_input_id < subgraph->num_values);
1047 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
1048 subgraph->values[fused_input_id].first_consumer = consumer_id;
1049 }
1050
1051 xnn_node_clear(producer);
1052 xnn_value_clear(value);
1053 }
1054 break;
1055 case xnn_node_type_depthwise_convolution_2d:
1056 if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
1057 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32,
1058 consumer_id, producer_id);
1059 assert(consumer->num_inputs >= 1);
1060 assert(consumer->inputs[0] == producer->outputs[0]);
1061
1062 consumer->params.depthwise_convolution_2d.input_padding_top +=
1063 producer->params.static_pad.pre_paddings[1];
1064 consumer->params.depthwise_convolution_2d.input_padding_right +=
1065 producer->params.static_pad.post_paddings[2];
1066 consumer->params.depthwise_convolution_2d.input_padding_bottom +=
1067 producer->params.static_pad.post_paddings[1];
1068 consumer->params.depthwise_convolution_2d.input_padding_left +=
1069 producer->params.static_pad.pre_paddings[2];
1070
1071 consumer->inputs[0] = producer->inputs[0];
1072
1073 const uint32_t fused_input_id = producer->inputs[0];
1074 assert(fused_input_id < subgraph->num_values);
1075 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
1076 subgraph->values[fused_input_id].first_consumer = consumer_id;
1077 }
1078
1079 xnn_node_clear(producer);
1080 xnn_value_clear(value);
1081 }
1082 break;
1083 default:
1084 break;
1085 }
1086 }
1087
1088 // Try to fuse copy upstream. Copy can be fused upstream as long as this value is internal.
1089 // E.g. ---> (N1) --- value ---> (Copy) ---> v1
1090 // If value is persistent or external, fusing copy upstream into N1 will skip the write to value, N1 will write to
1091 // v1 instead, which is wrong.
1092 if (consumer->type == xnn_node_type_copy && xnn_value_is_valid(value) && xnn_value_is_internal(value)) {
1093 xnn_log_info(
1094 "value %d fuse Copy Node #%" PRIu32 " into upstream %s Node #%" PRIu32, value->id, consumer->id,
1095 xnn_node_type_to_string(producer->type), producer->id);
1096 assert(consumer->num_inputs == 1);
1097 assert(consumer->num_outputs == 1);
1098 const uint32_t fused_output_id = consumer->outputs[0];
1099 assert(fused_output_id < subgraph->num_values);
1100 subgraph->values[fused_output_id].producer = producer_id;
1101 xnn_node_replace_output(producer, value->id, fused_output_id);
1102 xnn_node_clear(consumer);
1103 xnn_value_clear(value);
1104 }
1105
1106 // Try to fuse copy downstream.
1107 // E.g. --- v1 ---> (copy) --- value ---> (n2)
1108 // If value is external or persistent, we cannot simply remove the copy, since we need to write to value.
1109 if (producer->type == xnn_node_type_copy && xnn_value_is_valid(value) && xnn_value_is_internal(value)) {
1110 // We need to check that value is valid here because value could have been cleared by a previous optimization,
1111 // this can happen if we have a chain of Copy(s), e.g.:
1112 // ---v1--> (Copy1) ---v2--> (Copy2) ---v3--> (Copy3) ---v4-->
1113 // v2 could have been cleared when we fused Copy2 upstream into Copy1, so v2 isn't valid anymore, but since v2's
1114 // producer is also a Copy, we will incorrectly try to fuse Copy1 downstream into Copy2 (again).
1115 xnn_log_info(
1116 "value %d fuse Copy Node #%" PRIu32 " into downstream %s Node #%" PRIu32, value->id, producer->id,
1117 xnn_node_type_to_string(consumer->type), consumer->id);
1118 assert(producer->num_outputs == 1);
1119 assert(producer->num_inputs == 1);
1120 const uint32_t copy_input_id = producer->inputs[0];
1121 const uint32_t copy_output_id = producer->outputs[0];
1122 bool found_consumer_input = false;
1123 for (size_t i = 0; i < consumer->num_inputs; i++) {
1124 if (consumer->inputs[i] == copy_output_id) {
1125 consumer->inputs[i] = copy_input_id;;
1126 found_consumer_input = true;
1127 // TODO(b/254734644): A consumer can only consume this value once, since we asserted earlier that value has
1128 // only 1 consumer, so we can break here as there will be no other consumer inputs that has the same id.
1129 break;
1130 }
1131 }
1132 (void) found_consumer_input; // Silence unused variable warning in non-debug.
1133 assert(found_consumer_input);
1134
1135 if (subgraph->values[copy_input_id].first_consumer == producer_id) {
1136 subgraph->values[copy_input_id].first_consumer = consumer_id;
1137 }
1138 xnn_node_clear(producer);
1139 xnn_value_clear(value);
1140 }
1141 }
1142 }
1143
1144 return xnn_status_success;
1145}
1146
1147enum xnn_status xnn_subgraph_optimize(
1148 xnn_subgraph_t subgraph,
1149 uint32_t flags)
1150{
1151 xnn_subgraph_analyze_consumers_and_producers(subgraph);
1152
1153 // Remove unreferenced values.
1154 for (uint32_t i = 0; i < subgraph->num_values; i++) {
1155 struct xnn_value* value = &subgraph->values[i];
1156 if (value->type == xnn_value_type_invalid) {
1157 continue;
1158 }
1159
1160 if (!xnn_value_is_external_input(value) && value->num_consumers == 0 && !xnn_value_is_persistent(value)) {
1161 xnn_value_clear(value);
1162 }
1163 }
1164
1165 if (!(flags & XNN_FLAG_NO_OPERATOR_FUSION)) {
1166 xnn_subgraph_fusion(subgraph);
1167 }
1168
1169 if ((flags & XNN_FLAG_FORCE_FP16_INFERENCE) && !(xnn_params.init_flags & XNN_INIT_FLAG_F16)) {
1170 xnn_log_error("failed to force FP16 inference: hardware supports neither native nor emulated FP16 operators");
1171 return xnn_status_unsupported_hardware;
1172 }
1173 #ifndef XNN_NO_F16_OPERATORS
1174 const bool try_native_fp16 =
1175 (flags & XNN_FLAG_HINT_FP16_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_F16_NATIVE);
1176 const bool force_fp16 = (flags & XNN_FLAG_FORCE_FP16_INFERENCE);
1177 if (try_native_fp16 || force_fp16) {
1178 const bool fp16_rewrite_succeeded = xnn_subgraph_rewrite_for_fp16(subgraph);
1179 if (force_fp16 && !fp16_rewrite_succeeded) {
1180 xnn_log_error("failed to force FP16 inference: subgraph is incompatible with FP16 operators");
1181 return xnn_status_unsupported_parameter;
1182 }
1183 }
1184 #endif // XNN_NO_F16_OPERATORS
1185
1186 #if XNN_ENABLE_SPARSE
1187 if ((flags & XNN_FLAG_HINT_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) {
1188 xnn_subgraph_rewrite_for_nchw(subgraph);
1189 }
1190 #endif
1191
1192 return xnn_status_success;
1193}
1194
1195enum xnn_status xnn_delete_subgraph(
1196 xnn_subgraph_t subgraph)
1197{
1198 if (subgraph != NULL) {
1199 if (subgraph->nodes != NULL) {
1200 memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes);
1201 xnn_release_memory(subgraph->nodes);
1202 }
1203
1204 if (subgraph->values != NULL) {
1205 memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values);
1206 xnn_release_memory(subgraph->values);
1207 }
1208
1209 memset(subgraph, 0, sizeof(struct xnn_subgraph));
1210 xnn_release_memory(subgraph);
1211 }
1212 return xnn_status_success;
1213}
1214