1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example cpu_cnn_training_f32.c
18/// @copybrief cpu_cnn_training_f32_c
19
20/// @page cpu_cnn_training_f32_c CNN f32 training example
21/// This C API example demonstrates how to build an AlexNet model training.
22/// The example implements a few layers from AlexNet model.
23///
24/// @include cpu_cnn_training_f32.c
25
26// Required for posix_memalign
27#define _POSIX_C_SOURCE 200112L
28
29#include <stdio.h>
30#include <stdlib.h>
31#include <string.h>
32
33#include "oneapi/dnnl/dnnl.h"
34
35#include "example_utils.h"
36
37#define BATCH 8
38#define IC 3
39#define OC 96
40#define CONV_IH 227
41#define CONV_IW 227
42#define CONV_OH 55
43#define CONV_OW 55
44#define CONV_STRIDE 4
45#define CONV_PAD 0
46#define POOL_OH 27
47#define POOL_OW 27
48#define POOL_STRIDE 2
49#define POOL_PAD 0
50
51static size_t product(dnnl_dim_t *arr, size_t size) {
52 size_t prod = 1;
53 for (size_t i = 0; i < size; ++i)
54 prod *= arr[i];
55 return prod;
56}
57
58static void init_net_data(float *data, uint32_t dim, const dnnl_dim_t *dims) {
59 if (dim == 1) {
60 for (dnnl_dim_t i = 0; i < dims[0]; ++i) {
61 data[i] = (float)(i % 1637);
62 }
63 } else if (dim == 4) {
64 for (dnnl_dim_t in = 0; in < dims[0]; ++in)
65 for (dnnl_dim_t ic = 0; ic < dims[1]; ++ic)
66 for (dnnl_dim_t ih = 0; ih < dims[2]; ++ih)
67 for (dnnl_dim_t iw = 0; iw < dims[3]; ++iw) {
68 dnnl_dim_t indx = in * dims[1] * dims[2] * dims[3]
69 + ic * dims[2] * dims[3] + ih * dims[3] + iw;
70 data[indx] = (float)(indx % 1637);
71 }
72 }
73}
74
75typedef struct {
76 int nargs;
77 dnnl_exec_arg_t *args;
78} args_t;
79
80static void prepare_arg_node(args_t *node, int nargs) {
81 node->args = (dnnl_exec_arg_t *)malloc(sizeof(dnnl_exec_arg_t) * nargs);
82 node->nargs = nargs;
83}
84static void free_arg_node(args_t *node) {
85 free(node->args);
86}
87
88static void set_arg(dnnl_exec_arg_t *arg, int arg_idx, dnnl_memory_t memory) {
89 arg->arg = arg_idx;
90 arg->memory = memory;
91}
92
93static void init_data_memory(uint32_t dim, const dnnl_dim_t *dims,
94 dnnl_format_tag_t user_tag, dnnl_engine_t engine, float *data,
95 dnnl_memory_t *memory) {
96 dnnl_memory_desc_t user_md;
97 CHECK(dnnl_memory_desc_create_with_tag(
98 &user_md, dim, dims, dnnl_f32, user_tag));
99 CHECK(dnnl_memory_create(memory, user_md, engine, DNNL_MEMORY_ALLOCATE));
100 CHECK(dnnl_memory_desc_destroy(user_md));
101 write_to_dnnl_memory(data, *memory);
102}
103
104dnnl_status_t prepare_reorder(dnnl_memory_t *user_memory, // in
105 const_dnnl_memory_desc_t prim_memory_md, // in
106 dnnl_engine_t prim_engine, // in: primitive's engine
107 int dir_is_user_to_prim, // in: user -> prim or prim -> user
108 dnnl_memory_t *prim_memory, // out: primitive's memory created
109 dnnl_primitive_t *reorder, // out: reorder primitive created
110 uint32_t *net_index, // primitive index in net (inc if reorder created)
111 dnnl_primitive_t *net, args_t *net_args) { // net params
112 const_dnnl_memory_desc_t user_memory_md;
113 dnnl_memory_get_memory_desc(*user_memory, &user_memory_md);
114
115 dnnl_engine_t user_mem_engine;
116 dnnl_memory_get_engine(*user_memory, &user_mem_engine);
117
118 if (!dnnl_memory_desc_equal(user_memory_md, prim_memory_md)) {
119 CHECK(dnnl_memory_create(prim_memory, prim_memory_md, prim_engine,
120 DNNL_MEMORY_ALLOCATE));
121
122 dnnl_primitive_desc_t reorder_pd;
123 if (dir_is_user_to_prim) {
124 CHECK(dnnl_reorder_primitive_desc_create(&reorder_pd,
125 user_memory_md, user_mem_engine, prim_memory_md,
126 prim_engine, NULL));
127 } else {
128 CHECK(dnnl_reorder_primitive_desc_create(&reorder_pd,
129 prim_memory_md, prim_engine, user_memory_md,
130 user_mem_engine, NULL));
131 }
132 CHECK(dnnl_primitive_create(reorder, reorder_pd));
133 CHECK(dnnl_primitive_desc_destroy(reorder_pd));
134
135 net[*net_index] = *reorder;
136 prepare_arg_node(&net_args[*net_index], 2);
137 set_arg(&net_args[*net_index].args[0], DNNL_ARG_FROM,
138 dir_is_user_to_prim ? *user_memory : *prim_memory);
139 set_arg(&net_args[*net_index].args[1], DNNL_ARG_TO,
140 dir_is_user_to_prim ? *prim_memory : *user_memory);
141 (*net_index)++;
142 } else {
143 *prim_memory = NULL;
144 *reorder = NULL;
145 }
146
147 return dnnl_success;
148}
149
150void simple_net() {
151 dnnl_engine_t engine;
152 CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); // idx
153
154 // build a simple net
155 uint32_t n_fwd = 0, n_bwd = 0;
156 dnnl_primitive_t net_fwd[10], net_bwd[10];
157 args_t net_fwd_args[10], net_bwd_args[10];
158
159 const int ndims = 4;
160 dnnl_dims_t net_src_sizes = {BATCH, IC, CONV_IH, CONV_IW};
161 dnnl_dims_t net_dst_sizes = {BATCH, OC, POOL_OH, POOL_OW};
162
163 float *net_src
164 = (float *)malloc(product(net_src_sizes, ndims) * sizeof(float));
165 float *net_dst
166 = (float *)malloc(product(net_dst_sizes, ndims) * sizeof(float));
167
168 init_net_data(net_src, ndims, net_src_sizes);
169 memset(net_dst, 0, product(net_dst_sizes, ndims) * sizeof(float));
170
171 //----------------------------------------------------------------------
172 //----------------- Forward Stream -------------------------------------
173 // AlexNet: conv
174 // {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, 11, 11} ->
175 // {BATCH, OC, CONV_OH, CONV_OW}
176 // strides: {CONV_STRIDE, CONV_STRIDE}
177 dnnl_dims_t conv_user_src_sizes;
178 for (int i = 0; i < ndims; i++)
179 conv_user_src_sizes[i] = net_src_sizes[i];
180 dnnl_dims_t conv_user_weights_sizes = {OC, IC, 11, 11};
181 dnnl_dims_t conv_bias_sizes = {OC};
182 dnnl_dims_t conv_user_dst_sizes = {BATCH, OC, CONV_OH, CONV_OW};
183 dnnl_dims_t conv_strides = {CONV_STRIDE, CONV_STRIDE};
184 dnnl_dims_t conv_dilation = {0, 0};
185 dnnl_dims_t conv_padding = {CONV_PAD, CONV_PAD};
186
187 float *conv_src = net_src;
188 float *conv_weights = (float *)malloc(
189 product(conv_user_weights_sizes, ndims) * sizeof(float));
190 float *conv_bias
191 = (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float));
192
193 init_net_data(conv_weights, ndims, conv_user_weights_sizes);
194 init_net_data(conv_bias, 1, conv_bias_sizes);
195
196 // create memory for user data
197 dnnl_memory_t conv_user_src_memory, conv_user_weights_memory,
198 conv_user_bias_memory;
199 init_data_memory(ndims, conv_user_src_sizes, dnnl_nchw, engine, conv_src,
200 &conv_user_src_memory);
201 init_data_memory(ndims, conv_user_weights_sizes, dnnl_oihw, engine,
202 conv_weights, &conv_user_weights_memory);
203 init_data_memory(1, conv_bias_sizes, dnnl_x, engine, conv_bias,
204 &conv_user_bias_memory);
205
206 // create a convolution
207 dnnl_primitive_desc_t conv_pd;
208
209 {
210 // create data descriptors for convolution w/ no specified format
211 dnnl_memory_desc_t conv_src_md, conv_weights_md, conv_bias_md,
212 conv_dst_md;
213 CHECK(dnnl_memory_desc_create_with_tag(&conv_src_md, ndims,
214 conv_user_src_sizes, dnnl_f32, dnnl_format_tag_any));
215 CHECK(dnnl_memory_desc_create_with_tag(&conv_weights_md, ndims,
216 conv_user_weights_sizes, dnnl_f32, dnnl_format_tag_any));
217 CHECK(dnnl_memory_desc_create_with_tag(
218 &conv_bias_md, 1, conv_bias_sizes, dnnl_f32, dnnl_x));
219 CHECK(dnnl_memory_desc_create_with_tag(&conv_dst_md, ndims,
220 conv_user_dst_sizes, dnnl_f32, dnnl_format_tag_any));
221
222 CHECK(dnnl_convolution_forward_primitive_desc_create(&conv_pd, engine,
223 dnnl_forward, dnnl_convolution_direct, conv_src_md,
224 conv_weights_md, conv_bias_md, conv_dst_md, conv_strides,
225 conv_dilation, conv_padding, conv_padding, NULL));
226
227 CHECK(dnnl_memory_desc_destroy(conv_src_md));
228 CHECK(dnnl_memory_desc_destroy(conv_weights_md));
229 CHECK(dnnl_memory_desc_destroy(conv_bias_md));
230 CHECK(dnnl_memory_desc_destroy(conv_dst_md));
231 }
232
233 dnnl_memory_t conv_internal_src_memory, conv_internal_weights_memory,
234 conv_internal_dst_memory;
235
236 // create memory for dst data, we don't need to reorder it to user data
237 const_dnnl_memory_desc_t conv_dst_md
238 = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0);
239 CHECK(dnnl_memory_create(&conv_internal_dst_memory, conv_dst_md, engine,
240 DNNL_MEMORY_ALLOCATE));
241
242 // create reorder primitives between user data and convolution srcs
243 // if required
244 dnnl_primitive_t conv_reorder_src, conv_reorder_weights;
245
246 const_dnnl_memory_desc_t conv_src_md
247 = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_src_md, 0);
248 CHECK(prepare_reorder(&conv_user_src_memory, conv_src_md, engine, 1,
249 &conv_internal_src_memory, &conv_reorder_src, &n_fwd, net_fwd,
250 net_fwd_args));
251
252 const_dnnl_memory_desc_t conv_weights_md
253 = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_weights_md, 0);
254 CHECK(prepare_reorder(&conv_user_weights_memory, conv_weights_md, engine, 1,
255 &conv_internal_weights_memory, &conv_reorder_weights, &n_fwd,
256 net_fwd, net_fwd_args));
257
258 dnnl_memory_t conv_src_memory = conv_internal_src_memory
259 ? conv_internal_src_memory
260 : conv_user_src_memory;
261 dnnl_memory_t conv_weights_memory = conv_internal_weights_memory
262 ? conv_internal_weights_memory
263 : conv_user_weights_memory;
264
265 // finally create a convolution primitive
266 dnnl_primitive_t conv;
267 CHECK(dnnl_primitive_create(&conv, conv_pd));
268 net_fwd[n_fwd] = conv;
269 prepare_arg_node(&net_fwd_args[n_fwd], 4);
270 set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, conv_src_memory);
271 set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_WEIGHTS,
272 conv_weights_memory);
273 set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_BIAS, conv_user_bias_memory);
274 set_arg(&net_fwd_args[n_fwd].args[3], DNNL_ARG_DST,
275 conv_internal_dst_memory);
276 n_fwd++;
277
278 // AlexNet: relu
279 // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
280
281 float negative_slope = 0.0f;
282
283 // keep memory format of source same as the format of convolution
284 // output in order to avoid reorder
285 const_dnnl_memory_desc_t relu_src_md = conv_dst_md;
286 const_dnnl_memory_desc_t relu_dst_md = relu_src_md;
287
288 // create a relu primitive descriptor
289 dnnl_primitive_desc_t relu_pd;
290 CHECK(dnnl_eltwise_forward_primitive_desc_create(&relu_pd, engine,
291 dnnl_forward, dnnl_eltwise_relu, relu_src_md, relu_dst_md,
292 negative_slope, 0, NULL));
293
294 // create relu dst memory
295 dnnl_memory_t relu_dst_memory;
296 CHECK(dnnl_memory_create(
297 &relu_dst_memory, relu_dst_md, engine, DNNL_MEMORY_ALLOCATE));
298
299 // finally create a relu primitive
300 dnnl_primitive_t relu;
301 CHECK(dnnl_primitive_create(&relu, relu_pd));
302 net_fwd[n_fwd] = relu;
303 prepare_arg_node(&net_fwd_args[n_fwd], 2);
304 set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC,
305 conv_internal_dst_memory);
306 set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, relu_dst_memory);
307 n_fwd++;
308
309 // AlexNet: lrn
310 // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
311 // local size: 5
312 // alpha: 0.0001
313 // beta: 0.75
314 // k: 1.0
315 uint32_t local_size = 5;
316 float alpha = 0.0001f;
317 float beta = 0.75f;
318 float k = 1.0f;
319
320 // create lrn src memory descriptor using dst memory descriptor
321 // from previous primitive
322 const_dnnl_memory_desc_t lrn_src_md = relu_dst_md;
323 const_dnnl_memory_desc_t lrn_dst_md = lrn_src_md;
324
325 // create a lrn primitive descriptor
326 dnnl_primitive_desc_t lrn_pd;
327 CHECK(dnnl_lrn_forward_primitive_desc_create(&lrn_pd, engine, dnnl_forward,
328 dnnl_lrn_across_channels, lrn_src_md, lrn_dst_md, local_size, alpha,
329 beta, k, NULL));
330
331 // create primitives for lrn dst and workspace memory
332 dnnl_memory_t lrn_dst_memory, lrn_ws_memory;
333
334 CHECK(dnnl_memory_create(
335 &lrn_dst_memory, lrn_dst_md, engine, DNNL_MEMORY_ALLOCATE));
336
337 // create workspace only in training and only for forward primitive
338 // query lrn_pd for workspace, this memory will be shared with forward lrn
339 const_dnnl_memory_desc_t lrn_ws_md
340 = dnnl_primitive_desc_query_md(lrn_pd, dnnl_query_workspace_md, 0);
341 CHECK(dnnl_memory_create(
342 &lrn_ws_memory, lrn_ws_md, engine, DNNL_MEMORY_ALLOCATE));
343
344 // finally create a lrn primitive
345 dnnl_primitive_t lrn;
346 CHECK(dnnl_primitive_create(&lrn, lrn_pd));
347 net_fwd[n_fwd] = lrn;
348 prepare_arg_node(&net_fwd_args[n_fwd], 3);
349 set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, relu_dst_memory);
350 set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, lrn_dst_memory);
351 set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_WORKSPACE, lrn_ws_memory);
352 n_fwd++;
353
354 // AlexNet: pool
355 // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW}
356 // kernel: {3, 3}
357 // strides: {POOL_STRIDE, POOL_STRIDE}
358 // dilation: {0, 0}
359 dnnl_dims_t pool_dst_sizes;
360 for (int i = 0; i < ndims; i++)
361 pool_dst_sizes[i] = net_dst_sizes[i];
362 dnnl_dims_t pool_kernel = {3, 3};
363 dnnl_dims_t pool_strides = {POOL_STRIDE, POOL_STRIDE};
364 dnnl_dims_t pool_padding = {POOL_PAD, POOL_PAD};
365 dnnl_dims_t pool_dilation = {0, 0};
366
367 // create memory for user dst data
368 dnnl_memory_t pool_user_dst_memory;
369 init_data_memory(4, pool_dst_sizes, dnnl_nchw, engine, net_dst,
370 &pool_user_dst_memory);
371
372 // create a pooling primitive descriptor
373 dnnl_primitive_desc_t pool_pd;
374
375 {
376 // create pooling src memory descriptor using dst descriptor
377 // from previous primitive
378 const_dnnl_memory_desc_t pool_src_md = lrn_dst_md;
379
380 // create descriptors for dst pooling data
381 dnnl_memory_desc_t pool_dst_md;
382 CHECK(dnnl_memory_desc_create_with_tag(&pool_dst_md, 4, pool_dst_sizes,
383 dnnl_f32, dnnl_format_tag_any));
384
385 CHECK(dnnl_pooling_forward_primitive_desc_create(&pool_pd, engine,
386 dnnl_forward, dnnl_pooling_max, pool_src_md, pool_dst_md,
387 pool_strides, pool_kernel, pool_dilation, pool_padding,
388 pool_padding, NULL));
389 CHECK(dnnl_memory_desc_destroy(pool_dst_md));
390 }
391
392 // create memory for workspace
393 dnnl_memory_t pool_ws_memory;
394 const_dnnl_memory_desc_t pool_ws_md
395 = dnnl_primitive_desc_query_md(pool_pd, dnnl_query_workspace_md, 0);
396 CHECK(dnnl_memory_create(
397 &pool_ws_memory, pool_ws_md, engine, DNNL_MEMORY_ALLOCATE));
398
399 // create reorder primitives between pooling dsts and user format dst
400 // if required
401 dnnl_primitive_t pool_reorder_dst;
402 dnnl_memory_t pool_internal_dst_memory;
403 const_dnnl_memory_desc_t pool_dst_md
404 = dnnl_primitive_desc_query_md(pool_pd, dnnl_query_dst_md, 0);
405 n_fwd += 1; // tentative workaround: preserve space for pooling that should
406 // happen before the reorder
407 CHECK(prepare_reorder(&pool_user_dst_memory, pool_dst_md, engine, 0,
408 &pool_internal_dst_memory, &pool_reorder_dst, &n_fwd, net_fwd,
409 net_fwd_args));
410 n_fwd -= pool_reorder_dst ? 2 : 1;
411
412 dnnl_memory_t pool_dst_memory = pool_internal_dst_memory
413 ? pool_internal_dst_memory
414 : pool_user_dst_memory;
415
416 // finally create a pooling primitive
417 dnnl_primitive_t pool;
418 CHECK(dnnl_primitive_create(&pool, pool_pd));
419 net_fwd[n_fwd] = pool;
420 prepare_arg_node(&net_fwd_args[n_fwd], 3);
421 set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, lrn_dst_memory);
422 set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, pool_dst_memory);
423 set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_WORKSPACE, pool_ws_memory);
424 n_fwd++;
425
426 if (pool_reorder_dst) n_fwd += 1;
427
428 //-----------------------------------------------------------------------
429 //----------------- Backward Stream -------------------------------------
430 //-----------------------------------------------------------------------
431
432 // ... user diff_data ...
433 float *net_diff_dst
434 = (float *)malloc(product(pool_dst_sizes, 4) * sizeof(float));
435
436 init_net_data(net_diff_dst, 4, pool_dst_sizes);
437
438 // create memory for user diff dst data
439 dnnl_memory_t pool_user_diff_dst_memory;
440 init_data_memory(4, pool_dst_sizes, dnnl_nchw, engine, net_diff_dst,
441 &pool_user_diff_dst_memory);
442
443 // Pooling Backward
444 // pooling diff src memory descriptor
445 const_dnnl_memory_desc_t pool_diff_src_md = lrn_dst_md;
446
447 // pooling diff dst memory descriptor
448 const_dnnl_memory_desc_t pool_diff_dst_md = pool_dst_md;
449
450 // backward primitive descriptor needs to hint forward descriptor
451 dnnl_primitive_desc_t pool_bwd_pd;
452 CHECK(dnnl_pooling_backward_primitive_desc_create(&pool_bwd_pd, engine,
453 dnnl_pooling_max, pool_diff_src_md, pool_diff_dst_md, pool_strides,
454 pool_kernel, pool_dilation, pool_padding, pool_padding, pool_pd,
455 NULL));
456
457 // create reorder primitive between user diff dst and pool diff dst
458 // if required
459 dnnl_memory_t pool_diff_dst_memory, pool_internal_diff_dst_memory;
460 dnnl_primitive_t pool_reorder_diff_dst;
461 CHECK(prepare_reorder(&pool_user_diff_dst_memory, pool_diff_dst_md, engine,
462 1, &pool_internal_diff_dst_memory, &pool_reorder_diff_dst, &n_bwd,
463 net_bwd, net_bwd_args));
464
465 pool_diff_dst_memory = pool_internal_diff_dst_memory
466 ? pool_internal_diff_dst_memory
467 : pool_user_diff_dst_memory;
468
469 // create memory for pool diff src data
470 dnnl_memory_t pool_diff_src_memory;
471 CHECK(dnnl_memory_create(&pool_diff_src_memory, pool_diff_src_md, engine,
472 DNNL_MEMORY_ALLOCATE));
473
474 // finally create backward pooling primitive
475 dnnl_primitive_t pool_bwd;
476 CHECK(dnnl_primitive_create(&pool_bwd, pool_bwd_pd));
477 net_bwd[n_bwd] = pool_bwd;
478 prepare_arg_node(&net_bwd_args[n_bwd], 3);
479 set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_DIFF_DST,
480 pool_diff_dst_memory);
481 set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_WORKSPACE, pool_ws_memory);
482 set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_SRC,
483 pool_diff_src_memory);
484 n_bwd++;
485
486 // Backward lrn
487 const_dnnl_memory_desc_t lrn_diff_dst_md = pool_diff_src_md;
488 const_dnnl_memory_desc_t lrn_diff_src_md = lrn_diff_dst_md;
489
490 // create backward lrn descriptor
491 dnnl_primitive_desc_t lrn_bwd_pd;
492 CHECK(dnnl_lrn_backward_primitive_desc_create(&lrn_bwd_pd, engine,
493 dnnl_lrn_across_channels, lrn_diff_src_md, lrn_diff_dst_md,
494 lrn_src_md, local_size, alpha, beta, k, lrn_pd, NULL));
495
496 // create memory for lrn diff src
497 dnnl_memory_t lrn_diff_src_memory;
498 CHECK(dnnl_memory_create(&lrn_diff_src_memory, lrn_diff_src_md, engine,
499 DNNL_MEMORY_ALLOCATE));
500
501 // finally create backward lrn primitive
502 dnnl_primitive_t lrn_bwd;
503 CHECK(dnnl_primitive_create(&lrn_bwd, lrn_bwd_pd));
504 net_bwd[n_bwd] = lrn_bwd;
505 prepare_arg_node(&net_bwd_args[n_bwd], 4);
506 set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC, relu_dst_memory);
507 set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST,
508 pool_diff_src_memory);
509 set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_WORKSPACE, lrn_ws_memory);
510 set_arg(&net_bwd_args[n_bwd].args[3], DNNL_ARG_DIFF_SRC,
511 lrn_diff_src_memory);
512 n_bwd++;
513
514 // Backward relu
515 const_dnnl_memory_desc_t relu_diff_src_md = lrn_diff_src_md;
516 const_dnnl_memory_desc_t relu_diff_dst_md = lrn_diff_src_md;
517
518 // create backward relu descriptor
519 dnnl_primitive_desc_t relu_bwd_pd;
520 CHECK(dnnl_eltwise_backward_primitive_desc_create(&relu_bwd_pd, engine,
521 dnnl_eltwise_relu, relu_diff_src_md, relu_diff_dst_md, relu_src_md,
522 negative_slope, 0, relu_pd, NULL));
523
524 // create memory for relu diff src
525 dnnl_memory_t relu_diff_src_memory;
526 CHECK(dnnl_memory_create(&relu_diff_src_memory, relu_diff_src_md, engine,
527 DNNL_MEMORY_ALLOCATE));
528
529 // finally create backward relu primitive
530 dnnl_primitive_t relu_bwd;
531 CHECK(dnnl_primitive_create(&relu_bwd, relu_bwd_pd));
532 net_bwd[n_bwd] = relu_bwd;
533 prepare_arg_node(&net_bwd_args[n_bwd], 3);
534 set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC,
535 conv_internal_dst_memory);
536 set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST,
537 lrn_diff_src_memory);
538 set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_SRC,
539 relu_diff_src_memory);
540 n_bwd++;
541
542 // Backward convolution with respect to weights
543 float *conv_diff_bias_buffer
544 = (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float));
545 float *conv_user_diff_weights_buffer = (float *)malloc(
546 product(conv_user_weights_sizes, 4) * sizeof(float));
547
548 // initialize memory for diff weights in user format
549 dnnl_memory_t conv_user_diff_weights_memory;
550 init_data_memory(4, conv_user_weights_sizes, dnnl_oihw, engine,
551 conv_user_diff_weights_buffer, &conv_user_diff_weights_memory);
552
553 // create backward convolution primitive descriptor
554 dnnl_primitive_desc_t conv_bwd_weights_pd;
555
556 {
557 // memory descriptors should be in format `any` to allow backward
558 // convolution for
559 // weights to chose the format it prefers for best performance
560 dnnl_memory_desc_t conv_diff_src_md, conv_diff_weights_md,
561 conv_diff_bias_md, conv_diff_dst_md;
562 CHECK(dnnl_memory_desc_create_with_tag(&conv_diff_src_md, 4,
563 conv_user_src_sizes, dnnl_f32, dnnl_format_tag_any));
564 CHECK(dnnl_memory_desc_create_with_tag(&conv_diff_weights_md, 4,
565 conv_user_weights_sizes, dnnl_f32, dnnl_format_tag_any));
566 CHECK(dnnl_memory_desc_create_with_tag(
567 &conv_diff_bias_md, 1, conv_bias_sizes, dnnl_f32, dnnl_x));
568 CHECK(dnnl_memory_desc_create_with_tag(&conv_diff_dst_md, 4,
569 conv_user_dst_sizes, dnnl_f32, dnnl_format_tag_any));
570
571 // create backward convolution descriptor
572 CHECK(dnnl_convolution_backward_weights_primitive_desc_create(
573 &conv_bwd_weights_pd, engine, dnnl_convolution_direct,
574 conv_diff_src_md, conv_diff_weights_md, conv_diff_bias_md,
575 conv_diff_dst_md, conv_strides, conv_dilation, conv_padding,
576 conv_padding, conv_pd, NULL));
577
578 CHECK(dnnl_memory_desc_destroy(conv_diff_src_md));
579 CHECK(dnnl_memory_desc_destroy(conv_diff_weights_md));
580 CHECK(dnnl_memory_desc_destroy(conv_diff_bias_md));
581 CHECK(dnnl_memory_desc_destroy(conv_diff_dst_md));
582 }
583
584 // for best performance convolution backward might chose
585 // different memory format for src and diff_dst
586 // than the memory formats preferred by forward convolution
587 // for src and dst respectively
588 // create reorder primitives for src from forward convolution to the
589 // format chosen by backward convolution
590 dnnl_primitive_t conv_bwd_reorder_src;
591 dnnl_memory_t conv_bwd_internal_src_memory;
592 const_dnnl_memory_desc_t conv_diff_src_md = dnnl_primitive_desc_query_md(
593 conv_bwd_weights_pd, dnnl_query_src_md, 0);
594 CHECK(prepare_reorder(&conv_src_memory, conv_diff_src_md, engine, 1,
595 &conv_bwd_internal_src_memory, &conv_bwd_reorder_src, &n_bwd,
596 net_bwd, net_bwd_args));
597
598 dnnl_memory_t conv_bwd_weights_src_memory = conv_bwd_internal_src_memory
599 ? conv_bwd_internal_src_memory
600 : conv_src_memory;
601
602 // create reorder primitives for diff_dst between diff_src from relu_bwd
603 // and format preferred by conv_diff_weights
604 dnnl_primitive_t conv_reorder_diff_dst;
605 dnnl_memory_t conv_internal_diff_dst_memory;
606 const_dnnl_memory_desc_t conv_diff_dst_md = dnnl_primitive_desc_query_md(
607 conv_bwd_weights_pd, dnnl_query_diff_dst_md, 0);
608
609 CHECK(prepare_reorder(&relu_diff_src_memory, conv_diff_dst_md, engine, 1,
610 &conv_internal_diff_dst_memory, &conv_reorder_diff_dst, &n_bwd,
611 net_bwd, net_bwd_args));
612
613 dnnl_memory_t conv_diff_dst_memory = conv_internal_diff_dst_memory
614 ? conv_internal_diff_dst_memory
615 : relu_diff_src_memory;
616
617 // create reorder primitives for conv diff weights memory
618 dnnl_primitive_t conv_reorder_diff_weights;
619 dnnl_memory_t conv_internal_diff_weights_memory;
620 const_dnnl_memory_desc_t conv_diff_weights_md
621 = dnnl_primitive_desc_query_md(
622 conv_bwd_weights_pd, dnnl_query_diff_weights_md, 0);
623 n_bwd += 1; // tentative workaround: preserve space for conv_bwd_weights
624 // that should happen before the reorder
625
626 CHECK(prepare_reorder(&conv_user_diff_weights_memory, conv_diff_weights_md,
627 engine, 0, &conv_internal_diff_weights_memory,
628 &conv_reorder_diff_weights, &n_bwd, net_bwd, net_bwd_args));
629 n_bwd -= conv_reorder_diff_weights ? 2 : 1;
630
631 dnnl_memory_t conv_diff_weights_memory = conv_internal_diff_weights_memory
632 ? conv_internal_diff_weights_memory
633 : conv_user_diff_weights_memory;
634
635 // create memory for diff bias memory
636 dnnl_memory_t conv_diff_bias_memory;
637 const_dnnl_memory_desc_t conv_diff_bias_md = dnnl_primitive_desc_query_md(
638 conv_bwd_weights_pd, dnnl_query_diff_weights_md, 1);
639 CHECK(dnnl_memory_create(&conv_diff_bias_memory, conv_diff_bias_md, engine,
640 DNNL_MEMORY_ALLOCATE));
641 CHECK(dnnl_memory_set_data_handle(
642 conv_diff_bias_memory, conv_diff_bias_buffer));
643
644 // finally created backward convolution weights primitive
645 dnnl_primitive_t conv_bwd_weights;
646 CHECK(dnnl_primitive_create(&conv_bwd_weights, conv_bwd_weights_pd));
647 net_bwd[n_bwd] = conv_bwd_weights;
648 prepare_arg_node(&net_bwd_args[n_bwd], 4);
649 set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC,
650 conv_bwd_weights_src_memory);
651 set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST,
652 conv_diff_dst_memory);
653 set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_WEIGHTS,
654 conv_diff_weights_memory);
655 set_arg(&net_bwd_args[n_bwd].args[3], DNNL_ARG_DIFF_BIAS,
656 conv_diff_bias_memory);
657 n_bwd++;
658
659 if (conv_reorder_diff_weights) n_bwd += 1;
660
661 // output from backward stream
662 void *net_diff_weights = NULL;
663 void *net_diff_bias = NULL;
664
665 int n_iter = 10; // number of iterations for training.
666 dnnl_stream_t stream;
667 CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags));
668 // Execute the net
669 for (int i = 0; i < n_iter; i++) {
670 for (uint32_t i = 0; i < n_fwd; ++i)
671 CHECK(dnnl_primitive_execute(net_fwd[i], stream,
672 net_fwd_args[i].nargs, net_fwd_args[i].args));
673
674 // Update net_diff_dst
675 void *net_output = NULL; // output from forward stream:
676 CHECK(dnnl_memory_get_data_handle(pool_user_dst_memory, &net_output));
677 // ...user updates net_diff_dst using net_output...
678 // some user defined func update_diff_dst(net_diff_dst, net_output)
679
680 // Backward pass
681 for (uint32_t i = 0; i < n_bwd; ++i)
682 CHECK(dnnl_primitive_execute(net_bwd[i], stream,
683 net_bwd_args[i].nargs, net_bwd_args[i].args));
684
685 // ... update weights ...
686 CHECK(dnnl_memory_get_data_handle(
687 conv_user_diff_weights_memory, &net_diff_weights));
688 CHECK(dnnl_memory_get_data_handle(
689 conv_diff_bias_memory, &net_diff_bias));
690 // ...user updates weights and bias using diff weights and bias...
691 // some user defined func update_weights(conv_user_weights_memory,
692 // conv_bias_memory,
693 // net_diff_weights, net_diff_bias);
694 }
695 CHECK(dnnl_stream_wait(stream));
696
697 dnnl_stream_destroy(stream);
698
699 // clean up nets
700 for (uint32_t i = 0; i < n_fwd; ++i)
701 free_arg_node(&net_fwd_args[i]);
702 for (uint32_t i = 0; i < n_bwd; ++i)
703 free_arg_node(&net_bwd_args[i]);
704
705 // Cleanup forward
706 CHECK(dnnl_primitive_desc_destroy(pool_pd));
707 CHECK(dnnl_primitive_desc_destroy(lrn_pd));
708 CHECK(dnnl_primitive_desc_destroy(relu_pd));
709 CHECK(dnnl_primitive_desc_destroy(conv_pd));
710
711 free(net_src);
712 free(net_dst);
713
714 dnnl_memory_destroy(conv_user_src_memory);
715 dnnl_memory_destroy(conv_user_weights_memory);
716 dnnl_memory_destroy(conv_user_bias_memory);
717 dnnl_memory_destroy(conv_internal_src_memory);
718 dnnl_memory_destroy(conv_internal_weights_memory);
719 dnnl_memory_destroy(conv_internal_dst_memory);
720 dnnl_primitive_destroy(conv_reorder_src);
721 dnnl_primitive_destroy(conv_reorder_weights);
722 dnnl_primitive_destroy(conv);
723
724 free(conv_weights);
725 free(conv_bias);
726
727 dnnl_memory_destroy(relu_dst_memory);
728 dnnl_primitive_destroy(relu);
729
730 dnnl_memory_destroy(lrn_ws_memory);
731 dnnl_memory_destroy(lrn_dst_memory);
732 dnnl_primitive_destroy(lrn);
733
734 dnnl_memory_destroy(pool_user_dst_memory);
735 dnnl_memory_destroy(pool_internal_dst_memory);
736 dnnl_memory_destroy(pool_ws_memory);
737 dnnl_primitive_destroy(pool_reorder_dst);
738 dnnl_primitive_destroy(pool);
739
740 // Cleanup backward
741 CHECK(dnnl_primitive_desc_destroy(pool_bwd_pd));
742 CHECK(dnnl_primitive_desc_destroy(lrn_bwd_pd));
743 CHECK(dnnl_primitive_desc_destroy(relu_bwd_pd));
744 CHECK(dnnl_primitive_desc_destroy(conv_bwd_weights_pd));
745
746 dnnl_memory_destroy(pool_user_diff_dst_memory);
747 dnnl_memory_destroy(pool_diff_src_memory);
748 dnnl_memory_destroy(pool_internal_diff_dst_memory);
749 dnnl_primitive_destroy(pool_reorder_diff_dst);
750 dnnl_primitive_destroy(pool_bwd);
751
752 free(net_diff_dst);
753
754 dnnl_memory_destroy(lrn_diff_src_memory);
755 dnnl_primitive_destroy(lrn_bwd);
756
757 dnnl_memory_destroy(relu_diff_src_memory);
758 dnnl_primitive_destroy(relu_bwd);
759
760 dnnl_memory_destroy(conv_user_diff_weights_memory);
761 dnnl_memory_destroy(conv_diff_bias_memory);
762 dnnl_memory_destroy(conv_bwd_internal_src_memory);
763 dnnl_primitive_destroy(conv_bwd_reorder_src);
764 dnnl_memory_destroy(conv_internal_diff_dst_memory);
765 dnnl_primitive_destroy(conv_reorder_diff_dst);
766 dnnl_memory_destroy(conv_internal_diff_weights_memory);
767 dnnl_primitive_destroy(conv_reorder_diff_weights);
768 dnnl_primitive_destroy(conv_bwd_weights);
769
770 free(conv_diff_bias_buffer);
771 free(conv_user_diff_weights_buffer);
772
773 dnnl_engine_destroy(engine);
774}
775
776int main(int argc, char **argv) {
777 simple_net();
778 printf("Example passed on CPU.\n");
779 return 0;
780}
781