1 | #include <stddef.h> |
2 | |
3 | #include <nnpack.h> |
4 | #include <nnpack/macros.h> |
5 | #include <nnpack/utils.h> |
6 | #include <nnpack/system.h> |
7 | |
8 | #include <nnpack/hwinfo.h> |
9 | #include <nnpack/validation.h> |
10 | |
11 | |
12 | struct NNP_CACHE_ALIGN input_packing_context { |
13 | const float* matrix; |
14 | float* packed_matrix; |
15 | |
16 | size_t input_channels; |
17 | size_t outer_subblock_max; |
18 | }; |
19 | |
20 | static void pack_input_matrix( |
21 | const struct input_packing_context context[restrict static 1], |
22 | size_t outer_block_start, size_t input_channels_block_start, |
23 | size_t outer_block_size, size_t input_channels_block_size) |
24 | { |
25 | const float* matrix = context->matrix; |
26 | float* packed_matrix = context->packed_matrix; |
27 | const size_t input_channels = context->input_channels; |
28 | const size_t outer_subblock_max = context->outer_subblock_max; |
29 | |
30 | for (size_t outer_subblock_start = 0; outer_subblock_start < outer_block_size; outer_subblock_start += outer_subblock_max) { |
31 | const size_t outer_subblock_size = min(outer_block_size - outer_subblock_start, outer_subblock_max); |
32 | for (size_t input_channels_block_offset = 0; input_channels_block_offset < input_channels_block_size; input_channels_block_offset += 1) { |
33 | const size_t input_channel = input_channels_block_start + input_channels_block_offset; |
34 | for (size_t outer_subblock_offset = 0; outer_subblock_offset < outer_subblock_size; outer_subblock_offset += 1) { |
35 | const size_t index = (outer_block_start + outer_subblock_start + outer_subblock_offset) * input_channels + input_channel; |
36 | const size_t packed_index = outer_block_start * input_channels + input_channels_block_start * outer_block_size + |
37 | outer_subblock_start * input_channels_block_size + input_channels_block_offset * outer_subblock_size + outer_subblock_offset; |
38 | packed_matrix[packed_index] = matrix[index]; |
39 | } |
40 | } |
41 | } |
42 | } |
43 | |
44 | struct NNP_CACHE_ALIGN kernel_packing_context { |
45 | const float* matrix; |
46 | float* packed_matrix; |
47 | |
48 | size_t simd_width; |
49 | size_t input_channels; |
50 | size_t outer_subblock_max; |
51 | size_t input_channels_block_start; |
52 | size_t input_channels_block_size; |
53 | }; |
54 | |
55 | static void pack_kernel_matrix( |
56 | const struct kernel_packing_context context[restrict static 1], |
57 | size_t outer_block_start, size_t outer_block_size) |
58 | { |
59 | const float* matrix = context->matrix; |
60 | float* packed_matrix = context->packed_matrix; |
61 | const size_t input_channels = context->input_channels; |
62 | const size_t outer_subblock_max = context->outer_subblock_max; |
63 | const size_t input_channels_block_start = context->input_channels_block_start; |
64 | const size_t input_channels_block_size = context->input_channels_block_size; |
65 | const size_t simd_width = context->simd_width; |
66 | |
67 | for (size_t outer_subblock_start = 0; outer_subblock_start < outer_block_size; outer_subblock_start += outer_subblock_max) { |
68 | const size_t outer_subblock_size = min(outer_block_size - outer_subblock_start, outer_subblock_max); |
69 | const size_t outer_subblock_stride = round_up(outer_subblock_size, simd_width); |
70 | for (size_t input_channels_block_offset = 0; input_channels_block_offset < input_channels_block_size; input_channels_block_offset += 1) { |
71 | const size_t input_channel = input_channels_block_start + input_channels_block_offset; |
72 | for (size_t outer_subblock_offset = 0; outer_subblock_offset < outer_subblock_size; outer_subblock_offset += 1) { |
73 | const size_t index = (outer_block_start + outer_subblock_start + outer_subblock_offset) * input_channels + input_channel; |
74 | const size_t packed_index = (outer_block_start + outer_subblock_start) * input_channels_block_size + |
75 | input_channels_block_offset * outer_subblock_stride + outer_subblock_offset; |
76 | packed_matrix[packed_index] = matrix[index]; |
77 | } |
78 | } |
79 | } |
80 | } |
81 | |
82 | struct NNP_CACHE_ALIGN matrix_multiplication_context { |
83 | const float* input; |
84 | const float* kernel; |
85 | float* output; |
86 | size_t input_channels; |
87 | size_t output_channels; |
88 | size_t batch_block_start; |
89 | size_t batch_block_size; |
90 | size_t input_channels_block_start; |
91 | size_t input_channels_block_size; |
92 | size_t output_channels_subblock_max; |
93 | size_t batch_subblock_max; |
94 | size_t simd_width; |
95 | nnp_fast_sgemm_function fast_sgemm_function; |
96 | nnp_full_sgemm_function full_sgemm_function; |
97 | }; |
98 | |
99 | static void compute_matrix_multiplication( |
100 | const struct matrix_multiplication_context context[restrict static 1], |
101 | size_t output_channels_block_start, size_t batch_subblock_start, |
102 | size_t output_channels_block_size, size_t batch_subblock_size) |
103 | { |
104 | const float* input = context->input; |
105 | const float* kernel = context->kernel; |
106 | float* output = context->output; |
107 | const size_t input_channels = context->input_channels; |
108 | const size_t output_channels = context->output_channels; |
109 | const size_t input_channels_block_start = context->input_channels_block_start; |
110 | const size_t input_channels_block_size = context->input_channels_block_size; |
111 | const size_t batch_block_start = context->batch_block_start; |
112 | const size_t batch_block_size = context->batch_block_size; |
113 | const size_t output_channels_subblock_max = context->output_channels_subblock_max; |
114 | const size_t batch_subblock_max = context->batch_subblock_max; |
115 | const size_t simd_width = context->simd_width; |
116 | const nnp_fast_sgemm_function fast_sgemm = context->fast_sgemm_function; |
117 | const nnp_full_sgemm_function full_sgemm = context->full_sgemm_function; |
118 | |
119 | for (size_t output_channels_subblock_start = 0; output_channels_subblock_start < output_channels_block_size; output_channels_subblock_start += output_channels_subblock_max) { |
120 | const size_t output_channels_subblock_size = min(output_channels_block_size - output_channels_subblock_start, output_channels_subblock_max); |
121 | if ((batch_subblock_size == batch_subblock_max) && (output_channels_subblock_size == output_channels_subblock_max)) { |
122 | fast_sgemm( |
123 | input_channels_block_size, input_channels_block_start, |
124 | &input[batch_block_start * input_channels + input_channels_block_start * batch_block_size + batch_subblock_start * input_channels_block_size], |
125 | &kernel[(output_channels_block_start + output_channels_subblock_start) * input_channels_block_size], |
126 | &output[(batch_block_start + batch_subblock_start) * output_channels + (output_channels_block_start + output_channels_subblock_start)], |
127 | output_channels); |
128 | } else { |
129 | full_sgemm( |
130 | batch_subblock_size, output_channels_subblock_size, |
131 | input_channels_block_size, input_channels_block_start, |
132 | &input[batch_block_start * input_channels + input_channels_block_start * batch_block_size + batch_subblock_start * input_channels_block_size], |
133 | &kernel[(output_channels_block_start + output_channels_subblock_start) * input_channels_block_size], |
134 | &output[(batch_block_start + batch_subblock_start) * output_channels + (output_channels_block_start + output_channels_subblock_start)], |
135 | output_channels); |
136 | } |
137 | } |
138 | } |
139 | |
140 | static void compute_fully_connected_output( |
141 | size_t simd_width, |
142 | size_t batch_size, |
143 | size_t batch_block_max, |
144 | size_t batch_subblock_max, |
145 | size_t input_channels, |
146 | size_t input_channels_block_max, |
147 | size_t output_channels, |
148 | size_t output_channels_block_max, |
149 | size_t output_channels_subblock_max, |
150 | const float* input, const float* kernel, float* output, |
151 | float* packed_input, float* packed_kernel, |
152 | pthreadpool_t threadpool, |
153 | struct nnp_profile* profile) |
154 | { |
155 | NNP_INPUT_TRANSFORM_START(profile) |
156 | struct input_packing_context input_packing_context = { |
157 | .matrix = input, |
158 | .packed_matrix = packed_input, |
159 | .input_channels = input_channels, |
160 | .outer_subblock_max = batch_subblock_max, |
161 | }; |
162 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
163 | (pthreadpool_task_2d_tile_2d_t) pack_input_matrix, |
164 | &input_packing_context, |
165 | batch_size, input_channels, |
166 | batch_block_max, input_channels_block_max, |
167 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
168 | NNP_INPUT_TRANSFORM_END(profile) |
169 | |
170 | struct matrix_multiplication_context matrix_multiplication_context = { |
171 | .input = packed_input, |
172 | .kernel = packed_kernel, |
173 | .output = output, |
174 | .input_channels = input_channels, |
175 | .output_channels = output_channels, |
176 | .output_channels_subblock_max = output_channels_subblock_max, |
177 | .batch_subblock_max = batch_subblock_max, |
178 | .simd_width = simd_width, |
179 | .fast_sgemm_function = nnp_hwinfo.sgemm.only_mr_x_nr, |
180 | .full_sgemm_function = nnp_hwinfo.sgemm.upto_mr_x_nr, |
181 | }; |
182 | for (size_t input_channels_block_start = 0; input_channels_block_start < input_channels; input_channels_block_start += input_channels_block_max) { |
183 | const size_t input_channels_block_size = min(input_channels - input_channels_block_start, input_channels_block_max); |
184 | |
185 | NNP_KERNEL_TRANSFORM_START(profile) |
186 | struct kernel_packing_context kernel_packing_context = { |
187 | .matrix = kernel, |
188 | .packed_matrix = packed_kernel, |
189 | .simd_width = simd_width, |
190 | .input_channels = input_channels, |
191 | .outer_subblock_max = output_channels_subblock_max, |
192 | .input_channels_block_start = input_channels_block_start, |
193 | .input_channels_block_size = input_channels_block_size, |
194 | }; |
195 | pthreadpool_parallelize_1d_tile_1d(threadpool, |
196 | (pthreadpool_task_1d_tile_1d_t) pack_kernel_matrix, |
197 | &kernel_packing_context, |
198 | output_channels, output_channels_block_max, |
199 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
200 | NNP_KERNEL_TRANSFORM_END(profile) |
201 | |
202 | NNP_BLOCK_MULTIPLICATION_START(profile) |
203 | matrix_multiplication_context.input_channels_block_start = input_channels_block_start; |
204 | matrix_multiplication_context.input_channels_block_size = input_channels_block_size; |
205 | for (size_t batch_block_start = 0; batch_block_start < batch_size; batch_block_start += batch_block_max) { |
206 | const size_t batch_block_size = min(batch_size - batch_block_start, batch_block_max); |
207 | |
208 | matrix_multiplication_context.batch_block_start = batch_block_start; |
209 | matrix_multiplication_context.batch_block_size = batch_block_size; |
210 | pthreadpool_parallelize_2d_tile_2d(threadpool, |
211 | (pthreadpool_task_2d_tile_2d_t) compute_matrix_multiplication, |
212 | &matrix_multiplication_context, |
213 | output_channels, batch_block_size, |
214 | output_channels_block_max, batch_subblock_max, |
215 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
216 | } |
217 | NNP_BLOCK_MULTIPLICATION_END(profile) |
218 | } |
219 | } |
220 | |
221 | enum nnp_status nnp_fully_connected_output( |
222 | size_t batch_size, |
223 | size_t input_channels, |
224 | size_t output_channels, |
225 | const float input[], |
226 | const float kernel[], |
227 | float output[], |
228 | pthreadpool_t threadpool, |
229 | struct nnp_profile* profile) |
230 | { |
231 | void* memory_block = NULL; |
232 | NNP_TOTAL_START(profile) |
233 | |
234 | /* Basic validation of parameters. This check detects invalid, but not unsupported parameters. */ |
235 | enum nnp_status status = validate_fully_connected_arguments(batch_size, input_channels, output_channels); |
236 | if (status != nnp_status_success) { |
237 | goto cleanup; |
238 | } |
239 | |
240 | const size_t cache_elements_l1 = nnp_hwinfo.blocking.l1 / sizeof(float); |
241 | const size_t cache_elements_l2 = nnp_hwinfo.blocking.l2 / sizeof(float); |
242 | const size_t cache_elements_l3 = nnp_hwinfo.blocking.l3 / sizeof(float); |
243 | |
244 | const size_t simd_width = nnp_hwinfo.simd_width; |
245 | const size_t batch_subblock_max = nnp_hwinfo.sgemm.mr; |
246 | const size_t output_channels_subblock_max = nnp_hwinfo.sgemm.nr; |
247 | |
248 | const size_t input_channels_block_max = cache_elements_l1 / (batch_subblock_max + output_channels_subblock_max); |
249 | const size_t batch_block_max = round_down(cache_elements_l3 / input_channels_block_max, batch_subblock_max); |
250 | const size_t output_channels_block_max = round_down(cache_elements_l2 / input_channels_block_max, output_channels_subblock_max); |
251 | |
252 | /* Calculate memory footprint and allocate memory */ |
253 | const size_t packed_input_size = round_up(batch_size, batch_subblock_max) * input_channels * sizeof(float); |
254 | /* Extra alignment on 64 is needed to ensure that packed_kernel is always SIMD-aligned */ |
255 | const size_t packed_kernel_offset = round_up(packed_input_size, 64); |
256 | const size_t packed_kernel_size = round_up(output_channels, output_channels_subblock_max) * input_channels_block_max * sizeof(float); |
257 | const size_t memory_size = packed_kernel_offset + packed_kernel_size; |
258 | |
259 | memory_block = allocate_memory(memory_size); |
260 | if (memory_block == NULL) { |
261 | status = nnp_status_out_of_memory; |
262 | goto cleanup; |
263 | } |
264 | |
265 | float* packed_input = memory_block; |
266 | float* packed_kernel = memory_block + packed_kernel_offset; |
267 | |
268 | /* Do the computation */ |
269 | compute_fully_connected_output( |
270 | simd_width, |
271 | batch_size, batch_block_max, batch_subblock_max, |
272 | input_channels, input_channels_block_max, |
273 | output_channels, output_channels_block_max, output_channels_subblock_max, |
274 | input, kernel, output, |
275 | packed_input, packed_kernel, |
276 | threadpool, |
277 | profile); |
278 | |
279 | cleanup: |
280 | release_memory(memory_block, memory_size); |
281 | NNP_TOTAL_END(profile) |
282 | return status; |
283 | } |
284 | |