1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example cnn_inference_f32.c
18/// @copybrief cnn_inference_f32_c
19
20/// @page cnn_inference_f32_c CNN f32 inference example
21/// This C API example demonstrates how to build an AlexNet neural
22/// network topology for forward-pass inference.
23///
24/// Some key take-aways include:
25///
26/// * How tensors are implemented and submitted to primitives.
27/// * How primitives are created.
28/// * How primitives are sequentially submitted to the network, where the
29/// output from primitives is passed as input to the next primitive.
30/// The latter specifies a dependency between the primitive input and output
31/// data.
32/// * Specific 'inference-only' configurations.
33/// * Limiting the number of reorders performed that are detrimental
34/// to performance.
35///
36/// The example implements the AlexNet layers
37/// as numbered primitives (for example, conv1, pool1, conv2).
38///
39/// @include cnn_inference_f32.c
40
41// Required for posix_memalign
42#define _POSIX_C_SOURCE 200112L
43
44#include <stdio.h>
45#include <stdlib.h>
46#include <string.h>
47
48#include "oneapi/dnnl/dnnl.h"
49
50#include "example_utils.h"
51
52#define BATCH 8
53#define IC 3
54#define OC 96
55#define CONV_IH 227
56#define CONV_IW 227
57#define CONV_OH 55
58#define CONV_OW 55
59#define CONV_STRIDE 4
60#define CONV_PAD 0
61#define POOL_OH 27
62#define POOL_OW 27
63#define POOL_STRIDE 2
64#define POOL_PAD 0
65
66static size_t product(dnnl_dim_t *arr, size_t size) {
67 size_t prod = 1;
68 for (size_t i = 0; i < size; ++i)
69 prod *= arr[i];
70 return prod;
71}
72
73static void init_net_data(float *data, uint32_t dim, const dnnl_dim_t *dims) {
74 if (dim == 1) {
75 for (dnnl_dim_t i = 0; i < dims[0]; ++i) {
76 data[i] = (float)(i % 1637);
77 }
78 } else if (dim == 4) {
79 for (dnnl_dim_t in = 0; in < dims[0]; ++in)
80 for (dnnl_dim_t ic = 0; ic < dims[1]; ++ic)
81 for (dnnl_dim_t ih = 0; ih < dims[2]; ++ih)
82 for (dnnl_dim_t iw = 0; iw < dims[3]; ++iw) {
83 dnnl_dim_t indx = in * dims[1] * dims[2] * dims[3]
84 + ic * dims[2] * dims[3] + ih * dims[3] + iw;
85 data[indx] = (float)(indx % 1637);
86 }
87 }
88}
89
90typedef struct {
91 int nargs;
92 dnnl_exec_arg_t *args;
93} args_t;
94
95static void prepare_arg_node(args_t *node, int nargs) {
96 node->args = (dnnl_exec_arg_t *)malloc(sizeof(dnnl_exec_arg_t) * nargs);
97 node->nargs = nargs;
98}
99static void free_arg_node(args_t *node) {
100 free(node->args);
101}
102
103static void set_arg(dnnl_exec_arg_t *arg, int arg_idx, dnnl_memory_t memory) {
104 arg->arg = arg_idx;
105 arg->memory = memory;
106}
107
108static void init_data_memory(uint32_t dim, const dnnl_dim_t *dims,
109 dnnl_format_tag_t user_tag, dnnl_engine_t engine, float *data,
110 dnnl_memory_t *memory) {
111 dnnl_memory_desc_t user_md;
112 CHECK(dnnl_memory_desc_create_with_tag(
113 &user_md, dim, dims, dnnl_f32, user_tag));
114 CHECK(dnnl_memory_create(memory, user_md, engine, DNNL_MEMORY_ALLOCATE));
115 CHECK(dnnl_memory_desc_destroy(user_md));
116 write_to_dnnl_memory(data, *memory);
117}
118
119dnnl_status_t prepare_reorder(dnnl_memory_t *user_memory, // in
120 const_dnnl_memory_desc_t prim_memory_md, // in
121 dnnl_engine_t prim_engine, // in: primitive's engine
122 int dir_is_user_to_prim, // in: user -> prim or prim -> user
123 dnnl_memory_t *prim_memory, // out: primitive's memory created
124 dnnl_primitive_t *reorder, // out: reorder primitive created
125 uint32_t *net_index, // primitive index in net (inc if reorder created)
126 dnnl_primitive_t *net, args_t *net_args) { // net params
127 const_dnnl_memory_desc_t user_memory_md;
128 dnnl_memory_get_memory_desc(*user_memory, &user_memory_md);
129
130 dnnl_engine_t user_mem_engine;
131 dnnl_memory_get_engine(*user_memory, &user_mem_engine);
132
133 if (!dnnl_memory_desc_equal(user_memory_md, prim_memory_md)) {
134 CHECK(dnnl_memory_create(prim_memory, prim_memory_md, prim_engine,
135 DNNL_MEMORY_ALLOCATE));
136
137 dnnl_primitive_desc_t reorder_pd;
138 if (dir_is_user_to_prim) {
139 CHECK(dnnl_reorder_primitive_desc_create(&reorder_pd,
140 user_memory_md, user_mem_engine, prim_memory_md,
141 prim_engine, NULL));
142 } else {
143 CHECK(dnnl_reorder_primitive_desc_create(&reorder_pd,
144 prim_memory_md, prim_engine, user_memory_md,
145 user_mem_engine, NULL));
146 }
147 CHECK(dnnl_primitive_create(reorder, reorder_pd));
148 CHECK(dnnl_primitive_desc_destroy(reorder_pd));
149
150 net[*net_index] = *reorder;
151 prepare_arg_node(&net_args[*net_index], 2);
152 set_arg(&net_args[*net_index].args[0], DNNL_ARG_FROM,
153 dir_is_user_to_prim ? *user_memory : *prim_memory);
154 set_arg(&net_args[*net_index].args[1], DNNL_ARG_TO,
155 dir_is_user_to_prim ? *prim_memory : *user_memory);
156 (*net_index)++;
157 } else {
158 *prim_memory = NULL;
159 *reorder = NULL;
160 }
161
162 return dnnl_success;
163}
164
165void simple_net(dnnl_engine_kind_t engine_kind) {
166 dnnl_engine_t engine;
167 CHECK(dnnl_engine_create(&engine, engine_kind, 0));
168
169 // build a simple net
170 uint32_t n = 0;
171 dnnl_primitive_t net[10];
172 args_t net_args[10];
173
174 const int ndims = 4;
175 dnnl_dims_t net_src_sizes = {BATCH, IC, CONV_IH, CONV_IW};
176 dnnl_dims_t net_dst_sizes = {BATCH, OC, POOL_OH, POOL_OW};
177
178 float *net_src
179 = (float *)malloc(product(net_src_sizes, ndims) * sizeof(float));
180 float *net_dst
181 = (float *)malloc(product(net_dst_sizes, ndims) * sizeof(float));
182
183 init_net_data(net_src, ndims, net_src_sizes);
184 memset(net_dst, 0, product(net_dst_sizes, ndims) * sizeof(float));
185
186 // AlexNet: conv
187 // {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, 11, 11} ->
188 // {BATCH, OC, CONV_OH, CONV_OW}
189 // strides: {CONV_STRIDE, CONV_STRIDE}
190 dnnl_dims_t conv_user_src_sizes;
191 for (int i = 0; i < ndims; i++)
192 conv_user_src_sizes[i] = net_src_sizes[i];
193 dnnl_dims_t conv_user_weights_sizes = {OC, IC, 11, 11};
194 dnnl_dims_t conv_bias_sizes = {OC};
195 dnnl_dims_t conv_user_dst_sizes = {BATCH, OC, CONV_OH, CONV_OW};
196 dnnl_dims_t conv_strides = {CONV_STRIDE, CONV_STRIDE};
197 dnnl_dims_t conv_dilation = {0, 0};
198 dnnl_dims_t conv_padding = {CONV_PAD, CONV_PAD};
199
200 float *conv_src = net_src;
201 float *conv_weights = (float *)malloc(
202 product(conv_user_weights_sizes, ndims) * sizeof(float));
203 float *conv_bias
204 = (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float));
205
206 init_net_data(conv_weights, ndims, conv_user_weights_sizes);
207 init_net_data(conv_bias, 1, conv_bias_sizes);
208
209 // create memory for user data
210 dnnl_memory_t conv_user_src_memory, conv_user_weights_memory,
211 conv_user_bias_memory;
212 init_data_memory(ndims, conv_user_src_sizes, dnnl_nchw, engine, conv_src,
213 &conv_user_src_memory);
214 init_data_memory(ndims, conv_user_weights_sizes, dnnl_oihw, engine,
215 conv_weights, &conv_user_weights_memory);
216 init_data_memory(1, conv_bias_sizes, dnnl_x, engine, conv_bias,
217 &conv_user_bias_memory);
218
219 // create data descriptors for convolution w/ no specified format
220
221 dnnl_memory_desc_t conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md;
222 CHECK(dnnl_memory_desc_create_with_tag(&conv_src_md, ndims,
223 conv_user_src_sizes, dnnl_f32, dnnl_format_tag_any));
224 CHECK(dnnl_memory_desc_create_with_tag(&conv_weights_md, ndims,
225 conv_user_weights_sizes, dnnl_f32, dnnl_format_tag_any));
226 CHECK(dnnl_memory_desc_create_with_tag(
227 &conv_bias_md, 1, conv_bias_sizes, dnnl_f32, dnnl_x));
228 CHECK(dnnl_memory_desc_create_with_tag(&conv_dst_md, ndims,
229 conv_user_dst_sizes, dnnl_f32, dnnl_format_tag_any));
230
231 // create a convolution
232 dnnl_primitive_desc_t conv_pd;
233 CHECK(dnnl_convolution_forward_primitive_desc_create(&conv_pd, engine,
234 dnnl_forward, dnnl_convolution_direct, conv_src_md, conv_weights_md,
235 conv_bias_md, conv_dst_md, conv_strides, conv_dilation,
236 conv_padding, conv_padding, NULL));
237
238 dnnl_memory_t conv_internal_src_memory, conv_internal_weights_memory,
239 conv_internal_dst_memory;
240
241 // create memory for dst data, we don't need reorder it to user data
242 const_dnnl_memory_desc_t dst_md
243 = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0);
244 CHECK(dnnl_memory_create(
245 &conv_internal_dst_memory, dst_md, engine, DNNL_MEMORY_ALLOCATE));
246
247 // create reorder primitives between user data and convolution srcs
248 // if required
249 dnnl_primitive_t conv_reorder_src, conv_reorder_weights;
250
251 const_dnnl_memory_desc_t src_md
252 = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_src_md, 0);
253 CHECK(prepare_reorder(&conv_user_src_memory, src_md, engine, 1,
254 &conv_internal_src_memory, &conv_reorder_src, &n, net, net_args));
255
256 const_dnnl_memory_desc_t weights_md
257 = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_weights_md, 0);
258 CHECK(prepare_reorder(&conv_user_weights_memory, weights_md, engine, 1,
259 &conv_internal_weights_memory, &conv_reorder_weights, &n, net,
260 net_args));
261
262 dnnl_memory_t conv_src_memory = conv_internal_src_memory
263 ? conv_internal_src_memory
264 : conv_user_src_memory;
265 dnnl_memory_t conv_weights_memory = conv_internal_weights_memory
266 ? conv_internal_weights_memory
267 : conv_user_weights_memory;
268
269 // finally create a convolution primitive
270 dnnl_primitive_t conv;
271 CHECK(dnnl_primitive_create(&conv, conv_pd));
272 net[n] = conv;
273 prepare_arg_node(&net_args[n], 4);
274 set_arg(&net_args[n].args[0], DNNL_ARG_SRC, conv_src_memory);
275 set_arg(&net_args[n].args[1], DNNL_ARG_WEIGHTS, conv_weights_memory);
276 set_arg(&net_args[n].args[2], DNNL_ARG_BIAS, conv_user_bias_memory);
277 set_arg(&net_args[n].args[3], DNNL_ARG_DST, conv_internal_dst_memory);
278 n++;
279
280 // AlexNet: relu
281 // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
282 float negative_slope = 0.0f;
283
284 // create relu memory descriptor on dst memory descriptor
285 // from previous primitive
286 const_dnnl_memory_desc_t relu_src_md
287 = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0);
288 const_dnnl_memory_desc_t relu_dst_md = relu_src_md;
289
290 // create a relu
291 dnnl_primitive_desc_t relu_pd;
292 CHECK(dnnl_eltwise_forward_primitive_desc_create(&relu_pd, engine,
293 dnnl_forward, dnnl_eltwise_relu, relu_src_md, relu_dst_md,
294 negative_slope, 0, NULL));
295
296 dnnl_memory_t relu_dst_memory;
297 CHECK(dnnl_memory_create(
298 &relu_dst_memory, relu_dst_md, engine, DNNL_MEMORY_ALLOCATE));
299
300 // finally create a relu primitive
301 dnnl_primitive_t relu;
302 CHECK(dnnl_primitive_create(&relu, relu_pd));
303 net[n] = relu;
304 prepare_arg_node(&net_args[n], 2);
305 set_arg(&net_args[n].args[0], DNNL_ARG_SRC, conv_internal_dst_memory);
306 set_arg(&net_args[n].args[1], DNNL_ARG_DST, relu_dst_memory);
307 n++;
308
309 // AlexNet: lrn
310 // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
311 // local size: 5
312 // alpha: 0.0001
313 // beta: 0.75
314 // k: 1.0
315 uint32_t local_size = 5;
316 float alpha = 0.0001f;
317 float beta = 0.75f;
318 float k = 1.0f;
319
320 // create lrn src memory descriptor using dst memory descriptor
321 // from previous primitive
322 const_dnnl_memory_desc_t lrn_src_md = relu_dst_md;
323 const_dnnl_memory_desc_t lrn_dst_md = lrn_src_md;
324
325 // create a lrn primitive descriptor
326 dnnl_primitive_desc_t lrn_pd;
327 CHECK(dnnl_lrn_forward_primitive_desc_create(&lrn_pd, engine, dnnl_forward,
328 dnnl_lrn_across_channels, lrn_src_md, lrn_dst_md, local_size, alpha,
329 beta, k, NULL));
330
331 // create primitives for lrn dst and workspace memory
332 dnnl_memory_t lrn_dst_memory;
333 CHECK(dnnl_memory_create(
334 &lrn_dst_memory, lrn_dst_md, engine, DNNL_MEMORY_ALLOCATE));
335 dnnl_memory_t lrn_ws_memory;
336 const_dnnl_memory_desc_t lrn_ws_md
337 = dnnl_primitive_desc_query_md(lrn_pd, dnnl_query_workspace_md, 0);
338 CHECK(dnnl_memory_create(
339 &lrn_ws_memory, lrn_ws_md, engine, DNNL_MEMORY_ALLOCATE));
340
341 // finally create a lrn primitive
342 dnnl_primitive_t lrn;
343 CHECK(dnnl_primitive_create(&lrn, lrn_pd));
344 net[n] = lrn;
345 prepare_arg_node(&net_args[n], 3);
346 set_arg(&net_args[n].args[0], DNNL_ARG_SRC, relu_dst_memory);
347 set_arg(&net_args[n].args[1], DNNL_ARG_DST, lrn_dst_memory);
348 set_arg(&net_args[n].args[2], DNNL_ARG_WORKSPACE, lrn_ws_memory);
349 n++;
350
351 // AlexNet: pool
352 // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW}
353 // kernel: {3, 3}
354 // strides: {POOL_STRIDE, POOL_STRIDE}
355 // dilation: {0, 0}
356 dnnl_dims_t pool_dst_sizes;
357 for (int i = 0; i < ndims; i++)
358 pool_dst_sizes[i] = net_dst_sizes[i];
359 dnnl_dims_t pool_kernel = {3, 3};
360 dnnl_dims_t pool_strides = {POOL_STRIDE, POOL_STRIDE};
361 dnnl_dims_t pool_padding = {POOL_PAD, POOL_PAD};
362 dnnl_dims_t pool_dilation = {0, 0};
363
364 // create pooling memory descriptor on dst descriptor
365 // from previous primitive
366 const_dnnl_memory_desc_t pool_src_md = lrn_dst_md;
367
368 // create descriptors for dst pooling data
369 dnnl_memory_desc_t pool_dst_any_md;
370 CHECK(dnnl_memory_desc_create_with_tag(&pool_dst_any_md, ndims,
371 pool_dst_sizes, dnnl_f32, dnnl_format_tag_any));
372
373 // create memory for user data
374 dnnl_memory_t pool_user_dst_memory;
375 init_data_memory(ndims, pool_dst_sizes, dnnl_nchw, engine, net_dst,
376 &pool_user_dst_memory);
377
378 // create a pooling
379 dnnl_primitive_desc_t pool_pd;
380 CHECK(dnnl_pooling_forward_primitive_desc_create(&pool_pd, engine,
381 dnnl_forward, dnnl_pooling_max, pool_src_md, pool_dst_any_md,
382 pool_strides, pool_kernel, pool_dilation, pool_padding,
383 pool_padding, NULL));
384
385 // create memory for workspace
386 dnnl_memory_t pool_ws_memory;
387 const_dnnl_memory_desc_t pool_ws_md
388 = dnnl_primitive_desc_query_md(pool_pd, dnnl_query_workspace_md, 0);
389 CHECK(dnnl_memory_create(
390 &pool_ws_memory, pool_ws_md, engine, DNNL_MEMORY_ALLOCATE));
391
392 dnnl_memory_t pool_dst_memory;
393
394 // create reorder primitives between user data and pooling dsts
395 // if required
396 dnnl_primitive_t pool_reorder_dst;
397 dnnl_memory_t pool_internal_dst_memory;
398 const_dnnl_memory_desc_t pool_dst_md
399 = dnnl_primitive_desc_query_md(pool_pd, dnnl_query_dst_md, 0);
400 n += 1; // tentative workaround: preserve space for pooling that should
401 // happen before the reorder
402 CHECK(prepare_reorder(&pool_user_dst_memory, pool_dst_md, engine, 0,
403 &pool_internal_dst_memory, &pool_reorder_dst, &n, net, net_args));
404 n -= pool_reorder_dst ? 2 : 1;
405
406 pool_dst_memory = pool_internal_dst_memory ? pool_internal_dst_memory
407 : pool_user_dst_memory;
408
409 // finally create a pooling primitive
410 dnnl_primitive_t pool;
411 CHECK(dnnl_primitive_create(&pool, pool_pd));
412 net[n] = pool;
413 prepare_arg_node(&net_args[n], 3);
414 set_arg(&net_args[n].args[0], DNNL_ARG_SRC, lrn_dst_memory);
415 set_arg(&net_args[n].args[1], DNNL_ARG_DST, pool_dst_memory);
416 set_arg(&net_args[n].args[2], DNNL_ARG_WORKSPACE, pool_ws_memory);
417 n++;
418
419 if (pool_reorder_dst) n += 1;
420
421 dnnl_stream_t stream;
422 CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags));
423 for (uint32_t i = 0; i < n; ++i) {
424 CHECK(dnnl_primitive_execute(
425 net[i], stream, net_args[i].nargs, net_args[i].args));
426 }
427
428 CHECK(dnnl_stream_wait(stream));
429
430 // clean-up
431 for (uint32_t i = 0; i < n; ++i)
432 free_arg_node(&net_args[i]);
433
434 CHECK(dnnl_primitive_desc_destroy(conv_pd));
435 CHECK(dnnl_primitive_desc_destroy(relu_pd));
436 CHECK(dnnl_primitive_desc_destroy(lrn_pd));
437 CHECK(dnnl_primitive_desc_destroy(pool_pd));
438
439 dnnl_stream_destroy(stream);
440
441 free(net_src);
442 free(net_dst);
443
444 dnnl_memory_desc_destroy(conv_src_md);
445 dnnl_memory_desc_destroy(conv_weights_md);
446 dnnl_memory_desc_destroy(conv_bias_md);
447 dnnl_memory_desc_destroy(conv_dst_md);
448 dnnl_memory_desc_destroy(pool_dst_any_md);
449
450 dnnl_memory_destroy(conv_user_src_memory);
451 dnnl_memory_destroy(conv_user_weights_memory);
452 dnnl_memory_destroy(conv_user_bias_memory);
453 dnnl_memory_destroy(conv_internal_src_memory);
454 dnnl_memory_destroy(conv_internal_weights_memory);
455 dnnl_memory_destroy(conv_internal_dst_memory);
456 dnnl_primitive_destroy(conv_reorder_src);
457 dnnl_primitive_destroy(conv_reorder_weights);
458 dnnl_primitive_destroy(conv);
459
460 free(conv_weights);
461 free(conv_bias);
462
463 dnnl_memory_destroy(relu_dst_memory);
464 dnnl_primitive_destroy(relu);
465
466 dnnl_memory_destroy(lrn_ws_memory);
467 dnnl_memory_destroy(lrn_dst_memory);
468 dnnl_primitive_destroy(lrn);
469
470 dnnl_memory_destroy(pool_user_dst_memory);
471 dnnl_memory_destroy(pool_internal_dst_memory);
472 dnnl_memory_destroy(pool_ws_memory);
473 dnnl_primitive_destroy(pool_reorder_dst);
474 dnnl_primitive_destroy(pool);
475
476 dnnl_engine_destroy(engine);
477}
478
479int main(int argc, char **argv) {
480 dnnl_engine_kind_t engine_kind = parse_engine_kind(argc, argv);
481 simple_net(engine_kind);
482 printf("Example passed on %s.\n", engine_kind2str_upper(engine_kind));
483 return 0;
484}
485