1 | /* |
2 | * Copyright (c) Facebook, Inc. and its affiliates. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include <assert.h> |
10 | #include <stdbool.h> |
11 | #include <stddef.h> |
12 | #include <stdint.h> |
13 | #include <string.h> |
14 | #include <math.h> |
15 | |
16 | #include <qnnpack.h> |
17 | #include <qnnpack/operator.h> |
18 | #include <qnnpack/requantization.h> |
19 | #include <qnnpack/log.h> |
20 | #include <qnnpack/math.h> |
21 | #include <qnnpack/pack.h> |
22 | #include <qnnpack/params.h> |
23 | |
24 | |
25 | enum qnnp_status qnnp_create_fully_connected_nc_q8( |
26 | size_t input_channels, |
27 | size_t output_channels, |
28 | uint8_t input_zero_point, |
29 | float input_scale, |
30 | uint8_t kernel_zero_point, |
31 | float kernel_scale, |
32 | const uint8_t* kernel, |
33 | const int32_t* bias, |
34 | uint8_t output_zero_point, |
35 | float output_scale, |
36 | uint8_t output_min, |
37 | uint8_t output_max, |
38 | uint32_t flags, |
39 | qnnp_operator_t* fully_connected_out) |
40 | { |
41 | qnnp_operator_t fully_connected = NULL; |
42 | enum qnnp_status status = qnnp_status_uninitialized; |
43 | |
44 | if (!qnnp_params.initialized) { |
45 | qnnp_log_error("qnnp_create_fully_connected_nc_q8 failed because QNNPACK is not properly initialized" ); |
46 | goto error; |
47 | } |
48 | |
49 | status = qnnp_status_invalid_parameter; |
50 | |
51 | if (input_scale <= 0.0f || !isnormal(input_scale)) { |
52 | qnnp_log_error( |
53 | "failed to create fully connected operator with %.7g input scale: scale must be finite and positive" , input_scale); |
54 | goto error; |
55 | } |
56 | |
57 | if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) { |
58 | qnnp_log_error( |
59 | "failed to create fully connected operator with %.7g kernel scale: scale must be finite and positive" , kernel_scale); |
60 | goto error; |
61 | } |
62 | |
63 | if (output_scale <= 0.0f || !isnormal(output_scale)) { |
64 | qnnp_log_error( |
65 | "failed to create fully connected operator with %.7g output scale: scale must be finite and positive" , output_scale); |
66 | goto error; |
67 | } |
68 | |
69 | status = qnnp_status_unsupported_parameter; |
70 | |
71 | const float requantization_scale = input_scale * kernel_scale / output_scale; |
72 | if (requantization_scale >= 1.0f) { |
73 | qnnp_log_error( |
74 | "failed to create fully connected operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: " |
75 | "requantization scale %.7g is greater or equal to 1.0" , |
76 | input_scale, kernel_scale, output_scale, requantization_scale); |
77 | goto error; |
78 | } |
79 | |
80 | status = qnnp_status_out_of_memory; |
81 | |
82 | fully_connected = calloc(1, sizeof(struct qnnp_operator)); |
83 | if (fully_connected == NULL) { |
84 | qnnp_log_error("failed to allocate %zu bytes for qnnp_operator structure" , sizeof(struct qnnp_operator)); |
85 | goto error; |
86 | } |
87 | |
88 | const uint32_t nr = qnnp_params.q8conv.nr; |
89 | const uint32_t kr = qnnp_params.q8conv.kr; |
90 | |
91 | const uint32_t n_stride = (output_channels + (nr - 1)) & -nr; |
92 | const uint32_t k_stride = (input_channels + (kr - 1)) & -kr; |
93 | |
94 | fully_connected->packed_weights = malloc(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); |
95 | if (fully_connected->packed_weights == NULL) { |
96 | qnnp_log_error("failed to allocate %zu bytes for packed weights" , |
97 | n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); |
98 | goto error; |
99 | } |
100 | memset(fully_connected->packed_weights, kernel_zero_point, n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); |
101 | |
102 | pack_q8gemm_w( |
103 | output_channels, input_channels, |
104 | nr, nr, kr, |
105 | input_zero_point, kernel_zero_point, |
106 | kernel, bias, |
107 | fully_connected->packed_weights); |
108 | |
109 | fully_connected->groups = 1; |
110 | fully_connected->group_input_channels = input_channels; |
111 | fully_connected->group_output_channels = output_channels; |
112 | |
113 | fully_connected->kernel_zero_point = kernel_zero_point; |
114 | |
115 | fully_connected->conv_quantization_params = |
116 | qnnp_compute_conv_quantization_params( |
117 | input_zero_point, kernel_zero_point, |
118 | requantization_scale, output_zero_point, output_min, output_max); |
119 | |
120 | fully_connected->ukernel_type = qnnp_ukernel_type_gemm; |
121 | fully_connected->format = qnnp_format_quint8; |
122 | |
123 | *fully_connected_out = fully_connected; |
124 | return qnnp_status_success; |
125 | |
126 | error: |
127 | qnnp_delete_operator(fully_connected); |
128 | return status; |
129 | } |
130 | |
131 | enum qnnp_status qnnp_setup_fully_connected_nc_q8( |
132 | qnnp_operator_t fully_connected, |
133 | size_t batch_size, |
134 | const uint8_t* input, |
135 | size_t input_stride, |
136 | uint8_t* output, |
137 | size_t output_stride) |
138 | { |
139 | if (!qnnp_params.initialized) { |
140 | qnnp_log_error("qnnp_setup_fully_connected_nc_q8 failed because QNNPACK is not properly initialized" ); |
141 | return qnnp_status_uninitialized; |
142 | } |
143 | |
144 | if (batch_size == 0) { |
145 | fully_connected->batch_size = 0; |
146 | return qnnp_status_success; |
147 | } |
148 | |
149 | fully_connected->batch_size = 1; |
150 | fully_connected->input_height = batch_size; |
151 | fully_connected->input_width = 1; |
152 | fully_connected->input = input; |
153 | fully_connected->input_pixel_stride = input_stride; |
154 | |
155 | fully_connected->output_height = batch_size; |
156 | fully_connected->output_width = 1; |
157 | fully_connected->output = output; |
158 | fully_connected->output_pixel_stride = output_stride; |
159 | |
160 | return qnnp_status_success; |
161 | } |
162 | |