1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example rnn_training_f32.cpp
18/// @copybrief rnn_training_f32_cpp
19/// > Annotated version: @ref rnn_training_f32_cpp
20///
21/// @page rnn_training_f32_cpp RNN f32 training example
22/// This C++ API example demonstrates how to build GNMT model training.
23///
24/// @include rnn_training_f32.cpp
25
26#include <cstring>
27#include <math.h>
28#include <numeric>
29#include <utility>
30
31#include "oneapi/dnnl/dnnl.hpp"
32
33#include "example_utils.hpp"
34
35using namespace dnnl;
36
37// User input is:
38// N0 sequences of length T0
39const int N0 = 1 + rand() % 31;
40// N1 sequences of length T1
41const int N1 = 1 + rand() % 31;
42// Assume T0 > T1
43const int T0 = 31 + 1 + rand() % 31;
44const int T1 = 1 + rand() % 31;
45
46// Memory required to hold it: N0 * T0 + N1 * T1
47// However it is possible to have these coming
48// as padded chunks in larger memory:
49// e.g. (N0 + N1) * T0
50// We don't need to compact the data before processing,
51// we can address the chunks via sub-memory and
52// process the data via two RNN primitives:
53// of time lengths T1 and T0 - T1.
54// The leftmost primitive will process N0 + N1 subsequences of length T1
55// The rightmost primitive will process remaining N0 subsequences
56// of T0 - T1 length
57const int leftmost_batch = N0 + N1;
58const int rightmost_batch = N0;
59
60const int leftmost_seq_length = T1;
61const int rightmost_seq_length = T0 - T1;
62
63// Number of channels
64const int common_feature_size = 1024;
65
66// RNN primitive characteristics
67const int common_n_layers = 1;
68const int lstm_n_gates = 4;
69
70void simple_net(engine::kind engine_kind) {
71 using tag = memory::format_tag;
72 using dt = memory::data_type;
73
74 auto eng = engine(engine_kind, 0);
75 stream s(eng);
76
77 bool is_training = true;
78 auto fwd_inf_train = is_training ? prop_kind::forward_training
79 : prop_kind::forward_inference;
80
81 std::vector<primitive> fwd_net;
82 std::vector<primitive> bwd_net;
83
84 // Input tensor holds two batches with different sequence lengths.
85 // Shorter sequences are padded
86 memory::dims net_src_dims = {
87 T0, // time, maximum sequence length
88 N0 + N1, // n, total batch size
89 common_feature_size // c, common number of channels
90 };
91
92 // Two RNN primitives for different sequence lengths,
93 // one unidirectional layer, LSTM-based
94 memory::dims leftmost_src_layer_dims = {
95 leftmost_seq_length, // time
96 leftmost_batch, // n
97 common_feature_size // c
98 };
99 memory::dims rightmost_src_layer_dims = {
100 rightmost_seq_length, // time
101 rightmost_batch, // n
102 common_feature_size // c
103 };
104 memory::dims common_weights_layer_dims = {
105 common_n_layers, // layers
106 1, // directions
107 common_feature_size, // input feature size
108 lstm_n_gates, // gates number
109 common_feature_size // output feature size
110 };
111 memory::dims common_weights_iter_dims = {
112 common_n_layers, // layers
113 1, // directions
114 common_feature_size, // input feature size
115 lstm_n_gates, // gates number
116 common_feature_size // output feature size
117 };
118 memory::dims common_bias_dims = {
119 common_n_layers, // layers
120 1, // directions
121 lstm_n_gates, // gates number
122 common_feature_size // output feature size
123 };
124 memory::dims leftmost_dst_layer_dims = {
125 leftmost_seq_length, // time
126 leftmost_batch, // n
127 common_feature_size // c
128 };
129 memory::dims rightmost_dst_layer_dims = {
130 rightmost_seq_length, // time
131 rightmost_batch, // n
132 common_feature_size // c
133 };
134
135 // leftmost primitive passes its states to the next RNN iteration
136 // so it needs dst_iter parameter.
137 //
138 // rightmost primitive will consume these as src_iter and will access the
139 // memory via a sub-memory because it will have different batch dimension.
140 // We have arranged our primitives so that
141 // leftmost_batch >= rightmost_batch, and so the rightmost data will fit
142 // into the memory allocated for the leftmost.
143 memory::dims leftmost_dst_iter_dims = {
144 common_n_layers, // layers
145 1, // directions
146 leftmost_batch, // n
147 common_feature_size // c
148 };
149 memory::dims leftmost_dst_iter_c_dims = {
150 common_n_layers, // layers
151 1, // directions
152 leftmost_batch, // n
153 common_feature_size // c
154 };
155 memory::dims rightmost_src_iter_dims = {
156 common_n_layers, // layers
157 1, // directions
158 rightmost_batch, // n
159 common_feature_size // c
160 };
161 memory::dims rightmost_src_iter_c_dims = {
162 common_n_layers, // layers
163 1, // directions
164 rightmost_batch, // n
165 common_feature_size // c
166 };
167
168 // multiplication of tensor dimensions
169 auto tz_volume = [=](memory::dims tz_dims) {
170 return std::accumulate(tz_dims.begin(), tz_dims.end(), (memory::dim)1,
171 std::multiplies<memory::dim>());
172 };
173
174 // Create auxillary f32 memory descriptor
175 // based on user- supplied dimensions and layout.
176 auto formatted_md
177 = [=](const memory::dims &dimensions, memory::format_tag layout) {
178 return memory::desc {{dimensions}, dt::f32, layout};
179 };
180 // Create auxillary generic f32 memory descriptor
181 // based on supplied dimensions, with format_tag::any.
182 auto generic_md = [=](const memory::dims &dimensions) {
183 return formatted_md(dimensions, tag::any);
184 };
185
186 //
187 // I/O memory, coming from user
188 //
189
190 // Net input
191 std::vector<float> net_src(tz_volume(net_src_dims), 1.0f);
192 // NOTE: in this example we study input sequences with variable batch
193 // dimension, which get processed by two separate RNN primitives, thus
194 // the destination memory for the two will have different shapes: batch
195 // is the second dimension currently: see format_tag::tnc.
196 // We are not copying the output to some common user provided memory as we
197 // suggest that the user should rather keep the two output memories separate
198 // throughout the whole topology and only reorder to something else as
199 // needed.
200 // So there's no common net_dst, but there are two destinations instead:
201 // leftmost_dst_layer_memory
202 // rightmost_dst_layer_memory
203
204 // Memory for the user allocated memory
205 // Suppose user data is in tnc format.
206 auto net_src_memory
207 = dnnl::memory({{net_src_dims}, dt::f32, tag::tnc}, eng);
208 write_to_dnnl_memory(net_src.data(), net_src_memory);
209 // src_layer memory of the leftmost and rightmost RNN primitives
210 // are accessed through the respective sub-memories in larger memory.
211 // View primitives compute the strides to accommodate for padding.
212 auto user_leftmost_src_layer_md = net_src_memory.get_desc().submemory_desc(
213 leftmost_src_layer_dims, {0, 0, 0}); // t, n, c offsets
214 auto user_rightmost_src_layer_md
215 = net_src_memory.get_desc().submemory_desc(rightmost_src_layer_dims,
216 {leftmost_seq_length, 0, 0}); // t, n, c offsets
217 auto leftmost_src_layer_memory = net_src_memory;
218 auto rightmost_src_layer_memory = net_src_memory;
219
220 // Other user provided memory arrays, descriptors and primitives with the
221 // data layouts chosen by user. We'll have to reorder if RNN
222 // primitive prefers it in a different format.
223 std::vector<float> user_common_weights_layer(
224 tz_volume(common_weights_layer_dims), 1.0f);
225 auto user_common_weights_layer_memory = dnnl::memory(
226 {common_weights_layer_dims, dt::f32, tag::ldigo}, eng);
227 write_to_dnnl_memory(
228 user_common_weights_layer.data(), user_common_weights_layer_memory);
229
230 std::vector<float> user_common_weights_iter(
231 tz_volume(common_weights_iter_dims), 1.0f);
232 auto user_common_weights_iter_memory = dnnl::memory(
233 {{common_weights_iter_dims}, dt::f32, tag::ldigo}, eng);
234 write_to_dnnl_memory(
235 user_common_weights_layer.data(), user_common_weights_iter_memory);
236
237 std::vector<float> user_common_bias(tz_volume(common_bias_dims), 1.0f);
238 auto user_common_bias_memory
239 = dnnl::memory({{common_bias_dims}, dt::f32, tag::ldgo}, eng);
240 write_to_dnnl_memory(user_common_bias.data(), user_common_bias_memory);
241
242 std::vector<float> user_leftmost_dst_layer(
243 tz_volume(leftmost_dst_layer_dims), 1.0f);
244 auto user_leftmost_dst_layer_memory
245 = dnnl::memory({{leftmost_dst_layer_dims}, dt::f32, tag::tnc}, eng);
246 write_to_dnnl_memory(
247 user_leftmost_dst_layer.data(), user_leftmost_dst_layer_memory);
248
249 std::vector<float> user_rightmost_dst_layer(
250 tz_volume(rightmost_dst_layer_dims), 1.0f);
251 auto user_rightmost_dst_layer_memory = dnnl::memory(
252 {{rightmost_dst_layer_dims}, dt::f32, tag::tnc}, eng);
253 write_to_dnnl_memory(
254 user_rightmost_dst_layer.data(), user_rightmost_dst_layer_memory);
255
256 // Describe layer, forward pass, leftmost primitive.
257 // There are no primitives to the left from here,
258 // so src_iter_desc needs to be zero memory desc
259 auto leftmost_prim_desc = lstm_forward::primitive_desc(eng, // engine
260 fwd_inf_train, // aprop_kind
261 rnn_direction::unidirectional_left2right, // direction
262 user_leftmost_src_layer_md, // src_layer_desc
263 memory::desc(), // src_iter_desc
264 memory::desc(), // src_iter_c_desc
265 generic_md(common_weights_layer_dims), // weights_layer_desc
266 generic_md(common_weights_iter_dims), // weights_iter_desc
267 generic_md(common_bias_dims), // bias_desc
268 formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc
269 generic_md(leftmost_dst_iter_dims), // dst_iter_desc
270 generic_md(leftmost_dst_iter_c_dims) // dst_iter_c_desc
271 );
272
273 //
274 // Need to connect leftmost and rightmost via "iter" parameters.
275 // We allocate memory here based on the shapes provided by RNN primitive.
276 //
277 auto leftmost_dst_iter_memory
278 = dnnl::memory(leftmost_prim_desc.dst_iter_desc(), eng);
279 auto leftmost_dst_iter_c_memory
280 = dnnl::memory(leftmost_prim_desc.dst_iter_c_desc(), eng);
281
282 // rightmost src_iter will be a sub-memory of dst_iter of leftmost
283 auto rightmost_src_iter_md
284 = leftmost_dst_iter_memory.get_desc().submemory_desc(
285 rightmost_src_iter_dims,
286 {0, 0, 0, 0}); // l, d, n, c offsets
287 auto rightmost_src_iter_memory = leftmost_dst_iter_memory;
288
289 auto rightmost_src_iter_c_md
290 = leftmost_dst_iter_c_memory.get_desc().submemory_desc(
291 rightmost_src_iter_c_dims,
292 {0, 0, 0, 0}); // l, d, n, c offsets
293 auto rightmost_src_iter_c_memory = leftmost_dst_iter_c_memory;
294
295 // Now rightmost primitive
296 // There are no primitives to the right from here,
297 // so dst_iter_desc is explicit zero memory desc
298 auto rightmost_prim_desc = lstm_forward::primitive_desc(eng, // engine
299 fwd_inf_train, // aprop_kind
300 rnn_direction::unidirectional_left2right, // direction
301 user_rightmost_src_layer_md, // src_layer_desc
302 rightmost_src_iter_md, // src_iter_desc
303 rightmost_src_iter_c_md, // src_iter_c_desc
304 generic_md(common_weights_layer_dims), // weights_layer_desc
305 generic_md(common_weights_iter_dims), // weights_iter_desc
306 generic_md(common_bias_dims), // bias_desc
307 formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc
308 memory::desc(), // dst_iter_desc
309 memory::desc() // dst_iter_c_desc
310 );
311
312 //
313 // Weights and biases, layer memory
314 // Same layout should work across the layer, no reorders
315 // needed between leftmost and rigthmost, only reordering
316 // user memory to the RNN-friendly shapes.
317 //
318
319 auto common_weights_layer_memory = user_common_weights_layer_memory;
320 if (leftmost_prim_desc.weights_layer_desc()
321 != common_weights_layer_memory.get_desc()) {
322 common_weights_layer_memory
323 = dnnl::memory(leftmost_prim_desc.weights_layer_desc(), eng);
324 reorder(user_common_weights_layer_memory, common_weights_layer_memory)
325 .execute(s, user_common_weights_layer_memory,
326 common_weights_layer_memory);
327 }
328
329 auto common_weights_iter_memory = user_common_weights_iter_memory;
330 if (leftmost_prim_desc.weights_iter_desc()
331 != common_weights_iter_memory.get_desc()) {
332 common_weights_iter_memory
333 = dnnl::memory(leftmost_prim_desc.weights_iter_desc(), eng);
334 reorder(user_common_weights_iter_memory, common_weights_iter_memory)
335 .execute(s, user_common_weights_iter_memory,
336 common_weights_iter_memory);
337 }
338
339 auto common_bias_memory = user_common_bias_memory;
340 if (leftmost_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
341 common_bias_memory = dnnl::memory(leftmost_prim_desc.bias_desc(), eng);
342 reorder(user_common_bias_memory, common_bias_memory)
343 .execute(s, user_common_bias_memory, common_bias_memory);
344 }
345
346 //
347 // Destination layer memory
348 //
349
350 auto leftmost_dst_layer_memory = user_leftmost_dst_layer_memory;
351 if (leftmost_prim_desc.dst_layer_desc()
352 != leftmost_dst_layer_memory.get_desc()) {
353 leftmost_dst_layer_memory
354 = dnnl::memory(leftmost_prim_desc.dst_layer_desc(), eng);
355 reorder(user_leftmost_dst_layer_memory, leftmost_dst_layer_memory)
356 .execute(s, user_leftmost_dst_layer_memory,
357 leftmost_dst_layer_memory);
358 }
359
360 auto rightmost_dst_layer_memory = user_rightmost_dst_layer_memory;
361 if (rightmost_prim_desc.dst_layer_desc()
362 != rightmost_dst_layer_memory.get_desc()) {
363 rightmost_dst_layer_memory
364 = dnnl::memory(rightmost_prim_desc.dst_layer_desc(), eng);
365 reorder(user_rightmost_dst_layer_memory, rightmost_dst_layer_memory)
366 .execute(s, user_rightmost_dst_layer_memory,
367 rightmost_dst_layer_memory);
368 }
369
370 // We also create workspace memory based on the information from
371 // the workspace_primitive_desc(). This is needed for internal
372 // communication between forward and backward primitives during
373 // training.
374 auto create_ws = [=](dnnl::lstm_forward::primitive_desc &pd) {
375 return dnnl::memory(pd.workspace_desc(), eng);
376 };
377 auto leftmost_workspace_memory = create_ws(leftmost_prim_desc);
378 auto rightmost_workspace_memory = create_ws(rightmost_prim_desc);
379
380 // Construct the RNN primitive objects
381 lstm_forward leftmost_layer(leftmost_prim_desc);
382 leftmost_layer.execute(s,
383 {{DNNL_ARG_SRC_LAYER, leftmost_src_layer_memory},
384 {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory},
385 {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory},
386 {DNNL_ARG_BIAS, common_bias_memory},
387 {DNNL_ARG_DST_LAYER, leftmost_dst_layer_memory},
388 {DNNL_ARG_DST_ITER, leftmost_dst_iter_memory},
389 {DNNL_ARG_DST_ITER_C, leftmost_dst_iter_c_memory},
390 {DNNL_ARG_WORKSPACE, leftmost_workspace_memory}});
391
392 lstm_forward rightmost_layer(rightmost_prim_desc);
393 rightmost_layer.execute(s,
394 {{DNNL_ARG_SRC_LAYER, rightmost_src_layer_memory},
395 {DNNL_ARG_SRC_ITER, rightmost_src_iter_memory},
396 {DNNL_ARG_SRC_ITER_C, rightmost_src_iter_c_memory},
397 {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory},
398 {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory},
399 {DNNL_ARG_BIAS, common_bias_memory},
400 {DNNL_ARG_DST_LAYER, rightmost_dst_layer_memory},
401 {DNNL_ARG_WORKSPACE, rightmost_workspace_memory}});
402
403 // No backward pass for inference
404 if (!is_training) return;
405
406 //
407 // Backward primitives will reuse memory from forward
408 // and allocate/describe specifics here. Only relevant for training.
409 //
410
411 // User-provided memory for backward by data output
412 std::vector<float> net_diff_src(tz_volume(net_src_dims), 1.0f);
413 auto net_diff_src_memory
414 = dnnl::memory(formatted_md(net_src_dims, tag::tnc), eng);
415 write_to_dnnl_memory(net_diff_src.data(), net_diff_src_memory);
416
417 // diff_src follows the same layout we have for net_src
418 auto user_leftmost_diff_src_layer_md
419 = net_diff_src_memory.get_desc().submemory_desc(
420 leftmost_src_layer_dims, {0, 0, 0}); // t, n, c offsets
421 auto user_rightmost_diff_src_layer_md
422 = net_diff_src_memory.get_desc().submemory_desc(
423 rightmost_src_layer_dims,
424 {leftmost_seq_length, 0, 0}); // t, n, c offsets
425 auto leftmost_diff_src_layer_memory = net_diff_src_memory;
426 auto rightmost_diff_src_layer_memory = net_diff_src_memory;
427
428 // User-provided memory for backpropagation by weights
429 std::vector<float> user_common_diff_weights_layer(
430 tz_volume(common_weights_layer_dims), 1.0f);
431 auto user_common_diff_weights_layer_memory = dnnl::memory(
432 formatted_md(common_weights_layer_dims, tag::ldigo), eng);
433 write_to_dnnl_memory(user_common_diff_weights_layer.data(),
434 user_common_diff_weights_layer_memory);
435
436 std::vector<float> user_common_diff_bias(tz_volume(common_bias_dims), 1.0f);
437 auto user_common_diff_bias_memory
438 = dnnl::memory(formatted_md(common_bias_dims, tag::ldgo), eng);
439 write_to_dnnl_memory(
440 user_common_diff_bias.data(), user_common_diff_bias_memory);
441
442 // User-provided input to the backward primitive.
443 // To be updated by the user after forward pass using some cost function.
444 memory::dims net_diff_dst_dims = {
445 T0, // time
446 N0 + N1, // n
447 common_feature_size // c
448 };
449 // Suppose user data is in tnc format.
450 std::vector<float> net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f);
451 auto net_diff_dst_memory
452 = dnnl::memory(formatted_md(net_diff_dst_dims, tag::tnc), eng);
453 write_to_dnnl_memory(net_diff_dst.data(), net_diff_dst_memory);
454 // diff_dst_layer memory of the leftmost and rightmost RNN primitives
455 // are accessed through the respective sub-memory in larger memory.
456 // View primitives compute the strides to accommodate for padding.
457 auto user_leftmost_diff_dst_layer_md
458 = net_diff_dst_memory.get_desc().submemory_desc(
459 leftmost_dst_layer_dims, {0, 0, 0}); // t, n, c offsets
460 auto user_rightmost_diff_dst_layer_md
461 = net_diff_dst_memory.get_desc().submemory_desc(
462 rightmost_dst_layer_dims,
463 {leftmost_seq_length, 0, 0}); // t, n, c offsets
464 auto leftmost_diff_dst_layer_memory = net_diff_dst_memory;
465 auto rightmost_diff_dst_layer_memory = net_diff_dst_memory;
466
467 // Backward leftmost primitive descriptor
468 auto leftmost_bwd_prim_desc = lstm_backward::primitive_desc(eng, // engine
469 prop_kind::backward, // aprop_kind
470 rnn_direction::unidirectional_left2right, // direction
471 user_leftmost_src_layer_md, // src_layer_desc
472 memory::desc(), // src_iter_desc
473 memory::desc(), // src_iter_c_desc
474 generic_md(common_weights_layer_dims), // weights_layer_desc
475 generic_md(common_weights_iter_dims), // weights_iter_desc
476 generic_md(common_bias_dims), // bias_desc
477 formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc
478 generic_md(leftmost_dst_iter_dims), // dst_iter_desc
479 generic_md(leftmost_dst_iter_c_dims), // dst_iter_c_desc
480 user_leftmost_diff_src_layer_md, // diff_src_layer_desc
481 memory::desc(), // diff_src_iter_desc
482 memory::desc(), // diff_src_iter_c_desc
483 generic_md(common_weights_layer_dims), // diff_weights_layer_desc
484 generic_md(common_weights_iter_dims), // diff_weights_iter_desc
485 generic_md(common_bias_dims), // diff_bias_desc
486 user_leftmost_diff_dst_layer_md, // diff_dst_layer_desc
487 generic_md(leftmost_dst_iter_dims), // diff_dst_iter_desc
488 generic_md(leftmost_dst_iter_c_dims), // diff_dst_iter_c_desc
489 leftmost_prim_desc // hint from forward pass
490 );
491
492 // As the batch dimensions are different between leftmost and rightmost
493 // we need to use a sub-memory. rightmost needs less memory, so it will
494 // be a sub-memory of leftmost.
495 auto leftmost_diff_dst_iter_memory
496 = dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_desc(), eng);
497 auto leftmost_diff_dst_iter_c_memory
498 = dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_c_desc(), eng);
499
500 auto rightmost_diff_src_iter_md
501 = leftmost_diff_dst_iter_memory.get_desc().submemory_desc(
502 rightmost_src_iter_dims,
503 {0, 0, 0, 0}); // l, d, n, c offsets
504 auto rightmost_diff_src_iter_memory = leftmost_diff_dst_iter_memory;
505
506 auto rightmost_diff_src_iter_c_md
507 = leftmost_diff_dst_iter_c_memory.get_desc().submemory_desc(
508 rightmost_src_iter_c_dims,
509 {0, 0, 0, 0}); // l, d, n, c offsets
510 auto rightmost_diff_src_iter_c_memory = leftmost_diff_dst_iter_c_memory;
511
512 // Backward rightmost primitive descriptor
513 auto rightmost_bwd_prim_desc = lstm_backward::primitive_desc(eng, // engine
514 prop_kind::backward, // aprop_kind
515 rnn_direction::unidirectional_left2right, // direction
516 user_rightmost_src_layer_md, // src_layer_desc
517 generic_md(rightmost_src_iter_dims), // src_iter_desc
518 generic_md(rightmost_src_iter_c_dims), // src_iter_c_desc
519 generic_md(common_weights_layer_dims), // weights_layer_desc
520 generic_md(common_weights_iter_dims), // weights_iter_desc
521 generic_md(common_bias_dims), // bias_desc
522 formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc
523 memory::desc(), // dst_iter_desc
524 memory::desc(), // dst_iter_c_desc
525 user_rightmost_diff_src_layer_md, // diff_src_layer_desc
526 rightmost_diff_src_iter_md, // diff_src_iter_desc
527 rightmost_diff_src_iter_c_md, // diff_src_iter_c_desc
528 generic_md(common_weights_layer_dims), // diff_weights_layer_desc
529 generic_md(common_weights_iter_dims), // diff_weights_iter_desc
530 generic_md(common_bias_dims), // diff_bias_desc
531 user_rightmost_diff_dst_layer_md, // diff_dst_layer_desc
532 memory::desc(), // diff_dst_iter_desc
533 memory::desc(), // diff_dst_iter_c_desc
534 rightmost_prim_desc // hint from forward pass
535 );
536
537 //
538 // Memory for backward pass
539 //
540
541 // src layer uses the same memory as forward
542 auto leftmost_src_layer_bwd_memory = leftmost_src_layer_memory;
543 auto rightmost_src_layer_bwd_memory = rightmost_src_layer_memory;
544
545 // Memory for weights and biases for backward pass
546 // Try to use the same memory between forward and backward, but
547 // sometimes reorders are needed.
548 auto common_weights_layer_bwd_memory = common_weights_layer_memory;
549 if (leftmost_bwd_prim_desc.weights_layer_desc()
550 != leftmost_prim_desc.weights_layer_desc()) {
551 common_weights_layer_bwd_memory
552 = memory(leftmost_bwd_prim_desc.weights_layer_desc(), eng);
553 reorder(common_weights_layer_memory, common_weights_layer_bwd_memory)
554 .execute(s, common_weights_layer_memory,
555 common_weights_layer_bwd_memory);
556 }
557
558 auto common_weights_iter_bwd_memory = common_weights_iter_memory;
559 if (leftmost_bwd_prim_desc.weights_iter_desc()
560 != leftmost_prim_desc.weights_iter_desc()) {
561 common_weights_iter_bwd_memory
562 = memory(leftmost_bwd_prim_desc.weights_iter_desc(), eng);
563 reorder(common_weights_iter_memory, common_weights_iter_bwd_memory)
564 .execute(s, common_weights_iter_memory,
565 common_weights_iter_bwd_memory);
566 }
567
568 auto common_bias_bwd_memory = common_bias_memory;
569 if (leftmost_bwd_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
570 common_bias_bwd_memory
571 = dnnl::memory(leftmost_bwd_prim_desc.bias_desc(), eng);
572 reorder(common_bias_memory, common_bias_bwd_memory)
573 .execute(s, common_bias_memory, common_bias_bwd_memory);
574 }
575
576 // diff_weights and biases
577 auto common_diff_weights_layer_memory
578 = user_common_diff_weights_layer_memory;
579 auto reorder_common_diff_weights_layer = false;
580 if (leftmost_bwd_prim_desc.diff_weights_layer_desc()
581 != common_diff_weights_layer_memory.get_desc()) {
582 common_diff_weights_layer_memory = dnnl::memory(
583 leftmost_bwd_prim_desc.diff_weights_layer_desc(), eng);
584 reorder_common_diff_weights_layer = true;
585 }
586
587 auto common_diff_bias_memory = user_common_diff_bias_memory;
588 auto reorder_common_diff_bias = false;
589 if (leftmost_bwd_prim_desc.diff_bias_desc()
590 != common_diff_bias_memory.get_desc()) {
591 common_diff_bias_memory
592 = dnnl::memory(leftmost_bwd_prim_desc.diff_bias_desc(), eng);
593 reorder_common_diff_bias = true;
594 }
595
596 // dst_layer memory for backward pass
597 auto leftmost_dst_layer_bwd_memory = leftmost_dst_layer_memory;
598 if (leftmost_bwd_prim_desc.dst_layer_desc()
599 != leftmost_dst_layer_bwd_memory.get_desc()) {
600 leftmost_dst_layer_bwd_memory
601 = dnnl::memory(leftmost_bwd_prim_desc.dst_layer_desc(), eng);
602 reorder(leftmost_dst_layer_memory, leftmost_dst_layer_bwd_memory)
603 .execute(s, leftmost_dst_layer_memory,
604 leftmost_dst_layer_bwd_memory);
605 }
606
607 auto rightmost_dst_layer_bwd_memory = rightmost_dst_layer_memory;
608 if (rightmost_bwd_prim_desc.dst_layer_desc()
609 != rightmost_dst_layer_bwd_memory.get_desc()) {
610 rightmost_dst_layer_bwd_memory
611 = dnnl::memory(rightmost_bwd_prim_desc.dst_layer_desc(), eng);
612 reorder(rightmost_dst_layer_memory, rightmost_dst_layer_bwd_memory)
613 .execute(s, rightmost_dst_layer_memory,
614 rightmost_dst_layer_bwd_memory);
615 }
616
617 // Similar to forward, the backward primitives are connected
618 // via "iter" parameters.
619 auto common_diff_weights_iter_memory = dnnl::memory(
620 leftmost_bwd_prim_desc.diff_weights_iter_desc(), eng);
621
622 auto leftmost_dst_iter_bwd_memory = leftmost_dst_iter_memory;
623 if (leftmost_bwd_prim_desc.dst_iter_desc()
624 != leftmost_dst_iter_bwd_memory.get_desc()) {
625 leftmost_dst_iter_bwd_memory
626 = dnnl::memory(leftmost_bwd_prim_desc.dst_iter_desc(), eng);
627 reorder(leftmost_dst_iter_memory, leftmost_dst_iter_bwd_memory)
628 .execute(s, leftmost_dst_iter_memory,
629 leftmost_dst_iter_bwd_memory);
630 }
631
632 auto leftmost_dst_iter_c_bwd_memory = leftmost_dst_iter_c_memory;
633 if (leftmost_bwd_prim_desc.dst_iter_c_desc()
634 != leftmost_dst_iter_c_bwd_memory.get_desc()) {
635 leftmost_dst_iter_c_bwd_memory
636 = dnnl::memory(leftmost_bwd_prim_desc.dst_iter_c_desc(), eng);
637 reorder(leftmost_dst_iter_c_memory, leftmost_dst_iter_c_bwd_memory)
638 .execute(s, leftmost_dst_iter_c_memory,
639 leftmost_dst_iter_c_bwd_memory);
640 }
641
642 // Construct the RNN primitive objects for backward
643 lstm_backward rightmost_layer_bwd(rightmost_bwd_prim_desc);
644 rightmost_layer_bwd.execute(s,
645 {{DNNL_ARG_SRC_LAYER, rightmost_src_layer_bwd_memory},
646 {DNNL_ARG_SRC_ITER, rightmost_src_iter_memory},
647 {DNNL_ARG_SRC_ITER_C, rightmost_src_iter_c_memory},
648 {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory},
649 {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory},
650 {DNNL_ARG_BIAS, common_bias_bwd_memory},
651 {DNNL_ARG_DST_LAYER, rightmost_dst_layer_bwd_memory},
652 {DNNL_ARG_DIFF_SRC_LAYER, rightmost_diff_src_layer_memory},
653 {DNNL_ARG_DIFF_SRC_ITER, rightmost_diff_src_iter_memory},
654 {DNNL_ARG_DIFF_SRC_ITER_C,
655 rightmost_diff_src_iter_c_memory},
656 {DNNL_ARG_DIFF_WEIGHTS_LAYER,
657 common_diff_weights_layer_memory},
658 {DNNL_ARG_DIFF_WEIGHTS_ITER,
659 common_diff_weights_iter_memory},
660 {DNNL_ARG_DIFF_BIAS, common_diff_bias_memory},
661 {DNNL_ARG_DIFF_DST_LAYER, rightmost_diff_dst_layer_memory},
662 {DNNL_ARG_WORKSPACE, rightmost_workspace_memory}});
663
664 lstm_backward leftmost_layer_bwd(leftmost_bwd_prim_desc);
665 leftmost_layer_bwd.execute(s,
666 {{DNNL_ARG_SRC_LAYER, leftmost_src_layer_bwd_memory},
667 {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory},
668 {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory},
669 {DNNL_ARG_BIAS, common_bias_bwd_memory},
670 {DNNL_ARG_DST_LAYER, leftmost_dst_layer_bwd_memory},
671 {DNNL_ARG_DST_ITER, leftmost_dst_iter_bwd_memory},
672 {DNNL_ARG_DST_ITER_C, leftmost_dst_iter_c_bwd_memory},
673 {DNNL_ARG_DIFF_SRC_LAYER, leftmost_diff_src_layer_memory},
674 {DNNL_ARG_DIFF_WEIGHTS_LAYER,
675 common_diff_weights_layer_memory},
676 {DNNL_ARG_DIFF_WEIGHTS_ITER,
677 common_diff_weights_iter_memory},
678 {DNNL_ARG_DIFF_BIAS, common_diff_bias_memory},
679 {DNNL_ARG_DIFF_DST_LAYER, leftmost_diff_dst_layer_memory},
680 {DNNL_ARG_DIFF_DST_ITER, leftmost_diff_dst_iter_memory},
681 {DNNL_ARG_DIFF_DST_ITER_C, leftmost_diff_dst_iter_c_memory},
682 {DNNL_ARG_WORKSPACE, leftmost_workspace_memory}});
683 if (reorder_common_diff_weights_layer) {
684 reorder(common_diff_weights_layer_memory,
685 user_common_diff_weights_layer_memory)
686 .execute(s, common_diff_weights_layer_memory,
687 user_common_diff_weights_layer_memory);
688 }
689
690 if (reorder_common_diff_bias) {
691 reorder(common_diff_bias_memory, user_common_diff_bias_memory)
692 .execute(s, common_diff_bias_memory,
693 user_common_diff_bias_memory);
694 }
695
696 //
697 // User updates weights and bias using diffs
698 //
699
700 s.wait();
701}
702
703int main(int argc, char **argv) {
704 return handle_example_errors(simple_net, parse_engine_kind(argc, argv));
705}
706