1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example cpu_rnn_inference_f32.cpp
18/// @copybrief cpu_rnn_inference_f32_cpp
19/// > Annotated version: @ref cpu_rnn_inference_f32_cpp
20
21/// @page cpu_rnn_inference_f32_cpp RNN f32 inference example
22/// This C++ API example demonstrates how to build GNMT model inference.
23///
24/// > Example code: @ref cpu_rnn_inference_f32.cpp
25///
26/// For the encoder we use:
27/// - one primitive for the bidirectional layer of the encoder
28/// - one primitive for all remaining unidirectional layers in the encoder
29/// For the decoder we use:
30/// - one primitive for the first iteration
31/// - one primitive for all subsequent iterations in the decoder. Note that
32/// in this example, this primitive computes the states in place.
33/// - the attention mechanism is implemented separately as there is no support
34/// for the context vectors in oneDNN yet
35
36#include <assert.h>
37
38#include <cstring>
39#include <iostream>
40#include <math.h>
41#include <numeric>
42#include <string>
43
44#include "oneapi/dnnl/dnnl.hpp"
45
46#include "example_utils.hpp"
47
48using namespace dnnl;
49
50using dim_t = dnnl::memory::dim;
51
52const dim_t batch = 32;
53const dim_t src_seq_length_max = 10;
54const dim_t tgt_seq_length_max = 10;
55
56const dim_t feature_size = 256;
57
58const dim_t enc_bidir_n_layers = 1;
59const dim_t enc_unidir_n_layers = 3;
60const dim_t dec_n_layers = 4;
61
62const int lstm_n_gates = 4;
63std::vector<float> weighted_src_layer(batch *feature_size, 1.0f);
64std::vector<float> alignment_model(
65 src_seq_length_max *batch *feature_size, 1.0f);
66std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
67std::vector<float> exp_sums(batch, 1.0f);
68
69void compute_weighted_annotations(float *weighted_annotations,
70 dim_t src_seq_length_max, dim_t batch, dim_t feature_size,
71 float *weights_annot, float *annotations) {
72 // annotations(aka enc_dst_layer) is (t, n, 2c)
73 // weights_annot is (2c, c)
74
75 // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
76 dim_t num_weighted_annotations = src_seq_length_max * batch;
77 dnnl_sgemm('N', 'N', num_weighted_annotations, feature_size, feature_size,
78 1.f, annotations, feature_size, weights_annot, feature_size, 0.f,
79 weighted_annotations, feature_size);
80}
81
82void compute_attention(float *context_vectors, dim_t src_seq_length_max,
83 dim_t batch, dim_t feature_size, float *weights_src_layer,
84 float *dec_src_layer, float *annotations, float *weighted_annotations,
85 float *weights_alignments) {
86 // dst_iter : (n, c) matrix
87 // src_layer: (n, c) matrix
88 // weighted_annotations (t, n, c)
89
90 // weights_yi is (c, c)
91 // weights_ai is (c, 1)
92 // tmp[i] is (n, c)
93 // a[i] is (n, 1)
94 // p is (n, 1)
95
96 // first we precompute the weighted_dec_src_layer
97 dnnl_sgemm('N', 'N', batch, feature_size, feature_size, 1.f, dec_src_layer,
98 feature_size, weights_src_layer, feature_size, 0.f,
99 weighted_src_layer.data(), feature_size);
100
101 // then we compute the alignment model
102 float *alignment_model_ptr = alignment_model.data();
103
104 PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
105 for (dim_t i = 0; i < src_seq_length_max; i++) {
106 for (dim_t j = 0; j < batch * feature_size; j++)
107 alignment_model_ptr[i * batch * feature_size + j] = tanhf(
108 weighted_src_layer[j]
109 + weighted_annotations[i * batch * feature_size + j]);
110 }
111
112 // gemv with alignments weights. the resulting alignments are in alignments
113 dim_t num_weighted_annotations = src_seq_length_max * batch;
114 dnnl_sgemm('N', 'N', num_weighted_annotations, 1, feature_size, 1.f,
115 alignment_model_ptr, feature_size, weights_alignments, 1, 0.f,
116 alignments.data(), 1);
117
118 // softmax on alignments. the resulting context weights are in alignments
119 PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(1)
120 for (dim_t i = 0; i < batch; i++)
121 exp_sums[i] = 0.0f;
122
123 PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(1)
124 for (dim_t j = 0; j < batch; j++) {
125 for (dim_t i = 0; i < src_seq_length_max; i++) {
126 alignments[i * batch + j] = expf(alignments[i * batch + j]);
127 exp_sums[j] += alignments[i * batch + j];
128 }
129 }
130
131 PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
132 for (dim_t i = 0; i < src_seq_length_max; i++)
133 for (dim_t j = 0; j < batch; j++)
134 alignments[i * batch + j] /= exp_sums[j];
135
136 // then we compute the context vectors
137 PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
138 for (dim_t i = 0; i < batch; i++)
139 for (dim_t j = 0; j < feature_size; j++)
140 context_vectors[i * (feature_size + feature_size) + feature_size
141 + j]
142 = 0.0f;
143
144 PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
145 for (dim_t i = 0; i < batch; i++)
146 for (dim_t j = 0; j < feature_size; j++)
147 for (dim_t k = 0; k < src_seq_length_max; k++)
148 context_vectors[i * (feature_size + feature_size) + feature_size
149 + j]
150 += alignments[k * batch + i]
151 * annotations[j + feature_size * (i + batch * k)];
152}
153
154void copy_context(
155 float *src_iter, dim_t n_layers, dim_t batch, dim_t feature_size) {
156 // we copy the context from the first layer to all other layers
157 PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(3)
158 for (dim_t k = 1; k < n_layers; k++)
159 for (dim_t j = 0; j < batch; j++)
160 for (dim_t i = 0; i < feature_size; i++)
161 src_iter[(k * batch + j) * (feature_size + feature_size)
162 + feature_size + i]
163 = src_iter[j * (feature_size + feature_size)
164 + feature_size + i];
165}
166
167void simple_net() {
168 ///
169 /// Initialize a CPU engine and stream. The last parameter in the call represents
170 /// the index of the engine.
171 /// @snippet cpu_rnn_inference_f32.cpp Initialize engine and stream
172 ///
173 //[Initialize engine and stream]
174 auto cpu_engine = engine(engine::kind::cpu, 0);
175 stream s(cpu_engine);
176 //[Initialize engine and stream]
177 ///
178 /// Declare encoder net and decoder net
179 /// @snippet cpu_rnn_inference_f32.cpp declare net
180 ///
181 //[declare net]
182 std::vector<primitive> encoder_net, decoder_net;
183 std::vector<std::unordered_map<int, memory>> encoder_net_args,
184 decoder_net_args;
185
186 std::vector<float> net_src(batch * src_seq_length_max * feature_size, 1.0f);
187 std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 1.0f);
188 //[declare net]
189 ///
190 /// **Encoder**
191 ///
192 ///
193 /// Initialize Encoder Memory
194 /// @snippet cpu_rnn_inference_f32.cpp Initialize encoder memory
195 ///
196 //[Initialize encoder memory]
197 memory::dims enc_bidir_src_layer_tz
198 = {src_seq_length_max, batch, feature_size};
199 memory::dims enc_bidir_weights_layer_tz
200 = {enc_bidir_n_layers, 2, feature_size, lstm_n_gates, feature_size};
201 memory::dims enc_bidir_weights_iter_tz
202 = {enc_bidir_n_layers, 2, feature_size, lstm_n_gates, feature_size};
203 memory::dims enc_bidir_bias_tz
204 = {enc_bidir_n_layers, 2, lstm_n_gates, feature_size};
205 memory::dims enc_bidir_dst_layer_tz
206 = {src_seq_length_max, batch, 2 * feature_size};
207 //[Initialize encoder memory]
208
209 ///
210 ///
211 /// Encoder: 1 bidirectional layer and 7 unidirectional layers
212 ///
213
214 std::vector<float> user_enc_bidir_wei_layer(
215 enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
216 1.0f);
217 std::vector<float> user_enc_bidir_wei_iter(
218 enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
219 1.0f);
220 std::vector<float> user_enc_bidir_bias(
221 enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
222
223 ///
224 /// Create the memory for user data
225 /// @snippet cpu_rnn_inference_f32.cpp data memory creation
226 ///
227 //[data memory creation]
228 auto user_enc_bidir_src_layer_md = dnnl::memory::desc(
229 {enc_bidir_src_layer_tz}, dnnl::memory::data_type::f32,
230 dnnl::memory::format_tag::tnc);
231
232 auto user_enc_bidir_wei_layer_md = dnnl::memory::desc(
233 {enc_bidir_weights_layer_tz}, dnnl::memory::data_type::f32,
234 dnnl::memory::format_tag::ldigo);
235
236 auto user_enc_bidir_wei_iter_md = dnnl::memory::desc(
237 {enc_bidir_weights_iter_tz}, dnnl::memory::data_type::f32,
238 dnnl::memory::format_tag::ldigo);
239
240 auto user_enc_bidir_bias_md = dnnl::memory::desc({enc_bidir_bias_tz},
241 dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo);
242
243 auto user_enc_bidir_src_layer_memory = dnnl::memory(
244 user_enc_bidir_src_layer_md, cpu_engine, net_src.data());
245 auto user_enc_bidir_wei_layer_memory
246 = dnnl::memory(user_enc_bidir_wei_layer_md, cpu_engine,
247 user_enc_bidir_wei_layer.data());
248 auto user_enc_bidir_wei_iter_memory
249 = dnnl::memory(user_enc_bidir_wei_iter_md, cpu_engine,
250 user_enc_bidir_wei_iter.data());
251 auto user_enc_bidir_bias_memory = dnnl::memory(
252 user_enc_bidir_bias_md, cpu_engine, user_enc_bidir_bias.data());
253
254 //[data memory creation]
255 ///
256 /// Create memory descriptors for RNN data w/o specified layout
257 /// @snippet cpu_rnn_inference_f32.cpp memory desc for RNN data
258 ///
259 //[memory desc for RNN data]
260 auto enc_bidir_wei_layer_md = memory::desc({enc_bidir_weights_layer_tz},
261 memory::data_type::f32, memory::format_tag::any);
262
263 auto enc_bidir_wei_iter_md = memory::desc({enc_bidir_weights_iter_tz},
264 memory::data_type::f32, memory::format_tag::any);
265
266 auto enc_bidir_dst_layer_md = memory::desc({enc_bidir_dst_layer_tz},
267 memory::data_type::f32, memory::format_tag::any);
268
269 //[memory desc for RNN data]
270 ///
271 /// Create bidirectional RNN
272 /// @snippet cpu_rnn_inference_f32.cpp create rnn
273 ///
274 //[create rnn]
275
276 auto enc_bidir_prim_desc = lstm_forward::primitive_desc(cpu_engine,
277 prop_kind::forward_inference, rnn_direction::bidirectional_concat,
278 user_enc_bidir_src_layer_md, memory::desc(), memory::desc(),
279 enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
280 user_enc_bidir_bias_md, enc_bidir_dst_layer_md, memory::desc(),
281 memory::desc());
282 //[create rnn]
283
284 ///
285 /// Create memory for input data and use reorders to reorder user data
286 /// to internal representation
287 /// @snippet cpu_rnn_inference_f32.cpp reorder input data
288 ///
289 //[reorder input data]
290 auto enc_bidir_wei_layer_memory
291 = memory(enc_bidir_prim_desc.weights_layer_desc(), cpu_engine);
292 auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc(
293 user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory);
294 reorder(enc_bidir_wei_layer_reorder_pd)
295 .execute(s, user_enc_bidir_wei_layer_memory,
296 enc_bidir_wei_layer_memory);
297 //[reorder input data]
298
299 auto enc_bidir_wei_iter_memory
300 = memory(enc_bidir_prim_desc.weights_iter_desc(), cpu_engine);
301 auto enc_bidir_wei_iter_reorder_pd = reorder::primitive_desc(
302 user_enc_bidir_wei_iter_memory, enc_bidir_wei_iter_memory);
303 reorder(enc_bidir_wei_iter_reorder_pd)
304 .execute(s, user_enc_bidir_wei_iter_memory,
305 enc_bidir_wei_iter_memory);
306
307 auto enc_bidir_dst_layer_memory
308 = dnnl::memory(enc_bidir_prim_desc.dst_layer_desc(), cpu_engine);
309
310 ///
311 /// Encoder : add the bidirectional rnn primitive with related arguments into encoder_net
312 /// @snippet cpu_rnn_inference_f32.cpp push bi rnn to encoder net
313 ///
314 //[push bi rnn to encoder net]
315 encoder_net.push_back(lstm_forward(enc_bidir_prim_desc));
316 encoder_net_args.push_back(
317 {{DNNL_ARG_SRC_LAYER, user_enc_bidir_src_layer_memory},
318 {DNNL_ARG_WEIGHTS_LAYER, enc_bidir_wei_layer_memory},
319 {DNNL_ARG_WEIGHTS_ITER, enc_bidir_wei_iter_memory},
320 {DNNL_ARG_BIAS, user_enc_bidir_bias_memory},
321 {DNNL_ARG_DST_LAYER, enc_bidir_dst_layer_memory}});
322 //[push bi rnn to encoder net]
323
324 ///
325 /// Encoder: unidirectional layers
326 ///
327 ///
328 /// First unidirectinal layer scales 2 * feature_size output of bidirectional
329 /// layer to feature_size output
330 /// @snippet cpu_rnn_inference_f32.cpp first uni layer
331 ///
332 //[first uni layer]
333 std::vector<float> user_enc_uni_first_wei_layer(
334 1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f);
335 std::vector<float> user_enc_uni_first_wei_iter(
336 1 * 1 * feature_size * lstm_n_gates * feature_size, 1.0f);
337 std::vector<float> user_enc_uni_first_bias(
338 1 * 1 * lstm_n_gates * feature_size, 1.0f);
339 //[first uni layer]
340 memory::dims user_enc_uni_first_wei_layer_dims
341 = {1, 1, 2 * feature_size, lstm_n_gates, feature_size};
342 memory::dims user_enc_uni_first_wei_iter_dims
343 = {1, 1, feature_size, lstm_n_gates, feature_size};
344 memory::dims user_enc_uni_first_bias_dims
345 = {1, 1, lstm_n_gates, feature_size};
346 memory::dims enc_uni_first_dst_layer_dims
347 = {src_seq_length_max, batch, feature_size};
348 auto user_enc_uni_first_wei_layer_md = dnnl::memory::desc(
349 {user_enc_uni_first_wei_layer_dims}, dnnl::memory::data_type::f32,
350 dnnl::memory::format_tag::ldigo);
351 auto user_enc_uni_first_wei_iter_md = dnnl::memory::desc(
352 {user_enc_uni_first_wei_iter_dims}, dnnl::memory::data_type::f32,
353 dnnl::memory::format_tag::ldigo);
354 auto user_enc_uni_first_bias_md = dnnl::memory::desc(
355 {user_enc_uni_first_bias_dims}, dnnl::memory::data_type::f32,
356 dnnl::memory::format_tag::ldgo);
357 auto user_enc_uni_first_wei_layer_memory
358 = dnnl::memory(user_enc_uni_first_wei_layer_md, cpu_engine,
359 user_enc_uni_first_wei_layer.data());
360 auto user_enc_uni_first_wei_iter_memory
361 = dnnl::memory(user_enc_uni_first_wei_iter_md, cpu_engine,
362 user_enc_uni_first_wei_iter.data());
363 auto user_enc_uni_first_bias_memory
364 = dnnl::memory(user_enc_uni_first_bias_md, cpu_engine,
365 user_enc_uni_first_bias.data());
366
367 auto enc_uni_first_wei_layer_md
368 = memory::desc({user_enc_uni_first_wei_layer_dims},
369 memory::data_type::f32, memory::format_tag::any);
370 auto enc_uni_first_wei_iter_md
371 = memory::desc({user_enc_uni_first_wei_iter_dims},
372 memory::data_type::f32, memory::format_tag::any);
373 auto enc_uni_first_dst_layer_md
374 = memory::desc({enc_uni_first_dst_layer_dims},
375 memory::data_type::f32, memory::format_tag::any);
376
377 // TODO: add support for residual connections
378 // should it be a set residual in pd or a field to set manually?
379 // should be an integer to specify at which layer to start
380 ///
381 /// Encoder : Create unidirection RNN for first cell
382 /// @snippet cpu_rnn_inference_f32.cpp create uni first
383 ///
384 //[create uni first]
385 auto enc_uni_first_prim_desc = lstm_forward::primitive_desc(cpu_engine,
386 prop_kind::forward_inference,
387 rnn_direction::unidirectional_left2right, enc_bidir_dst_layer_md,
388 memory::desc(), memory::desc(), enc_uni_first_wei_layer_md,
389 enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
390 enc_uni_first_dst_layer_md, memory::desc(), memory::desc());
391
392 //[create uni first]
393 auto enc_uni_first_wei_layer_memory
394 = memory(enc_uni_first_prim_desc.weights_layer_desc(), cpu_engine);
395 auto enc_uni_first_wei_layer_reorder_pd
396 = reorder::primitive_desc(user_enc_uni_first_wei_layer_memory,
397 enc_uni_first_wei_layer_memory);
398 reorder(enc_uni_first_wei_layer_reorder_pd)
399 .execute(s, user_enc_uni_first_wei_layer_memory,
400 enc_uni_first_wei_layer_memory);
401
402 auto enc_uni_first_wei_iter_memory
403 = memory(enc_uni_first_prim_desc.weights_iter_desc(), cpu_engine);
404 auto enc_uni_first_wei_iter_reorder_pd = reorder::primitive_desc(
405 user_enc_uni_first_wei_iter_memory, enc_uni_first_wei_iter_memory);
406 reorder(enc_uni_first_wei_iter_reorder_pd)
407 .execute(s, user_enc_uni_first_wei_iter_memory,
408 enc_uni_first_wei_iter_memory);
409
410 auto enc_uni_first_dst_layer_memory = dnnl::memory(
411 enc_uni_first_prim_desc.dst_layer_desc(), cpu_engine);
412
413 /// Encoder : add the first unidirectional rnn primitive with related
414 /// arguments into encoder_net
415 ///
416 /// @snippet cpu_rnn_inference_f32.cpp push first uni rnn to encoder net
417 ///
418 //[push first uni rnn to encoder net]
419 // TODO: add a reorder when they will be available
420 encoder_net.push_back(lstm_forward(enc_uni_first_prim_desc));
421 encoder_net_args.push_back(
422 {{DNNL_ARG_SRC_LAYER, enc_bidir_dst_layer_memory},
423 {DNNL_ARG_WEIGHTS_LAYER, enc_uni_first_wei_layer_memory},
424 {DNNL_ARG_WEIGHTS_ITER, enc_uni_first_wei_iter_memory},
425 {DNNL_ARG_BIAS, user_enc_uni_first_bias_memory},
426 {DNNL_ARG_DST_LAYER, enc_uni_first_dst_layer_memory}});
427 //[push first uni rnn to encoder net]
428
429 ///
430 /// Encoder : Remaining unidirectional layers
431 /// @snippet cpu_rnn_inference_f32.cpp remaining uni layers
432 ///
433 //[remaining uni layers]
434 std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
435 * feature_size * lstm_n_gates * feature_size,
436 1.0f);
437 std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
438 * feature_size * lstm_n_gates * feature_size,
439 1.0f);
440 std::vector<float> user_enc_uni_bias(
441 (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);
442 //[remaining uni layers]
443 memory::dims user_enc_uni_wei_layer_dims = {(enc_unidir_n_layers - 1), 1,
444 feature_size, lstm_n_gates, feature_size};
445 memory::dims user_enc_uni_wei_iter_dims = {(enc_unidir_n_layers - 1), 1,
446 feature_size, lstm_n_gates, feature_size};
447 memory::dims user_enc_uni_bias_dims
448 = {(enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size};
449 memory::dims enc_dst_layer_dims = {src_seq_length_max, batch, feature_size};
450 auto user_enc_uni_wei_layer_md = dnnl::memory::desc(
451 {user_enc_uni_wei_layer_dims}, dnnl::memory::data_type::f32,
452 dnnl::memory::format_tag::ldigo);
453 auto user_enc_uni_wei_iter_md = dnnl::memory::desc(
454 {user_enc_uni_wei_iter_dims}, dnnl::memory::data_type::f32,
455 dnnl::memory::format_tag::ldigo);
456 auto user_enc_uni_bias_md = dnnl::memory::desc({user_enc_uni_bias_dims},
457 dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo);
458 auto user_enc_uni_wei_layer_memory = dnnl::memory(user_enc_uni_wei_layer_md,
459 cpu_engine, user_enc_uni_wei_layer.data());
460 auto user_enc_uni_wei_iter_memory = dnnl::memory(
461 user_enc_uni_wei_iter_md, cpu_engine, user_enc_uni_wei_iter.data());
462 auto user_enc_uni_bias_memory = dnnl::memory(
463 user_enc_uni_bias_md, cpu_engine, user_enc_uni_bias.data());
464
465 auto enc_uni_wei_layer_md = memory::desc({user_enc_uni_wei_layer_dims},
466 memory::data_type::f32, memory::format_tag::any);
467 auto enc_uni_wei_iter_md = memory::desc({user_enc_uni_wei_iter_dims},
468 memory::data_type::f32, memory::format_tag::any);
469 auto enc_dst_layer_md = memory::desc({enc_dst_layer_dims},
470 memory::data_type::f32, memory::format_tag::any);
471
472 // TODO: add support for residual connections
473 // should it be a set residual in pd or a field to set manually?
474 // should be an integer to specify at which layer to start
475 ///
476 /// Encoder : Create unidirection RNN cell
477 /// @snippet cpu_rnn_inference_f32.cpp create uni rnn
478 ///
479 //[create uni rnn]
480 auto enc_uni_prim_desc = lstm_forward::primitive_desc(cpu_engine,
481 prop_kind::forward_inference,
482 rnn_direction::unidirectional_left2right,
483 enc_uni_first_dst_layer_md, memory::desc(), memory::desc(),
484 enc_uni_wei_layer_md, enc_uni_wei_iter_md, user_enc_uni_bias_md,
485 enc_dst_layer_md, memory::desc(), memory::desc());
486 //[create uni rnn]
487
488 auto enc_uni_wei_layer_memory
489 = memory(enc_uni_prim_desc.weights_layer_desc(), cpu_engine);
490 auto enc_uni_wei_layer_reorder_pd = reorder::primitive_desc(
491 user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory);
492 reorder(enc_uni_wei_layer_reorder_pd)
493 .execute(
494 s, user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory);
495
496 auto enc_uni_wei_iter_memory
497 = memory(enc_uni_prim_desc.weights_iter_desc(), cpu_engine);
498 auto enc_uni_wei_iter_reorder_pd = reorder::primitive_desc(
499 user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory);
500 reorder(enc_uni_wei_iter_reorder_pd)
501 .execute(s, user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory);
502
503 auto enc_dst_layer_memory
504 = dnnl::memory(enc_uni_prim_desc.dst_layer_desc(), cpu_engine);
505
506 // TODO: add a reorder when they will be available
507 ///
508 /// Encoder : add the unidirectional rnn primitive with related arguments into encoder_net
509 /// @snippet cpu_rnn_inference_f32.cpp push uni rnn to encoder net
510 ///
511 //[push uni rnn to encoder net]
512 encoder_net.push_back(lstm_forward(enc_uni_prim_desc));
513 encoder_net_args.push_back(
514 {{DNNL_ARG_SRC_LAYER, enc_uni_first_dst_layer_memory},
515 {DNNL_ARG_WEIGHTS_LAYER, enc_uni_wei_layer_memory},
516 {DNNL_ARG_WEIGHTS_ITER, enc_uni_wei_iter_memory},
517 {DNNL_ARG_BIAS, user_enc_uni_bias_memory},
518 {DNNL_ARG_DST_LAYER, enc_dst_layer_memory}});
519 //[push uni rnn to encoder net]
520 ///
521 /// **Decoder with attention mechanism**
522 ///
523 ///
524 /// Decoder : declare memory dimensions
525 /// @snippet cpu_rnn_inference_f32.cpp dec mem dim
526 ///
527 //[dec mem dim]
528 std::vector<float> user_dec_wei_layer(
529 dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
530 1.0f);
531 std::vector<float> user_dec_wei_iter(dec_n_layers * 1
532 * (feature_size + feature_size) * lstm_n_gates
533 * feature_size,
534 1.0f);
535 std::vector<float> user_dec_bias(
536 dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
537 std::vector<float> user_dec_dst(
538 tgt_seq_length_max * batch * feature_size, 1.0f);
539 std::vector<float> user_weights_attention_src_layer(
540 feature_size * feature_size, 1.0f);
541 std::vector<float> user_weights_annotation(
542 feature_size * feature_size, 1.0f);
543 std::vector<float> user_weights_alignments(feature_size, 1.0f);
544
545 memory::dims user_dec_wei_layer_dims
546 = {dec_n_layers, 1, feature_size, lstm_n_gates, feature_size};
547 memory::dims user_dec_wei_iter_dims = {dec_n_layers, 1,
548 feature_size + feature_size, lstm_n_gates, feature_size};
549 memory::dims user_dec_bias_dims
550 = {dec_n_layers, 1, lstm_n_gates, feature_size};
551
552 memory::dims dec_src_layer_dims = {1, batch, feature_size};
553 memory::dims dec_dst_layer_dims = {1, batch, feature_size};
554 memory::dims dec_dst_iter_c_dims = {dec_n_layers, 1, batch, feature_size};
555 //[dec mem dim]
556
557 /// We will use the same memory for dec_src_iter and dec_dst_iter
558 /// However, dec_src_iter has a context vector but not
559 /// dec_dst_iter.
560 /// To resolve this we will create one memory that holds the
561 /// context vector as well as the both the hidden and cell states.
562 /// The dst_iter will be a sub-memory of this memory.
563 /// Note that the cell state will be padded by
564 /// feature_size values. However, we do not compute or
565 /// access those.
566 /// @snippet cpu_rnn_inference_f32.cpp noctx mem dim
567 //[noctx mem dim]
568 memory::dims dec_dst_iter_dims
569 = {dec_n_layers, 1, batch, feature_size + feature_size};
570 memory::dims dec_dst_iter_noctx_dims
571 = {dec_n_layers, 1, batch, feature_size};
572 //[noctx mem dim]
573
574 ///
575 /// Decoder : create memory description
576 /// @snippet cpu_rnn_inference_f32.cpp dec mem desc
577 ///
578 //[dec mem desc]
579 auto user_dec_wei_layer_md = dnnl::memory::desc({user_dec_wei_layer_dims},
580 dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldigo);
581 auto user_dec_wei_iter_md = dnnl::memory::desc({user_dec_wei_iter_dims},
582 dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldigo);
583 auto user_dec_bias_md = dnnl::memory::desc({user_dec_bias_dims},
584 dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo);
585 auto dec_dst_layer_md = dnnl::memory::desc({dec_dst_layer_dims},
586 dnnl::memory::data_type::f32, dnnl::memory::format_tag::tnc);
587 auto dec_src_layer_md = dnnl::memory::desc({dec_src_layer_dims},
588 dnnl::memory::data_type::f32, dnnl::memory::format_tag::tnc);
589 auto dec_dst_iter_md = dnnl::memory::desc({dec_dst_iter_dims},
590 dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldnc);
591 auto dec_dst_iter_c_md = dnnl::memory::desc({dec_dst_iter_c_dims},
592 dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldnc);
593 //[dec mem desc]
594 ///
595 /// Decoder : Create memory
596 /// @snippet cpu_rnn_inference_f32.cpp create dec memory
597 ///
598 //[create dec memory]
599 auto user_dec_wei_layer_memory = dnnl::memory(
600 user_dec_wei_layer_md, cpu_engine, user_dec_wei_layer.data());
601 auto user_dec_wei_iter_memory = dnnl::memory(
602 user_dec_wei_iter_md, cpu_engine, user_dec_wei_iter.data());
603 auto user_dec_bias_memory
604 = dnnl::memory(user_dec_bias_md, cpu_engine, user_dec_bias.data());
605 auto user_dec_dst_layer_memory
606 = dnnl::memory(dec_dst_layer_md, cpu_engine, user_dec_dst.data());
607 auto dec_src_layer_memory = dnnl::memory(dec_src_layer_md, cpu_engine);
608 auto dec_dst_iter_c_memory = dnnl::memory(dec_dst_iter_c_md, cpu_engine);
609 //[create dec memory]
610
611 auto dec_wei_layer_md = dnnl::memory::desc({user_dec_wei_layer_dims},
612 dnnl::memory::data_type::f32, dnnl::memory::format_tag::any);
613 auto dec_wei_iter_md = dnnl::memory::desc({user_dec_wei_iter_dims},
614 dnnl::memory::data_type::f32, dnnl::memory::format_tag::any);
615
616 // As mentioned above, we create a view without context out of the
617 // memory with context.
618 ///
619 /// Decoder : As mentioned above, we create a view without context out of the memory with context.
620 /// @snippet cpu_rnn_inference_f32.cpp create noctx mem
621 ///
622 //[create noctx mem]
623 auto dec_dst_iter_memory = dnnl::memory(dec_dst_iter_md, cpu_engine);
624 auto dec_dst_iter_noctx_md = dec_dst_iter_md.submemory_desc(
625 dec_dst_iter_noctx_dims, {0, 0, 0, 0, 0});
626 //[create noctx mem]
627
628 // TODO: add support for residual connections
629 // should it be a set residual in pd or a field to set manually?
630 // should be an integer to specify at which layer to start
631 ///
632 /// Decoder : Create RNN decoder cell
633 /// @snippet cpu_rnn_inference_f32.cpp create dec rnn
634 ///
635 //[create dec rnn]
636 auto dec_ctx_prim_desc = lstm_forward::primitive_desc(cpu_engine,
637 prop_kind::forward_inference,
638 rnn_direction::unidirectional_left2right, dec_src_layer_md,
639 dec_dst_iter_md, dec_dst_iter_c_md, dec_wei_layer_md,
640 dec_wei_iter_md, user_dec_bias_md, dec_dst_layer_md,
641 dec_dst_iter_noctx_md, dec_dst_iter_c_md);
642 //[create dec rnn]
643
644 ///
645 /// Decoder : reorder weight memory
646 /// @snippet cpu_rnn_inference_f32.cpp reorder weight memory
647 ///
648 //[reorder weight memory]
649 auto dec_wei_layer_memory
650 = memory(dec_ctx_prim_desc.weights_layer_desc(), cpu_engine);
651 auto dec_wei_layer_reorder_pd = reorder::primitive_desc(
652 user_dec_wei_layer_memory, dec_wei_layer_memory);
653 reorder(dec_wei_layer_reorder_pd)
654 .execute(s, user_dec_wei_layer_memory, dec_wei_layer_memory);
655
656 auto dec_wei_iter_memory
657 = memory(dec_ctx_prim_desc.weights_iter_desc(), cpu_engine);
658 auto dec_wei_iter_reorder_pd = reorder::primitive_desc(
659 user_dec_wei_iter_memory, dec_wei_iter_memory);
660 reorder(dec_wei_iter_reorder_pd)
661 .execute(s, user_dec_wei_iter_memory, dec_wei_iter_memory);
662 //[reorder weight memory]
663
664 ///
665 /// Decoder : add the rnn primitive with related arguments into decoder_net
666 /// @snippet cpu_rnn_inference_f32.cpp push rnn to decoder net
667 ///
668 //[push rnn to decoder net]
669 // TODO: add a reorder when they will be available
670 decoder_net.push_back(lstm_forward(dec_ctx_prim_desc));
671 decoder_net_args.push_back({{DNNL_ARG_SRC_LAYER, dec_src_layer_memory},
672 {DNNL_ARG_SRC_ITER, dec_dst_iter_memory},
673 {DNNL_ARG_SRC_ITER_C, dec_dst_iter_c_memory},
674 {DNNL_ARG_WEIGHTS_LAYER, dec_wei_layer_memory},
675 {DNNL_ARG_WEIGHTS_ITER, dec_wei_iter_memory},
676 {DNNL_ARG_BIAS, user_dec_bias_memory},
677 {DNNL_ARG_DST_LAYER, user_dec_dst_layer_memory},
678 {DNNL_ARG_DST_ITER, dec_dst_iter_memory},
679 {DNNL_ARG_DST_ITER_C, dec_dst_iter_c_memory}});
680 //[push rnn to decoder net]
681 // allocating temporary buffer for attention mechanism
682 std::vector<float> weighted_annotations(
683 src_seq_length_max * batch * feature_size, 1.0f);
684
685 ///
686 /// **Execution**
687 ///
688 auto execute = [&]() {
689 assert(encoder_net.size() == encoder_net_args.size()
690 && "something is missing");
691 ///
692 /// run encoder (1 stream)
693 /// @snippet cpu_rnn_inference_f32.cpp run enc
694 ///
695 //[run enc]
696 for (size_t p = 0; p < encoder_net.size(); ++p)
697 encoder_net.at(p).execute(s, encoder_net_args.at(p));
698 //[run enc]
699
700 ///
701 /// we compute the weighted annotations once before the decoder
702 /// @snippet cpu_rnn_inference_f32.cpp weight ano
703 ///
704 //[weight ano]
705 compute_weighted_annotations(weighted_annotations.data(),
706 src_seq_length_max, batch, feature_size,
707 user_weights_annotation.data(),
708 (float *)enc_dst_layer_memory.get_data_handle());
709 //[weight ano]
710
711 ///
712 /// We initialize src_layer to the embedding of the end of
713 /// sequence character, which are assumed to be 0 here
714 /// @snippet cpu_rnn_inference_f32.cpp init src_layer
715 ///
716 //[init src_layer]
717 memset(dec_src_layer_memory.get_data_handle(), 0,
718 dec_src_layer_memory.get_desc().get_size());
719 //[init src_layer]
720 ///
721 /// From now on, src points to the output of the last iteration
722 ///
723 for (dim_t i = 0; i < tgt_seq_length_max; i++) {
724 float *src_att_layer_handle
725 = (float *)dec_src_layer_memory.get_data_handle();
726 float *src_att_iter_handle
727 = (float *)dec_dst_iter_memory.get_data_handle();
728
729 ///
730 /// Compute attention context vector into the first layer src_iter
731 /// @snippet cpu_rnn_inference_f32.cpp att ctx
732 ///
733 //[att ctx]
734 compute_attention(src_att_iter_handle, src_seq_length_max, batch,
735 feature_size, user_weights_attention_src_layer.data(),
736 src_att_layer_handle,
737 (float *)enc_bidir_dst_layer_memory.get_data_handle(),
738 weighted_annotations.data(),
739 user_weights_alignments.data());
740 //[att ctx]
741
742 ///
743 /// copy the context vectors to all layers of src_iter
744 /// @snippet cpu_rnn_inference_f32.cpp cp ctx
745 ///
746 //[cp ctx]
747 copy_context(
748 src_att_iter_handle, dec_n_layers, batch, feature_size);
749 //[cp ctx]
750
751 assert(decoder_net.size() == decoder_net_args.size()
752 && "something is missing");
753 ///
754 /// run the decoder iteration
755 /// @snippet cpu_rnn_inference_f32.cpp run dec iter
756 ///
757 //[run dec iter]
758 for (size_t p = 0; p < decoder_net.size(); ++p)
759 decoder_net.at(p).execute(s, decoder_net_args.at(p));
760 //[run dec iter]
761
762 ///
763 /// Move the handle on the src/dst layer to the next iteration
764 /// @snippet cpu_rnn_inference_f32.cpp set handle
765 ///
766 //[set handle]
767 auto dst_layer_handle
768 = (float *)user_dec_dst_layer_memory.get_data_handle();
769 dec_src_layer_memory.set_data_handle(dst_layer_handle);
770 user_dec_dst_layer_memory.set_data_handle(
771 dst_layer_handle + batch * feature_size);
772 //[set handle]
773 }
774 };
775 /// @page cpu_rnn_inference_f32_cpp
776 ///
777 std::cout << "Parameters:" << std::endl
778 << " batch = " << batch << std::endl
779 << " feature size = " << feature_size << std::endl
780 << " maximum source sequence length = " << src_seq_length_max
781 << std::endl
782 << " maximum target sequence length = " << tgt_seq_length_max
783 << std::endl
784 << " number of layers of the bidirectional encoder = "
785 << enc_bidir_n_layers << std::endl
786 << " number of layers of the unidirectional encoder = "
787 << enc_unidir_n_layers << std::endl
788 << " number of layers of the decoder = " << dec_n_layers
789 << std::endl;
790
791 execute();
792 s.wait();
793}
794
795int main(int argc, char **argv) {
796 return handle_example_errors({engine::kind::cpu}, simple_net);
797}
798