1 | /* |
2 | * Copyright (c) Facebook, Inc. and its affiliates. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include <assert.h> |
10 | #include <math.h> |
11 | #include <stddef.h> |
12 | #include <stdint.h> |
13 | #include <stdlib.h> |
14 | |
15 | #include <qnnpack.h> |
16 | #include <qnnpack/operator.h> |
17 | #include <qnnpack/requantization.h> |
18 | #include <qnnpack/log.h> |
19 | #include <qnnpack/params.h> |
20 | |
21 | |
22 | enum qnnp_status qnnp_create_add_nc_q8( |
23 | size_t channels, |
24 | uint8_t a_zero_point, |
25 | float a_scale, |
26 | uint8_t b_zero_point, |
27 | float b_scale, |
28 | uint8_t sum_zero_point, |
29 | float sum_scale, |
30 | uint8_t sum_min, |
31 | uint8_t sum_max, |
32 | uint32_t flags, |
33 | qnnp_operator_t* add_out) |
34 | { |
35 | qnnp_operator_t add_op = NULL; |
36 | enum qnnp_status status = qnnp_status_uninitialized; |
37 | |
38 | if (!qnnp_params.initialized) { |
39 | qnnp_log_error("qnnp_create_add_nc_q8 failed because QNNPACK is not properly initialized" ); |
40 | goto error; |
41 | } |
42 | |
43 | status = qnnp_status_invalid_parameter; |
44 | |
45 | if (channels == 0) { |
46 | qnnp_log_error( |
47 | "failed to create add operator with %zu channels: number of channels must be non-zero" , channels); |
48 | goto error; |
49 | } |
50 | |
51 | if (a_scale <= 0.0f || !isnormal(a_scale)) { |
52 | qnnp_log_error( |
53 | "failed to create add operator with %.7g A scale: scale must be finite and positive" , a_scale); |
54 | goto error; |
55 | } |
56 | |
57 | if (b_scale <= 0.0f || !isnormal(b_scale)) { |
58 | qnnp_log_error( |
59 | "failed to create add operator with %.7g B scale: scale must be finite and positive" , b_scale); |
60 | goto error; |
61 | } |
62 | |
63 | if (sum_scale <= 0.0f || !isnormal(sum_scale)) { |
64 | qnnp_log_error( |
65 | "failed to create add operator with %.7g output scale: scale must be finite and positive" , sum_scale); |
66 | goto error; |
67 | } |
68 | |
69 | if (sum_min >= sum_max) { |
70 | qnnp_log_error( |
71 | "failed to create add operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max" , |
72 | sum_min, sum_max); |
73 | goto error; |
74 | } |
75 | |
76 | status = qnnp_status_unsupported_parameter; |
77 | |
78 | const float a_output_scale = a_scale / sum_scale; |
79 | if (a_output_scale < 0x1.0p-14f || a_output_scale >= 0x1.0p+8f) { |
80 | qnnp_log_error( |
81 | "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range" , |
82 | a_output_scale); |
83 | goto error; |
84 | } |
85 | |
86 | const float b_output_scale = b_scale / sum_scale; |
87 | if (b_output_scale < 0x1.0p-14f || b_output_scale >= 0x1.0p+8f) { |
88 | qnnp_log_error( |
89 | "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range" , |
90 | b_output_scale); |
91 | goto error; |
92 | } |
93 | |
94 | status = qnnp_status_out_of_memory; |
95 | |
96 | add_op = calloc(1, sizeof(struct qnnp_operator)); |
97 | if (add_op == NULL) { |
98 | qnnp_log_error("failed to allocate %zu bytes for qnnp_operator structure" , sizeof(struct qnnp_operator)); |
99 | goto error; |
100 | } |
101 | |
102 | add_op->channels = channels; |
103 | add_op->add_quantization_params = |
104 | qnnp_compute_add_quantization_params( |
105 | a_zero_point, b_zero_point, sum_zero_point, |
106 | a_scale / sum_scale, b_scale / sum_scale, |
107 | sum_min, sum_max); |
108 | |
109 | add_op->ukernel_type = qnnp_ukernel_type_add; |
110 | add_op->format = qnnp_format_quint8; |
111 | |
112 | *add_out = add_op; |
113 | return qnnp_status_success; |
114 | |
115 | error: |
116 | qnnp_delete_operator(add_op); |
117 | return status; |
118 | } |
119 | |
120 | enum qnnp_status qnnp_setup_add_nc_q8( |
121 | qnnp_operator_t add_op, |
122 | size_t batch_size, |
123 | const uint8_t* a, |
124 | size_t a_stride, |
125 | const uint8_t* b, |
126 | size_t b_stride, |
127 | uint8_t* sum, |
128 | size_t sum_stride) |
129 | { |
130 | if (!qnnp_params.initialized) { |
131 | qnnp_log_error("qnnp_setup_add_nc_q8 failed because QNNPACK is not properly initialized" ); |
132 | return qnnp_status_uninitialized; |
133 | } |
134 | |
135 | if (batch_size == 0) { |
136 | add_op->batch_size = 0; |
137 | return qnnp_status_success; |
138 | } |
139 | |
140 | add_op->batch_size = batch_size; |
141 | add_op->input = a; |
142 | add_op->input_pixel_stride = a_stride; |
143 | add_op->input2 = b; |
144 | add_op->input2_pixel_stride = b_stride; |
145 | add_op->output = sum; |
146 | add_op->output_pixel_stride = sum_stride; |
147 | |
148 | return qnnp_status_success; |
149 | } |
150 | |