1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include <assert.h>
10#include <stdbool.h>
11#include <stddef.h>
12#include <stdint.h>
13#include <string.h>
14#include <math.h>
15
16#include <qnnpack.h>
17#include <qnnpack/operator.h>
18#include <qnnpack/requantization.h>
19#include <qnnpack/log.h>
20#include <qnnpack/math.h>
21#include <qnnpack/pack.h>
22#include <qnnpack/params.h>
23
24
25enum qnnp_status qnnp_create_fully_connected_nc_q8(
26 size_t input_channels,
27 size_t output_channels,
28 uint8_t input_zero_point,
29 float input_scale,
30 uint8_t kernel_zero_point,
31 float kernel_scale,
32 const uint8_t* kernel,
33 const int32_t* bias,
34 uint8_t output_zero_point,
35 float output_scale,
36 uint8_t output_min,
37 uint8_t output_max,
38 uint32_t flags,
39 qnnp_operator_t* fully_connected_out)
40{
41 qnnp_operator_t fully_connected = NULL;
42 enum qnnp_status status = qnnp_status_uninitialized;
43
44 if (!qnnp_params.initialized) {
45 qnnp_log_error("qnnp_create_fully_connected_nc_q8 failed because QNNPACK is not properly initialized");
46 goto error;
47 }
48
49 status = qnnp_status_invalid_parameter;
50
51 if (input_scale <= 0.0f || !isnormal(input_scale)) {
52 qnnp_log_error(
53 "failed to create fully connected operator with %.7g input scale: scale must be finite and positive", input_scale);
54 goto error;
55 }
56
57 if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
58 qnnp_log_error(
59 "failed to create fully connected operator with %.7g kernel scale: scale must be finite and positive", kernel_scale);
60 goto error;
61 }
62
63 if (output_scale <= 0.0f || !isnormal(output_scale)) {
64 qnnp_log_error(
65 "failed to create fully connected operator with %.7g output scale: scale must be finite and positive", output_scale);
66 goto error;
67 }
68
69 status = qnnp_status_unsupported_parameter;
70
71 const float requantization_scale = input_scale * kernel_scale / output_scale;
72 if (requantization_scale >= 1.0f) {
73 qnnp_log_error(
74 "failed to create fully connected operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
75 "requantization scale %.7g is greater or equal to 1.0",
76 input_scale, kernel_scale, output_scale, requantization_scale);
77 goto error;
78 }
79
80 status = qnnp_status_out_of_memory;
81
82 fully_connected = calloc(1, sizeof(struct qnnp_operator));
83 if (fully_connected == NULL) {
84 qnnp_log_error("failed to allocate %zu bytes for qnnp_operator structure", sizeof(struct qnnp_operator));
85 goto error;
86 }
87
88 const uint32_t nr = qnnp_params.q8conv.nr;
89 const uint32_t kr = qnnp_params.q8conv.kr;
90
91 const uint32_t n_stride = (output_channels + (nr - 1)) & -nr;
92 const uint32_t k_stride = (input_channels + (kr - 1)) & -kr;
93
94 fully_connected->packed_weights = malloc(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
95 if (fully_connected->packed_weights == NULL) {
96 qnnp_log_error("failed to allocate %zu bytes for packed weights",
97 n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
98 goto error;
99 }
100 memset(fully_connected->packed_weights, kernel_zero_point, n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
101
102 pack_q8gemm_w(
103 output_channels, input_channels,
104 nr, nr, kr,
105 input_zero_point, kernel_zero_point,
106 kernel, bias,
107 fully_connected->packed_weights);
108
109 fully_connected->groups = 1;
110 fully_connected->group_input_channels = input_channels;
111 fully_connected->group_output_channels = output_channels;
112
113 fully_connected->kernel_zero_point = kernel_zero_point;
114
115 fully_connected->conv_quantization_params =
116 qnnp_compute_conv_quantization_params(
117 input_zero_point, kernel_zero_point,
118 requantization_scale, output_zero_point, output_min, output_max);
119
120 fully_connected->ukernel_type = qnnp_ukernel_type_gemm;
121 fully_connected->format = qnnp_format_quint8;
122
123 *fully_connected_out = fully_connected;
124 return qnnp_status_success;
125
126error:
127 qnnp_delete_operator(fully_connected);
128 return status;
129}
130
131enum qnnp_status qnnp_setup_fully_connected_nc_q8(
132 qnnp_operator_t fully_connected,
133 size_t batch_size,
134 const uint8_t* input,
135 size_t input_stride,
136 uint8_t* output,
137 size_t output_stride)
138{
139 if (!qnnp_params.initialized) {
140 qnnp_log_error("qnnp_setup_fully_connected_nc_q8 failed because QNNPACK is not properly initialized");
141 return qnnp_status_uninitialized;
142 }
143
144 if (batch_size == 0) {
145 fully_connected->batch_size = 0;
146 return qnnp_status_success;
147 }
148
149 fully_connected->batch_size = 1;
150 fully_connected->input_height = batch_size;
151 fully_connected->input_width = 1;
152 fully_connected->input = input;
153 fully_connected->input_pixel_stride = input_stride;
154
155 fully_connected->output_height = batch_size;
156 fully_connected->output_width = 1;
157 fully_connected->output = output;
158 fully_connected->output_pixel_stride = output_stride;
159
160 return qnnp_status_success;
161}
162