1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /// @example cpu_rnn_inference_f32.cpp |
18 | /// @copybrief cpu_rnn_inference_f32_cpp |
19 | /// > Annotated version: @ref cpu_rnn_inference_f32_cpp |
20 | |
21 | /// @page cpu_rnn_inference_f32_cpp RNN f32 inference example |
22 | /// This C++ API example demonstrates how to build GNMT model inference. |
23 | /// |
24 | /// > Example code: @ref cpu_rnn_inference_f32.cpp |
25 | /// |
26 | /// For the encoder we use: |
27 | /// - one primitive for the bidirectional layer of the encoder |
28 | /// - one primitive for all remaining unidirectional layers in the encoder |
29 | /// For the decoder we use: |
30 | /// - one primitive for the first iteration |
31 | /// - one primitive for all subsequent iterations in the decoder. Note that |
32 | /// in this example, this primitive computes the states in place. |
33 | /// - the attention mechanism is implemented separately as there is no support |
34 | /// for the context vectors in oneDNN yet |
35 | |
36 | #include <assert.h> |
37 | |
38 | #include <cstring> |
39 | #include <iostream> |
40 | #include <math.h> |
41 | #include <numeric> |
42 | #include <string> |
43 | |
44 | #include "oneapi/dnnl/dnnl.hpp" |
45 | |
46 | #include "example_utils.hpp" |
47 | |
48 | using namespace dnnl; |
49 | |
50 | using dim_t = dnnl::memory::dim; |
51 | |
52 | const dim_t batch = 32; |
53 | const dim_t src_seq_length_max = 10; |
54 | const dim_t tgt_seq_length_max = 10; |
55 | |
56 | const dim_t feature_size = 256; |
57 | |
58 | const dim_t enc_bidir_n_layers = 1; |
59 | const dim_t enc_unidir_n_layers = 3; |
60 | const dim_t dec_n_layers = 4; |
61 | |
62 | const int lstm_n_gates = 4; |
63 | std::vector<float> weighted_src_layer(batch *feature_size, 1.0f); |
64 | std::vector<float> alignment_model( |
65 | src_seq_length_max *batch *feature_size, 1.0f); |
66 | std::vector<float> alignments(src_seq_length_max *batch, 1.0f); |
67 | std::vector<float> exp_sums(batch, 1.0f); |
68 | |
69 | void compute_weighted_annotations(float *weighted_annotations, |
70 | dim_t src_seq_length_max, dim_t batch, dim_t feature_size, |
71 | float *weights_annot, float *annotations) { |
72 | // annotations(aka enc_dst_layer) is (t, n, 2c) |
73 | // weights_annot is (2c, c) |
74 | |
75 | // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]); |
76 | dim_t num_weighted_annotations = src_seq_length_max * batch; |
77 | dnnl_sgemm('N', 'N', num_weighted_annotations, feature_size, feature_size, |
78 | 1.f, annotations, feature_size, weights_annot, feature_size, 0.f, |
79 | weighted_annotations, feature_size); |
80 | } |
81 | |
82 | void compute_attention(float *context_vectors, dim_t src_seq_length_max, |
83 | dim_t batch, dim_t feature_size, float *weights_src_layer, |
84 | float *dec_src_layer, float *annotations, float *weighted_annotations, |
85 | float *weights_alignments) { |
86 | // dst_iter : (n, c) matrix |
87 | // src_layer: (n, c) matrix |
88 | // weighted_annotations (t, n, c) |
89 | |
90 | // weights_yi is (c, c) |
91 | // weights_ai is (c, 1) |
92 | // tmp[i] is (n, c) |
93 | // a[i] is (n, 1) |
94 | // p is (n, 1) |
95 | |
96 | // first we precompute the weighted_dec_src_layer |
97 | dnnl_sgemm('N', 'N', batch, feature_size, feature_size, 1.f, dec_src_layer, |
98 | feature_size, weights_src_layer, feature_size, 0.f, |
99 | weighted_src_layer.data(), feature_size); |
100 | |
101 | // then we compute the alignment model |
102 | float *alignment_model_ptr = alignment_model.data(); |
103 | |
104 | PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2) |
105 | for (dim_t i = 0; i < src_seq_length_max; i++) { |
106 | for (dim_t j = 0; j < batch * feature_size; j++) |
107 | alignment_model_ptr[i * batch * feature_size + j] = tanhf( |
108 | weighted_src_layer[j] |
109 | + weighted_annotations[i * batch * feature_size + j]); |
110 | } |
111 | |
112 | // gemv with alignments weights. the resulting alignments are in alignments |
113 | dim_t num_weighted_annotations = src_seq_length_max * batch; |
114 | dnnl_sgemm('N', 'N', num_weighted_annotations, 1, feature_size, 1.f, |
115 | alignment_model_ptr, feature_size, weights_alignments, 1, 0.f, |
116 | alignments.data(), 1); |
117 | |
118 | // softmax on alignments. the resulting context weights are in alignments |
119 | PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(1) |
120 | for (dim_t i = 0; i < batch; i++) |
121 | exp_sums[i] = 0.0f; |
122 | |
123 | PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(1) |
124 | for (dim_t j = 0; j < batch; j++) { |
125 | for (dim_t i = 0; i < src_seq_length_max; i++) { |
126 | alignments[i * batch + j] = expf(alignments[i * batch + j]); |
127 | exp_sums[j] += alignments[i * batch + j]; |
128 | } |
129 | } |
130 | |
131 | PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2) |
132 | for (dim_t i = 0; i < src_seq_length_max; i++) |
133 | for (dim_t j = 0; j < batch; j++) |
134 | alignments[i * batch + j] /= exp_sums[j]; |
135 | |
136 | // then we compute the context vectors |
137 | PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2) |
138 | for (dim_t i = 0; i < batch; i++) |
139 | for (dim_t j = 0; j < feature_size; j++) |
140 | context_vectors[i * (feature_size + feature_size) + feature_size |
141 | + j] |
142 | = 0.0f; |
143 | |
144 | PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2) |
145 | for (dim_t i = 0; i < batch; i++) |
146 | for (dim_t j = 0; j < feature_size; j++) |
147 | for (dim_t k = 0; k < src_seq_length_max; k++) |
148 | context_vectors[i * (feature_size + feature_size) + feature_size |
149 | + j] |
150 | += alignments[k * batch + i] |
151 | * annotations[j + feature_size * (i + batch * k)]; |
152 | } |
153 | |
154 | void copy_context( |
155 | float *src_iter, dim_t n_layers, dim_t batch, dim_t feature_size) { |
156 | // we copy the context from the first layer to all other layers |
157 | PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(3) |
158 | for (dim_t k = 1; k < n_layers; k++) |
159 | for (dim_t j = 0; j < batch; j++) |
160 | for (dim_t i = 0; i < feature_size; i++) |
161 | src_iter[(k * batch + j) * (feature_size + feature_size) |
162 | + feature_size + i] |
163 | = src_iter[j * (feature_size + feature_size) |
164 | + feature_size + i]; |
165 | } |
166 | |
167 | void simple_net() { |
168 | /// |
169 | /// Initialize a CPU engine and stream. The last parameter in the call represents |
170 | /// the index of the engine. |
171 | /// @snippet cpu_rnn_inference_f32.cpp Initialize engine and stream |
172 | /// |
173 | //[Initialize engine and stream] |
174 | auto cpu_engine = engine(engine::kind::cpu, 0); |
175 | stream s(cpu_engine); |
176 | //[Initialize engine and stream] |
177 | /// |
178 | /// Declare encoder net and decoder net |
179 | /// @snippet cpu_rnn_inference_f32.cpp declare net |
180 | /// |
181 | //[declare net] |
182 | std::vector<primitive> encoder_net, decoder_net; |
183 | std::vector<std::unordered_map<int, memory>> encoder_net_args, |
184 | decoder_net_args; |
185 | |
186 | std::vector<float> net_src(batch * src_seq_length_max * feature_size, 1.0f); |
187 | std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 1.0f); |
188 | //[declare net] |
189 | /// |
190 | /// **Encoder** |
191 | /// |
192 | /// |
193 | /// Initialize Encoder Memory |
194 | /// @snippet cpu_rnn_inference_f32.cpp Initialize encoder memory |
195 | /// |
196 | //[Initialize encoder memory] |
197 | memory::dims enc_bidir_src_layer_tz |
198 | = {src_seq_length_max, batch, feature_size}; |
199 | memory::dims enc_bidir_weights_layer_tz |
200 | = {enc_bidir_n_layers, 2, feature_size, lstm_n_gates, feature_size}; |
201 | memory::dims enc_bidir_weights_iter_tz |
202 | = {enc_bidir_n_layers, 2, feature_size, lstm_n_gates, feature_size}; |
203 | memory::dims enc_bidir_bias_tz |
204 | = {enc_bidir_n_layers, 2, lstm_n_gates, feature_size}; |
205 | memory::dims enc_bidir_dst_layer_tz |
206 | = {src_seq_length_max, batch, 2 * feature_size}; |
207 | //[Initialize encoder memory] |
208 | |
209 | /// |
210 | /// |
211 | /// Encoder: 1 bidirectional layer and 7 unidirectional layers |
212 | /// |
213 | |
214 | std::vector<float> user_enc_bidir_wei_layer( |
215 | enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size, |
216 | 1.0f); |
217 | std::vector<float> user_enc_bidir_wei_iter( |
218 | enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size, |
219 | 1.0f); |
220 | std::vector<float> user_enc_bidir_bias( |
221 | enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f); |
222 | |
223 | /// |
224 | /// Create the memory for user data |
225 | /// @snippet cpu_rnn_inference_f32.cpp data memory creation |
226 | /// |
227 | //[data memory creation] |
228 | auto user_enc_bidir_src_layer_md = dnnl::memory::desc( |
229 | {enc_bidir_src_layer_tz}, dnnl::memory::data_type::f32, |
230 | dnnl::memory::format_tag::tnc); |
231 | |
232 | auto user_enc_bidir_wei_layer_md = dnnl::memory::desc( |
233 | {enc_bidir_weights_layer_tz}, dnnl::memory::data_type::f32, |
234 | dnnl::memory::format_tag::ldigo); |
235 | |
236 | auto user_enc_bidir_wei_iter_md = dnnl::memory::desc( |
237 | {enc_bidir_weights_iter_tz}, dnnl::memory::data_type::f32, |
238 | dnnl::memory::format_tag::ldigo); |
239 | |
240 | auto user_enc_bidir_bias_md = dnnl::memory::desc({enc_bidir_bias_tz}, |
241 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo); |
242 | |
243 | auto user_enc_bidir_src_layer_memory = dnnl::memory( |
244 | user_enc_bidir_src_layer_md, cpu_engine, net_src.data()); |
245 | auto user_enc_bidir_wei_layer_memory |
246 | = dnnl::memory(user_enc_bidir_wei_layer_md, cpu_engine, |
247 | user_enc_bidir_wei_layer.data()); |
248 | auto user_enc_bidir_wei_iter_memory |
249 | = dnnl::memory(user_enc_bidir_wei_iter_md, cpu_engine, |
250 | user_enc_bidir_wei_iter.data()); |
251 | auto user_enc_bidir_bias_memory = dnnl::memory( |
252 | user_enc_bidir_bias_md, cpu_engine, user_enc_bidir_bias.data()); |
253 | |
254 | //[data memory creation] |
255 | /// |
256 | /// Create memory descriptors for RNN data w/o specified layout |
257 | /// @snippet cpu_rnn_inference_f32.cpp memory desc for RNN data |
258 | /// |
259 | //[memory desc for RNN data] |
260 | auto enc_bidir_wei_layer_md = memory::desc({enc_bidir_weights_layer_tz}, |
261 | memory::data_type::f32, memory::format_tag::any); |
262 | |
263 | auto enc_bidir_wei_iter_md = memory::desc({enc_bidir_weights_iter_tz}, |
264 | memory::data_type::f32, memory::format_tag::any); |
265 | |
266 | auto enc_bidir_dst_layer_md = memory::desc({enc_bidir_dst_layer_tz}, |
267 | memory::data_type::f32, memory::format_tag::any); |
268 | |
269 | //[memory desc for RNN data] |
270 | /// |
271 | /// Create bidirectional RNN |
272 | /// @snippet cpu_rnn_inference_f32.cpp create rnn |
273 | /// |
274 | //[create rnn] |
275 | |
276 | auto enc_bidir_prim_desc = lstm_forward::primitive_desc(cpu_engine, |
277 | prop_kind::forward_inference, rnn_direction::bidirectional_concat, |
278 | user_enc_bidir_src_layer_md, memory::desc(), memory::desc(), |
279 | enc_bidir_wei_layer_md, enc_bidir_wei_iter_md, |
280 | user_enc_bidir_bias_md, enc_bidir_dst_layer_md, memory::desc(), |
281 | memory::desc()); |
282 | //[create rnn] |
283 | |
284 | /// |
285 | /// Create memory for input data and use reorders to reorder user data |
286 | /// to internal representation |
287 | /// @snippet cpu_rnn_inference_f32.cpp reorder input data |
288 | /// |
289 | //[reorder input data] |
290 | auto enc_bidir_wei_layer_memory |
291 | = memory(enc_bidir_prim_desc.weights_layer_desc(), cpu_engine); |
292 | auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc( |
293 | user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory); |
294 | reorder(enc_bidir_wei_layer_reorder_pd) |
295 | .execute(s, user_enc_bidir_wei_layer_memory, |
296 | enc_bidir_wei_layer_memory); |
297 | //[reorder input data] |
298 | |
299 | auto enc_bidir_wei_iter_memory |
300 | = memory(enc_bidir_prim_desc.weights_iter_desc(), cpu_engine); |
301 | auto enc_bidir_wei_iter_reorder_pd = reorder::primitive_desc( |
302 | user_enc_bidir_wei_iter_memory, enc_bidir_wei_iter_memory); |
303 | reorder(enc_bidir_wei_iter_reorder_pd) |
304 | .execute(s, user_enc_bidir_wei_iter_memory, |
305 | enc_bidir_wei_iter_memory); |
306 | |
307 | auto enc_bidir_dst_layer_memory |
308 | = dnnl::memory(enc_bidir_prim_desc.dst_layer_desc(), cpu_engine); |
309 | |
310 | /// |
311 | /// Encoder : add the bidirectional rnn primitive with related arguments into encoder_net |
312 | /// @snippet cpu_rnn_inference_f32.cpp push bi rnn to encoder net |
313 | /// |
314 | //[push bi rnn to encoder net] |
315 | encoder_net.push_back(lstm_forward(enc_bidir_prim_desc)); |
316 | encoder_net_args.push_back( |
317 | {{DNNL_ARG_SRC_LAYER, user_enc_bidir_src_layer_memory}, |
318 | {DNNL_ARG_WEIGHTS_LAYER, enc_bidir_wei_layer_memory}, |
319 | {DNNL_ARG_WEIGHTS_ITER, enc_bidir_wei_iter_memory}, |
320 | {DNNL_ARG_BIAS, user_enc_bidir_bias_memory}, |
321 | {DNNL_ARG_DST_LAYER, enc_bidir_dst_layer_memory}}); |
322 | //[push bi rnn to encoder net] |
323 | |
324 | /// |
325 | /// Encoder: unidirectional layers |
326 | /// |
327 | /// |
328 | /// First unidirectinal layer scales 2 * feature_size output of bidirectional |
329 | /// layer to feature_size output |
330 | /// @snippet cpu_rnn_inference_f32.cpp first uni layer |
331 | /// |
332 | //[first uni layer] |
333 | std::vector<float> user_enc_uni_first_wei_layer( |
334 | 1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f); |
335 | std::vector<float> user_enc_uni_first_wei_iter( |
336 | 1 * 1 * feature_size * lstm_n_gates * feature_size, 1.0f); |
337 | std::vector<float> user_enc_uni_first_bias( |
338 | 1 * 1 * lstm_n_gates * feature_size, 1.0f); |
339 | //[first uni layer] |
340 | memory::dims user_enc_uni_first_wei_layer_dims |
341 | = {1, 1, 2 * feature_size, lstm_n_gates, feature_size}; |
342 | memory::dims user_enc_uni_first_wei_iter_dims |
343 | = {1, 1, feature_size, lstm_n_gates, feature_size}; |
344 | memory::dims user_enc_uni_first_bias_dims |
345 | = {1, 1, lstm_n_gates, feature_size}; |
346 | memory::dims enc_uni_first_dst_layer_dims |
347 | = {src_seq_length_max, batch, feature_size}; |
348 | auto user_enc_uni_first_wei_layer_md = dnnl::memory::desc( |
349 | {user_enc_uni_first_wei_layer_dims}, dnnl::memory::data_type::f32, |
350 | dnnl::memory::format_tag::ldigo); |
351 | auto user_enc_uni_first_wei_iter_md = dnnl::memory::desc( |
352 | {user_enc_uni_first_wei_iter_dims}, dnnl::memory::data_type::f32, |
353 | dnnl::memory::format_tag::ldigo); |
354 | auto user_enc_uni_first_bias_md = dnnl::memory::desc( |
355 | {user_enc_uni_first_bias_dims}, dnnl::memory::data_type::f32, |
356 | dnnl::memory::format_tag::ldgo); |
357 | auto user_enc_uni_first_wei_layer_memory |
358 | = dnnl::memory(user_enc_uni_first_wei_layer_md, cpu_engine, |
359 | user_enc_uni_first_wei_layer.data()); |
360 | auto user_enc_uni_first_wei_iter_memory |
361 | = dnnl::memory(user_enc_uni_first_wei_iter_md, cpu_engine, |
362 | user_enc_uni_first_wei_iter.data()); |
363 | auto user_enc_uni_first_bias_memory |
364 | = dnnl::memory(user_enc_uni_first_bias_md, cpu_engine, |
365 | user_enc_uni_first_bias.data()); |
366 | |
367 | auto enc_uni_first_wei_layer_md |
368 | = memory::desc({user_enc_uni_first_wei_layer_dims}, |
369 | memory::data_type::f32, memory::format_tag::any); |
370 | auto enc_uni_first_wei_iter_md |
371 | = memory::desc({user_enc_uni_first_wei_iter_dims}, |
372 | memory::data_type::f32, memory::format_tag::any); |
373 | auto enc_uni_first_dst_layer_md |
374 | = memory::desc({enc_uni_first_dst_layer_dims}, |
375 | memory::data_type::f32, memory::format_tag::any); |
376 | |
377 | // TODO: add support for residual connections |
378 | // should it be a set residual in pd or a field to set manually? |
379 | // should be an integer to specify at which layer to start |
380 | /// |
381 | /// Encoder : Create unidirection RNN for first cell |
382 | /// @snippet cpu_rnn_inference_f32.cpp create uni first |
383 | /// |
384 | //[create uni first] |
385 | auto enc_uni_first_prim_desc = lstm_forward::primitive_desc(cpu_engine, |
386 | prop_kind::forward_inference, |
387 | rnn_direction::unidirectional_left2right, enc_bidir_dst_layer_md, |
388 | memory::desc(), memory::desc(), enc_uni_first_wei_layer_md, |
389 | enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md, |
390 | enc_uni_first_dst_layer_md, memory::desc(), memory::desc()); |
391 | |
392 | //[create uni first] |
393 | auto enc_uni_first_wei_layer_memory |
394 | = memory(enc_uni_first_prim_desc.weights_layer_desc(), cpu_engine); |
395 | auto enc_uni_first_wei_layer_reorder_pd |
396 | = reorder::primitive_desc(user_enc_uni_first_wei_layer_memory, |
397 | enc_uni_first_wei_layer_memory); |
398 | reorder(enc_uni_first_wei_layer_reorder_pd) |
399 | .execute(s, user_enc_uni_first_wei_layer_memory, |
400 | enc_uni_first_wei_layer_memory); |
401 | |
402 | auto enc_uni_first_wei_iter_memory |
403 | = memory(enc_uni_first_prim_desc.weights_iter_desc(), cpu_engine); |
404 | auto enc_uni_first_wei_iter_reorder_pd = reorder::primitive_desc( |
405 | user_enc_uni_first_wei_iter_memory, enc_uni_first_wei_iter_memory); |
406 | reorder(enc_uni_first_wei_iter_reorder_pd) |
407 | .execute(s, user_enc_uni_first_wei_iter_memory, |
408 | enc_uni_first_wei_iter_memory); |
409 | |
410 | auto enc_uni_first_dst_layer_memory = dnnl::memory( |
411 | enc_uni_first_prim_desc.dst_layer_desc(), cpu_engine); |
412 | |
413 | /// Encoder : add the first unidirectional rnn primitive with related |
414 | /// arguments into encoder_net |
415 | /// |
416 | /// @snippet cpu_rnn_inference_f32.cpp push first uni rnn to encoder net |
417 | /// |
418 | //[push first uni rnn to encoder net] |
419 | // TODO: add a reorder when they will be available |
420 | encoder_net.push_back(lstm_forward(enc_uni_first_prim_desc)); |
421 | encoder_net_args.push_back( |
422 | {{DNNL_ARG_SRC_LAYER, enc_bidir_dst_layer_memory}, |
423 | {DNNL_ARG_WEIGHTS_LAYER, enc_uni_first_wei_layer_memory}, |
424 | {DNNL_ARG_WEIGHTS_ITER, enc_uni_first_wei_iter_memory}, |
425 | {DNNL_ARG_BIAS, user_enc_uni_first_bias_memory}, |
426 | {DNNL_ARG_DST_LAYER, enc_uni_first_dst_layer_memory}}); |
427 | //[push first uni rnn to encoder net] |
428 | |
429 | /// |
430 | /// Encoder : Remaining unidirectional layers |
431 | /// @snippet cpu_rnn_inference_f32.cpp remaining uni layers |
432 | /// |
433 | //[remaining uni layers] |
434 | std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1 |
435 | * feature_size * lstm_n_gates * feature_size, |
436 | 1.0f); |
437 | std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1 |
438 | * feature_size * lstm_n_gates * feature_size, |
439 | 1.0f); |
440 | std::vector<float> user_enc_uni_bias( |
441 | (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f); |
442 | //[remaining uni layers] |
443 | memory::dims user_enc_uni_wei_layer_dims = {(enc_unidir_n_layers - 1), 1, |
444 | feature_size, lstm_n_gates, feature_size}; |
445 | memory::dims user_enc_uni_wei_iter_dims = {(enc_unidir_n_layers - 1), 1, |
446 | feature_size, lstm_n_gates, feature_size}; |
447 | memory::dims user_enc_uni_bias_dims |
448 | = {(enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size}; |
449 | memory::dims enc_dst_layer_dims = {src_seq_length_max, batch, feature_size}; |
450 | auto user_enc_uni_wei_layer_md = dnnl::memory::desc( |
451 | {user_enc_uni_wei_layer_dims}, dnnl::memory::data_type::f32, |
452 | dnnl::memory::format_tag::ldigo); |
453 | auto user_enc_uni_wei_iter_md = dnnl::memory::desc( |
454 | {user_enc_uni_wei_iter_dims}, dnnl::memory::data_type::f32, |
455 | dnnl::memory::format_tag::ldigo); |
456 | auto user_enc_uni_bias_md = dnnl::memory::desc({user_enc_uni_bias_dims}, |
457 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo); |
458 | auto user_enc_uni_wei_layer_memory = dnnl::memory(user_enc_uni_wei_layer_md, |
459 | cpu_engine, user_enc_uni_wei_layer.data()); |
460 | auto user_enc_uni_wei_iter_memory = dnnl::memory( |
461 | user_enc_uni_wei_iter_md, cpu_engine, user_enc_uni_wei_iter.data()); |
462 | auto user_enc_uni_bias_memory = dnnl::memory( |
463 | user_enc_uni_bias_md, cpu_engine, user_enc_uni_bias.data()); |
464 | |
465 | auto enc_uni_wei_layer_md = memory::desc({user_enc_uni_wei_layer_dims}, |
466 | memory::data_type::f32, memory::format_tag::any); |
467 | auto enc_uni_wei_iter_md = memory::desc({user_enc_uni_wei_iter_dims}, |
468 | memory::data_type::f32, memory::format_tag::any); |
469 | auto enc_dst_layer_md = memory::desc({enc_dst_layer_dims}, |
470 | memory::data_type::f32, memory::format_tag::any); |
471 | |
472 | // TODO: add support for residual connections |
473 | // should it be a set residual in pd or a field to set manually? |
474 | // should be an integer to specify at which layer to start |
475 | /// |
476 | /// Encoder : Create unidirection RNN cell |
477 | /// @snippet cpu_rnn_inference_f32.cpp create uni rnn |
478 | /// |
479 | //[create uni rnn] |
480 | auto enc_uni_prim_desc = lstm_forward::primitive_desc(cpu_engine, |
481 | prop_kind::forward_inference, |
482 | rnn_direction::unidirectional_left2right, |
483 | enc_uni_first_dst_layer_md, memory::desc(), memory::desc(), |
484 | enc_uni_wei_layer_md, enc_uni_wei_iter_md, user_enc_uni_bias_md, |
485 | enc_dst_layer_md, memory::desc(), memory::desc()); |
486 | //[create uni rnn] |
487 | |
488 | auto enc_uni_wei_layer_memory |
489 | = memory(enc_uni_prim_desc.weights_layer_desc(), cpu_engine); |
490 | auto enc_uni_wei_layer_reorder_pd = reorder::primitive_desc( |
491 | user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory); |
492 | reorder(enc_uni_wei_layer_reorder_pd) |
493 | .execute( |
494 | s, user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory); |
495 | |
496 | auto enc_uni_wei_iter_memory |
497 | = memory(enc_uni_prim_desc.weights_iter_desc(), cpu_engine); |
498 | auto enc_uni_wei_iter_reorder_pd = reorder::primitive_desc( |
499 | user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory); |
500 | reorder(enc_uni_wei_iter_reorder_pd) |
501 | .execute(s, user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory); |
502 | |
503 | auto enc_dst_layer_memory |
504 | = dnnl::memory(enc_uni_prim_desc.dst_layer_desc(), cpu_engine); |
505 | |
506 | // TODO: add a reorder when they will be available |
507 | /// |
508 | /// Encoder : add the unidirectional rnn primitive with related arguments into encoder_net |
509 | /// @snippet cpu_rnn_inference_f32.cpp push uni rnn to encoder net |
510 | /// |
511 | //[push uni rnn to encoder net] |
512 | encoder_net.push_back(lstm_forward(enc_uni_prim_desc)); |
513 | encoder_net_args.push_back( |
514 | {{DNNL_ARG_SRC_LAYER, enc_uni_first_dst_layer_memory}, |
515 | {DNNL_ARG_WEIGHTS_LAYER, enc_uni_wei_layer_memory}, |
516 | {DNNL_ARG_WEIGHTS_ITER, enc_uni_wei_iter_memory}, |
517 | {DNNL_ARG_BIAS, user_enc_uni_bias_memory}, |
518 | {DNNL_ARG_DST_LAYER, enc_dst_layer_memory}}); |
519 | //[push uni rnn to encoder net] |
520 | /// |
521 | /// **Decoder with attention mechanism** |
522 | /// |
523 | /// |
524 | /// Decoder : declare memory dimensions |
525 | /// @snippet cpu_rnn_inference_f32.cpp dec mem dim |
526 | /// |
527 | //[dec mem dim] |
528 | std::vector<float> user_dec_wei_layer( |
529 | dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size, |
530 | 1.0f); |
531 | std::vector<float> user_dec_wei_iter(dec_n_layers * 1 |
532 | * (feature_size + feature_size) * lstm_n_gates |
533 | * feature_size, |
534 | 1.0f); |
535 | std::vector<float> user_dec_bias( |
536 | dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f); |
537 | std::vector<float> user_dec_dst( |
538 | tgt_seq_length_max * batch * feature_size, 1.0f); |
539 | std::vector<float> user_weights_attention_src_layer( |
540 | feature_size * feature_size, 1.0f); |
541 | std::vector<float> user_weights_annotation( |
542 | feature_size * feature_size, 1.0f); |
543 | std::vector<float> user_weights_alignments(feature_size, 1.0f); |
544 | |
545 | memory::dims user_dec_wei_layer_dims |
546 | = {dec_n_layers, 1, feature_size, lstm_n_gates, feature_size}; |
547 | memory::dims user_dec_wei_iter_dims = {dec_n_layers, 1, |
548 | feature_size + feature_size, lstm_n_gates, feature_size}; |
549 | memory::dims user_dec_bias_dims |
550 | = {dec_n_layers, 1, lstm_n_gates, feature_size}; |
551 | |
552 | memory::dims dec_src_layer_dims = {1, batch, feature_size}; |
553 | memory::dims dec_dst_layer_dims = {1, batch, feature_size}; |
554 | memory::dims dec_dst_iter_c_dims = {dec_n_layers, 1, batch, feature_size}; |
555 | //[dec mem dim] |
556 | |
557 | /// We will use the same memory for dec_src_iter and dec_dst_iter |
558 | /// However, dec_src_iter has a context vector but not |
559 | /// dec_dst_iter. |
560 | /// To resolve this we will create one memory that holds the |
561 | /// context vector as well as the both the hidden and cell states. |
562 | /// The dst_iter will be a sub-memory of this memory. |
563 | /// Note that the cell state will be padded by |
564 | /// feature_size values. However, we do not compute or |
565 | /// access those. |
566 | /// @snippet cpu_rnn_inference_f32.cpp noctx mem dim |
567 | //[noctx mem dim] |
568 | memory::dims dec_dst_iter_dims |
569 | = {dec_n_layers, 1, batch, feature_size + feature_size}; |
570 | memory::dims dec_dst_iter_noctx_dims |
571 | = {dec_n_layers, 1, batch, feature_size}; |
572 | //[noctx mem dim] |
573 | |
574 | /// |
575 | /// Decoder : create memory description |
576 | /// @snippet cpu_rnn_inference_f32.cpp dec mem desc |
577 | /// |
578 | //[dec mem desc] |
579 | auto user_dec_wei_layer_md = dnnl::memory::desc({user_dec_wei_layer_dims}, |
580 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldigo); |
581 | auto user_dec_wei_iter_md = dnnl::memory::desc({user_dec_wei_iter_dims}, |
582 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldigo); |
583 | auto user_dec_bias_md = dnnl::memory::desc({user_dec_bias_dims}, |
584 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo); |
585 | auto dec_dst_layer_md = dnnl::memory::desc({dec_dst_layer_dims}, |
586 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::tnc); |
587 | auto dec_src_layer_md = dnnl::memory::desc({dec_src_layer_dims}, |
588 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::tnc); |
589 | auto dec_dst_iter_md = dnnl::memory::desc({dec_dst_iter_dims}, |
590 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldnc); |
591 | auto dec_dst_iter_c_md = dnnl::memory::desc({dec_dst_iter_c_dims}, |
592 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldnc); |
593 | //[dec mem desc] |
594 | /// |
595 | /// Decoder : Create memory |
596 | /// @snippet cpu_rnn_inference_f32.cpp create dec memory |
597 | /// |
598 | //[create dec memory] |
599 | auto user_dec_wei_layer_memory = dnnl::memory( |
600 | user_dec_wei_layer_md, cpu_engine, user_dec_wei_layer.data()); |
601 | auto user_dec_wei_iter_memory = dnnl::memory( |
602 | user_dec_wei_iter_md, cpu_engine, user_dec_wei_iter.data()); |
603 | auto user_dec_bias_memory |
604 | = dnnl::memory(user_dec_bias_md, cpu_engine, user_dec_bias.data()); |
605 | auto user_dec_dst_layer_memory |
606 | = dnnl::memory(dec_dst_layer_md, cpu_engine, user_dec_dst.data()); |
607 | auto dec_src_layer_memory = dnnl::memory(dec_src_layer_md, cpu_engine); |
608 | auto dec_dst_iter_c_memory = dnnl::memory(dec_dst_iter_c_md, cpu_engine); |
609 | //[create dec memory] |
610 | |
611 | auto dec_wei_layer_md = dnnl::memory::desc({user_dec_wei_layer_dims}, |
612 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::any); |
613 | auto dec_wei_iter_md = dnnl::memory::desc({user_dec_wei_iter_dims}, |
614 | dnnl::memory::data_type::f32, dnnl::memory::format_tag::any); |
615 | |
616 | // As mentioned above, we create a view without context out of the |
617 | // memory with context. |
618 | /// |
619 | /// Decoder : As mentioned above, we create a view without context out of the memory with context. |
620 | /// @snippet cpu_rnn_inference_f32.cpp create noctx mem |
621 | /// |
622 | //[create noctx mem] |
623 | auto dec_dst_iter_memory = dnnl::memory(dec_dst_iter_md, cpu_engine); |
624 | auto dec_dst_iter_noctx_md = dec_dst_iter_md.submemory_desc( |
625 | dec_dst_iter_noctx_dims, {0, 0, 0, 0, 0}); |
626 | //[create noctx mem] |
627 | |
628 | // TODO: add support for residual connections |
629 | // should it be a set residual in pd or a field to set manually? |
630 | // should be an integer to specify at which layer to start |
631 | /// |
632 | /// Decoder : Create RNN decoder cell |
633 | /// @snippet cpu_rnn_inference_f32.cpp create dec rnn |
634 | /// |
635 | //[create dec rnn] |
636 | auto dec_ctx_prim_desc = lstm_forward::primitive_desc(cpu_engine, |
637 | prop_kind::forward_inference, |
638 | rnn_direction::unidirectional_left2right, dec_src_layer_md, |
639 | dec_dst_iter_md, dec_dst_iter_c_md, dec_wei_layer_md, |
640 | dec_wei_iter_md, user_dec_bias_md, dec_dst_layer_md, |
641 | dec_dst_iter_noctx_md, dec_dst_iter_c_md); |
642 | //[create dec rnn] |
643 | |
644 | /// |
645 | /// Decoder : reorder weight memory |
646 | /// @snippet cpu_rnn_inference_f32.cpp reorder weight memory |
647 | /// |
648 | //[reorder weight memory] |
649 | auto dec_wei_layer_memory |
650 | = memory(dec_ctx_prim_desc.weights_layer_desc(), cpu_engine); |
651 | auto dec_wei_layer_reorder_pd = reorder::primitive_desc( |
652 | user_dec_wei_layer_memory, dec_wei_layer_memory); |
653 | reorder(dec_wei_layer_reorder_pd) |
654 | .execute(s, user_dec_wei_layer_memory, dec_wei_layer_memory); |
655 | |
656 | auto dec_wei_iter_memory |
657 | = memory(dec_ctx_prim_desc.weights_iter_desc(), cpu_engine); |
658 | auto dec_wei_iter_reorder_pd = reorder::primitive_desc( |
659 | user_dec_wei_iter_memory, dec_wei_iter_memory); |
660 | reorder(dec_wei_iter_reorder_pd) |
661 | .execute(s, user_dec_wei_iter_memory, dec_wei_iter_memory); |
662 | //[reorder weight memory] |
663 | |
664 | /// |
665 | /// Decoder : add the rnn primitive with related arguments into decoder_net |
666 | /// @snippet cpu_rnn_inference_f32.cpp push rnn to decoder net |
667 | /// |
668 | //[push rnn to decoder net] |
669 | // TODO: add a reorder when they will be available |
670 | decoder_net.push_back(lstm_forward(dec_ctx_prim_desc)); |
671 | decoder_net_args.push_back({{DNNL_ARG_SRC_LAYER, dec_src_layer_memory}, |
672 | {DNNL_ARG_SRC_ITER, dec_dst_iter_memory}, |
673 | {DNNL_ARG_SRC_ITER_C, dec_dst_iter_c_memory}, |
674 | {DNNL_ARG_WEIGHTS_LAYER, dec_wei_layer_memory}, |
675 | {DNNL_ARG_WEIGHTS_ITER, dec_wei_iter_memory}, |
676 | {DNNL_ARG_BIAS, user_dec_bias_memory}, |
677 | {DNNL_ARG_DST_LAYER, user_dec_dst_layer_memory}, |
678 | {DNNL_ARG_DST_ITER, dec_dst_iter_memory}, |
679 | {DNNL_ARG_DST_ITER_C, dec_dst_iter_c_memory}}); |
680 | //[push rnn to decoder net] |
681 | // allocating temporary buffer for attention mechanism |
682 | std::vector<float> weighted_annotations( |
683 | src_seq_length_max * batch * feature_size, 1.0f); |
684 | |
685 | /// |
686 | /// **Execution** |
687 | /// |
688 | auto execute = [&]() { |
689 | assert(encoder_net.size() == encoder_net_args.size() |
690 | && "something is missing" ); |
691 | /// |
692 | /// run encoder (1 stream) |
693 | /// @snippet cpu_rnn_inference_f32.cpp run enc |
694 | /// |
695 | //[run enc] |
696 | for (size_t p = 0; p < encoder_net.size(); ++p) |
697 | encoder_net.at(p).execute(s, encoder_net_args.at(p)); |
698 | //[run enc] |
699 | |
700 | /// |
701 | /// we compute the weighted annotations once before the decoder |
702 | /// @snippet cpu_rnn_inference_f32.cpp weight ano |
703 | /// |
704 | //[weight ano] |
705 | compute_weighted_annotations(weighted_annotations.data(), |
706 | src_seq_length_max, batch, feature_size, |
707 | user_weights_annotation.data(), |
708 | (float *)enc_dst_layer_memory.get_data_handle()); |
709 | //[weight ano] |
710 | |
711 | /// |
712 | /// We initialize src_layer to the embedding of the end of |
713 | /// sequence character, which are assumed to be 0 here |
714 | /// @snippet cpu_rnn_inference_f32.cpp init src_layer |
715 | /// |
716 | //[init src_layer] |
717 | memset(dec_src_layer_memory.get_data_handle(), 0, |
718 | dec_src_layer_memory.get_desc().get_size()); |
719 | //[init src_layer] |
720 | /// |
721 | /// From now on, src points to the output of the last iteration |
722 | /// |
723 | for (dim_t i = 0; i < tgt_seq_length_max; i++) { |
724 | float *src_att_layer_handle |
725 | = (float *)dec_src_layer_memory.get_data_handle(); |
726 | float *src_att_iter_handle |
727 | = (float *)dec_dst_iter_memory.get_data_handle(); |
728 | |
729 | /// |
730 | /// Compute attention context vector into the first layer src_iter |
731 | /// @snippet cpu_rnn_inference_f32.cpp att ctx |
732 | /// |
733 | //[att ctx] |
734 | compute_attention(src_att_iter_handle, src_seq_length_max, batch, |
735 | feature_size, user_weights_attention_src_layer.data(), |
736 | src_att_layer_handle, |
737 | (float *)enc_bidir_dst_layer_memory.get_data_handle(), |
738 | weighted_annotations.data(), |
739 | user_weights_alignments.data()); |
740 | //[att ctx] |
741 | |
742 | /// |
743 | /// copy the context vectors to all layers of src_iter |
744 | /// @snippet cpu_rnn_inference_f32.cpp cp ctx |
745 | /// |
746 | //[cp ctx] |
747 | copy_context( |
748 | src_att_iter_handle, dec_n_layers, batch, feature_size); |
749 | //[cp ctx] |
750 | |
751 | assert(decoder_net.size() == decoder_net_args.size() |
752 | && "something is missing" ); |
753 | /// |
754 | /// run the decoder iteration |
755 | /// @snippet cpu_rnn_inference_f32.cpp run dec iter |
756 | /// |
757 | //[run dec iter] |
758 | for (size_t p = 0; p < decoder_net.size(); ++p) |
759 | decoder_net.at(p).execute(s, decoder_net_args.at(p)); |
760 | //[run dec iter] |
761 | |
762 | /// |
763 | /// Move the handle on the src/dst layer to the next iteration |
764 | /// @snippet cpu_rnn_inference_f32.cpp set handle |
765 | /// |
766 | //[set handle] |
767 | auto dst_layer_handle |
768 | = (float *)user_dec_dst_layer_memory.get_data_handle(); |
769 | dec_src_layer_memory.set_data_handle(dst_layer_handle); |
770 | user_dec_dst_layer_memory.set_data_handle( |
771 | dst_layer_handle + batch * feature_size); |
772 | //[set handle] |
773 | } |
774 | }; |
775 | /// @page cpu_rnn_inference_f32_cpp |
776 | /// |
777 | std::cout << "Parameters:" << std::endl |
778 | << " batch = " << batch << std::endl |
779 | << " feature size = " << feature_size << std::endl |
780 | << " maximum source sequence length = " << src_seq_length_max |
781 | << std::endl |
782 | << " maximum target sequence length = " << tgt_seq_length_max |
783 | << std::endl |
784 | << " number of layers of the bidirectional encoder = " |
785 | << enc_bidir_n_layers << std::endl |
786 | << " number of layers of the unidirectional encoder = " |
787 | << enc_unidir_n_layers << std::endl |
788 | << " number of layers of the decoder = " << dec_n_layers |
789 | << std::endl; |
790 | |
791 | execute(); |
792 | s.wait(); |
793 | } |
794 | |
795 | int main(int argc, char **argv) { |
796 | return handle_example_errors({engine::kind::cpu}, simple_net); |
797 | } |
798 | |