1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <math.h>
7#include <stddef.h>
8#include <stdint.h>
9#include <stdlib.h>
10
11#include <fp16.h>
12
13#include <xnnpack.h>
14#include <xnnpack/allocator.h>
15#include <xnnpack/log.h>
16#include <xnnpack/math.h>
17#include <xnnpack/params.h>
18#include <xnnpack/subgraph.h>
19
20
21#ifndef XNN_ENABLE_SPARSE
22 #error "XNN_ENABLE_SPARSE not defined"
23#endif
24
25enum xnn_status xnn_create_subgraph(
26 uint32_t external_value_ids,
27 uint32_t flags,
28 xnn_subgraph_t* subgraph_out)
29{
30 struct xnn_subgraph* subgraph = NULL;
31 enum xnn_status status = xnn_status_uninitialized;
32
33 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
34 xnn_log_error("failed to create subgraph: XNNPACK is not initialized");
35 goto error;
36 }
37
38 status = xnn_status_out_of_memory;
39
40 subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph));
41 if (subgraph == NULL) {
42 xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph));
43 goto error;
44 }
45
46 subgraph->external_value_ids = external_value_ids;
47
48 subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value));
49 if (subgraph->values == NULL) {
50 xnn_log_error("failed to allocate %zu bytes for subgraph values",
51 (size_t) external_value_ids * sizeof(struct xnn_value));
52 goto error;
53 }
54 for (size_t i = 0; i < external_value_ids; i++) {
55 subgraph->values[i].id = i;
56 }
57 subgraph->num_values = external_value_ids;
58 subgraph->num_reserved_values = external_value_ids;
59
60 *subgraph_out = subgraph;
61 return xnn_status_success;
62
63error:
64 xnn_delete_subgraph(subgraph);
65 return status;
66}
67
68
69struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)
70{
71 struct xnn_value* values = subgraph->values;
72 const size_t size = subgraph->num_values;
73 const size_t capacity = subgraph->num_reserved_values;
74 if (capacity < size + 1) {
75 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
76 assert(new_capacity >= size + 1);
77 values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value));
78 if (values == NULL) {
79 xnn_log_error("failed to allocate %zu bytes for subgraph values",
80 capacity * sizeof(struct xnn_value));
81 return values;
82 }
83
84 memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value));
85 subgraph->num_reserved_values = new_capacity;
86 subgraph->values = values;
87 }
88 subgraph->num_values = size + 1;
89 struct xnn_value* new_value = values + size;
90 new_value->id = size;
91 return new_value;
92}
93
94void xnn_node_clear(struct xnn_node* node) {
95 assert(node != NULL);
96 memset(node, 0, sizeof(struct xnn_node));
97}
98
99void xnn_value_clear(struct xnn_value* value) {
100 assert(value != NULL);
101 memset(value, 0, sizeof(struct xnn_value));
102}
103
104void xnn_value_copy(
105 struct xnn_value* dst_value,
106 const struct xnn_value* src_value)
107{
108 // Note: Value ID stays unchanged
109
110 dst_value->type = src_value->type;
111 dst_value->datatype = src_value->datatype;
112 dst_value->quantization = src_value->quantization;
113 dst_value->shape = src_value->shape;
114 dst_value->flags = src_value->flags;
115 dst_value->data = src_value->data;
116 dst_value->producer = src_value->producer;
117 dst_value->first_consumer = src_value->first_consumer;
118}
119
120struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph)
121{
122 struct xnn_node* nodes = subgraph->nodes;
123 const size_t size = subgraph->num_nodes;
124 const size_t capacity = subgraph->num_reserved_nodes;
125
126 if (capacity < size + 1) {
127 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
128 assert(new_capacity >= size + 1);
129 nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
130 if (nodes == NULL) {
131 xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
132 capacity * sizeof(struct xnn_node));
133 return nodes;
134 }
135
136 memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
137 subgraph->num_reserved_nodes = new_capacity;
138 subgraph->nodes = nodes;
139 }
140 subgraph->num_nodes = size + 1;
141 struct xnn_node* new_node = nodes + size;
142 new_node->id = size;
143 return new_node;
144}
145
146void xnn_subgraph_add_nodes(xnn_subgraph_t subgraph, size_t num_nodes)
147{
148 struct xnn_node* nodes = subgraph->nodes;
149 const size_t size = subgraph->num_nodes;
150 const size_t capacity = subgraph->num_reserved_nodes;
151
152 if (capacity < size + num_nodes) {
153 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + max(num_nodes, 64));
154 assert(new_capacity >= size + num_nodes);
155 nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
156 if (nodes == NULL) {
157 xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
158 capacity * sizeof(struct xnn_node));
159 return;
160 }
161
162 memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
163 subgraph->num_reserved_nodes = new_capacity;
164 subgraph->nodes = nodes;
165 }
166 subgraph->num_nodes = size + num_nodes;
167 struct xnn_node* new_nodes = nodes + size;
168 for (size_t i = 0; i < num_nodes; i++) {
169 new_nodes[i].id = size + i;
170 }
171}
172
173void xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph)
174{
175 // Initialize producer/consumer fields to safe defaults.
176 for (uint32_t i = 0; i < subgraph->num_values; i++) {
177 struct xnn_value* value = &subgraph->values[i];
178 value->producer = XNN_INVALID_NODE_ID;
179 value->first_consumer = XNN_INVALID_NODE_ID;
180 value->num_consumers = 0;
181 }
182
183 // Analyse Nodes' inputs and output and update Values' producer/consumer fields
184 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
185 struct xnn_node* node = &subgraph->nodes[n];
186
187 for (uint32_t i = 0; i < node->num_inputs; i++) {
188 const uint32_t input_id = node->inputs[i];
189 assert(input_id < subgraph->num_values);
190
191 if (subgraph->values[input_id].num_consumers++ == 0) {
192 assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID);
193 subgraph->values[input_id].first_consumer = n;
194 }
195 }
196
197 for (uint32_t o = 0; o < node->num_outputs; o++) {
198 const uint32_t output_id = node->outputs[o];
199 assert(output_id < subgraph->num_values);
200
201 // Persistent values can be produced by multiple nodes, e.g. copy nodes writing to the same persistent value.
202 assert(xnn_value_is_persistent(&subgraph->values[output_id]) ||
203 subgraph->values[output_id].producer == XNN_INVALID_NODE_ID);
204 subgraph->values[output_id].producer = n;
205 }
206 }
207
208 // Count extra consumer for Values which are external outputs.
209 // Remove unreferenced values.
210 for (uint32_t i = 0; i < subgraph->num_values; i++) {
211 struct xnn_value* value = &subgraph->values[i];
212 if (xnn_value_is_external_output(value)) {
213 value->num_consumers += 1;
214 }
215 }
216}
217
218#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW 1
219#define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
220#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
221#define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
222
223uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
224 if (node->compute_type != xnn_compute_type_fp32) {
225 if (node->type != xnn_node_type_invalid) {
226 xnn_log_info(
227 "Node %s compute type %d is incompatible with sparse inference",
228 xnn_node_type_to_string(node->type), node->compute_type);
229 }
230 return 0;
231 }
232
233 switch (node->type) {
234 case xnn_node_type_convolution_2d:
235 // Supported cases:
236 // - 1x1 convolution (no stride, no dilation, no padding, no groups)
237 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
238 if (node->params.convolution_2d.groups != 1) {
239 xnn_log_info("Node %s groups (%" PRIu32 ") "
240 "is incompatible with sparse inference",
241 xnn_node_type_to_string(node->type),
242 node->params.convolution_2d.groups);
243 return 0;
244 }
245 if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
246 xnn_log_info("Node %s dilation (height=%" PRIu32 ", width=%" PRIu32 ") "
247 "is incompatible with sparse inference",
248 xnn_node_type_to_string(node->type),
249 node->params.convolution_2d.dilation_height,
250 node->params.convolution_2d.dilation_width);
251 return 0;
252 }
253 if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
254 if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
255 node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0) {
256 xnn_log_info("Node %s (1x1 kernel) padding (top=%" PRIu32 ", right=%" PRIu32", bottom=%" PRIu32 ", left=%" PRIu32") "
257 "is incompatible with sparse inference",
258 xnn_node_type_to_string(node->type),
259 node->params.convolution_2d.input_padding_top,
260 node->params.convolution_2d.input_padding_right,
261 node->params.convolution_2d.input_padding_bottom,
262 node->params.convolution_2d.input_padding_left);
263 return 0;
264 }
265 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
266 xnn_log_info("Node %s (1x1 kernel) subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
267 "is incompatible with sparse inference",
268 xnn_node_type_to_string(node->type),
269 node->params.convolution_2d.subsampling_height,
270 node->params.convolution_2d.subsampling_width);
271 return 0;
272 }
273 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
274 } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
275 if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
276 node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1) {
277 xnn_log_info("Node %s (3x3 kernel) padding (top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
278 "is incompatible with sparse inference",
279 xnn_node_type_to_string(node->type),
280 node->params.convolution_2d.input_padding_top,
281 node->params.convolution_2d.input_padding_right,
282 node->params.convolution_2d.input_padding_bottom,
283 node->params.convolution_2d.input_padding_left);
284 return 0;
285 }
286 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
287 xnn_log_info("Node %s (3x3 kernel) subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
288 "is incompatible with sparse inference",
289 xnn_node_type_to_string(node->type),
290 node->params.convolution_2d.subsampling_height,
291 node->params.convolution_2d.subsampling_width);
292 return 0;
293 }
294 if (node->params.convolution_2d.group_input_channels != 3) {
295 xnn_log_info("Node %s (3x3 kernel) input channels (%zu) "
296 "is incompatible with sparse inference",
297 xnn_node_type_to_string(node->type),
298 node->params.convolution_2d.group_input_channels);
299 return 0;
300 }
301 return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
302 }
303 return 0;
304 case xnn_node_type_depthwise_convolution_2d:
305 // Supported cases:
306 // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
307 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
308 // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
309 // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
310 if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
311 xnn_log_info("Node %s dilation (height=%" PRIu32 ", width=%" PRIu32 ") "
312 "is incompatible with sparse inference",
313 xnn_node_type_to_string(node->type),
314 node->params.convolution_2d.dilation_height,
315 node->params.convolution_2d.dilation_width);
316 return 0;
317 }
318 if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
319 xnn_log_info("Node %s flags (%" PRIu32 ") has padding incompatible with sparse inference",
320 xnn_node_type_to_string(node->type),
321 node->flags);
322 return 0;
323 }
324 if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
325 xnn_log_info("Node %s depth_multiplier (%" PRIu32 ") is incompatible with sparse inference",
326 xnn_node_type_to_string(node->type),
327 node->params.depthwise_convolution_2d.depth_multiplier);
328 return 0;
329 }
330 if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
331 xnn_log_info("Node %s subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
332 "is incompatible with sparse inference",
333 xnn_node_type_to_string(node->type),
334 node->params.depthwise_convolution_2d.subsampling_height,
335 node->params.depthwise_convolution_2d.subsampling_width);
336 return 0;
337 }
338 switch (node->params.depthwise_convolution_2d.subsampling_height) {
339 case 1:
340 case 2:
341 break;
342 default:
343 xnn_log_info("Node %s subsampling_height (%" PRIu32 ") "
344 "is incompatible with sparse inference",
345 xnn_node_type_to_string(node->type),
346 node->params.depthwise_convolution_2d.subsampling_height);
347 return 0;
348 }
349 if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
350 xnn_log_info("Node %s kernel (height=%" PRIu32 ", width=%" PRIu32 ") "
351 "is incompatible with sparse inference",
352 xnn_node_type_to_string(node->type),
353 node->params.depthwise_convolution_2d.kernel_height,
354 node->params.depthwise_convolution_2d.kernel_width);
355 return 0;
356 }
357 switch (node->params.depthwise_convolution_2d.kernel_height) {
358 case 3:
359 if (node->params.depthwise_convolution_2d.input_padding_top == 1 &&
360 node->params.depthwise_convolution_2d.input_padding_right == 1 &&
361 node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
362 node->params.depthwise_convolution_2d.input_padding_left == 1) {
363 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
364 } else {
365 xnn_log_info("Node %s (3x3 kernel) padding "
366 "(top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
367 "is incompatible with sparse inference",
368 xnn_node_type_to_string(node->type),
369 node->params.depthwise_convolution_2d.input_padding_top,
370 node->params.depthwise_convolution_2d.input_padding_right,
371 node->params.depthwise_convolution_2d.input_padding_bottom,
372 node->params.depthwise_convolution_2d.input_padding_left);
373 return 0;
374 }
375 case 5:
376 if (node->params.depthwise_convolution_2d.input_padding_top == 2 &&
377 node->params.depthwise_convolution_2d.input_padding_right == 2 &&
378 node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
379 node->params.depthwise_convolution_2d.input_padding_left == 2) {
380 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
381 } else {
382 xnn_log_info("Node %s (5x5 kernel) padding "
383 "(top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
384 "is incompatible with sparse inference",
385 xnn_node_type_to_string(node->type),
386 node->params.depthwise_convolution_2d.input_padding_top,
387 node->params.depthwise_convolution_2d.input_padding_right,
388 node->params.depthwise_convolution_2d.input_padding_bottom,
389 node->params.depthwise_convolution_2d.input_padding_left);
390 return 0;
391 }
392 default:
393 return 0;
394 }
395 case xnn_node_type_depth_to_space:
396 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
397 case xnn_node_type_global_average_pooling_2d:
398 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
399 case xnn_node_type_add2:
400 case xnn_node_type_multiply2:
401 assert(node->num_inputs == 2);
402 assert(node->num_outputs == 1);
403 if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
404 subgraph->values[node->inputs[1]].shape.num_dims != 4)
405 {
406 xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
407 xnn_node_type_to_string(node->type));
408 return 0;
409 }
410
411 if (subgraph->values[node->inputs[0]].data != NULL) {
412 // Check that the first input is representable as either a scalar, or a vector
413 size_t num_nonunit_dims = 0;
414 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
415 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
416 num_nonunit_dims += 1;
417 }
418 }
419 if (num_nonunit_dims > 1) {
420 return 0;
421 }
422 }
423
424 if (subgraph->values[node->inputs[1]].data != NULL) {
425 // Check that the second input is representable as either a scalar, or a vector
426 size_t num_nonunit_dims = 0;
427 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
428 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
429 num_nonunit_dims += 1;
430 }
431 }
432 if (num_nonunit_dims > 1) {
433 return 0;
434 }
435 }
436
437 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
438 case xnn_node_type_static_resize_bilinear_2d:
439 if (subgraph->values[node->inputs[0]].shape.dim[1] > 1 &&
440 subgraph->values[node->inputs[0]].shape.dim[2] > 1) {
441 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
442 } else {
443 xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
444 xnn_node_type_to_string(node->type));
445 return 0;
446 }
447 case xnn_node_type_abs:
448 case xnn_node_type_bankers_rounding:
449 case xnn_node_type_ceiling:
450 case xnn_node_type_clamp:
451 case xnn_node_type_elu:
452 case xnn_node_type_floor:
453 case xnn_node_type_hardswish:
454 case xnn_node_type_leaky_relu:
455 case xnn_node_type_negate:
456 case xnn_node_type_sigmoid:
457 case xnn_node_type_square:
458 assert(node->num_inputs == 1);
459 assert(node->num_outputs == 1);
460 if (subgraph->values[node->inputs[0]].shape.num_dims == 4) {
461 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
462 } else {
463 xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
464 xnn_node_type_to_string(node->type));
465 return 0;
466 }
467 default:
468 return false;
469 }
470}
471
472void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
473{
474 // Convert parts of the subgraph to NCHW for sparse inference
475 // Step 1: detect NCHW-compatible Nodes
476 // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
477 // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
478 // Step 4: switch Values' layout to NCHW
479 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
480 struct xnn_node* node = &subgraph->nodes[n];
481 node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
482 xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
483 n, xnn_node_type_to_string(node->type),
484 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
485 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
486 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
487 }
488
489 // Run Shiloach-Vishkin connected components algorithm i.e. find all
490 // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders
491 // to all the producer nodes
492 bool update = false;
493 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
494 struct xnn_node* node = &subgraph->nodes[n];
495 node->cluster_leader = n;
496 if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
497 for (uint32_t i = 0; i < node->num_inputs; i++) {
498 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
499 if (value->data != NULL) {
500 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
501 // during the initial NCHW compatibility check for the Node.
502 continue;
503 }
504 if (xnn_value_is_external(value)) {
505 // External value, invalid cluster
506 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
507 continue;
508 }
509 const uint32_t producer_id = value->producer;
510 assert(producer_id != XNN_INVALID_NODE_ID);
511 assert(producer_id < n);
512 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
513 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
514 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
515 {
516 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
517 if (producer_node->cluster_leader != node->cluster_leader) {
518 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
519 update = true;
520 }
521 } else {
522 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
523 }
524 }
525 }
526 }
527 // No NCHW2NHWC compatible nodes have been found thus the graph rewriting
528 // practically cannot happen.
529 if (!update) {
530 return;
531 }
532 // Propagate the cluster leader to other nodes in the graph untill all the
533 // nodes in the cluster is not updated
534 while (update) {
535 update = false;
536 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
537 struct xnn_node* node = &subgraph->nodes[n];
538 if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
539 continue;
540 }
541
542 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
543 continue;
544 }
545
546 for (uint32_t i = 0; i < node->num_inputs; i++) {
547 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
548 if (value->data != NULL) {
549 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
550 // during the initial NCHW compatibility check for the Node.
551 continue;
552 }
553 if (xnn_value_is_external(value)) {
554 // External value, invalid cluster
555 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
556 continue;
557 }
558 const uint32_t producer_id = value->producer;
559 assert(producer_id != XNN_INVALID_NODE_ID);
560 assert(producer_id < n);
561 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
562 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
563 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
564 {
565 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
566 if (producer_node->cluster_leader != node->cluster_leader) {
567 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
568 update = true;
569 }
570 } else {
571 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
572 }
573 }
574 }
575 }
576 // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
577 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
578 struct xnn_node* node = &subgraph->nodes[n];
579 subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
580 }
581 // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
582 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
583 struct xnn_node* node = &subgraph->nodes[n];
584 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
585 continue;
586 }
587
588 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
589 continue;
590 }
591
592 for (uint32_t i = 0; i < node->num_inputs; i++) {
593 struct xnn_value* value = &subgraph->values[node->inputs[i]];
594 if (value->data != NULL) {
595 // Static data, skip this input value because it doesn't have a producer Node.
596 continue;
597 }
598 assert(!xnn_value_is_external(value));
599 value->num_nchw_compatible_consumers += 1;
600 }
601 }
602 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
603 struct xnn_node* node = &subgraph->nodes[n];
604 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
605 continue;
606 }
607
608 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
609 continue;
610 }
611
612 for (uint32_t i = 0; i < node->num_inputs; i++) {
613 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
614 if (value->data != NULL) {
615 // Static data, skip this input value because it doesn't have a producer Node.
616 continue;
617 }
618 assert(!xnn_value_is_external(value));
619 assert(value->num_nchw_compatible_consumers > 0);
620 if (value->num_nchw_compatible_consumers != value->num_consumers) {
621 subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
622 }
623 }
624 }
625 // Evaluate if it is profitable to run the model as sparse:
626 // - Compute the number of parameters and zeroes in 1x1 Convolution weights
627 // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0)
628 // or with less than 2/3rd of zeroes in 1x1 Convolution filters
629 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
630 struct xnn_node* node = &subgraph->nodes[n];
631 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
632 continue;
633 }
634
635 if (node->type == xnn_node_type_convolution_2d &&
636 max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1)
637 {
638 assert(node->num_inputs >= 2);
639
640 const struct xnn_value* filter = &subgraph->values[node->inputs[1]];
641 assert(filter->data != NULL);
642 assert(filter->shape.num_dims == 4);
643
644 const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3];
645 subgraph->nodes[node->cluster_leader].num_params += num_params;
646
647 const float* data = (const float*) filter->data;
648 size_t num_zeroes = 0;
649 for (size_t i = 0; i < num_params; i++) {
650 num_zeroes += (size_t) (data[i] == 0.0f);
651 }
652 xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params);
653 subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes;
654 }
655 }
656 bool use_nchw_layout = false;
657 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
658 struct xnn_node* node = &subgraph->nodes[n];
659 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
660 continue;
661 }
662
663 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
664 continue;
665 }
666
667 if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) {
668 xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights",
669 n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params);
670 continue;
671 }
672
673 for (uint32_t i = 0; i < node->num_inputs; i++) {
674 struct xnn_value* value = &subgraph->values[node->inputs[i]];
675 if (value->data != NULL) {
676 // Static data, skip this input value because it doesn't have a producer Node.
677 continue;
678 }
679 assert(!xnn_value_is_external(value));
680 assert(value->num_nchw_compatible_consumers > 0);
681 assert(value->num_nchw_compatible_consumers == value->num_consumers);
682 if (value->layout != xnn_layout_type_nchw) {
683 value->layout = xnn_layout_type_nchw;
684 xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
685 use_nchw_layout = true;
686 }
687 }
688 }
689 if (use_nchw_layout) {
690 xnn_log_info("XNNPACK has switched to sparse inference mode!");
691 }
692}
693
694bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
695{
696 xnn_log_info("Analyzing subgraph for FP16 compatibility");
697
698 // Convert tensors and operators in the subgraph to FP16
699 // 1. Check that all operators in the subgraph are supported in FP16.
700 // 2. Indicate values that must be converted to FP16.
701 // 3. Replace FP32 Values with FP16 Values as Nodes' inputs/outputs.
702 // 4. Insert FP32->FP16 Convert Nodes for external FP32 inputs and FP16->FP32 Convert Nodes for external outputs.
703
704 // Check that all operators in the subgraph are supported in FP16, bail out on any unsupported one.
705 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
706 struct xnn_node* node = &subgraph->nodes[n];
707 if (node->type == xnn_node_type_invalid) {
708 // Node was fused away, skip.
709 continue;
710 }
711
712 if (node->compute_type != xnn_compute_type_fp32) {
713 xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not FP32", n, xnn_node_type_to_string(node->type));
714 return false;
715 }
716 for (uint32_t i = 0; i < node->num_inputs; i++) {
717 if (subgraph->values[node->inputs[i]].layout == xnn_layout_type_nchw) {
718 xnn_log_warning(
719 "FP16 rewrite aborted: input #%" PRIu32 " (Value #%" PRIu32 ") of node #%" PRIu32 " (%s) has NCHW layout",
720 i, node->inputs[i], n, xnn_node_type_to_string(node->type));
721 return false;
722 }
723 }
724 for (uint32_t o = 0; o < node->num_outputs; o++) {
725 if (subgraph->values[node->outputs[o]].layout == xnn_layout_type_nchw) {
726 xnn_log_warning(
727 "FP16 rewrite aborted: output #%" PRIu32 " (Value #%" PRIu32 ") of node #%" PRIu32 " (%s) has NCHW layout",
728 o, node->outputs[o], n, xnn_node_type_to_string(node->type));
729 return false;
730 }
731 }
732 switch (node->type) {
733 case xnn_node_type_abs:
734 case xnn_node_type_add2:
735 case xnn_node_type_divide:
736 case xnn_node_type_maximum2:
737 case xnn_node_type_minimum2:
738 case xnn_node_type_multiply2:
739 case xnn_node_type_concatenate2:
740 case xnn_node_type_concatenate3:
741 case xnn_node_type_concatenate4:
742 case xnn_node_type_squared_difference:
743 case xnn_node_type_subtract:
744 for (uint32_t i = 0; i < node->num_inputs; i++) {
745 if (subgraph->values[node->inputs[i]].data != NULL) {
746 xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) has static input %" PRIu32,
747 n, xnn_node_type_to_string(node->type), i);
748 return false;
749 }
750 }
751 break;
752 case xnn_node_type_average_pooling_2d:
753 case xnn_node_type_bankers_rounding:
754 case xnn_node_type_ceiling:
755 case xnn_node_type_clamp:
756 case xnn_node_type_convolution_2d:
757 case xnn_node_type_deconvolution_2d:
758 case xnn_node_type_depthwise_convolution_2d:
759 case xnn_node_type_depth_to_space:
760 case xnn_node_type_elu:
761 case xnn_node_type_even_split2:
762 case xnn_node_type_even_split3:
763 case xnn_node_type_even_split4:
764 case xnn_node_type_floor:
765 case xnn_node_type_fully_connected:
766 case xnn_node_type_global_average_pooling_2d:
767 case xnn_node_type_hardswish:
768 case xnn_node_type_leaky_relu:
769 case xnn_node_type_max_pooling_2d:
770 case xnn_node_type_negate:
771 case xnn_node_type_prelu:
772 case xnn_node_type_sigmoid:
773 case xnn_node_type_softmax:
774 case xnn_node_type_static_constant_pad:
775 case xnn_node_type_static_reshape:
776 case xnn_node_type_static_resize_bilinear_2d:
777 case xnn_node_type_static_transpose:
778 case xnn_node_type_square:
779 case xnn_node_type_square_root:
780 break;
781 default:
782 xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference",
783 n, xnn_node_type_to_string(node->type));
784 return false;
785 }
786 }
787
788 // Annotate Values to be converted to FP16 as FP16-compatible.
789 // Note that static weights in [Depthwise] Convolution, Fully Connected, and PReLU Nodes remain FP32,
790 // they will be converted to FP16 during weight repacking when the operator is created.
791 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
792 struct xnn_node* node = &subgraph->nodes[n];
793 switch (node->type) {
794 case xnn_node_type_convolution_2d:
795 case xnn_node_type_deconvolution_2d:
796 case xnn_node_type_depthwise_convolution_2d:
797 case xnn_node_type_fully_connected:
798 case xnn_node_type_prelu:
799 subgraph->values[node->inputs[0]].fp16_compatible = true;
800 subgraph->values[node->outputs[0]].fp16_compatible = true;
801 break;
802 default:
803 for (uint32_t i = 0; i < node->num_inputs; i++) {
804 subgraph->values[node->inputs[i]].fp16_compatible = true;
805 }
806 for (uint32_t o = 0; o < node->num_outputs; o++) {
807 subgraph->values[node->outputs[o]].fp16_compatible = true;
808 }
809 break;
810 }
811 }
812
813 // Replace FP32 Values in Nodes' inputs/outputs with FP16 Values.
814 // FP32 Values that are not external inputs or outputs are converted to FP16 in-place,
815 // for external inputs and outputs we create same-shaped FP16 Values and use those instead.
816 const uint32_t num_original_values = subgraph->num_values;
817 xnn_subgraph_analyze_consumers_and_producers(subgraph);
818 for (uint32_t n = 0; n < num_original_values; n++) {
819 struct xnn_value* value = &subgraph->values[n];
820 value->fp16_id = XNN_INVALID_VALUE_ID;
821 value->fp32_id = XNN_INVALID_VALUE_ID;
822 if (value->fp16_compatible) {
823 assert(value->data == NULL);
824 assert(value->datatype == xnn_datatype_fp32);
825 if (xnn_value_is_external(value)) {
826 struct xnn_value* fp16_value = xnn_subgraph_new_internal_value(subgraph);
827
828 // Recompute value due to potential reallocation in xnn_subgraph_new_internal_value
829 value = &subgraph->values[n];
830 xnn_value_copy(fp16_value, value);
831 fp16_value->datatype = xnn_datatype_fp16;
832
833 fp16_value->producer = value->producer;
834 fp16_value->num_consumers = value->num_consumers;
835 fp16_value->first_consumer = value->first_consumer;
836 value->producer = XNN_INVALID_NODE_ID;
837 value->num_consumers = 0;
838 value->first_consumer = XNN_INVALID_NODE_ID;
839
840 // Clear external input/output flags
841 fp16_value->flags = 0;
842 xnn_log_debug("FP16 rewrite: created FP16 tensor #%" PRIu32 " for FP32 tensor #%" PRIu32, fp16_value->id, n);
843
844 value->fp16_id = fp16_value->id;
845 fp16_value->fp32_id = n;
846 } else {
847 xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n);
848 value->datatype = xnn_datatype_fp16;
849 }
850 }
851 }
852 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
853 struct xnn_node* node = &subgraph->nodes[n];
854 if (node->type == xnn_node_type_invalid) {
855 // Node was fused away, skip.
856 continue;
857 }
858
859 assert(node->compute_type == xnn_compute_type_fp32);
860 node->compute_type = xnn_compute_type_fp16;
861 if (node->type == xnn_node_type_static_constant_pad) {
862 node->params.static_pad.padding_value =
863 fp16_ieee_from_fp32_value(uint32_as_float(node->params.static_pad.padding_value));
864 }
865 for (uint32_t i = 0; i < node->num_inputs; i++) {
866 const uint32_t fp16_id = subgraph->values[node->inputs[i]].fp16_id;
867 if (fp16_id != XNN_INVALID_VALUE_ID) {
868 assert(subgraph->values[fp16_id].fp32_id == node->inputs[i]);
869 node->inputs[i] = fp16_id;
870 }
871 }
872 for (uint32_t o = 0; o < node->num_outputs; o++) {
873 const uint32_t fp16_id = subgraph->values[node->outputs[o]].fp16_id;
874 if (fp16_id != XNN_INVALID_VALUE_ID) {
875 assert(subgraph->values[fp16_id].fp32_id == node->outputs[o]);
876 node->outputs[o] = fp16_id;
877 }
878 }
879 }
880
881 // Count the number of external inputs and outputs which require Convert nodes
882 uint32_t num_external_inputs = 0;
883 uint32_t num_external_outputs = 0;
884 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
885 const struct xnn_node* node = &subgraph->nodes[n];
886 for (uint32_t i = 0; i < node->num_inputs; i++) {
887 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
888 if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n) {
889 assert(value->data == NULL);
890 assert(value->datatype == xnn_datatype_fp16);
891 assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
892 // This value isn't always an external input, it could be an external output of the current subgraph (due to
893 // partition), and be simultaneously consumed by the current node.
894 if (xnn_value_is_external_input(&subgraph->values[value->fp32_id])) {
895 num_external_inputs += 1;
896 }
897 }
898 }
899 for (uint32_t o = 0; o < node->num_outputs; o++) {
900 const struct xnn_value* value = &subgraph->values[node->outputs[o]];
901 if (value->fp32_id != XNN_INVALID_VALUE_ID) {
902 assert(value->datatype == xnn_datatype_fp16);
903 assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
904 assert(xnn_value_is_external_output(&subgraph->values[value->fp32_id]));
905 num_external_outputs += 1;
906 }
907 }
908 }
909 xnn_log_debug("Discovered %"PRIu32" external inputs and %"PRIu32" external outputs",
910 num_external_inputs, num_external_outputs);
911
912 const uint32_t num_original_nodes = subgraph->num_nodes;
913 xnn_subgraph_add_nodes(subgraph, num_external_inputs + num_external_outputs);
914 struct xnn_node* output_node = subgraph->nodes + subgraph->num_nodes - 1;
915 for (uint32_t n = num_original_nodes; n != 0; n--) {
916 const struct xnn_node* node = &subgraph->nodes[n - 1];
917 // Insert Convert nodes for outputs
918 for (uint32_t o = 0; o < node->num_outputs; o++) {
919 const struct xnn_value* value = &subgraph->values[node->outputs[o]];
920 if (value->fp32_id != XNN_INVALID_VALUE_ID) {
921 xnn_log_debug("Inserted FP16->FP32 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
922 value->id, value->fp32_id);
923 const uint32_t output_node_id = output_node->id;
924 assert(output_node >= subgraph->nodes);
925 xnn_node_clear(output_node);
926 output_node->id = output_node_id;
927 xnn_init_convert_node(output_node, xnn_compute_type_fp16_to_fp32, value->id, value->fp32_id, 0 /* flags */);
928 output_node -= 1;
929 }
930 }
931 // Move the Node to the new location
932 if (output_node != node) {
933 const uint32_t output_node_id = output_node->id;
934 assert(output_node >= subgraph->nodes);
935 memcpy(output_node, node, sizeof(struct xnn_node));
936 output_node->id = output_node_id;
937 output_node -= 1;
938 }
939 // Insert Convert nodes for inputs
940 for (uint32_t i = 0; i < node->num_inputs; i++) {
941 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
942 if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n - 1) {
943 // Only insert convert nodes if the value actually is an external input. This value could be an external output,
944 // if that's the case, we have already inserted a convert node in loop above for outputs.
945 if (xnn_value_is_external_input(&subgraph->values[value->fp32_id])) {
946 xnn_log_debug("Inserted FP32->FP16 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
947 value->fp32_id, value->id);
948 const uint32_t output_node_id = output_node->id;
949 assert(output_node >= subgraph->nodes);
950 xnn_node_clear(output_node);
951 output_node->id = output_node_id;
952 xnn_init_convert_node(output_node, xnn_compute_type_fp32_to_fp16, value->fp32_id, value->id, 0 /* flags */);
953 output_node -= 1;
954 }
955 }
956 }
957 }
958
959 return true;
960}
961
962enum xnn_status xnn_subgraph_fusion(
963 xnn_subgraph_t subgraph)
964{
965 // Fuse Nodes where possible
966 for (uint32_t i = 0; i < subgraph->num_values; i++) {
967 struct xnn_value* value = &subgraph->values[i];
968 if (value->num_consumers == 1) {
969 const uint32_t producer_id = value->producer;
970 if (producer_id == XNN_INVALID_NODE_ID) {
971 continue;
972 }
973 assert(producer_id < subgraph->num_nodes);
974
975 const uint32_t consumer_id = value->first_consumer;
976 if (consumer_id == XNN_INVALID_NODE_ID) {
977 continue;
978 }
979 assert(consumer_id < subgraph->num_nodes);
980
981 struct xnn_node* producer = &subgraph->nodes[producer_id];
982 assert(producer->type != xnn_node_type_invalid);
983 struct xnn_node* consumer = &subgraph->nodes[consumer_id];
984 assert(consumer->type != xnn_node_type_invalid);
985
986 // Try to fuse Clamp Node upstream into producer Node
987 if (consumer->type == xnn_node_type_clamp) {
988 switch (producer->type) {
989 case xnn_node_type_add2:
990 case xnn_node_type_average_pooling_2d:
991 case xnn_node_type_clamp:
992 case xnn_node_type_convolution_2d:
993 case xnn_node_type_divide:
994 case xnn_node_type_deconvolution_2d:
995 case xnn_node_type_depthwise_convolution_2d:
996 case xnn_node_type_fully_connected:
997 case xnn_node_type_multiply2:
998 case xnn_node_type_max_pooling_2d:
999 case xnn_node_type_subtract:
1000 xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id);
1001 assert(producer->num_outputs == 1);
1002 assert(consumer->num_inputs == 1);
1003 assert(consumer->num_outputs == 1);
1004
1005 const uint32_t fused_output_id = consumer->outputs[0];
1006 assert(fused_output_id < subgraph->num_values);
1007 subgraph->values[fused_output_id].producer = producer_id;
1008 producer->outputs[0] = fused_output_id;
1009
1010 producer->activation.output_min =
1011 math_max_f32(producer->activation.output_min, consumer->activation.output_min);
1012 producer->activation.output_max =
1013 math_min_f32(producer->activation.output_max, consumer->activation.output_max);
1014
1015 xnn_node_clear(consumer);
1016 xnn_value_clear(value);
1017 break;
1018 default:
1019 break;
1020 }
1021 }
1022 // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node
1023 if (producer->type == xnn_node_type_static_constant_pad) {
1024 assert(producer->num_inputs == 1);
1025 assert(producer->num_outputs == 1);
1026 const bool is_spatial_2d_padding = value->shape.num_dims == 4 &&
1027 (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] |
1028 producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0;
1029 const enum xnn_datatype padding_datatype = subgraph->values[producer->outputs[0]].datatype;
1030 const uint32_t padding_value = producer->params.static_pad.padding_value;
1031 const bool is_zero_padding =
1032 (padding_datatype == xnn_datatype_fp32 && padding_value == 0) ||
1033 ((padding_datatype == xnn_datatype_qint8 || padding_datatype == xnn_datatype_quint8) &&
1034 padding_value == (uint32_t) (uint8_t) subgraph->values[producer->outputs[0]].quantization.zero_point);
1035 switch (consumer->type) {
1036 case xnn_node_type_convolution_2d:
1037 if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
1038 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32,
1039 consumer_id, producer_id);
1040 assert(consumer->num_inputs >= 1);
1041 assert(consumer->inputs[0] == producer->outputs[0]);
1042
1043 consumer->params.convolution_2d.input_padding_top += producer->params.static_pad.pre_paddings[1];
1044 consumer->params.convolution_2d.input_padding_right += producer->params.static_pad.post_paddings[2];
1045 consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1];
1046 consumer->params.convolution_2d.input_padding_left += producer->params.static_pad.pre_paddings[2];
1047
1048 consumer->inputs[0] = producer->inputs[0];
1049
1050 const uint32_t fused_input_id = producer->inputs[0];
1051 assert(fused_input_id < subgraph->num_values);
1052 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
1053 subgraph->values[fused_input_id].first_consumer = consumer_id;
1054 }
1055
1056 xnn_node_clear(producer);
1057 xnn_value_clear(value);
1058 }
1059 break;
1060 case xnn_node_type_depthwise_convolution_2d:
1061 if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
1062 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32,
1063 consumer_id, producer_id);
1064 assert(consumer->num_inputs >= 1);
1065 assert(consumer->inputs[0] == producer->outputs[0]);
1066
1067 consumer->params.depthwise_convolution_2d.input_padding_top +=
1068 producer->params.static_pad.pre_paddings[1];
1069 consumer->params.depthwise_convolution_2d.input_padding_right +=
1070 producer->params.static_pad.post_paddings[2];
1071 consumer->params.depthwise_convolution_2d.input_padding_bottom +=
1072 producer->params.static_pad.post_paddings[1];
1073 consumer->params.depthwise_convolution_2d.input_padding_left +=
1074 producer->params.static_pad.pre_paddings[2];
1075
1076 consumer->inputs[0] = producer->inputs[0];
1077
1078 const uint32_t fused_input_id = producer->inputs[0];
1079 assert(fused_input_id < subgraph->num_values);
1080 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
1081 subgraph->values[fused_input_id].first_consumer = consumer_id;
1082 }
1083
1084 xnn_node_clear(producer);
1085 xnn_value_clear(value);
1086 }
1087 break;
1088 default:
1089 break;
1090 }
1091 }
1092 }
1093 }
1094
1095 return xnn_status_success;
1096}
1097
1098enum xnn_status xnn_subgraph_optimize(
1099 xnn_subgraph_t subgraph,
1100 uint32_t flags)
1101{
1102 xnn_subgraph_analyze_consumers_and_producers(subgraph);
1103
1104 // Remove unreferenced values.
1105 for (uint32_t i = 0; i < subgraph->num_values; i++) {
1106 struct xnn_value* value = &subgraph->values[i];
1107 if (value->type == xnn_value_type_invalid) {
1108 continue;
1109 }
1110
1111 if (!xnn_value_is_external_input(value) && value->num_consumers == 0 && !xnn_value_is_persistent(value)) {
1112 xnn_value_clear(value);
1113 }
1114 }
1115
1116
1117 if (!(flags & XNN_FLAG_NO_OPERATOR_FUSION)) {
1118 xnn_subgraph_fusion(subgraph);
1119 }
1120
1121 #if XNN_ENABLE_SPARSE
1122 if ((flags & XNN_FLAG_HINT_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) {
1123 xnn_subgraph_rewrite_for_nchw(subgraph);
1124 }
1125 #endif
1126
1127 if ((flags & XNN_FLAG_FORCE_FP16_INFERENCE) && !(xnn_params.init_flags & XNN_INIT_FLAG_F16)) {
1128 xnn_log_error("failed to force FP16 inference: hardware supports neither native nor emulated FP16 operators");
1129 return xnn_status_unsupported_hardware;
1130 }
1131 #ifndef XNN_NO_F16_OPERATORS
1132 const bool try_native_fp16 =
1133 (flags & XNN_FLAG_HINT_FP16_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_F16_NATIVE);
1134 const bool force_fp16 = (flags & XNN_FLAG_FORCE_FP16_INFERENCE);
1135 if (try_native_fp16 || force_fp16) {
1136 const bool fp16_rewrite_succeeded = xnn_subgraph_rewrite_for_fp16(subgraph);
1137 if (force_fp16 && !fp16_rewrite_succeeded) {
1138 xnn_log_error("failed to force FP16 inference: subgraph is incompatible with FP16 operators");
1139 return xnn_status_unsupported_parameter;
1140 }
1141 }
1142 #endif // XNN_NO_F16_OPERATORS
1143
1144 return xnn_status_success;
1145}
1146
1147enum xnn_status xnn_delete_subgraph(
1148 xnn_subgraph_t subgraph)
1149{
1150 if (subgraph != NULL) {
1151 memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes);
1152 xnn_release_memory(subgraph->nodes);
1153
1154 memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values);
1155 xnn_release_memory(subgraph->values);
1156
1157 memset(subgraph, 0, sizeof(struct xnn_subgraph));
1158 xnn_release_memory(subgraph);
1159 }
1160 return xnn_status_success;
1161}
1162