1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example cnn_training_f32.cpp
18/// @copybrief cnn_training_f32_cpp
19/// > Annotated version: @ref cnn_training_f32_cpp
20///
21/// @page cnn_training_f32_cpp CNN f32 training example
22/// This C++ API example demonstrates how to build an AlexNet model training.
23/// The example implements a few layers from AlexNet model.
24///
25/// @include cnn_training_f32.cpp
26
27#include <assert.h>
28
29#include <math.h>
30
31#include "oneapi/dnnl/dnnl.hpp"
32
33#include "example_utils.hpp"
34
35using namespace dnnl;
36
37void simple_net(engine::kind engine_kind) {
38 using tag = memory::format_tag;
39 using dt = memory::data_type;
40
41 auto eng = engine(engine_kind, 0);
42 stream s(eng);
43
44 // Vector of primitives and their execute arguments
45 std::vector<primitive> net_fwd, net_bwd;
46 std::vector<std::unordered_map<int, memory>> net_fwd_args, net_bwd_args;
47
48 const int batch = 32;
49
50 std::vector<float> net_src(batch * 3 * 227 * 227);
51 std::vector<float> net_dst(batch * 96 * 27 * 27);
52
53 // initializing non-zero values for src
54 for (size_t i = 0; i < net_src.size(); ++i)
55 net_src[i] = sinf((float)i);
56
57 // AlexNet: conv
58 // {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55}
59 // strides: {4, 4}
60
61 memory::dims conv_src_tz = {batch, 3, 227, 227};
62 memory::dims conv_weights_tz = {96, 3, 11, 11};
63 memory::dims conv_bias_tz = {96};
64 memory::dims conv_dst_tz = {batch, 96, 55, 55};
65 memory::dims conv_strides = {4, 4};
66 memory::dims conv_padding = {0, 0};
67
68 std::vector<float> conv_weights(product(conv_weights_tz));
69 std::vector<float> conv_bias(product(conv_bias_tz));
70
71 // initializing non-zero values for weights and bias
72 for (size_t i = 0; i < conv_weights.size(); ++i)
73 conv_weights[i] = sinf((float)i);
74 for (size_t i = 0; i < conv_bias.size(); ++i)
75 conv_bias[i] = sinf((float)i);
76
77 // create memory for user data
78 auto conv_user_src_memory
79 = memory({{conv_src_tz}, dt::f32, tag::nchw}, eng);
80 write_to_dnnl_memory(net_src.data(), conv_user_src_memory);
81 auto conv_user_weights_memory
82 = memory({{conv_weights_tz}, dt::f32, tag::oihw}, eng);
83 write_to_dnnl_memory((void *)conv_weights.data(), conv_user_weights_memory);
84 auto conv_user_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng);
85 write_to_dnnl_memory(conv_bias.data(), conv_user_bias_memory);
86
87 // create memory descriptors for convolution data w/ no specified
88 // format tag(`any`)
89 // tag `any` lets a primitive(convolution in this case)
90 // chose the memory format preferred for best performance.
91 auto conv_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any);
92 auto conv_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any);
93 auto conv_weights_md = memory::desc({conv_weights_tz}, dt::f32, tag::any);
94 auto conv_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::any);
95
96 // create a convolution primitive descriptor
97 auto conv_pd = convolution_forward::primitive_desc(eng, prop_kind::forward,
98 algorithm::convolution_direct, conv_src_md, conv_weights_md,
99 conv_bias_md, conv_dst_md, conv_strides, conv_padding,
100 conv_padding);
101
102 // create reorder primitives between user input and conv src if needed
103 auto conv_src_memory = conv_user_src_memory;
104 if (conv_pd.src_desc() != conv_user_src_memory.get_desc()) {
105 conv_src_memory = memory(conv_pd.src_desc(), eng);
106 net_fwd.push_back(reorder(conv_user_src_memory, conv_src_memory));
107 net_fwd_args.push_back({{DNNL_ARG_FROM, conv_user_src_memory},
108 {DNNL_ARG_TO, conv_src_memory}});
109 }
110
111 auto conv_weights_memory = conv_user_weights_memory;
112 if (conv_pd.weights_desc() != conv_user_weights_memory.get_desc()) {
113 conv_weights_memory = memory(conv_pd.weights_desc(), eng);
114 net_fwd.push_back(
115 reorder(conv_user_weights_memory, conv_weights_memory));
116 net_fwd_args.push_back({{DNNL_ARG_FROM, conv_user_weights_memory},
117 {DNNL_ARG_TO, conv_weights_memory}});
118 }
119
120 // create memory for conv dst
121 auto conv_dst_memory = memory(conv_pd.dst_desc(), eng);
122
123 // finally create a convolution primitive
124 net_fwd.push_back(convolution_forward(conv_pd));
125 net_fwd_args.push_back({{DNNL_ARG_SRC, conv_src_memory},
126 {DNNL_ARG_WEIGHTS, conv_weights_memory},
127 {DNNL_ARG_BIAS, conv_user_bias_memory},
128 {DNNL_ARG_DST, conv_dst_memory}});
129
130 // AlexNet: relu
131 // {batch, 96, 55, 55} -> {batch, 96, 55, 55}
132 memory::dims relu_data_tz = {batch, 96, 55, 55};
133 const float negative_slope = 0.0f;
134
135 // create relu primitive desc
136 // keep memory format tag of source same as the format tag of convolution
137 // output in order to avoid reorder
138 auto relu_pd = eltwise_forward::primitive_desc(eng, prop_kind::forward,
139 algorithm::eltwise_relu, conv_pd.dst_desc(), conv_pd.dst_desc(),
140 negative_slope);
141
142 // create relu dst memory
143 auto relu_dst_memory = memory(relu_pd.dst_desc(), eng);
144
145 // finally create a relu primitive
146 net_fwd.push_back(eltwise_forward(relu_pd));
147 net_fwd_args.push_back(
148 {{DNNL_ARG_SRC, conv_dst_memory}, {DNNL_ARG_DST, relu_dst_memory}});
149
150 // AlexNet: lrn
151 // {batch, 96, 55, 55} -> {batch, 96, 55, 55}
152 // local size: 5
153 // alpha: 0.0001
154 // beta: 0.75
155 // k: 1.0
156 memory::dims lrn_data_tz = {batch, 96, 55, 55};
157 const uint32_t local_size = 5;
158 const float alpha = 0.0001f;
159 const float beta = 0.75f;
160 const float k = 1.0f;
161
162 // create a lrn primitive descriptor
163 auto lrn_pd = lrn_forward::primitive_desc(eng, prop_kind::forward,
164 algorithm::lrn_across_channels, relu_pd.dst_desc(),
165 relu_pd.dst_desc(), local_size, alpha, beta, k);
166
167 // create lrn dst memory
168 auto lrn_dst_memory = memory(lrn_pd.dst_desc(), eng);
169
170 // create workspace only in training and only for forward primitive
171 // query lrn_pd for workspace, this memory will be shared with forward lrn
172 auto lrn_workspace_memory = memory(lrn_pd.workspace_desc(), eng);
173
174 // finally create a lrn primitive
175 net_fwd.push_back(lrn_forward(lrn_pd));
176 net_fwd_args.push_back(
177 {{DNNL_ARG_SRC, relu_dst_memory}, {DNNL_ARG_DST, lrn_dst_memory},
178 {DNNL_ARG_WORKSPACE, lrn_workspace_memory}});
179
180 // AlexNet: pool
181 // {batch, 96, 55, 55} -> {batch, 96, 27, 27}
182 // kernel: {3, 3}
183 // strides: {2, 2}
184
185 memory::dims pool_dst_tz = {batch, 96, 27, 27};
186 memory::dims pool_kernel = {3, 3};
187 memory::dims pool_strides = {2, 2};
188 memory::dims pool_dilation = {0, 0};
189 memory::dims pool_padding = {0, 0};
190
191 // create memory for pool dst data in user format
192 auto pool_user_dst_memory
193 = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng);
194 write_to_dnnl_memory(net_dst.data(), pool_user_dst_memory);
195
196 // create pool dst memory descriptor in format any
197 auto pool_dst_md = memory::desc({pool_dst_tz}, dt::f32, tag::any);
198
199 // create a pooling primitive descriptor
200 auto pool_pd = pooling_forward::primitive_desc(eng, prop_kind::forward,
201 algorithm::pooling_max, lrn_dst_memory.get_desc(), pool_dst_md,
202 pool_strides, pool_kernel, pool_dilation, pool_padding,
203 pool_padding);
204
205 // create pooling workspace memory if training
206 auto pool_workspace_memory = memory(pool_pd.workspace_desc(), eng);
207
208 // create a pooling primitive
209 net_fwd.push_back(pooling_forward(pool_pd));
210 // leave DST unknown for now (see the next reorder)
211 net_fwd_args.push_back({{DNNL_ARG_SRC, lrn_dst_memory},
212 // delay putting DST until reorder (if needed)
213 {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
214
215 // create reorder primitive between pool dst and user dst format
216 // if needed
217 auto pool_dst_memory = pool_user_dst_memory;
218 if (pool_pd.dst_desc() != pool_user_dst_memory.get_desc()) {
219 pool_dst_memory = memory(pool_pd.dst_desc(), eng);
220 net_fwd_args.back().insert({DNNL_ARG_DST, pool_dst_memory});
221 net_fwd.push_back(reorder(pool_dst_memory, pool_user_dst_memory));
222 net_fwd_args.push_back({{DNNL_ARG_FROM, pool_dst_memory},
223 {DNNL_ARG_TO, pool_user_dst_memory}});
224 } else {
225 net_fwd_args.back().insert({DNNL_ARG_DST, pool_dst_memory});
226 }
227
228 //-----------------------------------------------------------------------
229 //----------------- Backward Stream -------------------------------------
230 // ... user diff_data ...
231 std::vector<float> net_diff_dst(batch * 96 * 27 * 27);
232 for (size_t i = 0; i < net_diff_dst.size(); ++i)
233 net_diff_dst[i] = sinf((float)i);
234
235 // create memory for user diff dst data
236 auto pool_user_diff_dst_memory
237 = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng);
238 write_to_dnnl_memory(net_diff_dst.data(), pool_user_diff_dst_memory);
239
240 // Backward pooling
241 // create memory descriptors for pooling
242 auto pool_diff_src_md = memory::desc({lrn_data_tz}, dt::f32, tag::any);
243 auto pool_diff_dst_md = memory::desc({pool_dst_tz}, dt::f32, tag::any);
244
245 // backward primitive descriptor needs to hint forward descriptor
246 auto pool_bwd_pd = pooling_backward::primitive_desc(eng,
247 algorithm::pooling_max, pool_diff_src_md, pool_diff_dst_md,
248 pool_strides, pool_kernel, pool_dilation, pool_padding,
249 pool_padding, pool_pd);
250
251 // create reorder primitive between user diff dst and pool diff dst
252 // if required
253 auto pool_diff_dst_memory = pool_user_diff_dst_memory;
254 if (pool_dst_memory.get_desc() != pool_user_diff_dst_memory.get_desc()) {
255 pool_diff_dst_memory = memory(pool_dst_memory.get_desc(), eng);
256 net_bwd.push_back(
257 reorder(pool_user_diff_dst_memory, pool_diff_dst_memory));
258 net_bwd_args.push_back({{DNNL_ARG_FROM, pool_user_diff_dst_memory},
259 {DNNL_ARG_TO, pool_diff_dst_memory}});
260 }
261
262 // create memory for pool diff src
263 auto pool_diff_src_memory = memory(pool_bwd_pd.diff_src_desc(), eng);
264
265 // finally create backward pooling primitive
266 net_bwd.push_back(pooling_backward(pool_bwd_pd));
267 net_bwd_args.push_back({{DNNL_ARG_DIFF_DST, pool_diff_dst_memory},
268 {DNNL_ARG_DIFF_SRC, pool_diff_src_memory},
269 {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
270
271 // Backward lrn
272 auto lrn_diff_dst_md = memory::desc({lrn_data_tz}, dt::f32, tag::any);
273 const auto &lrn_diff_src_md = lrn_diff_dst_md;
274
275 // create backward lrn primitive descriptor
276 auto lrn_bwd_pd = lrn_backward::primitive_desc(eng,
277 algorithm::lrn_across_channels, lrn_diff_src_md, lrn_diff_dst_md,
278 lrn_pd.src_desc(), local_size, alpha, beta, k, lrn_pd);
279
280 // create reorder primitive between pool diff src and lrn diff dst
281 // if required
282 auto lrn_diff_dst_memory = pool_diff_src_memory;
283 if (lrn_diff_dst_memory.get_desc() != lrn_bwd_pd.diff_dst_desc()) {
284 lrn_diff_dst_memory = memory(lrn_bwd_pd.diff_dst_desc(), eng);
285 net_bwd.push_back(reorder(pool_diff_src_memory, lrn_diff_dst_memory));
286 net_bwd_args.push_back({{DNNL_ARG_FROM, pool_diff_src_memory},
287 {DNNL_ARG_TO, lrn_diff_dst_memory}});
288 }
289
290 // create memory for lrn diff src
291 auto lrn_diff_src_memory = memory(lrn_bwd_pd.diff_src_desc(), eng);
292
293 // finally create a lrn backward primitive
294 // backward lrn needs src: relu dst in this topology
295 net_bwd.push_back(lrn_backward(lrn_bwd_pd));
296 net_bwd_args.push_back({{DNNL_ARG_SRC, relu_dst_memory},
297 {DNNL_ARG_DIFF_DST, lrn_diff_dst_memory},
298 {DNNL_ARG_DIFF_SRC, lrn_diff_src_memory},
299 {DNNL_ARG_WORKSPACE, lrn_workspace_memory}});
300
301 // Backward relu
302 auto relu_diff_src_md = memory::desc({relu_data_tz}, dt::f32, tag::any);
303 auto relu_diff_dst_md = memory::desc({relu_data_tz}, dt::f32, tag::any);
304 auto relu_src_md = conv_pd.dst_desc();
305
306 // create backward relu primitive_descriptor
307 auto relu_bwd_pd = eltwise_backward::primitive_desc(eng,
308 algorithm::eltwise_relu, relu_diff_src_md, relu_diff_dst_md,
309 relu_src_md, negative_slope, relu_pd);
310
311 // create reorder primitive between lrn diff src and relu diff dst
312 // if required
313 auto relu_diff_dst_memory = lrn_diff_src_memory;
314 if (relu_diff_dst_memory.get_desc() != relu_bwd_pd.diff_dst_desc()) {
315 relu_diff_dst_memory = memory(relu_bwd_pd.diff_dst_desc(), eng);
316 net_bwd.push_back(reorder(lrn_diff_src_memory, relu_diff_dst_memory));
317 net_bwd_args.push_back({{DNNL_ARG_FROM, lrn_diff_src_memory},
318 {DNNL_ARG_TO, relu_diff_dst_memory}});
319 }
320
321 // create memory for relu diff src
322 auto relu_diff_src_memory = memory(relu_bwd_pd.diff_src_desc(), eng);
323
324 // finally create a backward relu primitive
325 net_bwd.push_back(eltwise_backward(relu_bwd_pd));
326 net_bwd_args.push_back({{DNNL_ARG_SRC, conv_dst_memory},
327 {DNNL_ARG_DIFF_DST, relu_diff_dst_memory},
328 {DNNL_ARG_DIFF_SRC, relu_diff_src_memory}});
329
330 // Backward convolution with respect to weights
331 // create user format diff weights and diff bias memory
332 std::vector<float> conv_user_diff_weights_buffer(product(conv_weights_tz));
333 std::vector<float> conv_diff_bias_buffer(product(conv_bias_tz));
334
335 auto conv_user_diff_weights_memory
336 = memory({{conv_weights_tz}, dt::f32, tag::nchw}, eng);
337 write_to_dnnl_memory(conv_user_diff_weights_buffer.data(),
338 conv_user_diff_weights_memory);
339 auto conv_diff_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng);
340 write_to_dnnl_memory(conv_diff_bias_buffer.data(), conv_diff_bias_memory);
341
342 // create memory descriptors
343 auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any);
344 auto conv_diff_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any);
345 auto conv_diff_weights_md
346 = memory::desc({conv_weights_tz}, dt::f32, tag::any);
347 auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::any);
348
349 // create backward convolution primitive descriptor
350 auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(eng,
351 algorithm::convolution_direct, conv_bwd_src_md,
352 conv_diff_weights_md, conv_diff_bias_md, conv_diff_dst_md,
353 conv_strides, conv_padding, conv_padding, conv_pd);
354
355 // for best performance convolution backward might chose
356 // different memory format for src and diff_dst
357 // than the memory formats preferred by forward convolution
358 // for src and dst respectively
359 // create reorder primitives for src from forward convolution to the
360 // format chosen by backward convolution
361 auto conv_bwd_src_memory = conv_src_memory;
362 if (conv_bwd_weights_pd.src_desc() != conv_src_memory.get_desc()) {
363 conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng);
364 net_bwd.push_back(reorder(conv_src_memory, conv_bwd_src_memory));
365 net_bwd_args.push_back({{DNNL_ARG_FROM, conv_src_memory},
366 {DNNL_ARG_TO, conv_bwd_src_memory}});
367 }
368
369 // create reorder primitives for diff_dst between diff_src from relu_bwd
370 // and format preferred by conv_diff_weights
371 auto conv_diff_dst_memory = relu_diff_src_memory;
372 if (conv_bwd_weights_pd.diff_dst_desc()
373 != relu_diff_src_memory.get_desc()) {
374 conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng);
375 net_bwd.push_back(reorder(relu_diff_src_memory, conv_diff_dst_memory));
376 net_bwd_args.push_back({{DNNL_ARG_FROM, relu_diff_src_memory},
377 {DNNL_ARG_TO, conv_diff_dst_memory}});
378 }
379
380 // create backward convolution primitive
381 net_bwd.push_back(convolution_backward_weights(conv_bwd_weights_pd));
382 net_bwd_args.push_back({{DNNL_ARG_SRC, conv_bwd_src_memory},
383 {DNNL_ARG_DIFF_DST, conv_diff_dst_memory},
384 // delay putting DIFF_WEIGHTS until reorder (if needed)
385 {DNNL_ARG_DIFF_BIAS, conv_diff_bias_memory}});
386
387 // create reorder primitives between conv diff weights and user diff weights
388 // if needed
389 auto conv_diff_weights_memory = conv_user_diff_weights_memory;
390 if (conv_bwd_weights_pd.diff_weights_desc()
391 != conv_user_diff_weights_memory.get_desc()) {
392 conv_diff_weights_memory
393 = memory(conv_bwd_weights_pd.diff_weights_desc(), eng);
394 net_bwd_args.back().insert(
395 {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
396
397 net_bwd.push_back(reorder(
398 conv_diff_weights_memory, conv_user_diff_weights_memory));
399 net_bwd_args.push_back({{DNNL_ARG_FROM, conv_diff_weights_memory},
400 {DNNL_ARG_TO, conv_user_diff_weights_memory}});
401 } else {
402 net_bwd_args.back().insert(
403 {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
404 }
405
406 // didn't we forget anything?
407 assert(net_fwd.size() == net_fwd_args.size() && "something is missing");
408 assert(net_bwd.size() == net_bwd_args.size() && "something is missing");
409
410 int n_iter = 1; // number of iterations for training
411 // execute
412 while (n_iter) {
413 // forward
414 for (size_t i = 0; i < net_fwd.size(); ++i)
415 net_fwd.at(i).execute(s, net_fwd_args.at(i));
416
417 // update net_diff_dst
418 // auto net_output = pool_user_dst_memory.get_data_handle();
419 // ..user updates net_diff_dst using net_output...
420 // some user defined func update_diff_dst(net_diff_dst.data(),
421 // net_output)
422
423 for (size_t i = 0; i < net_bwd.size(); ++i)
424 net_bwd.at(i).execute(s, net_bwd_args.at(i));
425 // update weights and bias using diff weights and bias
426 //
427 // auto net_diff_weights
428 // = conv_user_diff_weights_memory.get_data_handle();
429 // auto net_diff_bias = conv_diff_bias_memory.get_data_handle();
430 //
431 // ...user updates weights and bias using diff weights and bias...
432 //
433 // some user defined func update_weights(conv_weights.data(),
434 // conv_bias.data(), net_diff_weights, net_diff_bias);
435
436 --n_iter;
437 }
438
439 s.wait();
440}
441
442int main(int argc, char **argv) {
443 return handle_example_errors(simple_net, parse_engine_kind(argc, argv));
444}
445