1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /// @example cpu_cnn_training_f32.c |
18 | /// @copybrief cpu_cnn_training_f32_c |
19 | |
20 | /// @page cpu_cnn_training_f32_c CNN f32 training example |
21 | /// This C API example demonstrates how to build an AlexNet model training. |
22 | /// The example implements a few layers from AlexNet model. |
23 | /// |
24 | /// @include cpu_cnn_training_f32.c |
25 | |
26 | // Required for posix_memalign |
27 | #define _POSIX_C_SOURCE 200112L |
28 | |
29 | #include <stdio.h> |
30 | #include <stdlib.h> |
31 | #include <string.h> |
32 | |
33 | #include "oneapi/dnnl/dnnl.h" |
34 | |
35 | #include "example_utils.h" |
36 | |
37 | #define BATCH 8 |
38 | #define IC 3 |
39 | #define OC 96 |
40 | #define CONV_IH 227 |
41 | #define CONV_IW 227 |
42 | #define CONV_OH 55 |
43 | #define CONV_OW 55 |
44 | #define CONV_STRIDE 4 |
45 | #define CONV_PAD 0 |
46 | #define POOL_OH 27 |
47 | #define POOL_OW 27 |
48 | #define POOL_STRIDE 2 |
49 | #define POOL_PAD 0 |
50 | |
51 | static size_t product(dnnl_dim_t *arr, size_t size) { |
52 | size_t prod = 1; |
53 | for (size_t i = 0; i < size; ++i) |
54 | prod *= arr[i]; |
55 | return prod; |
56 | } |
57 | |
58 | static void init_net_data(float *data, uint32_t dim, const dnnl_dim_t *dims) { |
59 | if (dim == 1) { |
60 | for (dnnl_dim_t i = 0; i < dims[0]; ++i) { |
61 | data[i] = (float)(i % 1637); |
62 | } |
63 | } else if (dim == 4) { |
64 | for (dnnl_dim_t in = 0; in < dims[0]; ++in) |
65 | for (dnnl_dim_t ic = 0; ic < dims[1]; ++ic) |
66 | for (dnnl_dim_t ih = 0; ih < dims[2]; ++ih) |
67 | for (dnnl_dim_t iw = 0; iw < dims[3]; ++iw) { |
68 | dnnl_dim_t indx = in * dims[1] * dims[2] * dims[3] |
69 | + ic * dims[2] * dims[3] + ih * dims[3] + iw; |
70 | data[indx] = (float)(indx % 1637); |
71 | } |
72 | } |
73 | } |
74 | |
75 | typedef struct { |
76 | int nargs; |
77 | dnnl_exec_arg_t *args; |
78 | } args_t; |
79 | |
80 | static void prepare_arg_node(args_t *node, int nargs) { |
81 | node->args = (dnnl_exec_arg_t *)malloc(sizeof(dnnl_exec_arg_t) * nargs); |
82 | node->nargs = nargs; |
83 | } |
84 | static void free_arg_node(args_t *node) { |
85 | free(node->args); |
86 | } |
87 | |
88 | static void set_arg(dnnl_exec_arg_t *arg, int arg_idx, dnnl_memory_t memory) { |
89 | arg->arg = arg_idx; |
90 | arg->memory = memory; |
91 | } |
92 | |
93 | static void init_data_memory(uint32_t dim, const dnnl_dim_t *dims, |
94 | dnnl_format_tag_t user_tag, dnnl_engine_t engine, float *data, |
95 | dnnl_memory_t *memory) { |
96 | dnnl_memory_desc_t user_md; |
97 | CHECK(dnnl_memory_desc_create_with_tag( |
98 | &user_md, dim, dims, dnnl_f32, user_tag)); |
99 | CHECK(dnnl_memory_create(memory, user_md, engine, DNNL_MEMORY_ALLOCATE)); |
100 | CHECK(dnnl_memory_desc_destroy(user_md)); |
101 | write_to_dnnl_memory(data, *memory); |
102 | } |
103 | |
104 | dnnl_status_t prepare_reorder(dnnl_memory_t *user_memory, // in |
105 | const_dnnl_memory_desc_t prim_memory_md, // in |
106 | dnnl_engine_t prim_engine, // in: primitive's engine |
107 | int dir_is_user_to_prim, // in: user -> prim or prim -> user |
108 | dnnl_memory_t *prim_memory, // out: primitive's memory created |
109 | dnnl_primitive_t *reorder, // out: reorder primitive created |
110 | uint32_t *net_index, // primitive index in net (inc if reorder created) |
111 | dnnl_primitive_t *net, args_t *net_args) { // net params |
112 | const_dnnl_memory_desc_t user_memory_md; |
113 | dnnl_memory_get_memory_desc(*user_memory, &user_memory_md); |
114 | |
115 | dnnl_engine_t user_mem_engine; |
116 | dnnl_memory_get_engine(*user_memory, &user_mem_engine); |
117 | |
118 | if (!dnnl_memory_desc_equal(user_memory_md, prim_memory_md)) { |
119 | CHECK(dnnl_memory_create(prim_memory, prim_memory_md, prim_engine, |
120 | DNNL_MEMORY_ALLOCATE)); |
121 | |
122 | dnnl_primitive_desc_t reorder_pd; |
123 | if (dir_is_user_to_prim) { |
124 | CHECK(dnnl_reorder_primitive_desc_create(&reorder_pd, |
125 | user_memory_md, user_mem_engine, prim_memory_md, |
126 | prim_engine, NULL)); |
127 | } else { |
128 | CHECK(dnnl_reorder_primitive_desc_create(&reorder_pd, |
129 | prim_memory_md, prim_engine, user_memory_md, |
130 | user_mem_engine, NULL)); |
131 | } |
132 | CHECK(dnnl_primitive_create(reorder, reorder_pd)); |
133 | CHECK(dnnl_primitive_desc_destroy(reorder_pd)); |
134 | |
135 | net[*net_index] = *reorder; |
136 | prepare_arg_node(&net_args[*net_index], 2); |
137 | set_arg(&net_args[*net_index].args[0], DNNL_ARG_FROM, |
138 | dir_is_user_to_prim ? *user_memory : *prim_memory); |
139 | set_arg(&net_args[*net_index].args[1], DNNL_ARG_TO, |
140 | dir_is_user_to_prim ? *prim_memory : *user_memory); |
141 | (*net_index)++; |
142 | } else { |
143 | *prim_memory = NULL; |
144 | *reorder = NULL; |
145 | } |
146 | |
147 | return dnnl_success; |
148 | } |
149 | |
150 | void simple_net() { |
151 | dnnl_engine_t engine; |
152 | CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); // idx |
153 | |
154 | // build a simple net |
155 | uint32_t n_fwd = 0, n_bwd = 0; |
156 | dnnl_primitive_t net_fwd[10], net_bwd[10]; |
157 | args_t net_fwd_args[10], net_bwd_args[10]; |
158 | |
159 | const int ndims = 4; |
160 | dnnl_dims_t net_src_sizes = {BATCH, IC, CONV_IH, CONV_IW}; |
161 | dnnl_dims_t net_dst_sizes = {BATCH, OC, POOL_OH, POOL_OW}; |
162 | |
163 | float *net_src |
164 | = (float *)malloc(product(net_src_sizes, ndims) * sizeof(float)); |
165 | float *net_dst |
166 | = (float *)malloc(product(net_dst_sizes, ndims) * sizeof(float)); |
167 | |
168 | init_net_data(net_src, ndims, net_src_sizes); |
169 | memset(net_dst, 0, product(net_dst_sizes, ndims) * sizeof(float)); |
170 | |
171 | //---------------------------------------------------------------------- |
172 | //----------------- Forward Stream ------------------------------------- |
173 | // AlexNet: conv |
174 | // {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, 11, 11} -> |
175 | // {BATCH, OC, CONV_OH, CONV_OW} |
176 | // strides: {CONV_STRIDE, CONV_STRIDE} |
177 | dnnl_dims_t conv_user_src_sizes; |
178 | for (int i = 0; i < ndims; i++) |
179 | conv_user_src_sizes[i] = net_src_sizes[i]; |
180 | dnnl_dims_t conv_user_weights_sizes = {OC, IC, 11, 11}; |
181 | dnnl_dims_t conv_bias_sizes = {OC}; |
182 | dnnl_dims_t conv_user_dst_sizes = {BATCH, OC, CONV_OH, CONV_OW}; |
183 | dnnl_dims_t conv_strides = {CONV_STRIDE, CONV_STRIDE}; |
184 | dnnl_dims_t conv_dilation = {0, 0}; |
185 | dnnl_dims_t conv_padding = {CONV_PAD, CONV_PAD}; |
186 | |
187 | float *conv_src = net_src; |
188 | float *conv_weights = (float *)malloc( |
189 | product(conv_user_weights_sizes, ndims) * sizeof(float)); |
190 | float *conv_bias |
191 | = (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float)); |
192 | |
193 | init_net_data(conv_weights, ndims, conv_user_weights_sizes); |
194 | init_net_data(conv_bias, 1, conv_bias_sizes); |
195 | |
196 | // create memory for user data |
197 | dnnl_memory_t conv_user_src_memory, conv_user_weights_memory, |
198 | conv_user_bias_memory; |
199 | init_data_memory(ndims, conv_user_src_sizes, dnnl_nchw, engine, conv_src, |
200 | &conv_user_src_memory); |
201 | init_data_memory(ndims, conv_user_weights_sizes, dnnl_oihw, engine, |
202 | conv_weights, &conv_user_weights_memory); |
203 | init_data_memory(1, conv_bias_sizes, dnnl_x, engine, conv_bias, |
204 | &conv_user_bias_memory); |
205 | |
206 | // create a convolution |
207 | dnnl_primitive_desc_t conv_pd; |
208 | |
209 | { |
210 | // create data descriptors for convolution w/ no specified format |
211 | dnnl_memory_desc_t conv_src_md, conv_weights_md, conv_bias_md, |
212 | conv_dst_md; |
213 | CHECK(dnnl_memory_desc_create_with_tag(&conv_src_md, ndims, |
214 | conv_user_src_sizes, dnnl_f32, dnnl_format_tag_any)); |
215 | CHECK(dnnl_memory_desc_create_with_tag(&conv_weights_md, ndims, |
216 | conv_user_weights_sizes, dnnl_f32, dnnl_format_tag_any)); |
217 | CHECK(dnnl_memory_desc_create_with_tag( |
218 | &conv_bias_md, 1, conv_bias_sizes, dnnl_f32, dnnl_x)); |
219 | CHECK(dnnl_memory_desc_create_with_tag(&conv_dst_md, ndims, |
220 | conv_user_dst_sizes, dnnl_f32, dnnl_format_tag_any)); |
221 | |
222 | CHECK(dnnl_convolution_forward_primitive_desc_create(&conv_pd, engine, |
223 | dnnl_forward, dnnl_convolution_direct, conv_src_md, |
224 | conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, |
225 | conv_dilation, conv_padding, conv_padding, NULL)); |
226 | |
227 | CHECK(dnnl_memory_desc_destroy(conv_src_md)); |
228 | CHECK(dnnl_memory_desc_destroy(conv_weights_md)); |
229 | CHECK(dnnl_memory_desc_destroy(conv_bias_md)); |
230 | CHECK(dnnl_memory_desc_destroy(conv_dst_md)); |
231 | } |
232 | |
233 | dnnl_memory_t conv_internal_src_memory, conv_internal_weights_memory, |
234 | conv_internal_dst_memory; |
235 | |
236 | // create memory for dst data, we don't need to reorder it to user data |
237 | const_dnnl_memory_desc_t conv_dst_md |
238 | = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0); |
239 | CHECK(dnnl_memory_create(&conv_internal_dst_memory, conv_dst_md, engine, |
240 | DNNL_MEMORY_ALLOCATE)); |
241 | |
242 | // create reorder primitives between user data and convolution srcs |
243 | // if required |
244 | dnnl_primitive_t conv_reorder_src, conv_reorder_weights; |
245 | |
246 | const_dnnl_memory_desc_t conv_src_md |
247 | = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_src_md, 0); |
248 | CHECK(prepare_reorder(&conv_user_src_memory, conv_src_md, engine, 1, |
249 | &conv_internal_src_memory, &conv_reorder_src, &n_fwd, net_fwd, |
250 | net_fwd_args)); |
251 | |
252 | const_dnnl_memory_desc_t conv_weights_md |
253 | = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_weights_md, 0); |
254 | CHECK(prepare_reorder(&conv_user_weights_memory, conv_weights_md, engine, 1, |
255 | &conv_internal_weights_memory, &conv_reorder_weights, &n_fwd, |
256 | net_fwd, net_fwd_args)); |
257 | |
258 | dnnl_memory_t conv_src_memory = conv_internal_src_memory |
259 | ? conv_internal_src_memory |
260 | : conv_user_src_memory; |
261 | dnnl_memory_t conv_weights_memory = conv_internal_weights_memory |
262 | ? conv_internal_weights_memory |
263 | : conv_user_weights_memory; |
264 | |
265 | // finally create a convolution primitive |
266 | dnnl_primitive_t conv; |
267 | CHECK(dnnl_primitive_create(&conv, conv_pd)); |
268 | net_fwd[n_fwd] = conv; |
269 | prepare_arg_node(&net_fwd_args[n_fwd], 4); |
270 | set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, conv_src_memory); |
271 | set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_WEIGHTS, |
272 | conv_weights_memory); |
273 | set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_BIAS, conv_user_bias_memory); |
274 | set_arg(&net_fwd_args[n_fwd].args[3], DNNL_ARG_DST, |
275 | conv_internal_dst_memory); |
276 | n_fwd++; |
277 | |
278 | // AlexNet: relu |
279 | // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW} |
280 | |
281 | float negative_slope = 0.0f; |
282 | |
283 | // keep memory format of source same as the format of convolution |
284 | // output in order to avoid reorder |
285 | const_dnnl_memory_desc_t relu_src_md = conv_dst_md; |
286 | const_dnnl_memory_desc_t relu_dst_md = relu_src_md; |
287 | |
288 | // create a relu primitive descriptor |
289 | dnnl_primitive_desc_t relu_pd; |
290 | CHECK(dnnl_eltwise_forward_primitive_desc_create(&relu_pd, engine, |
291 | dnnl_forward, dnnl_eltwise_relu, relu_src_md, relu_dst_md, |
292 | negative_slope, 0, NULL)); |
293 | |
294 | // create relu dst memory |
295 | dnnl_memory_t relu_dst_memory; |
296 | CHECK(dnnl_memory_create( |
297 | &relu_dst_memory, relu_dst_md, engine, DNNL_MEMORY_ALLOCATE)); |
298 | |
299 | // finally create a relu primitive |
300 | dnnl_primitive_t relu; |
301 | CHECK(dnnl_primitive_create(&relu, relu_pd)); |
302 | net_fwd[n_fwd] = relu; |
303 | prepare_arg_node(&net_fwd_args[n_fwd], 2); |
304 | set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, |
305 | conv_internal_dst_memory); |
306 | set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, relu_dst_memory); |
307 | n_fwd++; |
308 | |
309 | // AlexNet: lrn |
310 | // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW} |
311 | // local size: 5 |
312 | // alpha: 0.0001 |
313 | // beta: 0.75 |
314 | // k: 1.0 |
315 | uint32_t local_size = 5; |
316 | float alpha = 0.0001f; |
317 | float beta = 0.75f; |
318 | float k = 1.0f; |
319 | |
320 | // create lrn src memory descriptor using dst memory descriptor |
321 | // from previous primitive |
322 | const_dnnl_memory_desc_t lrn_src_md = relu_dst_md; |
323 | const_dnnl_memory_desc_t lrn_dst_md = lrn_src_md; |
324 | |
325 | // create a lrn primitive descriptor |
326 | dnnl_primitive_desc_t lrn_pd; |
327 | CHECK(dnnl_lrn_forward_primitive_desc_create(&lrn_pd, engine, dnnl_forward, |
328 | dnnl_lrn_across_channels, lrn_src_md, lrn_dst_md, local_size, alpha, |
329 | beta, k, NULL)); |
330 | |
331 | // create primitives for lrn dst and workspace memory |
332 | dnnl_memory_t lrn_dst_memory, lrn_ws_memory; |
333 | |
334 | CHECK(dnnl_memory_create( |
335 | &lrn_dst_memory, lrn_dst_md, engine, DNNL_MEMORY_ALLOCATE)); |
336 | |
337 | // create workspace only in training and only for forward primitive |
338 | // query lrn_pd for workspace, this memory will be shared with forward lrn |
339 | const_dnnl_memory_desc_t lrn_ws_md |
340 | = dnnl_primitive_desc_query_md(lrn_pd, dnnl_query_workspace_md, 0); |
341 | CHECK(dnnl_memory_create( |
342 | &lrn_ws_memory, lrn_ws_md, engine, DNNL_MEMORY_ALLOCATE)); |
343 | |
344 | // finally create a lrn primitive |
345 | dnnl_primitive_t lrn; |
346 | CHECK(dnnl_primitive_create(&lrn, lrn_pd)); |
347 | net_fwd[n_fwd] = lrn; |
348 | prepare_arg_node(&net_fwd_args[n_fwd], 3); |
349 | set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, relu_dst_memory); |
350 | set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, lrn_dst_memory); |
351 | set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_WORKSPACE, lrn_ws_memory); |
352 | n_fwd++; |
353 | |
354 | // AlexNet: pool |
355 | // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW} |
356 | // kernel: {3, 3} |
357 | // strides: {POOL_STRIDE, POOL_STRIDE} |
358 | // dilation: {0, 0} |
359 | dnnl_dims_t pool_dst_sizes; |
360 | for (int i = 0; i < ndims; i++) |
361 | pool_dst_sizes[i] = net_dst_sizes[i]; |
362 | dnnl_dims_t pool_kernel = {3, 3}; |
363 | dnnl_dims_t pool_strides = {POOL_STRIDE, POOL_STRIDE}; |
364 | dnnl_dims_t pool_padding = {POOL_PAD, POOL_PAD}; |
365 | dnnl_dims_t pool_dilation = {0, 0}; |
366 | |
367 | // create memory for user dst data |
368 | dnnl_memory_t pool_user_dst_memory; |
369 | init_data_memory(4, pool_dst_sizes, dnnl_nchw, engine, net_dst, |
370 | &pool_user_dst_memory); |
371 | |
372 | // create a pooling primitive descriptor |
373 | dnnl_primitive_desc_t pool_pd; |
374 | |
375 | { |
376 | // create pooling src memory descriptor using dst descriptor |
377 | // from previous primitive |
378 | const_dnnl_memory_desc_t pool_src_md = lrn_dst_md; |
379 | |
380 | // create descriptors for dst pooling data |
381 | dnnl_memory_desc_t pool_dst_md; |
382 | CHECK(dnnl_memory_desc_create_with_tag(&pool_dst_md, 4, pool_dst_sizes, |
383 | dnnl_f32, dnnl_format_tag_any)); |
384 | |
385 | CHECK(dnnl_pooling_forward_primitive_desc_create(&pool_pd, engine, |
386 | dnnl_forward, dnnl_pooling_max, pool_src_md, pool_dst_md, |
387 | pool_strides, pool_kernel, pool_dilation, pool_padding, |
388 | pool_padding, NULL)); |
389 | CHECK(dnnl_memory_desc_destroy(pool_dst_md)); |
390 | } |
391 | |
392 | // create memory for workspace |
393 | dnnl_memory_t pool_ws_memory; |
394 | const_dnnl_memory_desc_t pool_ws_md |
395 | = dnnl_primitive_desc_query_md(pool_pd, dnnl_query_workspace_md, 0); |
396 | CHECK(dnnl_memory_create( |
397 | &pool_ws_memory, pool_ws_md, engine, DNNL_MEMORY_ALLOCATE)); |
398 | |
399 | // create reorder primitives between pooling dsts and user format dst |
400 | // if required |
401 | dnnl_primitive_t pool_reorder_dst; |
402 | dnnl_memory_t pool_internal_dst_memory; |
403 | const_dnnl_memory_desc_t pool_dst_md |
404 | = dnnl_primitive_desc_query_md(pool_pd, dnnl_query_dst_md, 0); |
405 | n_fwd += 1; // tentative workaround: preserve space for pooling that should |
406 | // happen before the reorder |
407 | CHECK(prepare_reorder(&pool_user_dst_memory, pool_dst_md, engine, 0, |
408 | &pool_internal_dst_memory, &pool_reorder_dst, &n_fwd, net_fwd, |
409 | net_fwd_args)); |
410 | n_fwd -= pool_reorder_dst ? 2 : 1; |
411 | |
412 | dnnl_memory_t pool_dst_memory = pool_internal_dst_memory |
413 | ? pool_internal_dst_memory |
414 | : pool_user_dst_memory; |
415 | |
416 | // finally create a pooling primitive |
417 | dnnl_primitive_t pool; |
418 | CHECK(dnnl_primitive_create(&pool, pool_pd)); |
419 | net_fwd[n_fwd] = pool; |
420 | prepare_arg_node(&net_fwd_args[n_fwd], 3); |
421 | set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, lrn_dst_memory); |
422 | set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, pool_dst_memory); |
423 | set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_WORKSPACE, pool_ws_memory); |
424 | n_fwd++; |
425 | |
426 | if (pool_reorder_dst) n_fwd += 1; |
427 | |
428 | //----------------------------------------------------------------------- |
429 | //----------------- Backward Stream ------------------------------------- |
430 | //----------------------------------------------------------------------- |
431 | |
432 | // ... user diff_data ... |
433 | float *net_diff_dst |
434 | = (float *)malloc(product(pool_dst_sizes, 4) * sizeof(float)); |
435 | |
436 | init_net_data(net_diff_dst, 4, pool_dst_sizes); |
437 | |
438 | // create memory for user diff dst data |
439 | dnnl_memory_t pool_user_diff_dst_memory; |
440 | init_data_memory(4, pool_dst_sizes, dnnl_nchw, engine, net_diff_dst, |
441 | &pool_user_diff_dst_memory); |
442 | |
443 | // Pooling Backward |
444 | // pooling diff src memory descriptor |
445 | const_dnnl_memory_desc_t pool_diff_src_md = lrn_dst_md; |
446 | |
447 | // pooling diff dst memory descriptor |
448 | const_dnnl_memory_desc_t pool_diff_dst_md = pool_dst_md; |
449 | |
450 | // backward primitive descriptor needs to hint forward descriptor |
451 | dnnl_primitive_desc_t pool_bwd_pd; |
452 | CHECK(dnnl_pooling_backward_primitive_desc_create(&pool_bwd_pd, engine, |
453 | dnnl_pooling_max, pool_diff_src_md, pool_diff_dst_md, pool_strides, |
454 | pool_kernel, pool_dilation, pool_padding, pool_padding, pool_pd, |
455 | NULL)); |
456 | |
457 | // create reorder primitive between user diff dst and pool diff dst |
458 | // if required |
459 | dnnl_memory_t pool_diff_dst_memory, pool_internal_diff_dst_memory; |
460 | dnnl_primitive_t pool_reorder_diff_dst; |
461 | CHECK(prepare_reorder(&pool_user_diff_dst_memory, pool_diff_dst_md, engine, |
462 | 1, &pool_internal_diff_dst_memory, &pool_reorder_diff_dst, &n_bwd, |
463 | net_bwd, net_bwd_args)); |
464 | |
465 | pool_diff_dst_memory = pool_internal_diff_dst_memory |
466 | ? pool_internal_diff_dst_memory |
467 | : pool_user_diff_dst_memory; |
468 | |
469 | // create memory for pool diff src data |
470 | dnnl_memory_t pool_diff_src_memory; |
471 | CHECK(dnnl_memory_create(&pool_diff_src_memory, pool_diff_src_md, engine, |
472 | DNNL_MEMORY_ALLOCATE)); |
473 | |
474 | // finally create backward pooling primitive |
475 | dnnl_primitive_t pool_bwd; |
476 | CHECK(dnnl_primitive_create(&pool_bwd, pool_bwd_pd)); |
477 | net_bwd[n_bwd] = pool_bwd; |
478 | prepare_arg_node(&net_bwd_args[n_bwd], 3); |
479 | set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_DIFF_DST, |
480 | pool_diff_dst_memory); |
481 | set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_WORKSPACE, pool_ws_memory); |
482 | set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_SRC, |
483 | pool_diff_src_memory); |
484 | n_bwd++; |
485 | |
486 | // Backward lrn |
487 | const_dnnl_memory_desc_t lrn_diff_dst_md = pool_diff_src_md; |
488 | const_dnnl_memory_desc_t lrn_diff_src_md = lrn_diff_dst_md; |
489 | |
490 | // create backward lrn descriptor |
491 | dnnl_primitive_desc_t lrn_bwd_pd; |
492 | CHECK(dnnl_lrn_backward_primitive_desc_create(&lrn_bwd_pd, engine, |
493 | dnnl_lrn_across_channels, lrn_diff_src_md, lrn_diff_dst_md, |
494 | lrn_src_md, local_size, alpha, beta, k, lrn_pd, NULL)); |
495 | |
496 | // create memory for lrn diff src |
497 | dnnl_memory_t lrn_diff_src_memory; |
498 | CHECK(dnnl_memory_create(&lrn_diff_src_memory, lrn_diff_src_md, engine, |
499 | DNNL_MEMORY_ALLOCATE)); |
500 | |
501 | // finally create backward lrn primitive |
502 | dnnl_primitive_t lrn_bwd; |
503 | CHECK(dnnl_primitive_create(&lrn_bwd, lrn_bwd_pd)); |
504 | net_bwd[n_bwd] = lrn_bwd; |
505 | prepare_arg_node(&net_bwd_args[n_bwd], 4); |
506 | set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC, relu_dst_memory); |
507 | set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST, |
508 | pool_diff_src_memory); |
509 | set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_WORKSPACE, lrn_ws_memory); |
510 | set_arg(&net_bwd_args[n_bwd].args[3], DNNL_ARG_DIFF_SRC, |
511 | lrn_diff_src_memory); |
512 | n_bwd++; |
513 | |
514 | // Backward relu |
515 | const_dnnl_memory_desc_t relu_diff_src_md = lrn_diff_src_md; |
516 | const_dnnl_memory_desc_t relu_diff_dst_md = lrn_diff_src_md; |
517 | |
518 | // create backward relu descriptor |
519 | dnnl_primitive_desc_t relu_bwd_pd; |
520 | CHECK(dnnl_eltwise_backward_primitive_desc_create(&relu_bwd_pd, engine, |
521 | dnnl_eltwise_relu, relu_diff_src_md, relu_diff_dst_md, relu_src_md, |
522 | negative_slope, 0, relu_pd, NULL)); |
523 | |
524 | // create memory for relu diff src |
525 | dnnl_memory_t relu_diff_src_memory; |
526 | CHECK(dnnl_memory_create(&relu_diff_src_memory, relu_diff_src_md, engine, |
527 | DNNL_MEMORY_ALLOCATE)); |
528 | |
529 | // finally create backward relu primitive |
530 | dnnl_primitive_t relu_bwd; |
531 | CHECK(dnnl_primitive_create(&relu_bwd, relu_bwd_pd)); |
532 | net_bwd[n_bwd] = relu_bwd; |
533 | prepare_arg_node(&net_bwd_args[n_bwd], 3); |
534 | set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC, |
535 | conv_internal_dst_memory); |
536 | set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST, |
537 | lrn_diff_src_memory); |
538 | set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_SRC, |
539 | relu_diff_src_memory); |
540 | n_bwd++; |
541 | |
542 | // Backward convolution with respect to weights |
543 | float *conv_diff_bias_buffer |
544 | = (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float)); |
545 | float *conv_user_diff_weights_buffer = (float *)malloc( |
546 | product(conv_user_weights_sizes, 4) * sizeof(float)); |
547 | |
548 | // initialize memory for diff weights in user format |
549 | dnnl_memory_t conv_user_diff_weights_memory; |
550 | init_data_memory(4, conv_user_weights_sizes, dnnl_oihw, engine, |
551 | conv_user_diff_weights_buffer, &conv_user_diff_weights_memory); |
552 | |
553 | // create backward convolution primitive descriptor |
554 | dnnl_primitive_desc_t conv_bwd_weights_pd; |
555 | |
556 | { |
557 | // memory descriptors should be in format `any` to allow backward |
558 | // convolution for |
559 | // weights to chose the format it prefers for best performance |
560 | dnnl_memory_desc_t conv_diff_src_md, conv_diff_weights_md, |
561 | conv_diff_bias_md, conv_diff_dst_md; |
562 | CHECK(dnnl_memory_desc_create_with_tag(&conv_diff_src_md, 4, |
563 | conv_user_src_sizes, dnnl_f32, dnnl_format_tag_any)); |
564 | CHECK(dnnl_memory_desc_create_with_tag(&conv_diff_weights_md, 4, |
565 | conv_user_weights_sizes, dnnl_f32, dnnl_format_tag_any)); |
566 | CHECK(dnnl_memory_desc_create_with_tag( |
567 | &conv_diff_bias_md, 1, conv_bias_sizes, dnnl_f32, dnnl_x)); |
568 | CHECK(dnnl_memory_desc_create_with_tag(&conv_diff_dst_md, 4, |
569 | conv_user_dst_sizes, dnnl_f32, dnnl_format_tag_any)); |
570 | |
571 | // create backward convolution descriptor |
572 | CHECK(dnnl_convolution_backward_weights_primitive_desc_create( |
573 | &conv_bwd_weights_pd, engine, dnnl_convolution_direct, |
574 | conv_diff_src_md, conv_diff_weights_md, conv_diff_bias_md, |
575 | conv_diff_dst_md, conv_strides, conv_dilation, conv_padding, |
576 | conv_padding, conv_pd, NULL)); |
577 | |
578 | CHECK(dnnl_memory_desc_destroy(conv_diff_src_md)); |
579 | CHECK(dnnl_memory_desc_destroy(conv_diff_weights_md)); |
580 | CHECK(dnnl_memory_desc_destroy(conv_diff_bias_md)); |
581 | CHECK(dnnl_memory_desc_destroy(conv_diff_dst_md)); |
582 | } |
583 | |
584 | // for best performance convolution backward might chose |
585 | // different memory format for src and diff_dst |
586 | // than the memory formats preferred by forward convolution |
587 | // for src and dst respectively |
588 | // create reorder primitives for src from forward convolution to the |
589 | // format chosen by backward convolution |
590 | dnnl_primitive_t conv_bwd_reorder_src; |
591 | dnnl_memory_t conv_bwd_internal_src_memory; |
592 | const_dnnl_memory_desc_t conv_diff_src_md = dnnl_primitive_desc_query_md( |
593 | conv_bwd_weights_pd, dnnl_query_src_md, 0); |
594 | CHECK(prepare_reorder(&conv_src_memory, conv_diff_src_md, engine, 1, |
595 | &conv_bwd_internal_src_memory, &conv_bwd_reorder_src, &n_bwd, |
596 | net_bwd, net_bwd_args)); |
597 | |
598 | dnnl_memory_t conv_bwd_weights_src_memory = conv_bwd_internal_src_memory |
599 | ? conv_bwd_internal_src_memory |
600 | : conv_src_memory; |
601 | |
602 | // create reorder primitives for diff_dst between diff_src from relu_bwd |
603 | // and format preferred by conv_diff_weights |
604 | dnnl_primitive_t conv_reorder_diff_dst; |
605 | dnnl_memory_t conv_internal_diff_dst_memory; |
606 | const_dnnl_memory_desc_t conv_diff_dst_md = dnnl_primitive_desc_query_md( |
607 | conv_bwd_weights_pd, dnnl_query_diff_dst_md, 0); |
608 | |
609 | CHECK(prepare_reorder(&relu_diff_src_memory, conv_diff_dst_md, engine, 1, |
610 | &conv_internal_diff_dst_memory, &conv_reorder_diff_dst, &n_bwd, |
611 | net_bwd, net_bwd_args)); |
612 | |
613 | dnnl_memory_t conv_diff_dst_memory = conv_internal_diff_dst_memory |
614 | ? conv_internal_diff_dst_memory |
615 | : relu_diff_src_memory; |
616 | |
617 | // create reorder primitives for conv diff weights memory |
618 | dnnl_primitive_t conv_reorder_diff_weights; |
619 | dnnl_memory_t conv_internal_diff_weights_memory; |
620 | const_dnnl_memory_desc_t conv_diff_weights_md |
621 | = dnnl_primitive_desc_query_md( |
622 | conv_bwd_weights_pd, dnnl_query_diff_weights_md, 0); |
623 | n_bwd += 1; // tentative workaround: preserve space for conv_bwd_weights |
624 | // that should happen before the reorder |
625 | |
626 | CHECK(prepare_reorder(&conv_user_diff_weights_memory, conv_diff_weights_md, |
627 | engine, 0, &conv_internal_diff_weights_memory, |
628 | &conv_reorder_diff_weights, &n_bwd, net_bwd, net_bwd_args)); |
629 | n_bwd -= conv_reorder_diff_weights ? 2 : 1; |
630 | |
631 | dnnl_memory_t conv_diff_weights_memory = conv_internal_diff_weights_memory |
632 | ? conv_internal_diff_weights_memory |
633 | : conv_user_diff_weights_memory; |
634 | |
635 | // create memory for diff bias memory |
636 | dnnl_memory_t conv_diff_bias_memory; |
637 | const_dnnl_memory_desc_t conv_diff_bias_md = dnnl_primitive_desc_query_md( |
638 | conv_bwd_weights_pd, dnnl_query_diff_weights_md, 1); |
639 | CHECK(dnnl_memory_create(&conv_diff_bias_memory, conv_diff_bias_md, engine, |
640 | DNNL_MEMORY_ALLOCATE)); |
641 | CHECK(dnnl_memory_set_data_handle( |
642 | conv_diff_bias_memory, conv_diff_bias_buffer)); |
643 | |
644 | // finally created backward convolution weights primitive |
645 | dnnl_primitive_t conv_bwd_weights; |
646 | CHECK(dnnl_primitive_create(&conv_bwd_weights, conv_bwd_weights_pd)); |
647 | net_bwd[n_bwd] = conv_bwd_weights; |
648 | prepare_arg_node(&net_bwd_args[n_bwd], 4); |
649 | set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC, |
650 | conv_bwd_weights_src_memory); |
651 | set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST, |
652 | conv_diff_dst_memory); |
653 | set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_WEIGHTS, |
654 | conv_diff_weights_memory); |
655 | set_arg(&net_bwd_args[n_bwd].args[3], DNNL_ARG_DIFF_BIAS, |
656 | conv_diff_bias_memory); |
657 | n_bwd++; |
658 | |
659 | if (conv_reorder_diff_weights) n_bwd += 1; |
660 | |
661 | // output from backward stream |
662 | void *net_diff_weights = NULL; |
663 | void *net_diff_bias = NULL; |
664 | |
665 | int n_iter = 10; // number of iterations for training. |
666 | dnnl_stream_t stream; |
667 | CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags)); |
668 | // Execute the net |
669 | for (int i = 0; i < n_iter; i++) { |
670 | for (uint32_t i = 0; i < n_fwd; ++i) |
671 | CHECK(dnnl_primitive_execute(net_fwd[i], stream, |
672 | net_fwd_args[i].nargs, net_fwd_args[i].args)); |
673 | |
674 | // Update net_diff_dst |
675 | void *net_output = NULL; // output from forward stream: |
676 | CHECK(dnnl_memory_get_data_handle(pool_user_dst_memory, &net_output)); |
677 | // ...user updates net_diff_dst using net_output... |
678 | // some user defined func update_diff_dst(net_diff_dst, net_output) |
679 | |
680 | // Backward pass |
681 | for (uint32_t i = 0; i < n_bwd; ++i) |
682 | CHECK(dnnl_primitive_execute(net_bwd[i], stream, |
683 | net_bwd_args[i].nargs, net_bwd_args[i].args)); |
684 | |
685 | // ... update weights ... |
686 | CHECK(dnnl_memory_get_data_handle( |
687 | conv_user_diff_weights_memory, &net_diff_weights)); |
688 | CHECK(dnnl_memory_get_data_handle( |
689 | conv_diff_bias_memory, &net_diff_bias)); |
690 | // ...user updates weights and bias using diff weights and bias... |
691 | // some user defined func update_weights(conv_user_weights_memory, |
692 | // conv_bias_memory, |
693 | // net_diff_weights, net_diff_bias); |
694 | } |
695 | CHECK(dnnl_stream_wait(stream)); |
696 | |
697 | dnnl_stream_destroy(stream); |
698 | |
699 | // clean up nets |
700 | for (uint32_t i = 0; i < n_fwd; ++i) |
701 | free_arg_node(&net_fwd_args[i]); |
702 | for (uint32_t i = 0; i < n_bwd; ++i) |
703 | free_arg_node(&net_bwd_args[i]); |
704 | |
705 | // Cleanup forward |
706 | CHECK(dnnl_primitive_desc_destroy(pool_pd)); |
707 | CHECK(dnnl_primitive_desc_destroy(lrn_pd)); |
708 | CHECK(dnnl_primitive_desc_destroy(relu_pd)); |
709 | CHECK(dnnl_primitive_desc_destroy(conv_pd)); |
710 | |
711 | free(net_src); |
712 | free(net_dst); |
713 | |
714 | dnnl_memory_destroy(conv_user_src_memory); |
715 | dnnl_memory_destroy(conv_user_weights_memory); |
716 | dnnl_memory_destroy(conv_user_bias_memory); |
717 | dnnl_memory_destroy(conv_internal_src_memory); |
718 | dnnl_memory_destroy(conv_internal_weights_memory); |
719 | dnnl_memory_destroy(conv_internal_dst_memory); |
720 | dnnl_primitive_destroy(conv_reorder_src); |
721 | dnnl_primitive_destroy(conv_reorder_weights); |
722 | dnnl_primitive_destroy(conv); |
723 | |
724 | free(conv_weights); |
725 | free(conv_bias); |
726 | |
727 | dnnl_memory_destroy(relu_dst_memory); |
728 | dnnl_primitive_destroy(relu); |
729 | |
730 | dnnl_memory_destroy(lrn_ws_memory); |
731 | dnnl_memory_destroy(lrn_dst_memory); |
732 | dnnl_primitive_destroy(lrn); |
733 | |
734 | dnnl_memory_destroy(pool_user_dst_memory); |
735 | dnnl_memory_destroy(pool_internal_dst_memory); |
736 | dnnl_memory_destroy(pool_ws_memory); |
737 | dnnl_primitive_destroy(pool_reorder_dst); |
738 | dnnl_primitive_destroy(pool); |
739 | |
740 | // Cleanup backward |
741 | CHECK(dnnl_primitive_desc_destroy(pool_bwd_pd)); |
742 | CHECK(dnnl_primitive_desc_destroy(lrn_bwd_pd)); |
743 | CHECK(dnnl_primitive_desc_destroy(relu_bwd_pd)); |
744 | CHECK(dnnl_primitive_desc_destroy(conv_bwd_weights_pd)); |
745 | |
746 | dnnl_memory_destroy(pool_user_diff_dst_memory); |
747 | dnnl_memory_destroy(pool_diff_src_memory); |
748 | dnnl_memory_destroy(pool_internal_diff_dst_memory); |
749 | dnnl_primitive_destroy(pool_reorder_diff_dst); |
750 | dnnl_primitive_destroy(pool_bwd); |
751 | |
752 | free(net_diff_dst); |
753 | |
754 | dnnl_memory_destroy(lrn_diff_src_memory); |
755 | dnnl_primitive_destroy(lrn_bwd); |
756 | |
757 | dnnl_memory_destroy(relu_diff_src_memory); |
758 | dnnl_primitive_destroy(relu_bwd); |
759 | |
760 | dnnl_memory_destroy(conv_user_diff_weights_memory); |
761 | dnnl_memory_destroy(conv_diff_bias_memory); |
762 | dnnl_memory_destroy(conv_bwd_internal_src_memory); |
763 | dnnl_primitive_destroy(conv_bwd_reorder_src); |
764 | dnnl_memory_destroy(conv_internal_diff_dst_memory); |
765 | dnnl_primitive_destroy(conv_reorder_diff_dst); |
766 | dnnl_memory_destroy(conv_internal_diff_weights_memory); |
767 | dnnl_primitive_destroy(conv_reorder_diff_weights); |
768 | dnnl_primitive_destroy(conv_bwd_weights); |
769 | |
770 | free(conv_diff_bias_buffer); |
771 | free(conv_user_diff_weights_buffer); |
772 | |
773 | dnnl_engine_destroy(engine); |
774 | } |
775 | |
776 | int main(int argc, char **argv) { |
777 | simple_net(); |
778 | printf("Example passed on CPU.\n" ); |
779 | return 0; |
780 | } |
781 | |