1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example cnn_training_bf16.cpp
18/// @copybrief cnn_training_bf16_cpp
19/// > Annotated version: @ref cnn_training_bf16_cpp
20///
21/// @page cnn_training_bf16_cpp CNN bf16 training example
22/// This C++ API example demonstrates how to build an AlexNet model training
23/// using the bfloat16 data type.
24///
25/// The example implements a few layers from AlexNet model.
26///
27/// @include cnn_training_bf16.cpp
28
29#include <cassert>
30#include <cmath>
31#include <iostream>
32#include <stdexcept>
33
34#include "oneapi/dnnl/dnnl.hpp"
35
36#include "example_utils.hpp"
37
38using namespace dnnl;
39
40void simple_net(engine::kind engine_kind) {
41 using tag = memory::format_tag;
42 using dt = memory::data_type;
43
44 auto eng = engine(engine_kind, 0);
45 stream s(eng);
46
47 // Vector of primitives and their execute arguments
48 std::vector<primitive> net_fwd, net_bwd;
49 std::vector<std::unordered_map<int, memory>> net_fwd_args, net_bwd_args;
50
51 const int batch = 32;
52
53 // float data type is used for user data
54 std::vector<float> net_src(batch * 3 * 227 * 227);
55
56 // initializing non-zero values for src
57 for (size_t i = 0; i < net_src.size(); ++i)
58 net_src[i] = sinf((float)i);
59
60 // AlexNet: conv
61 // {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55}
62 // strides: {4, 4}
63
64 memory::dims conv_src_tz = {batch, 3, 227, 227};
65 memory::dims conv_weights_tz = {96, 3, 11, 11};
66 memory::dims conv_bias_tz = {96};
67 memory::dims conv_dst_tz = {batch, 96, 55, 55};
68 memory::dims conv_strides = {4, 4};
69 memory::dims conv_padding = {0, 0};
70
71 // float data type is used for user data
72 std::vector<float> conv_weights(product(conv_weights_tz));
73 std::vector<float> conv_bias(product(conv_bias_tz));
74
75 // initializing non-zero values for weights and bias
76 for (size_t i = 0; i < conv_weights.size(); ++i)
77 conv_weights[i] = sinf((float)i);
78 for (size_t i = 0; i < conv_bias.size(); ++i)
79 conv_bias[i] = sinf((float)i);
80
81 // create memory for user data
82 auto conv_user_src_memory
83 = memory({{conv_src_tz}, dt::f32, tag::nchw}, eng);
84 write_to_dnnl_memory(net_src.data(), conv_user_src_memory);
85
86 auto conv_user_weights_memory
87 = memory({{conv_weights_tz}, dt::f32, tag::oihw}, eng);
88 write_to_dnnl_memory(conv_weights.data(), conv_user_weights_memory);
89
90 auto conv_user_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng);
91 write_to_dnnl_memory(conv_bias.data(), conv_user_bias_memory);
92
93 // create memory descriptors for bfloat16 convolution data w/ no specified
94 // format tag(`any`)
95 // tag `any` lets a primitive(convolution in this case)
96 // chose the memory format preferred for best performance.
97 auto conv_src_md = memory::desc({conv_src_tz}, dt::bf16, tag::any);
98 auto conv_weights_md = memory::desc({conv_weights_tz}, dt::bf16, tag::any);
99 auto conv_dst_md = memory::desc({conv_dst_tz}, dt::bf16, tag::any);
100 // here bias data type is set to bf16.
101 // additionally, f32 data type is supported for bf16 convolution.
102 auto conv_bias_md = memory::desc({conv_bias_tz}, dt::bf16, tag::any);
103
104 // create a convolution primitive descriptor
105
106 // check if bf16 convolution is supported
107 try {
108 convolution_forward::primitive_desc(eng, prop_kind::forward,
109 algorithm::convolution_direct, conv_src_md, conv_weights_md,
110 conv_bias_md, conv_dst_md, conv_strides, conv_padding,
111 conv_padding);
112 } catch (error &e) {
113 if (e.status == dnnl_unimplemented)
114 throw example_allows_unimplemented {
115 "No bf16 convolution implementation is available for this "
116 "platform.\n"
117 "Please refer to the developer guide for details."};
118
119 // on any other error just re-throw
120 throw;
121 }
122
123 auto conv_pd = convolution_forward::primitive_desc(eng, prop_kind::forward,
124 algorithm::convolution_direct, conv_src_md, conv_weights_md,
125 conv_bias_md, conv_dst_md, conv_strides, conv_padding,
126 conv_padding);
127
128 // create reorder primitives between user input and conv src if needed
129 auto conv_src_memory = conv_user_src_memory;
130 if (conv_pd.src_desc() != conv_user_src_memory.get_desc()) {
131 conv_src_memory = memory(conv_pd.src_desc(), eng);
132 net_fwd.push_back(reorder(conv_user_src_memory, conv_src_memory));
133 net_fwd_args.push_back({{DNNL_ARG_FROM, conv_user_src_memory},
134 {DNNL_ARG_TO, conv_src_memory}});
135 }
136
137 auto conv_weights_memory = conv_user_weights_memory;
138 if (conv_pd.weights_desc() != conv_user_weights_memory.get_desc()) {
139 conv_weights_memory = memory(conv_pd.weights_desc(), eng);
140 net_fwd.push_back(
141 reorder(conv_user_weights_memory, conv_weights_memory));
142 net_fwd_args.push_back({{DNNL_ARG_FROM, conv_user_weights_memory},
143 {DNNL_ARG_TO, conv_weights_memory}});
144 }
145
146 // convert bias from f32 to bf16 as convolution descriptor is created with
147 // bias data type as bf16.
148 auto conv_bias_memory = conv_user_bias_memory;
149 if (conv_pd.bias_desc() != conv_user_bias_memory.get_desc()) {
150 conv_bias_memory = memory(conv_pd.bias_desc(), eng);
151 net_fwd.push_back(reorder(conv_user_bias_memory, conv_bias_memory));
152 net_fwd_args.push_back({{DNNL_ARG_FROM, conv_user_bias_memory},
153 {DNNL_ARG_TO, conv_bias_memory}});
154 }
155
156 // create memory for conv dst
157 auto conv_dst_memory = memory(conv_pd.dst_desc(), eng);
158
159 // finally create a convolution primitive
160 net_fwd.push_back(convolution_forward(conv_pd));
161 net_fwd_args.push_back({{DNNL_ARG_SRC, conv_src_memory},
162 {DNNL_ARG_WEIGHTS, conv_weights_memory},
163 {DNNL_ARG_BIAS, conv_bias_memory},
164 {DNNL_ARG_DST, conv_dst_memory}});
165
166 // AlexNet: relu
167 // {batch, 96, 55, 55} -> {batch, 96, 55, 55}
168 memory::dims relu_data_tz = {batch, 96, 55, 55};
169 const float negative_slope = 0.0f;
170
171 // create relu primitive desc
172 // keep memory format tag of source same as the format tag of convolution
173 // output in order to avoid reorder
174 auto relu_pd = eltwise_forward::primitive_desc(eng, prop_kind::forward,
175 algorithm::eltwise_relu, conv_pd.dst_desc(), conv_pd.dst_desc(),
176 negative_slope);
177
178 // create relu dst memory
179 auto relu_dst_memory = memory(relu_pd.dst_desc(), eng);
180
181 // finally create a relu primitive
182 net_fwd.push_back(eltwise_forward(relu_pd));
183 net_fwd_args.push_back(
184 {{DNNL_ARG_SRC, conv_dst_memory}, {DNNL_ARG_DST, relu_dst_memory}});
185
186 // AlexNet: lrn
187 // {batch, 96, 55, 55} -> {batch, 96, 55, 55}
188 // local size: 5
189 // alpha: 0.0001
190 // beta: 0.75
191 // k: 1.0
192 memory::dims lrn_data_tz = {batch, 96, 55, 55};
193 const uint32_t local_size = 5;
194 const float alpha = 0.0001f;
195 const float beta = 0.75f;
196 const float k = 1.0f;
197
198 // create a lrn primitive descriptor
199 auto lrn_pd = lrn_forward::primitive_desc(eng, prop_kind::forward,
200 algorithm::lrn_across_channels, relu_pd.dst_desc(),
201 relu_pd.dst_desc(), local_size, alpha, beta, k);
202
203 // create lrn dst memory
204 auto lrn_dst_memory = memory(lrn_pd.dst_desc(), eng);
205
206 // create workspace only in training and only for forward primitive
207 // query lrn_pd for workspace, this memory will be shared with forward lrn
208 auto lrn_workspace_memory = memory(lrn_pd.workspace_desc(), eng);
209
210 // finally create a lrn primitive
211 net_fwd.push_back(lrn_forward(lrn_pd));
212 net_fwd_args.push_back(
213 {{DNNL_ARG_SRC, relu_dst_memory}, {DNNL_ARG_DST, lrn_dst_memory},
214 {DNNL_ARG_WORKSPACE, lrn_workspace_memory}});
215
216 // AlexNet: pool
217 // {batch, 96, 55, 55} -> {batch, 96, 27, 27}
218 // kernel: {3, 3}
219 // strides: {2, 2}
220
221 memory::dims pool_dst_tz = {batch, 96, 27, 27};
222 memory::dims pool_kernel = {3, 3};
223 memory::dims pool_strides = {2, 2};
224 memory::dims pool_dilation = {0, 0};
225 memory::dims pool_padding = {0, 0};
226
227 // create memory for pool dst data in user format
228 auto pool_user_dst_memory
229 = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng);
230
231 // create pool dst memory descriptor in format any for bfloat16 data type
232 auto pool_dst_md = memory::desc({pool_dst_tz}, dt::bf16, tag::any);
233
234 // create a pooling primitive descriptor
235 auto pool_pd = pooling_forward::primitive_desc(eng, prop_kind::forward,
236 algorithm::pooling_max, lrn_dst_memory.get_desc(), pool_dst_md,
237 pool_strides, pool_kernel, pool_dilation, pool_padding,
238 pool_padding);
239
240 // create pooling workspace memory if training
241 auto pool_workspace_memory = memory(pool_pd.workspace_desc(), eng);
242
243 // create a pooling primitive
244 net_fwd.push_back(pooling_forward(pool_pd));
245 // leave DST unknown for now (see the next reorder)
246 net_fwd_args.push_back({{DNNL_ARG_SRC, lrn_dst_memory},
247 // delay putting DST until reorder (if needed)
248 {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
249
250 // create reorder primitive between pool dst and user dst format
251 // if needed
252 auto pool_dst_memory = pool_user_dst_memory;
253 if (pool_pd.dst_desc() != pool_user_dst_memory.get_desc()) {
254 pool_dst_memory = memory(pool_pd.dst_desc(), eng);
255 net_fwd_args.back().insert({DNNL_ARG_DST, pool_dst_memory});
256
257 net_fwd.push_back(reorder(pool_dst_memory, pool_user_dst_memory));
258 net_fwd_args.push_back({{DNNL_ARG_FROM, pool_dst_memory},
259 {DNNL_ARG_TO, pool_user_dst_memory}});
260 } else {
261 net_fwd_args.back().insert({DNNL_ARG_DST, pool_dst_memory});
262 }
263
264 //-----------------------------------------------------------------------
265 //----------------- Backward Stream -------------------------------------
266 // ... user diff_data in float data type ...
267 std::vector<float> net_diff_dst(batch * 96 * 27 * 27);
268 for (size_t i = 0; i < net_diff_dst.size(); ++i)
269 net_diff_dst[i] = sinf((float)i);
270
271 // create memory for user diff dst data stored in float data type
272 auto pool_user_diff_dst_memory
273 = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng);
274 write_to_dnnl_memory(net_diff_dst.data(), pool_user_diff_dst_memory);
275
276 // Backward pooling
277 // create memory descriptors for pooling
278 auto pool_diff_src_md = memory::desc({lrn_data_tz}, dt::bf16, tag::any);
279 auto pool_diff_dst_md = memory::desc({pool_dst_tz}, dt::bf16, tag::any);
280
281 // backward primitive descriptor needs to hint forward descriptor
282 auto pool_bwd_pd = pooling_backward::primitive_desc(eng,
283 algorithm::pooling_max, pool_diff_src_md, pool_diff_dst_md,
284 pool_strides, pool_kernel, pool_dilation, pool_padding,
285 pool_padding, pool_pd);
286
287 // create reorder primitive between user diff dst and pool diff dst
288 // if required
289 auto pool_diff_dst_memory = pool_user_diff_dst_memory;
290 if (pool_dst_memory.get_desc() != pool_user_diff_dst_memory.get_desc()) {
291 pool_diff_dst_memory = memory(pool_dst_memory.get_desc(), eng);
292 net_bwd.push_back(
293 reorder(pool_user_diff_dst_memory, pool_diff_dst_memory));
294 net_bwd_args.push_back({{DNNL_ARG_FROM, pool_user_diff_dst_memory},
295 {DNNL_ARG_TO, pool_diff_dst_memory}});
296 }
297
298 // create memory for pool diff src
299 auto pool_diff_src_memory = memory(pool_bwd_pd.diff_src_desc(), eng);
300
301 // finally create backward pooling primitive
302 net_bwd.push_back(pooling_backward(pool_bwd_pd));
303 net_bwd_args.push_back({{DNNL_ARG_DIFF_DST, pool_diff_dst_memory},
304 {DNNL_ARG_DIFF_SRC, pool_diff_src_memory},
305 {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
306
307 // Backward lrn
308 auto lrn_diff_dst_md = memory::desc({lrn_data_tz}, dt::bf16, tag::any);
309 const auto &lrn_diff_src_md = lrn_diff_dst_md;
310
311 // create backward lrn primitive descriptor
312 auto lrn_bwd_pd = lrn_backward::primitive_desc(eng,
313 algorithm::lrn_across_channels, lrn_diff_src_md, lrn_diff_dst_md,
314 lrn_pd.src_desc(), local_size, alpha, beta, k, lrn_pd);
315
316 // create reorder primitive between pool diff src and lrn diff dst
317 // if required
318 auto lrn_diff_dst_memory = pool_diff_src_memory;
319 if (lrn_diff_dst_memory.get_desc() != lrn_bwd_pd.diff_dst_desc()) {
320 lrn_diff_dst_memory = memory(lrn_bwd_pd.diff_dst_desc(), eng);
321 net_bwd.push_back(reorder(pool_diff_src_memory, lrn_diff_dst_memory));
322 net_bwd_args.push_back({{DNNL_ARG_FROM, pool_diff_src_memory},
323 {DNNL_ARG_TO, lrn_diff_dst_memory}});
324 }
325
326 // create memory for lrn diff src
327 auto lrn_diff_src_memory = memory(lrn_bwd_pd.diff_src_desc(), eng);
328
329 // finally create a lrn backward primitive
330 // backward lrn needs src: relu dst in this topology
331 net_bwd.push_back(lrn_backward(lrn_bwd_pd));
332 net_bwd_args.push_back({{DNNL_ARG_SRC, relu_dst_memory},
333 {DNNL_ARG_DIFF_DST, lrn_diff_dst_memory},
334 {DNNL_ARG_DIFF_SRC, lrn_diff_src_memory},
335 {DNNL_ARG_WORKSPACE, lrn_workspace_memory}});
336
337 // Backward relu
338 auto relu_diff_src_md = memory::desc({relu_data_tz}, dt::bf16, tag::any);
339 auto relu_diff_dst_md = memory::desc({relu_data_tz}, dt::bf16, tag::any);
340 auto relu_src_md = conv_pd.dst_desc();
341
342 // create backward relu primitive_descriptor
343 auto relu_bwd_pd = eltwise_backward::primitive_desc(eng,
344 algorithm::eltwise_relu, relu_diff_src_md, relu_diff_dst_md,
345 relu_src_md, negative_slope, relu_pd);
346
347 // create reorder primitive between lrn diff src and relu diff dst
348 // if required
349 auto relu_diff_dst_memory = lrn_diff_src_memory;
350 if (relu_diff_dst_memory.get_desc() != relu_bwd_pd.diff_dst_desc()) {
351 relu_diff_dst_memory = memory(relu_bwd_pd.diff_dst_desc(), eng);
352 net_bwd.push_back(reorder(lrn_diff_src_memory, relu_diff_dst_memory));
353 net_bwd_args.push_back({{DNNL_ARG_FROM, lrn_diff_src_memory},
354 {DNNL_ARG_TO, relu_diff_dst_memory}});
355 }
356
357 // create memory for relu diff src
358 auto relu_diff_src_memory = memory(relu_bwd_pd.diff_src_desc(), eng);
359
360 // finally create a backward relu primitive
361 net_bwd.push_back(eltwise_backward(relu_bwd_pd));
362 net_bwd_args.push_back({{DNNL_ARG_SRC, conv_dst_memory},
363 {DNNL_ARG_DIFF_DST, relu_diff_dst_memory},
364 {DNNL_ARG_DIFF_SRC, relu_diff_src_memory}});
365
366 // Backward convolution with respect to weights
367 // create user format diff weights and diff bias memory for float data type
368
369 auto conv_user_diff_weights_memory
370 = memory({{conv_weights_tz}, dt::f32, tag::nchw}, eng);
371 auto conv_diff_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng);
372
373 // create memory descriptors for bfloat16 convolution data
374 auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::bf16, tag::any);
375 auto conv_diff_weights_md
376 = memory::desc({conv_weights_tz}, dt::bf16, tag::any);
377 auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::bf16, tag::any);
378
379 // use diff bias provided by the user
380 auto conv_diff_bias_md = conv_diff_bias_memory.get_desc();
381
382 // create backward convolution primitive descriptor
383 auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(eng,
384 algorithm::convolution_direct, conv_bwd_src_md,
385 conv_diff_weights_md, conv_diff_bias_md, conv_diff_dst_md,
386 conv_strides, conv_padding, conv_padding, conv_pd);
387
388 // for best performance convolution backward might chose
389 // different memory format for src and diff_dst
390 // than the memory formats preferred by forward convolution
391 // for src and dst respectively
392 // create reorder primitives for src from forward convolution to the
393 // format chosen by backward convolution
394 auto conv_bwd_src_memory = conv_src_memory;
395 if (conv_bwd_weights_pd.src_desc() != conv_src_memory.get_desc()) {
396 conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng);
397 net_bwd.push_back(reorder(conv_src_memory, conv_bwd_src_memory));
398 net_bwd_args.push_back({{DNNL_ARG_FROM, conv_src_memory},
399 {DNNL_ARG_TO, conv_bwd_src_memory}});
400 }
401
402 // create reorder primitives for diff_dst between diff_src from relu_bwd
403 // and format preferred by conv_diff_weights
404 auto conv_diff_dst_memory = relu_diff_src_memory;
405 if (conv_bwd_weights_pd.diff_dst_desc()
406 != relu_diff_src_memory.get_desc()) {
407 conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng);
408 net_bwd.push_back(reorder(relu_diff_src_memory, conv_diff_dst_memory));
409 net_bwd_args.push_back({{DNNL_ARG_FROM, relu_diff_src_memory},
410 {DNNL_ARG_TO, conv_diff_dst_memory}});
411 }
412
413 // create backward convolution primitive
414 net_bwd.push_back(convolution_backward_weights(conv_bwd_weights_pd));
415 net_bwd_args.push_back({{DNNL_ARG_SRC, conv_bwd_src_memory},
416 {DNNL_ARG_DIFF_DST, conv_diff_dst_memory},
417 // delay putting DIFF_WEIGHTS until reorder (if needed)
418 {DNNL_ARG_DIFF_BIAS, conv_diff_bias_memory}});
419
420 // create reorder primitives between conv diff weights and user diff weights
421 // if needed
422 auto conv_diff_weights_memory = conv_user_diff_weights_memory;
423 if (conv_bwd_weights_pd.diff_weights_desc()
424 != conv_user_diff_weights_memory.get_desc()) {
425 conv_diff_weights_memory
426 = memory(conv_bwd_weights_pd.diff_weights_desc(), eng);
427 net_bwd_args.back().insert(
428 {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
429
430 net_bwd.push_back(reorder(
431 conv_diff_weights_memory, conv_user_diff_weights_memory));
432 net_bwd_args.push_back({{DNNL_ARG_FROM, conv_diff_weights_memory},
433 {DNNL_ARG_TO, conv_user_diff_weights_memory}});
434 } else {
435 net_bwd_args.back().insert(
436 {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
437 }
438
439 // didn't we forget anything?
440 assert(net_fwd.size() == net_fwd_args.size() && "something is missing");
441 assert(net_bwd.size() == net_bwd_args.size() && "something is missing");
442
443 int n_iter = 1; // number of iterations for training
444 // execute
445 while (n_iter) {
446 // forward
447 for (size_t i = 0; i < net_fwd.size(); ++i)
448 net_fwd.at(i).execute(s, net_fwd_args.at(i));
449
450 // update net_diff_dst
451 // auto net_output = pool_user_dst_memory.get_data_handle();
452 // ..user updates net_diff_dst using net_output...
453 // some user defined func update_diff_dst(net_diff_dst.data(),
454 // net_output)
455
456 for (size_t i = 0; i < net_bwd.size(); ++i)
457 net_bwd.at(i).execute(s, net_bwd_args.at(i));
458 // update weights and bias using diff weights and bias
459 //
460 // auto net_diff_weights
461 // = conv_user_diff_weights_memory.get_data_handle();
462 // auto net_diff_bias = conv_diff_bias_memory.get_data_handle();
463 //
464 // ...user updates weights and bias using diff weights and bias...
465 //
466 // some user defined func update_weights(conv_weights.data(),
467 // conv_bias.data(), net_diff_weights, net_diff_bias);
468
469 --n_iter;
470 }
471
472 s.wait();
473}
474
475int main(int argc, char **argv) {
476 return handle_example_errors(simple_net, parse_engine_kind(argc, argv));
477}
478