1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include <assert.h>
10#include <math.h>
11#include <stddef.h>
12#include <stdint.h>
13#include <stdlib.h>
14
15#include <qnnpack.h>
16#include <qnnpack/operator.h>
17#include <qnnpack/requantization.h>
18#include <qnnpack/log.h>
19#include <qnnpack/params.h>
20
21
22enum qnnp_status qnnp_create_add_nc_q8(
23 size_t channels,
24 uint8_t a_zero_point,
25 float a_scale,
26 uint8_t b_zero_point,
27 float b_scale,
28 uint8_t sum_zero_point,
29 float sum_scale,
30 uint8_t sum_min,
31 uint8_t sum_max,
32 uint32_t flags,
33 qnnp_operator_t* add_out)
34{
35 qnnp_operator_t add_op = NULL;
36 enum qnnp_status status = qnnp_status_uninitialized;
37
38 if (!qnnp_params.initialized) {
39 qnnp_log_error("qnnp_create_add_nc_q8 failed because QNNPACK is not properly initialized");
40 goto error;
41 }
42
43 status = qnnp_status_invalid_parameter;
44
45 if (channels == 0) {
46 qnnp_log_error(
47 "failed to create add operator with %zu channels: number of channels must be non-zero", channels);
48 goto error;
49 }
50
51 if (a_scale <= 0.0f || !isnormal(a_scale)) {
52 qnnp_log_error(
53 "failed to create add operator with %.7g A scale: scale must be finite and positive", a_scale);
54 goto error;
55 }
56
57 if (b_scale <= 0.0f || !isnormal(b_scale)) {
58 qnnp_log_error(
59 "failed to create add operator with %.7g B scale: scale must be finite and positive", b_scale);
60 goto error;
61 }
62
63 if (sum_scale <= 0.0f || !isnormal(sum_scale)) {
64 qnnp_log_error(
65 "failed to create add operator with %.7g output scale: scale must be finite and positive", sum_scale);
66 goto error;
67 }
68
69 if (sum_min >= sum_max) {
70 qnnp_log_error(
71 "failed to create add operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
72 sum_min, sum_max);
73 goto error;
74 }
75
76 status = qnnp_status_unsupported_parameter;
77
78 const float a_output_scale = a_scale / sum_scale;
79 if (a_output_scale < 0x1.0p-14f || a_output_scale >= 0x1.0p+8f) {
80 qnnp_log_error(
81 "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
82 a_output_scale);
83 goto error;
84 }
85
86 const float b_output_scale = b_scale / sum_scale;
87 if (b_output_scale < 0x1.0p-14f || b_output_scale >= 0x1.0p+8f) {
88 qnnp_log_error(
89 "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
90 b_output_scale);
91 goto error;
92 }
93
94 status = qnnp_status_out_of_memory;
95
96 add_op = calloc(1, sizeof(struct qnnp_operator));
97 if (add_op == NULL) {
98 qnnp_log_error("failed to allocate %zu bytes for qnnp_operator structure", sizeof(struct qnnp_operator));
99 goto error;
100 }
101
102 add_op->channels = channels;
103 add_op->add_quantization_params =
104 qnnp_compute_add_quantization_params(
105 a_zero_point, b_zero_point, sum_zero_point,
106 a_scale / sum_scale, b_scale / sum_scale,
107 sum_min, sum_max);
108
109 add_op->ukernel_type = qnnp_ukernel_type_add;
110 add_op->format = qnnp_format_quint8;
111
112 *add_out = add_op;
113 return qnnp_status_success;
114
115error:
116 qnnp_delete_operator(add_op);
117 return status;
118}
119
120enum qnnp_status qnnp_setup_add_nc_q8(
121 qnnp_operator_t add_op,
122 size_t batch_size,
123 const uint8_t* a,
124 size_t a_stride,
125 const uint8_t* b,
126 size_t b_stride,
127 uint8_t* sum,
128 size_t sum_stride)
129{
130 if (!qnnp_params.initialized) {
131 qnnp_log_error("qnnp_setup_add_nc_q8 failed because QNNPACK is not properly initialized");
132 return qnnp_status_uninitialized;
133 }
134
135 if (batch_size == 0) {
136 add_op->batch_size = 0;
137 return qnnp_status_success;
138 }
139
140 add_op->batch_size = batch_size;
141 add_op->input = a;
142 add_op->input_pixel_stride = a_stride;
143 add_op->input2 = b;
144 add_op->input2_pixel_stride = b_stride;
145 add_op->output = sum;
146 add_op->output_pixel_stride = sum_stride;
147
148 return qnnp_status_success;
149}
150