1#include <stddef.h>
2
3#include <nnpack.h>
4#include <nnpack/macros.h>
5#include <nnpack/utils.h>
6#include <nnpack/system.h>
7
8#include <nnpack/hwinfo.h>
9#include <nnpack/validation.h>
10
11
12struct NNP_CACHE_ALIGN input_packing_context {
13 const float* matrix;
14 float* packed_matrix;
15
16 size_t input_channels;
17 size_t outer_subblock_max;
18};
19
20static void pack_input_matrix(
21 const struct input_packing_context context[restrict static 1],
22 size_t outer_block_start, size_t input_channels_block_start,
23 size_t outer_block_size, size_t input_channels_block_size)
24{
25 const float* matrix = context->matrix;
26 float* packed_matrix = context->packed_matrix;
27 const size_t input_channels = context->input_channels;
28 const size_t outer_subblock_max = context->outer_subblock_max;
29
30 for (size_t outer_subblock_start = 0; outer_subblock_start < outer_block_size; outer_subblock_start += outer_subblock_max) {
31 const size_t outer_subblock_size = min(outer_block_size - outer_subblock_start, outer_subblock_max);
32 for (size_t input_channels_block_offset = 0; input_channels_block_offset < input_channels_block_size; input_channels_block_offset += 1) {
33 const size_t input_channel = input_channels_block_start + input_channels_block_offset;
34 for (size_t outer_subblock_offset = 0; outer_subblock_offset < outer_subblock_size; outer_subblock_offset += 1) {
35 const size_t index = (outer_block_start + outer_subblock_start + outer_subblock_offset) * input_channels + input_channel;
36 const size_t packed_index = outer_block_start * input_channels + input_channels_block_start * outer_block_size +
37 outer_subblock_start * input_channels_block_size + input_channels_block_offset * outer_subblock_size + outer_subblock_offset;
38 packed_matrix[packed_index] = matrix[index];
39 }
40 }
41 }
42}
43
44struct NNP_CACHE_ALIGN kernel_packing_context {
45 const float* matrix;
46 float* packed_matrix;
47
48 size_t simd_width;
49 size_t input_channels;
50 size_t outer_subblock_max;
51 size_t input_channels_block_start;
52 size_t input_channels_block_size;
53};
54
55static void pack_kernel_matrix(
56 const struct kernel_packing_context context[restrict static 1],
57 size_t outer_block_start, size_t outer_block_size)
58{
59 const float* matrix = context->matrix;
60 float* packed_matrix = context->packed_matrix;
61 const size_t input_channels = context->input_channels;
62 const size_t outer_subblock_max = context->outer_subblock_max;
63 const size_t input_channels_block_start = context->input_channels_block_start;
64 const size_t input_channels_block_size = context->input_channels_block_size;
65 const size_t simd_width = context->simd_width;
66
67 for (size_t outer_subblock_start = 0; outer_subblock_start < outer_block_size; outer_subblock_start += outer_subblock_max) {
68 const size_t outer_subblock_size = min(outer_block_size - outer_subblock_start, outer_subblock_max);
69 const size_t outer_subblock_stride = round_up(outer_subblock_size, simd_width);
70 for (size_t input_channels_block_offset = 0; input_channels_block_offset < input_channels_block_size; input_channels_block_offset += 1) {
71 const size_t input_channel = input_channels_block_start + input_channels_block_offset;
72 for (size_t outer_subblock_offset = 0; outer_subblock_offset < outer_subblock_size; outer_subblock_offset += 1) {
73 const size_t index = (outer_block_start + outer_subblock_start + outer_subblock_offset) * input_channels + input_channel;
74 const size_t packed_index = (outer_block_start + outer_subblock_start) * input_channels_block_size +
75 input_channels_block_offset * outer_subblock_stride + outer_subblock_offset;
76 packed_matrix[packed_index] = matrix[index];
77 }
78 }
79 }
80}
81
82struct NNP_CACHE_ALIGN matrix_multiplication_context {
83 const float* input;
84 const float* kernel;
85 float* output;
86 size_t input_channels;
87 size_t output_channels;
88 size_t batch_block_start;
89 size_t batch_block_size;
90 size_t input_channels_block_start;
91 size_t input_channels_block_size;
92 size_t output_channels_subblock_max;
93 size_t batch_subblock_max;
94 size_t simd_width;
95 nnp_fast_sgemm_function fast_sgemm_function;
96 nnp_full_sgemm_function full_sgemm_function;
97};
98
99static void compute_matrix_multiplication(
100 const struct matrix_multiplication_context context[restrict static 1],
101 size_t output_channels_block_start, size_t batch_subblock_start,
102 size_t output_channels_block_size, size_t batch_subblock_size)
103{
104 const float* input = context->input;
105 const float* kernel = context->kernel;
106 float* output = context->output;
107 const size_t input_channels = context->input_channels;
108 const size_t output_channels = context->output_channels;
109 const size_t input_channels_block_start = context->input_channels_block_start;
110 const size_t input_channels_block_size = context->input_channels_block_size;
111 const size_t batch_block_start = context->batch_block_start;
112 const size_t batch_block_size = context->batch_block_size;
113 const size_t output_channels_subblock_max = context->output_channels_subblock_max;
114 const size_t batch_subblock_max = context->batch_subblock_max;
115 const size_t simd_width = context->simd_width;
116 const nnp_fast_sgemm_function fast_sgemm = context->fast_sgemm_function;
117 const nnp_full_sgemm_function full_sgemm = context->full_sgemm_function;
118
119 for (size_t output_channels_subblock_start = 0; output_channels_subblock_start < output_channels_block_size; output_channels_subblock_start += output_channels_subblock_max) {
120 const size_t output_channels_subblock_size = min(output_channels_block_size - output_channels_subblock_start, output_channels_subblock_max);
121 if ((batch_subblock_size == batch_subblock_max) && (output_channels_subblock_size == output_channels_subblock_max)) {
122 fast_sgemm(
123 input_channels_block_size, input_channels_block_start,
124 &input[batch_block_start * input_channels + input_channels_block_start * batch_block_size + batch_subblock_start * input_channels_block_size],
125 &kernel[(output_channels_block_start + output_channels_subblock_start) * input_channels_block_size],
126 &output[(batch_block_start + batch_subblock_start) * output_channels + (output_channels_block_start + output_channels_subblock_start)],
127 output_channels);
128 } else {
129 full_sgemm(
130 batch_subblock_size, output_channels_subblock_size,
131 input_channels_block_size, input_channels_block_start,
132 &input[batch_block_start * input_channels + input_channels_block_start * batch_block_size + batch_subblock_start * input_channels_block_size],
133 &kernel[(output_channels_block_start + output_channels_subblock_start) * input_channels_block_size],
134 &output[(batch_block_start + batch_subblock_start) * output_channels + (output_channels_block_start + output_channels_subblock_start)],
135 output_channels);
136 }
137 }
138}
139
140static void compute_fully_connected_output(
141 size_t simd_width,
142 size_t batch_size,
143 size_t batch_block_max,
144 size_t batch_subblock_max,
145 size_t input_channels,
146 size_t input_channels_block_max,
147 size_t output_channels,
148 size_t output_channels_block_max,
149 size_t output_channels_subblock_max,
150 const float* input, const float* kernel, float* output,
151 float* packed_input, float* packed_kernel,
152 pthreadpool_t threadpool,
153 struct nnp_profile* profile)
154{
155 NNP_INPUT_TRANSFORM_START(profile)
156 struct input_packing_context input_packing_context = {
157 .matrix = input,
158 .packed_matrix = packed_input,
159 .input_channels = input_channels,
160 .outer_subblock_max = batch_subblock_max,
161 };
162 pthreadpool_parallelize_2d_tile_2d(threadpool,
163 (pthreadpool_task_2d_tile_2d_t) pack_input_matrix,
164 &input_packing_context,
165 batch_size, input_channels,
166 batch_block_max, input_channels_block_max,
167 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
168 NNP_INPUT_TRANSFORM_END(profile)
169
170 struct matrix_multiplication_context matrix_multiplication_context = {
171 .input = packed_input,
172 .kernel = packed_kernel,
173 .output = output,
174 .input_channels = input_channels,
175 .output_channels = output_channels,
176 .output_channels_subblock_max = output_channels_subblock_max,
177 .batch_subblock_max = batch_subblock_max,
178 .simd_width = simd_width,
179 .fast_sgemm_function = nnp_hwinfo.sgemm.only_mr_x_nr,
180 .full_sgemm_function = nnp_hwinfo.sgemm.upto_mr_x_nr,
181 };
182 for (size_t input_channels_block_start = 0; input_channels_block_start < input_channels; input_channels_block_start += input_channels_block_max) {
183 const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max);
184
185 NNP_KERNEL_TRANSFORM_START(profile)
186 struct kernel_packing_context kernel_packing_context = {
187 .matrix = kernel,
188 .packed_matrix = packed_kernel,
189 .simd_width = simd_width,
190 .input_channels = input_channels,
191 .outer_subblock_max = output_channels_subblock_max,
192 .input_channels_block_start = input_channels_block_start,
193 .input_channels_block_size = input_channels_block_size,
194 };
195 pthreadpool_parallelize_1d_tile_1d(threadpool,
196 (pthreadpool_task_1d_tile_1d_t) pack_kernel_matrix,
197 &kernel_packing_context,
198 output_channels, output_channels_block_max,
199 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
200 NNP_KERNEL_TRANSFORM_END(profile)
201
202 NNP_BLOCK_MULTIPLICATION_START(profile)
203 matrix_multiplication_context.input_channels_block_start = input_channels_block_start;
204 matrix_multiplication_context.input_channels_block_size = input_channels_block_size;
205 for (size_t batch_block_start = 0; batch_block_start < batch_size; batch_block_start += batch_block_max) {
206 const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max);
207
208 matrix_multiplication_context.batch_block_start = batch_block_start;
209 matrix_multiplication_context.batch_block_size = batch_block_size;
210 pthreadpool_parallelize_2d_tile_2d(threadpool,
211 (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication,
212 &matrix_multiplication_context,
213 output_channels, batch_block_size,
214 output_channels_block_max, batch_subblock_max,
215 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
216 }
217 NNP_BLOCK_MULTIPLICATION_END(profile)
218 }
219}
220
221enum nnp_status nnp_fully_connected_output(
222 size_t batch_size,
223 size_t input_channels,
224 size_t output_channels,
225 const float input[],
226 const float kernel[],
227 float output[],
228 pthreadpool_t threadpool,
229 struct nnp_profile* profile)
230{
231 void* memory_block = NULL;
232 NNP_TOTAL_START(profile)
233
234 /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */
235 enum nnp_status status = validate_fully_connected_arguments(batch_size, input_channels, output_channels);
236 if (status != nnp_status_success) {
237 goto cleanup;
238 }
239
240 const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / sizeof(float);
241 const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / sizeof(float);
242 const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / sizeof(float);
243
244 const size_t simd_width = nnp_hwinfo.simd_width;
245 const size_t batch_subblock_max = nnp_hwinfo.sgemm.mr;
246 const size_t output_channels_subblock_max = nnp_hwinfo.sgemm.nr;
247
248 const size_t input_channels_block_max = cache_elements_l1 / (batch_subblock_max + output_channels_subblock_max);
249 const size_t batch_block_max = round_down(cache_elements_l3 / input_channels_block_max, batch_subblock_max);
250 const size_t output_channels_block_max = round_down(cache_elements_l2 / input_channels_block_max, output_channels_subblock_max);
251
252 /* Calculate memory footprint and allocate memory */
253 const size_t packed_input_size = round_up(batch_size, batch_subblock_max) * input_channels * sizeof(float);
254 /* Extra alignment on 64 is needed to ensure that packed_kernel is always SIMD-aligned */
255 const size_t packed_kernel_offset = round_up(packed_input_size, 64);
256 const size_t packed_kernel_size = round_up(output_channels, output_channels_subblock_max) * input_channels_block_max * sizeof(float);
257 const size_t memory_size = packed_kernel_offset + packed_kernel_size;
258
259 memory_block = allocate_memory(memory_size);
260 if (memory_block == NULL) {
261 status = nnp_status_out_of_memory;
262 goto cleanup;
263 }
264
265 float* packed_input = memory_block;
266 float* packed_kernel = memory_block + packed_kernel_offset;
267
268 /* Do the computation */
269 compute_fully_connected_output(
270 simd_width,
271 batch_size, batch_block_max, batch_subblock_max,
272 input_channels, input_channels_block_max,
273 output_channels, output_channels_block_max, output_channels_subblock_max,
274 input, kernel, output,
275 packed_input, packed_kernel,
276 threadpool,
277 profile);
278
279cleanup:
280 release_memory(memory_block, memory_size);
281 NNP_TOTAL_END(profile)
282 return status;
283}
284