1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /// @example rnn_training_f32.cpp |
18 | /// @copybrief rnn_training_f32_cpp |
19 | /// > Annotated version: @ref rnn_training_f32_cpp |
20 | /// |
21 | /// @page rnn_training_f32_cpp RNN f32 training example |
22 | /// This C++ API example demonstrates how to build GNMT model training. |
23 | /// |
24 | /// @include rnn_training_f32.cpp |
25 | |
26 | #include <cstring> |
27 | #include <math.h> |
28 | #include <numeric> |
29 | #include <utility> |
30 | |
31 | #include "oneapi/dnnl/dnnl.hpp" |
32 | |
33 | #include "example_utils.hpp" |
34 | |
35 | using namespace dnnl; |
36 | |
37 | // User input is: |
38 | // N0 sequences of length T0 |
39 | const int N0 = 1 + rand() % 31; |
40 | // N1 sequences of length T1 |
41 | const int N1 = 1 + rand() % 31; |
42 | // Assume T0 > T1 |
43 | const int T0 = 31 + 1 + rand() % 31; |
44 | const int T1 = 1 + rand() % 31; |
45 | |
46 | // Memory required to hold it: N0 * T0 + N1 * T1 |
47 | // However it is possible to have these coming |
48 | // as padded chunks in larger memory: |
49 | // e.g. (N0 + N1) * T0 |
50 | // We don't need to compact the data before processing, |
51 | // we can address the chunks via sub-memory and |
52 | // process the data via two RNN primitives: |
53 | // of time lengths T1 and T0 - T1. |
54 | // The leftmost primitive will process N0 + N1 subsequences of length T1 |
55 | // The rightmost primitive will process remaining N0 subsequences |
56 | // of T0 - T1 length |
57 | const int leftmost_batch = N0 + N1; |
58 | const int rightmost_batch = N0; |
59 | |
60 | const int leftmost_seq_length = T1; |
61 | const int rightmost_seq_length = T0 - T1; |
62 | |
63 | // Number of channels |
64 | const int common_feature_size = 1024; |
65 | |
66 | // RNN primitive characteristics |
67 | const int common_n_layers = 1; |
68 | const int lstm_n_gates = 4; |
69 | |
70 | void simple_net(engine::kind engine_kind) { |
71 | using tag = memory::format_tag; |
72 | using dt = memory::data_type; |
73 | |
74 | auto eng = engine(engine_kind, 0); |
75 | stream s(eng); |
76 | |
77 | bool is_training = true; |
78 | auto fwd_inf_train = is_training ? prop_kind::forward_training |
79 | : prop_kind::forward_inference; |
80 | |
81 | std::vector<primitive> fwd_net; |
82 | std::vector<primitive> bwd_net; |
83 | |
84 | // Input tensor holds two batches with different sequence lengths. |
85 | // Shorter sequences are padded |
86 | memory::dims net_src_dims = { |
87 | T0, // time, maximum sequence length |
88 | N0 + N1, // n, total batch size |
89 | common_feature_size // c, common number of channels |
90 | }; |
91 | |
92 | // Two RNN primitives for different sequence lengths, |
93 | // one unidirectional layer, LSTM-based |
94 | memory::dims leftmost_src_layer_dims = { |
95 | leftmost_seq_length, // time |
96 | leftmost_batch, // n |
97 | common_feature_size // c |
98 | }; |
99 | memory::dims rightmost_src_layer_dims = { |
100 | rightmost_seq_length, // time |
101 | rightmost_batch, // n |
102 | common_feature_size // c |
103 | }; |
104 | memory::dims common_weights_layer_dims = { |
105 | common_n_layers, // layers |
106 | 1, // directions |
107 | common_feature_size, // input feature size |
108 | lstm_n_gates, // gates number |
109 | common_feature_size // output feature size |
110 | }; |
111 | memory::dims common_weights_iter_dims = { |
112 | common_n_layers, // layers |
113 | 1, // directions |
114 | common_feature_size, // input feature size |
115 | lstm_n_gates, // gates number |
116 | common_feature_size // output feature size |
117 | }; |
118 | memory::dims common_bias_dims = { |
119 | common_n_layers, // layers |
120 | 1, // directions |
121 | lstm_n_gates, // gates number |
122 | common_feature_size // output feature size |
123 | }; |
124 | memory::dims leftmost_dst_layer_dims = { |
125 | leftmost_seq_length, // time |
126 | leftmost_batch, // n |
127 | common_feature_size // c |
128 | }; |
129 | memory::dims rightmost_dst_layer_dims = { |
130 | rightmost_seq_length, // time |
131 | rightmost_batch, // n |
132 | common_feature_size // c |
133 | }; |
134 | |
135 | // leftmost primitive passes its states to the next RNN iteration |
136 | // so it needs dst_iter parameter. |
137 | // |
138 | // rightmost primitive will consume these as src_iter and will access the |
139 | // memory via a sub-memory because it will have different batch dimension. |
140 | // We have arranged our primitives so that |
141 | // leftmost_batch >= rightmost_batch, and so the rightmost data will fit |
142 | // into the memory allocated for the leftmost. |
143 | memory::dims leftmost_dst_iter_dims = { |
144 | common_n_layers, // layers |
145 | 1, // directions |
146 | leftmost_batch, // n |
147 | common_feature_size // c |
148 | }; |
149 | memory::dims leftmost_dst_iter_c_dims = { |
150 | common_n_layers, // layers |
151 | 1, // directions |
152 | leftmost_batch, // n |
153 | common_feature_size // c |
154 | }; |
155 | memory::dims rightmost_src_iter_dims = { |
156 | common_n_layers, // layers |
157 | 1, // directions |
158 | rightmost_batch, // n |
159 | common_feature_size // c |
160 | }; |
161 | memory::dims rightmost_src_iter_c_dims = { |
162 | common_n_layers, // layers |
163 | 1, // directions |
164 | rightmost_batch, // n |
165 | common_feature_size // c |
166 | }; |
167 | |
168 | // multiplication of tensor dimensions |
169 | auto tz_volume = [=](memory::dims tz_dims) { |
170 | return std::accumulate(tz_dims.begin(), tz_dims.end(), (memory::dim)1, |
171 | std::multiplies<memory::dim>()); |
172 | }; |
173 | |
174 | // Create auxillary f32 memory descriptor |
175 | // based on user- supplied dimensions and layout. |
176 | auto formatted_md |
177 | = [=](const memory::dims &dimensions, memory::format_tag layout) { |
178 | return memory::desc {{dimensions}, dt::f32, layout}; |
179 | }; |
180 | // Create auxillary generic f32 memory descriptor |
181 | // based on supplied dimensions, with format_tag::any. |
182 | auto generic_md = [=](const memory::dims &dimensions) { |
183 | return formatted_md(dimensions, tag::any); |
184 | }; |
185 | |
186 | // |
187 | // I/O memory, coming from user |
188 | // |
189 | |
190 | // Net input |
191 | std::vector<float> net_src(tz_volume(net_src_dims), 1.0f); |
192 | // NOTE: in this example we study input sequences with variable batch |
193 | // dimension, which get processed by two separate RNN primitives, thus |
194 | // the destination memory for the two will have different shapes: batch |
195 | // is the second dimension currently: see format_tag::tnc. |
196 | // We are not copying the output to some common user provided memory as we |
197 | // suggest that the user should rather keep the two output memories separate |
198 | // throughout the whole topology and only reorder to something else as |
199 | // needed. |
200 | // So there's no common net_dst, but there are two destinations instead: |
201 | // leftmost_dst_layer_memory |
202 | // rightmost_dst_layer_memory |
203 | |
204 | // Memory for the user allocated memory |
205 | // Suppose user data is in tnc format. |
206 | auto net_src_memory |
207 | = dnnl::memory({{net_src_dims}, dt::f32, tag::tnc}, eng); |
208 | write_to_dnnl_memory(net_src.data(), net_src_memory); |
209 | // src_layer memory of the leftmost and rightmost RNN primitives |
210 | // are accessed through the respective sub-memories in larger memory. |
211 | // View primitives compute the strides to accommodate for padding. |
212 | auto user_leftmost_src_layer_md = net_src_memory.get_desc().submemory_desc( |
213 | leftmost_src_layer_dims, {0, 0, 0}); // t, n, c offsets |
214 | auto user_rightmost_src_layer_md |
215 | = net_src_memory.get_desc().submemory_desc(rightmost_src_layer_dims, |
216 | {leftmost_seq_length, 0, 0}); // t, n, c offsets |
217 | auto leftmost_src_layer_memory = net_src_memory; |
218 | auto rightmost_src_layer_memory = net_src_memory; |
219 | |
220 | // Other user provided memory arrays, descriptors and primitives with the |
221 | // data layouts chosen by user. We'll have to reorder if RNN |
222 | // primitive prefers it in a different format. |
223 | std::vector<float> user_common_weights_layer( |
224 | tz_volume(common_weights_layer_dims), 1.0f); |
225 | auto user_common_weights_layer_memory = dnnl::memory( |
226 | {common_weights_layer_dims, dt::f32, tag::ldigo}, eng); |
227 | write_to_dnnl_memory( |
228 | user_common_weights_layer.data(), user_common_weights_layer_memory); |
229 | |
230 | std::vector<float> user_common_weights_iter( |
231 | tz_volume(common_weights_iter_dims), 1.0f); |
232 | auto user_common_weights_iter_memory = dnnl::memory( |
233 | {{common_weights_iter_dims}, dt::f32, tag::ldigo}, eng); |
234 | write_to_dnnl_memory( |
235 | user_common_weights_layer.data(), user_common_weights_iter_memory); |
236 | |
237 | std::vector<float> user_common_bias(tz_volume(common_bias_dims), 1.0f); |
238 | auto user_common_bias_memory |
239 | = dnnl::memory({{common_bias_dims}, dt::f32, tag::ldgo}, eng); |
240 | write_to_dnnl_memory(user_common_bias.data(), user_common_bias_memory); |
241 | |
242 | std::vector<float> user_leftmost_dst_layer( |
243 | tz_volume(leftmost_dst_layer_dims), 1.0f); |
244 | auto user_leftmost_dst_layer_memory |
245 | = dnnl::memory({{leftmost_dst_layer_dims}, dt::f32, tag::tnc}, eng); |
246 | write_to_dnnl_memory( |
247 | user_leftmost_dst_layer.data(), user_leftmost_dst_layer_memory); |
248 | |
249 | std::vector<float> user_rightmost_dst_layer( |
250 | tz_volume(rightmost_dst_layer_dims), 1.0f); |
251 | auto user_rightmost_dst_layer_memory = dnnl::memory( |
252 | {{rightmost_dst_layer_dims}, dt::f32, tag::tnc}, eng); |
253 | write_to_dnnl_memory( |
254 | user_rightmost_dst_layer.data(), user_rightmost_dst_layer_memory); |
255 | |
256 | // Describe layer, forward pass, leftmost primitive. |
257 | // There are no primitives to the left from here, |
258 | // so src_iter_desc needs to be zero memory desc |
259 | auto leftmost_prim_desc = lstm_forward::primitive_desc(eng, // engine |
260 | fwd_inf_train, // aprop_kind |
261 | rnn_direction::unidirectional_left2right, // direction |
262 | user_leftmost_src_layer_md, // src_layer_desc |
263 | memory::desc(), // src_iter_desc |
264 | memory::desc(), // src_iter_c_desc |
265 | generic_md(common_weights_layer_dims), // weights_layer_desc |
266 | generic_md(common_weights_iter_dims), // weights_iter_desc |
267 | generic_md(common_bias_dims), // bias_desc |
268 | formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc |
269 | generic_md(leftmost_dst_iter_dims), // dst_iter_desc |
270 | generic_md(leftmost_dst_iter_c_dims) // dst_iter_c_desc |
271 | ); |
272 | |
273 | // |
274 | // Need to connect leftmost and rightmost via "iter" parameters. |
275 | // We allocate memory here based on the shapes provided by RNN primitive. |
276 | // |
277 | auto leftmost_dst_iter_memory |
278 | = dnnl::memory(leftmost_prim_desc.dst_iter_desc(), eng); |
279 | auto leftmost_dst_iter_c_memory |
280 | = dnnl::memory(leftmost_prim_desc.dst_iter_c_desc(), eng); |
281 | |
282 | // rightmost src_iter will be a sub-memory of dst_iter of leftmost |
283 | auto rightmost_src_iter_md |
284 | = leftmost_dst_iter_memory.get_desc().submemory_desc( |
285 | rightmost_src_iter_dims, |
286 | {0, 0, 0, 0}); // l, d, n, c offsets |
287 | auto rightmost_src_iter_memory = leftmost_dst_iter_memory; |
288 | |
289 | auto rightmost_src_iter_c_md |
290 | = leftmost_dst_iter_c_memory.get_desc().submemory_desc( |
291 | rightmost_src_iter_c_dims, |
292 | {0, 0, 0, 0}); // l, d, n, c offsets |
293 | auto rightmost_src_iter_c_memory = leftmost_dst_iter_c_memory; |
294 | |
295 | // Now rightmost primitive |
296 | // There are no primitives to the right from here, |
297 | // so dst_iter_desc is explicit zero memory desc |
298 | auto rightmost_prim_desc = lstm_forward::primitive_desc(eng, // engine |
299 | fwd_inf_train, // aprop_kind |
300 | rnn_direction::unidirectional_left2right, // direction |
301 | user_rightmost_src_layer_md, // src_layer_desc |
302 | rightmost_src_iter_md, // src_iter_desc |
303 | rightmost_src_iter_c_md, // src_iter_c_desc |
304 | generic_md(common_weights_layer_dims), // weights_layer_desc |
305 | generic_md(common_weights_iter_dims), // weights_iter_desc |
306 | generic_md(common_bias_dims), // bias_desc |
307 | formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc |
308 | memory::desc(), // dst_iter_desc |
309 | memory::desc() // dst_iter_c_desc |
310 | ); |
311 | |
312 | // |
313 | // Weights and biases, layer memory |
314 | // Same layout should work across the layer, no reorders |
315 | // needed between leftmost and rigthmost, only reordering |
316 | // user memory to the RNN-friendly shapes. |
317 | // |
318 | |
319 | auto common_weights_layer_memory = user_common_weights_layer_memory; |
320 | if (leftmost_prim_desc.weights_layer_desc() |
321 | != common_weights_layer_memory.get_desc()) { |
322 | common_weights_layer_memory |
323 | = dnnl::memory(leftmost_prim_desc.weights_layer_desc(), eng); |
324 | reorder(user_common_weights_layer_memory, common_weights_layer_memory) |
325 | .execute(s, user_common_weights_layer_memory, |
326 | common_weights_layer_memory); |
327 | } |
328 | |
329 | auto common_weights_iter_memory = user_common_weights_iter_memory; |
330 | if (leftmost_prim_desc.weights_iter_desc() |
331 | != common_weights_iter_memory.get_desc()) { |
332 | common_weights_iter_memory |
333 | = dnnl::memory(leftmost_prim_desc.weights_iter_desc(), eng); |
334 | reorder(user_common_weights_iter_memory, common_weights_iter_memory) |
335 | .execute(s, user_common_weights_iter_memory, |
336 | common_weights_iter_memory); |
337 | } |
338 | |
339 | auto common_bias_memory = user_common_bias_memory; |
340 | if (leftmost_prim_desc.bias_desc() != common_bias_memory.get_desc()) { |
341 | common_bias_memory = dnnl::memory(leftmost_prim_desc.bias_desc(), eng); |
342 | reorder(user_common_bias_memory, common_bias_memory) |
343 | .execute(s, user_common_bias_memory, common_bias_memory); |
344 | } |
345 | |
346 | // |
347 | // Destination layer memory |
348 | // |
349 | |
350 | auto leftmost_dst_layer_memory = user_leftmost_dst_layer_memory; |
351 | if (leftmost_prim_desc.dst_layer_desc() |
352 | != leftmost_dst_layer_memory.get_desc()) { |
353 | leftmost_dst_layer_memory |
354 | = dnnl::memory(leftmost_prim_desc.dst_layer_desc(), eng); |
355 | reorder(user_leftmost_dst_layer_memory, leftmost_dst_layer_memory) |
356 | .execute(s, user_leftmost_dst_layer_memory, |
357 | leftmost_dst_layer_memory); |
358 | } |
359 | |
360 | auto rightmost_dst_layer_memory = user_rightmost_dst_layer_memory; |
361 | if (rightmost_prim_desc.dst_layer_desc() |
362 | != rightmost_dst_layer_memory.get_desc()) { |
363 | rightmost_dst_layer_memory |
364 | = dnnl::memory(rightmost_prim_desc.dst_layer_desc(), eng); |
365 | reorder(user_rightmost_dst_layer_memory, rightmost_dst_layer_memory) |
366 | .execute(s, user_rightmost_dst_layer_memory, |
367 | rightmost_dst_layer_memory); |
368 | } |
369 | |
370 | // We also create workspace memory based on the information from |
371 | // the workspace_primitive_desc(). This is needed for internal |
372 | // communication between forward and backward primitives during |
373 | // training. |
374 | auto create_ws = [=](dnnl::lstm_forward::primitive_desc &pd) { |
375 | return dnnl::memory(pd.workspace_desc(), eng); |
376 | }; |
377 | auto leftmost_workspace_memory = create_ws(leftmost_prim_desc); |
378 | auto rightmost_workspace_memory = create_ws(rightmost_prim_desc); |
379 | |
380 | // Construct the RNN primitive objects |
381 | lstm_forward leftmost_layer(leftmost_prim_desc); |
382 | leftmost_layer.execute(s, |
383 | {{DNNL_ARG_SRC_LAYER, leftmost_src_layer_memory}, |
384 | {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory}, |
385 | {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory}, |
386 | {DNNL_ARG_BIAS, common_bias_memory}, |
387 | {DNNL_ARG_DST_LAYER, leftmost_dst_layer_memory}, |
388 | {DNNL_ARG_DST_ITER, leftmost_dst_iter_memory}, |
389 | {DNNL_ARG_DST_ITER_C, leftmost_dst_iter_c_memory}, |
390 | {DNNL_ARG_WORKSPACE, leftmost_workspace_memory}}); |
391 | |
392 | lstm_forward rightmost_layer(rightmost_prim_desc); |
393 | rightmost_layer.execute(s, |
394 | {{DNNL_ARG_SRC_LAYER, rightmost_src_layer_memory}, |
395 | {DNNL_ARG_SRC_ITER, rightmost_src_iter_memory}, |
396 | {DNNL_ARG_SRC_ITER_C, rightmost_src_iter_c_memory}, |
397 | {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory}, |
398 | {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory}, |
399 | {DNNL_ARG_BIAS, common_bias_memory}, |
400 | {DNNL_ARG_DST_LAYER, rightmost_dst_layer_memory}, |
401 | {DNNL_ARG_WORKSPACE, rightmost_workspace_memory}}); |
402 | |
403 | // No backward pass for inference |
404 | if (!is_training) return; |
405 | |
406 | // |
407 | // Backward primitives will reuse memory from forward |
408 | // and allocate/describe specifics here. Only relevant for training. |
409 | // |
410 | |
411 | // User-provided memory for backward by data output |
412 | std::vector<float> net_diff_src(tz_volume(net_src_dims), 1.0f); |
413 | auto net_diff_src_memory |
414 | = dnnl::memory(formatted_md(net_src_dims, tag::tnc), eng); |
415 | write_to_dnnl_memory(net_diff_src.data(), net_diff_src_memory); |
416 | |
417 | // diff_src follows the same layout we have for net_src |
418 | auto user_leftmost_diff_src_layer_md |
419 | = net_diff_src_memory.get_desc().submemory_desc( |
420 | leftmost_src_layer_dims, {0, 0, 0}); // t, n, c offsets |
421 | auto user_rightmost_diff_src_layer_md |
422 | = net_diff_src_memory.get_desc().submemory_desc( |
423 | rightmost_src_layer_dims, |
424 | {leftmost_seq_length, 0, 0}); // t, n, c offsets |
425 | auto leftmost_diff_src_layer_memory = net_diff_src_memory; |
426 | auto rightmost_diff_src_layer_memory = net_diff_src_memory; |
427 | |
428 | // User-provided memory for backpropagation by weights |
429 | std::vector<float> user_common_diff_weights_layer( |
430 | tz_volume(common_weights_layer_dims), 1.0f); |
431 | auto user_common_diff_weights_layer_memory = dnnl::memory( |
432 | formatted_md(common_weights_layer_dims, tag::ldigo), eng); |
433 | write_to_dnnl_memory(user_common_diff_weights_layer.data(), |
434 | user_common_diff_weights_layer_memory); |
435 | |
436 | std::vector<float> user_common_diff_bias(tz_volume(common_bias_dims), 1.0f); |
437 | auto user_common_diff_bias_memory |
438 | = dnnl::memory(formatted_md(common_bias_dims, tag::ldgo), eng); |
439 | write_to_dnnl_memory( |
440 | user_common_diff_bias.data(), user_common_diff_bias_memory); |
441 | |
442 | // User-provided input to the backward primitive. |
443 | // To be updated by the user after forward pass using some cost function. |
444 | memory::dims net_diff_dst_dims = { |
445 | T0, // time |
446 | N0 + N1, // n |
447 | common_feature_size // c |
448 | }; |
449 | // Suppose user data is in tnc format. |
450 | std::vector<float> net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f); |
451 | auto net_diff_dst_memory |
452 | = dnnl::memory(formatted_md(net_diff_dst_dims, tag::tnc), eng); |
453 | write_to_dnnl_memory(net_diff_dst.data(), net_diff_dst_memory); |
454 | // diff_dst_layer memory of the leftmost and rightmost RNN primitives |
455 | // are accessed through the respective sub-memory in larger memory. |
456 | // View primitives compute the strides to accommodate for padding. |
457 | auto user_leftmost_diff_dst_layer_md |
458 | = net_diff_dst_memory.get_desc().submemory_desc( |
459 | leftmost_dst_layer_dims, {0, 0, 0}); // t, n, c offsets |
460 | auto user_rightmost_diff_dst_layer_md |
461 | = net_diff_dst_memory.get_desc().submemory_desc( |
462 | rightmost_dst_layer_dims, |
463 | {leftmost_seq_length, 0, 0}); // t, n, c offsets |
464 | auto leftmost_diff_dst_layer_memory = net_diff_dst_memory; |
465 | auto rightmost_diff_dst_layer_memory = net_diff_dst_memory; |
466 | |
467 | // Backward leftmost primitive descriptor |
468 | auto leftmost_bwd_prim_desc = lstm_backward::primitive_desc(eng, // engine |
469 | prop_kind::backward, // aprop_kind |
470 | rnn_direction::unidirectional_left2right, // direction |
471 | user_leftmost_src_layer_md, // src_layer_desc |
472 | memory::desc(), // src_iter_desc |
473 | memory::desc(), // src_iter_c_desc |
474 | generic_md(common_weights_layer_dims), // weights_layer_desc |
475 | generic_md(common_weights_iter_dims), // weights_iter_desc |
476 | generic_md(common_bias_dims), // bias_desc |
477 | formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc |
478 | generic_md(leftmost_dst_iter_dims), // dst_iter_desc |
479 | generic_md(leftmost_dst_iter_c_dims), // dst_iter_c_desc |
480 | user_leftmost_diff_src_layer_md, // diff_src_layer_desc |
481 | memory::desc(), // diff_src_iter_desc |
482 | memory::desc(), // diff_src_iter_c_desc |
483 | generic_md(common_weights_layer_dims), // diff_weights_layer_desc |
484 | generic_md(common_weights_iter_dims), // diff_weights_iter_desc |
485 | generic_md(common_bias_dims), // diff_bias_desc |
486 | user_leftmost_diff_dst_layer_md, // diff_dst_layer_desc |
487 | generic_md(leftmost_dst_iter_dims), // diff_dst_iter_desc |
488 | generic_md(leftmost_dst_iter_c_dims), // diff_dst_iter_c_desc |
489 | leftmost_prim_desc // hint from forward pass |
490 | ); |
491 | |
492 | // As the batch dimensions are different between leftmost and rightmost |
493 | // we need to use a sub-memory. rightmost needs less memory, so it will |
494 | // be a sub-memory of leftmost. |
495 | auto leftmost_diff_dst_iter_memory |
496 | = dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_desc(), eng); |
497 | auto leftmost_diff_dst_iter_c_memory |
498 | = dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_c_desc(), eng); |
499 | |
500 | auto rightmost_diff_src_iter_md |
501 | = leftmost_diff_dst_iter_memory.get_desc().submemory_desc( |
502 | rightmost_src_iter_dims, |
503 | {0, 0, 0, 0}); // l, d, n, c offsets |
504 | auto rightmost_diff_src_iter_memory = leftmost_diff_dst_iter_memory; |
505 | |
506 | auto rightmost_diff_src_iter_c_md |
507 | = leftmost_diff_dst_iter_c_memory.get_desc().submemory_desc( |
508 | rightmost_src_iter_c_dims, |
509 | {0, 0, 0, 0}); // l, d, n, c offsets |
510 | auto rightmost_diff_src_iter_c_memory = leftmost_diff_dst_iter_c_memory; |
511 | |
512 | // Backward rightmost primitive descriptor |
513 | auto rightmost_bwd_prim_desc = lstm_backward::primitive_desc(eng, // engine |
514 | prop_kind::backward, // aprop_kind |
515 | rnn_direction::unidirectional_left2right, // direction |
516 | user_rightmost_src_layer_md, // src_layer_desc |
517 | generic_md(rightmost_src_iter_dims), // src_iter_desc |
518 | generic_md(rightmost_src_iter_c_dims), // src_iter_c_desc |
519 | generic_md(common_weights_layer_dims), // weights_layer_desc |
520 | generic_md(common_weights_iter_dims), // weights_iter_desc |
521 | generic_md(common_bias_dims), // bias_desc |
522 | formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc |
523 | memory::desc(), // dst_iter_desc |
524 | memory::desc(), // dst_iter_c_desc |
525 | user_rightmost_diff_src_layer_md, // diff_src_layer_desc |
526 | rightmost_diff_src_iter_md, // diff_src_iter_desc |
527 | rightmost_diff_src_iter_c_md, // diff_src_iter_c_desc |
528 | generic_md(common_weights_layer_dims), // diff_weights_layer_desc |
529 | generic_md(common_weights_iter_dims), // diff_weights_iter_desc |
530 | generic_md(common_bias_dims), // diff_bias_desc |
531 | user_rightmost_diff_dst_layer_md, // diff_dst_layer_desc |
532 | memory::desc(), // diff_dst_iter_desc |
533 | memory::desc(), // diff_dst_iter_c_desc |
534 | rightmost_prim_desc // hint from forward pass |
535 | ); |
536 | |
537 | // |
538 | // Memory for backward pass |
539 | // |
540 | |
541 | // src layer uses the same memory as forward |
542 | auto leftmost_src_layer_bwd_memory = leftmost_src_layer_memory; |
543 | auto rightmost_src_layer_bwd_memory = rightmost_src_layer_memory; |
544 | |
545 | // Memory for weights and biases for backward pass |
546 | // Try to use the same memory between forward and backward, but |
547 | // sometimes reorders are needed. |
548 | auto common_weights_layer_bwd_memory = common_weights_layer_memory; |
549 | if (leftmost_bwd_prim_desc.weights_layer_desc() |
550 | != leftmost_prim_desc.weights_layer_desc()) { |
551 | common_weights_layer_bwd_memory |
552 | = memory(leftmost_bwd_prim_desc.weights_layer_desc(), eng); |
553 | reorder(common_weights_layer_memory, common_weights_layer_bwd_memory) |
554 | .execute(s, common_weights_layer_memory, |
555 | common_weights_layer_bwd_memory); |
556 | } |
557 | |
558 | auto common_weights_iter_bwd_memory = common_weights_iter_memory; |
559 | if (leftmost_bwd_prim_desc.weights_iter_desc() |
560 | != leftmost_prim_desc.weights_iter_desc()) { |
561 | common_weights_iter_bwd_memory |
562 | = memory(leftmost_bwd_prim_desc.weights_iter_desc(), eng); |
563 | reorder(common_weights_iter_memory, common_weights_iter_bwd_memory) |
564 | .execute(s, common_weights_iter_memory, |
565 | common_weights_iter_bwd_memory); |
566 | } |
567 | |
568 | auto common_bias_bwd_memory = common_bias_memory; |
569 | if (leftmost_bwd_prim_desc.bias_desc() != common_bias_memory.get_desc()) { |
570 | common_bias_bwd_memory |
571 | = dnnl::memory(leftmost_bwd_prim_desc.bias_desc(), eng); |
572 | reorder(common_bias_memory, common_bias_bwd_memory) |
573 | .execute(s, common_bias_memory, common_bias_bwd_memory); |
574 | } |
575 | |
576 | // diff_weights and biases |
577 | auto common_diff_weights_layer_memory |
578 | = user_common_diff_weights_layer_memory; |
579 | auto reorder_common_diff_weights_layer = false; |
580 | if (leftmost_bwd_prim_desc.diff_weights_layer_desc() |
581 | != common_diff_weights_layer_memory.get_desc()) { |
582 | common_diff_weights_layer_memory = dnnl::memory( |
583 | leftmost_bwd_prim_desc.diff_weights_layer_desc(), eng); |
584 | reorder_common_diff_weights_layer = true; |
585 | } |
586 | |
587 | auto common_diff_bias_memory = user_common_diff_bias_memory; |
588 | auto reorder_common_diff_bias = false; |
589 | if (leftmost_bwd_prim_desc.diff_bias_desc() |
590 | != common_diff_bias_memory.get_desc()) { |
591 | common_diff_bias_memory |
592 | = dnnl::memory(leftmost_bwd_prim_desc.diff_bias_desc(), eng); |
593 | reorder_common_diff_bias = true; |
594 | } |
595 | |
596 | // dst_layer memory for backward pass |
597 | auto leftmost_dst_layer_bwd_memory = leftmost_dst_layer_memory; |
598 | if (leftmost_bwd_prim_desc.dst_layer_desc() |
599 | != leftmost_dst_layer_bwd_memory.get_desc()) { |
600 | leftmost_dst_layer_bwd_memory |
601 | = dnnl::memory(leftmost_bwd_prim_desc.dst_layer_desc(), eng); |
602 | reorder(leftmost_dst_layer_memory, leftmost_dst_layer_bwd_memory) |
603 | .execute(s, leftmost_dst_layer_memory, |
604 | leftmost_dst_layer_bwd_memory); |
605 | } |
606 | |
607 | auto rightmost_dst_layer_bwd_memory = rightmost_dst_layer_memory; |
608 | if (rightmost_bwd_prim_desc.dst_layer_desc() |
609 | != rightmost_dst_layer_bwd_memory.get_desc()) { |
610 | rightmost_dst_layer_bwd_memory |
611 | = dnnl::memory(rightmost_bwd_prim_desc.dst_layer_desc(), eng); |
612 | reorder(rightmost_dst_layer_memory, rightmost_dst_layer_bwd_memory) |
613 | .execute(s, rightmost_dst_layer_memory, |
614 | rightmost_dst_layer_bwd_memory); |
615 | } |
616 | |
617 | // Similar to forward, the backward primitives are connected |
618 | // via "iter" parameters. |
619 | auto common_diff_weights_iter_memory = dnnl::memory( |
620 | leftmost_bwd_prim_desc.diff_weights_iter_desc(), eng); |
621 | |
622 | auto leftmost_dst_iter_bwd_memory = leftmost_dst_iter_memory; |
623 | if (leftmost_bwd_prim_desc.dst_iter_desc() |
624 | != leftmost_dst_iter_bwd_memory.get_desc()) { |
625 | leftmost_dst_iter_bwd_memory |
626 | = dnnl::memory(leftmost_bwd_prim_desc.dst_iter_desc(), eng); |
627 | reorder(leftmost_dst_iter_memory, leftmost_dst_iter_bwd_memory) |
628 | .execute(s, leftmost_dst_iter_memory, |
629 | leftmost_dst_iter_bwd_memory); |
630 | } |
631 | |
632 | auto leftmost_dst_iter_c_bwd_memory = leftmost_dst_iter_c_memory; |
633 | if (leftmost_bwd_prim_desc.dst_iter_c_desc() |
634 | != leftmost_dst_iter_c_bwd_memory.get_desc()) { |
635 | leftmost_dst_iter_c_bwd_memory |
636 | = dnnl::memory(leftmost_bwd_prim_desc.dst_iter_c_desc(), eng); |
637 | reorder(leftmost_dst_iter_c_memory, leftmost_dst_iter_c_bwd_memory) |
638 | .execute(s, leftmost_dst_iter_c_memory, |
639 | leftmost_dst_iter_c_bwd_memory); |
640 | } |
641 | |
642 | // Construct the RNN primitive objects for backward |
643 | lstm_backward rightmost_layer_bwd(rightmost_bwd_prim_desc); |
644 | rightmost_layer_bwd.execute(s, |
645 | {{DNNL_ARG_SRC_LAYER, rightmost_src_layer_bwd_memory}, |
646 | {DNNL_ARG_SRC_ITER, rightmost_src_iter_memory}, |
647 | {DNNL_ARG_SRC_ITER_C, rightmost_src_iter_c_memory}, |
648 | {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory}, |
649 | {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory}, |
650 | {DNNL_ARG_BIAS, common_bias_bwd_memory}, |
651 | {DNNL_ARG_DST_LAYER, rightmost_dst_layer_bwd_memory}, |
652 | {DNNL_ARG_DIFF_SRC_LAYER, rightmost_diff_src_layer_memory}, |
653 | {DNNL_ARG_DIFF_SRC_ITER, rightmost_diff_src_iter_memory}, |
654 | {DNNL_ARG_DIFF_SRC_ITER_C, |
655 | rightmost_diff_src_iter_c_memory}, |
656 | {DNNL_ARG_DIFF_WEIGHTS_LAYER, |
657 | common_diff_weights_layer_memory}, |
658 | {DNNL_ARG_DIFF_WEIGHTS_ITER, |
659 | common_diff_weights_iter_memory}, |
660 | {DNNL_ARG_DIFF_BIAS, common_diff_bias_memory}, |
661 | {DNNL_ARG_DIFF_DST_LAYER, rightmost_diff_dst_layer_memory}, |
662 | {DNNL_ARG_WORKSPACE, rightmost_workspace_memory}}); |
663 | |
664 | lstm_backward leftmost_layer_bwd(leftmost_bwd_prim_desc); |
665 | leftmost_layer_bwd.execute(s, |
666 | {{DNNL_ARG_SRC_LAYER, leftmost_src_layer_bwd_memory}, |
667 | {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory}, |
668 | {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory}, |
669 | {DNNL_ARG_BIAS, common_bias_bwd_memory}, |
670 | {DNNL_ARG_DST_LAYER, leftmost_dst_layer_bwd_memory}, |
671 | {DNNL_ARG_DST_ITER, leftmost_dst_iter_bwd_memory}, |
672 | {DNNL_ARG_DST_ITER_C, leftmost_dst_iter_c_bwd_memory}, |
673 | {DNNL_ARG_DIFF_SRC_LAYER, leftmost_diff_src_layer_memory}, |
674 | {DNNL_ARG_DIFF_WEIGHTS_LAYER, |
675 | common_diff_weights_layer_memory}, |
676 | {DNNL_ARG_DIFF_WEIGHTS_ITER, |
677 | common_diff_weights_iter_memory}, |
678 | {DNNL_ARG_DIFF_BIAS, common_diff_bias_memory}, |
679 | {DNNL_ARG_DIFF_DST_LAYER, leftmost_diff_dst_layer_memory}, |
680 | {DNNL_ARG_DIFF_DST_ITER, leftmost_diff_dst_iter_memory}, |
681 | {DNNL_ARG_DIFF_DST_ITER_C, leftmost_diff_dst_iter_c_memory}, |
682 | {DNNL_ARG_WORKSPACE, leftmost_workspace_memory}}); |
683 | if (reorder_common_diff_weights_layer) { |
684 | reorder(common_diff_weights_layer_memory, |
685 | user_common_diff_weights_layer_memory) |
686 | .execute(s, common_diff_weights_layer_memory, |
687 | user_common_diff_weights_layer_memory); |
688 | } |
689 | |
690 | if (reorder_common_diff_bias) { |
691 | reorder(common_diff_bias_memory, user_common_diff_bias_memory) |
692 | .execute(s, common_diff_bias_memory, |
693 | user_common_diff_bias_memory); |
694 | } |
695 | |
696 | // |
697 | // User updates weights and bias using diffs |
698 | // |
699 | |
700 | s.wait(); |
701 | } |
702 | |
703 | int main(int argc, char **argv) { |
704 | return handle_example_errors(simple_net, parse_engine_kind(argc, argv)); |
705 | } |
706 | |