1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @file
18/// C++ API
19
20#ifndef ONEAPI_DNNL_DNNL_HPP
21#define ONEAPI_DNNL_DNNL_HPP
22
23#include "oneapi/dnnl/dnnl_config.h"
24
25/// @cond DO_NOT_DOCUMENT_THIS
26#include <algorithm>
27#include <cstdlib>
28#include <iterator>
29#include <memory>
30#include <string>
31#include <vector>
32#include <unordered_map>
33
34#include "oneapi/dnnl/dnnl.h"
35#include "oneapi/dnnl/dnnl_common.hpp"
36
37/// @endcond
38
39/// @addtogroup dnnl_api oneDNN API
40/// @{
41
42/// oneDNN namespace
43namespace dnnl {
44
45/// @addtogroup dnnl_api_utils Utilities
46/// Utility types and definitions.
47/// @{
48
49/// @cond DO_NOT_DOCUMENT_THIS
50template <typename T>
51void validate_container_size(const T &v, const char *error_message,
52 int min_size = 1, int max_size = -1) {
53 const int size = (int)v.size();
54 if (size < min_size || (max_size >= 0 && size > max_size))
55 DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
56}
57/// @endcond
58
59/// @cond DO_NOT_DOCUMENT_THIS
60template <>
61struct handle_traits<dnnl_memory_desc_t> {
62 static dnnl_status_t destructor(dnnl_memory_desc_t p) {
63 return dnnl_memory_desc_destroy(p);
64 }
65};
66
67template <>
68struct handle_traits<dnnl_memory_t> {
69 static dnnl_status_t destructor(dnnl_memory_t p) {
70 return dnnl_memory_destroy(p);
71 }
72};
73
74template <>
75struct handle_traits<dnnl_primitive_desc_t> {
76 static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
77 return dnnl_primitive_desc_destroy(p);
78 }
79};
80
81template <>
82struct handle_traits<dnnl_primitive_t> {
83 static dnnl_status_t destructor(dnnl_primitive_t p) {
84 return dnnl_primitive_destroy(p);
85 }
86};
87
88/// @endcond
89
90/// @} dnnl_api_utils
91
92struct stream;
93struct memory;
94struct primitive_desc;
95
96/// @addtogroup dnnl_api_primitives Primitives
97/// Compute primitives
98/// @sa @ref dev_guide_basic_concepts
99/// @{
100
101/// @addtogroup dnnl_api_primitives_common Common
102/// Common operations to create, destroy and inspect primitives
103/// @{
104
105/// Base class for all computational primitives.
106struct primitive : public handle<dnnl_primitive_t> {
107 /// Kinds of primitives supported by the library.
108 enum class kind {
109 /// Undefined primitive
110 undef = dnnl_undefined_primitive,
111 /// A reorder primitive.
112 reorder = dnnl_reorder,
113 /// A shuffle primitive.
114 shuffle = dnnl_shuffle,
115 /// A (out-of-place) tensor concatenation primitive.
116 concat = dnnl_concat,
117 /// A summation primitive.
118 sum = dnnl_sum,
119 /// A convolution primitive.
120 convolution = dnnl_convolution,
121 /// A deconvolution primitive.
122 deconvolution = dnnl_deconvolution,
123 /// An element-wise primitive.
124 eltwise = dnnl_eltwise,
125 /// An LRN primitive.
126 lrn = dnnl_lrn,
127 /// A batch normalization primitive.
128 batch_normalization = dnnl_batch_normalization,
129 /// An inner product primitive.
130 inner_product = dnnl_inner_product,
131 /// An RNN primitive.
132 rnn = dnnl_rnn,
133 /// A binary primitive.
134 binary = dnnl_binary,
135 /// A matmul (matrix multiplication) primitive.
136 matmul = dnnl_matmul,
137 /// A resampling primitive.
138 resampling = dnnl_resampling,
139 /// A pooling primitive.
140 pooling = dnnl_pooling,
141 /// A reduction primitive.
142 reduction = dnnl_reduction,
143 /// A PReLU primitive.
144 prelu = dnnl_prelu,
145 /// A softmax primitive.
146 softmax = dnnl_softmax,
147 /// A layer normalization primitive.
148 layer_normalization = dnnl_layer_normalization,
149 };
150
151 using handle::handle;
152
153 /// Default constructor. Constructs an empty object.
154 primitive() = default;
155
156 /// Constructs a primitive from a C API primitive descriptor.
157 ///
158 /// @param c_pd C API primitive descriptor.
159 primitive(const_dnnl_primitive_desc_t c_pd);
160
161 /// Constructs a primitive from a C API primitive descriptor and a cache blob.
162 ///
163 /// @param c_pd C API primitive descriptor.
164 /// @param cache_blob Cache blob.
165 primitive(const_dnnl_primitive_desc_t c_pd,
166 const std::vector<uint8_t> &cache_blob);
167
168 /// Constructs a primitive from a primitive descriptor.
169 ///
170 /// @param pd Primitive descriptor.
171 primitive(const primitive_desc &pd);
172
173 /// Constructs a primitive from a primitive descriptor and a cache blob.
174 ///
175 /// @param pd Primitive descriptor.
176 /// @param cache_blob Cache blob.
177 primitive(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob);
178
179 /// Returns the C API primitive descriptor of the underlying C API
180 /// primitive.
181 ///
182 /// @returns The underlying C API primitive descriptor.
183 inline const_dnnl_primitive_desc_t get_primitive_desc() const;
184
185 /// Returns the kind of the primitive.
186 ///
187 /// @returns The primitive kind.
188 inline kind get_kind() const;
189
190 /// Returns a cache blob for the primitive.
191 ///
192 /// @returns Vector containing the cache blob.
193 ///
194 /// @note The cache blob can be empty. It's the user's responsibility to
195 /// check whether it's empty prior to passing it to the primitive
196 /// constructor.
197 inline std::vector<uint8_t> get_cache_blob() const;
198
199 /// Executes computations specified by the primitive in a specified stream.
200 ///
201 /// Arguments are passed via an arguments map containing <index,
202 /// memory object> pairs. The index must be one of the `DNNL_ARG_*` values
203 /// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
204 /// matching the one returned by
205 /// primitive_desc::query_md(#query::exec_arg_md, index) unless using
206 /// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
207 ///
208 /// @param astream Stream object. The stream must belong to the same engine
209 /// as the primitive.
210 /// @param args Arguments map.
211 void execute(const stream &astream,
212 const std::unordered_map<int, memory> &args) const;
213};
214
215/// Converts primitive kind enum value from C++ API to C API type.
216///
217/// @param akind C++ API primitive kind enum value.
218/// @returns Corresponding C API primitive kind enum value.
219inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) {
220 return static_cast<dnnl_primitive_kind_t>(akind);
221}
222
223const_dnnl_primitive_desc_t primitive::get_primitive_desc() const {
224 const_dnnl_primitive_desc_t pd;
225 error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd),
226 "could not get a primitive descriptor from a primitive");
227 return pd;
228}
229
230dnnl::primitive::kind primitive::get_kind() const {
231 const_dnnl_primitive_desc_t pd = get_primitive_desc();
232 // TODO (Roma): the code below is only needed because get_primitive_desc
233 // returns a C type.
234 dnnl_primitive_kind_t kind;
235 error::wrap_c_api(dnnl_primitive_desc_query(
236 pd, dnnl_query_primitive_kind, 0, (void *)&kind),
237 "could not get a primitive kind from a primitive descriptor");
238 return static_cast<dnnl::primitive::kind>(kind);
239}
240
241std::vector<uint8_t> primitive::get_cache_blob() const {
242 size_t size;
243 error::wrap_c_api(dnnl_primitive_get_cache_blob(get(), &size, nullptr),
244 "could not get cache blob size from a primitive");
245
246 std::vector<uint8_t> cache_blob(size);
247 error::wrap_c_api(
248 dnnl_primitive_get_cache_blob(get(), &size, cache_blob.data()),
249 "could not get a cache blob from a primitive");
250 return cache_blob;
251}
252
253/// @} dnnl_api_primitives_common
254
255/// @addtogroup dnnl_api_attributes
256///
257/// A container for parameters that extend primitives behavior.
258///
259/// Attributes can also contain Post-ops, which are computations executed
260/// after the primitive.
261///
262/// @sa @ref dev_guide_attributes
263/// @sa @ref dev_guide_attributes_post_ops
264///
265/// @{
266
267/// Scratchpad mode
268enum class scratchpad_mode {
269 /// The library manages the scratchpad allocation according to the policy
270 /// specified by the `DNNL_ENABLE_CONCURRENT_EXEC`
271 /// [build option](@ref dev_guide_build_options) (default).
272 ///
273 /// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library
274 /// scratchpad is common to all primitives to reduce the memory footprint.
275 /// This configuration comes with limited thread-safety properties, namely
276 /// primitives can be created and executed in parallel but cannot migrate
277 /// between threads (in other words, each primitive should be executed in
278 /// the same thread it was created in).
279 ///
280 /// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is
281 /// private to each primitive. The memory footprint is larger than when
282 /// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be
283 /// created and run concurrently (the same primitive cannot be run
284 /// concurrently from two different threads though).
285 library = dnnl_scratchpad_mode_library,
286 /// The user manages the scratchpad allocation by querying and providing
287 /// the scratchpad memory to primitives. This mode is thread-safe as long
288 /// as the scratchpad buffers are not used concurrently by two primitive
289 /// executions.
290 user = dnnl_scratchpad_mode_user,
291};
292
293/// Converts a scratchpad mode enum value from C++ API to C API type.
294///
295/// @param mode C++ API scratchpad mode enum value.
296/// @returns Corresponding C API scratchpad mode enum value.
297inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
298 return static_cast<dnnl_scratchpad_mode_t>(mode);
299}
300
301/// Propagation kind.
302enum class prop_kind {
303 /// Undefined propagation kind.
304 undef = dnnl_prop_kind_undef,
305 /// Forward data propagation (training mode). In this mode, primitives
306 /// perform computations necessary for subsequent backward propagation.
307 forward_training = dnnl_forward_training,
308 /// Forward data propagation (inference mode). In this mode, primitives
309 /// perform only computations that are necessary for inference and omit
310 /// computations that are necessary only for backward propagation.
311 forward_inference = dnnl_forward_inference,
312 /// Forward data propagation,
313 /// alias for #dnnl::prop_kind::forward_training.
314 forward = dnnl_forward,
315 /// Backward propagation (with respect to all parameters).
316 backward = dnnl_backward,
317 /// Backward data propagation.
318 backward_data = dnnl_backward_data,
319 /// Backward weights propagation.
320 backward_weights = dnnl_backward_weights,
321 /// Backward bias propagation.
322 backward_bias = dnnl_backward_bias
323};
324
325/// Converts propagation kind enum value from C++ API to C API type.
326///
327/// @param akind C++ API propagation kind enum value.
328/// @returns Corresponding C API propagation kind enum value.
329inline dnnl_prop_kind_t convert_to_c(prop_kind akind) {
330 return static_cast<dnnl_prop_kind_t>(akind);
331}
332
333/// Kinds of algorithms.
334enum class algorithm {
335 /// Undefined algorithm
336 undef = dnnl_alg_kind_undef,
337 /// Convolution algorithm that is chosen to be either direct or Winograd
338 /// automatically
339 convolution_auto = dnnl_convolution_auto,
340 /// Direct convolution
341 convolution_direct = dnnl_convolution_direct,
342 /// Winograd convolution
343 convolution_winograd = dnnl_convolution_winograd,
344 /// Direct deconvolution
345 deconvolution_direct = dnnl_deconvolution_direct,
346 /// Winograd deconvolution
347 deconvolution_winograd = dnnl_deconvolution_winograd,
348 /// Elementwise: rectified linear unit (ReLU)
349 eltwise_relu = dnnl_eltwise_relu,
350 /// Elementwise: hyperbolic tangent non-linearity (tanh)
351 eltwise_tanh = dnnl_eltwise_tanh,
352 /// Elementwise: exponential linear unit (ELU)
353 eltwise_elu = dnnl_eltwise_elu,
354 /// Elementwise: square
355 eltwise_square = dnnl_eltwise_square,
356 /// Elementwise: abs
357 eltwise_abs = dnnl_eltwise_abs,
358 /// Elementwise: square root
359 eltwise_sqrt = dnnl_eltwise_sqrt,
360 /// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$)
361 eltwise_swish = dnnl_eltwise_swish,
362 /// Elementwise: linear
363 eltwise_linear = dnnl_eltwise_linear,
364 /// Elementwise: soft_relu
365 eltwise_soft_relu = dnnl_eltwise_soft_relu,
366 /// Elementwise: mish
367 eltwise_mish = dnnl_eltwise_mish,
368 /// Elementwise: logistic
369 eltwise_logistic = dnnl_eltwise_logistic,
370 /// Elementwise: exponent
371 eltwise_exp = dnnl_eltwise_exp,
372 /// Elementwise: tanh-based gelu
373 eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh,
374 /// Elementwise: erf-based gelu
375 eltwise_gelu_erf = dnnl_eltwise_gelu_erf,
376 /// Elementwise: natural logarithm
377 eltwise_log = dnnl_eltwise_log,
378 /// Elementwise: clip
379 eltwise_clip = dnnl_eltwise_clip,
380 /// Eltwise: clip version 2
381 eltwise_clip_v2 = dnnl_eltwise_clip_v2,
382 /// Elementwise: pow
383 eltwise_pow = dnnl_eltwise_pow,
384 /// Elementwise: round
385 eltwise_round = dnnl_eltwise_round,
386 /// Elementwise: hardswish
387 eltwise_hardswish = dnnl_eltwise_hardswish,
388 /// Elementwise: hardsigmoid
389 eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid,
390 /// Elementwise: rectified linar unit (ReLU) (dst for backward)
391 eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd,
392 /// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
393 eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd,
394 /// Elementwise: exponential linear unit (ELU) (dst for backward)
395 eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd,
396 /// Elementwise: square root (dst for backward)
397 eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd,
398 /// Elementwise: logistic (dst for backward)
399 eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd,
400 /// Elementwise: exponent (dst for backward)
401 eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd,
402 /// Elementwise: clip version 2 (dst for backward)
403 eltwise_clip_v2_use_dst_for_bwd = dnnl_eltwise_clip_v2_use_dst_for_bwd,
404 /// Local response normalization (LRN) across multiple channels
405 lrn_across_channels = dnnl_lrn_across_channels,
406 /// LRN within a single channel
407 lrn_within_channel = dnnl_lrn_within_channel,
408 /// Max pooling
409 pooling_max = dnnl_pooling_max,
410 /// Average pooling include padding
411 pooling_avg_include_padding = dnnl_pooling_avg_include_padding,
412 /// Average pooling exclude padding
413 pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding,
414 /// RNN cell
415 vanilla_rnn = dnnl_vanilla_rnn,
416 /// LSTM cell
417 vanilla_lstm = dnnl_vanilla_lstm,
418 /// GRU cell
419 vanilla_gru = dnnl_vanilla_gru,
420 /// GRU cell with linear before reset. Differs from the vanilla GRU
421 /// in how the new memory gate is calculated:
422 /// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$
423 /// LRB GRU expects 4 bias tensors on input:
424 /// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
425 lbr_gru = dnnl_lbr_gru,
426 /// AUGRU cell
427 vanilla_augru = dnnl_vanilla_augru,
428 /// AUGRU cell with linear before reset
429 lbr_augru = dnnl_lbr_augru,
430 /// Binary add
431 binary_add = dnnl_binary_add,
432 /// Binary mul
433 binary_mul = dnnl_binary_mul,
434 /// Binary max
435 binary_max = dnnl_binary_max,
436 /// Binary min
437 binary_min = dnnl_binary_min,
438 /// Binary div
439 binary_div = dnnl_binary_div,
440 /// Binary sub
441 binary_sub = dnnl_binary_sub,
442 /// Binary greater than or equal
443 binary_ge = dnnl_binary_ge,
444 /// Binary greater than
445 binary_gt = dnnl_binary_gt,
446 /// Binary less than or equal
447 binary_le = dnnl_binary_le,
448 /// Binary less than
449 binary_lt = dnnl_binary_lt,
450 /// Binary equal
451 binary_eq = dnnl_binary_eq,
452 /// Binary not equal
453 binary_ne = dnnl_binary_ne,
454 /// Nearest Neighbor resampling method
455 resampling_nearest = dnnl_resampling_nearest,
456 /// Linear (Bilinear, Trilinear) resampling method
457 resampling_linear = dnnl_resampling_linear,
458 /// Reduction using max operation
459 reduction_max = dnnl_reduction_max,
460 /// Reduction using min operation
461 reduction_min = dnnl_reduction_min,
462 /// Reduction using sum operation
463 reduction_sum = dnnl_reduction_sum,
464 /// Reduction using mul operation
465 reduction_mul = dnnl_reduction_mul,
466 /// Reduction using mean operation
467 reduction_mean = dnnl_reduction_mean,
468 /// Reduction using norm_lp_max operation
469 reduction_norm_lp_max = dnnl_reduction_norm_lp_max,
470 /// Reduction using norm_lp_sum operation
471 reduction_norm_lp_sum = dnnl_reduction_norm_lp_sum,
472 /// Reduction using norm_lp_power_p_max operation
473 reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max,
474 /// Reduction using norm_lp_power_p_sum operation
475 reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum,
476 /// Softmax, numerically stable
477 softmax_accurate = dnnl_softmax_accurate,
478 /// LogSoftmax, numerically stable
479 softmax_log = dnnl_softmax_log,
480};
481
482/// Converts algorithm kind enum value from C++ API to C API type.
483/// @param aalgorithm C++ API algorithm kind enum value.
484/// @returns Corresponding C API algorithm kind enum value.
485inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) {
486 return static_cast<dnnl_alg_kind_t>(aalgorithm);
487}
488
489/// @} dnnl_api_attributes
490
491/// @addtogroup dnnl_api_primitives_common
492/// @{
493
494/// Flags for normalization primitives.
495enum class normalization_flags : unsigned {
496 /// Use no normalization flags. If specified, the library computes mean and
497 /// variance on forward propagation for training and inference, outputs them
498 /// on forward propagation for training, and computes the respective
499 /// derivatives on backward propagation.
500 none = dnnl_normalization_flags_none,
501
502 /// Use global statistics. If specified, the library uses mean and
503 /// variance provided by the user as an input on forward propagation and
504 /// does not compute their derivatives on backward propagation. Otherwise,
505 /// the library computes mean and variance on forward propagation for
506 /// training and inference, outputs them on forward propagation for
507 /// training, and computes the respective derivatives on backward
508 /// propagation.
509 use_global_stats = dnnl_use_global_stats,
510
511 /// Use scale parameter. If specified, the user is expected to pass scale as
512 /// input on forward propagation. On backward propagation of type
513 /// #dnnl::prop_kind::backward, the library computes its derivative.
514 use_scale = dnnl_use_scale,
515
516 /// Use shift parameter. If specified, the user is expected to pass shift as
517 /// input on forward propagation. On backward propagation of type
518 /// #dnnl::prop_kind::backward, the library computes its derivative.
519 use_shift = dnnl_use_shift,
520
521 /// Fuse normalization with ReLU. On training, normalization will require
522 /// the workspace to implement backward propagation. On inference, the
523 /// workspace is not required and behavior is the same as when normalization
524 /// is fused with ReLU using the post-ops API.
525 fuse_norm_relu = dnnl_fuse_norm_relu,
526
527 /// Fuse normalization with elementwise binary Add and then fuse with ReLU.
528 /// On training, normalization will require the workspace to implement
529 /// backward propagation. On inference, the workspace is not required.
530 fuse_norm_add_relu = dnnl_fuse_norm_add_relu,
531};
532
533/// Converts normalization flags enum value from C++ API to C API type.
534/// @param flags C++ API normalization flags enum value.
535/// @returns Corresponding C API normalization flags enum value.
536inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) {
537 return static_cast<dnnl_normalization_flags_t>(flags);
538}
539
540/// @} dnnl_api_primitives_common
541
542/// @addtogroup dnnl_api_rnn
543/// @{
544
545/// RNN cell flags.
546enum class rnn_flags : unsigned {
547 /// Undefined RNN flags
548 undef = dnnl_rnn_flags_undef
549};
550
551/// Converts RNN cell flags enum value from C++ API to C API type.
552/// @param flags C++ API RNN cell flags enum value.
553/// @returns Corresponding C API RNN cell flags enum value.
554inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) {
555 return static_cast<dnnl_rnn_flags_t>(flags);
556}
557
558DNNL_DEFINE_BITMASK_OPS(normalization_flags)
559DNNL_DEFINE_BITMASK_OPS(rnn_flags)
560
561/// A direction of RNN primitive execution
562enum class rnn_direction {
563 /// Undefined RNN direction.
564 undef = dnnl_rnn_direction_undef,
565 /// Unidirectional execution of RNN primitive from left to right.
566 unidirectional_left2right = dnnl_unidirectional_left2right,
567 /// Unidirectional execution of RNN primitive from right to left.
568 unidirectional_right2left = dnnl_unidirectional_right2left,
569 /// Bidirectional execution of RNN primitive with concatenation of the
570 /// results.
571 bidirectional_concat = dnnl_bidirectional_concat,
572 /// Bidirectional execution of RNN primitive with summation of the
573 /// results.
574 bidirectional_sum = dnnl_bidirectional_sum,
575};
576
577/// Converts RNN direction enum value from C++ API to C API type.
578/// @param dir C++ API RNN direction enum value.
579/// @returns Corresponding C API RNN direction enum value.
580inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) {
581 return static_cast<dnnl_rnn_direction_t>(dir);
582}
583
584/// @} dnnl_api_rnn
585
586/// @addtogroup dnnl_api_primitives_common
587/// @{
588
589/// Primitive descriptor query specification.
590///
591/// In general, queries are not used with the C++ API because most queries are
592/// implemented as class members.
593///
594/// See @ref dnnl_query_t for more information.
595enum class query {
596 /// no query
597 undef = dnnl_query_undef,
598
599 /// execution engine
600 engine = dnnl_query_engine,
601 /// primitive kind
602 primitive_kind = dnnl_query_primitive_kind,
603
604 /// number of inputs expected
605 num_of_inputs_s32 = dnnl_query_num_of_inputs_s32,
606 /// number of outputs expected
607 num_of_outputs_s32 = dnnl_query_num_of_outputs_s32,
608
609 /// runtime estimation (seconds), unimplemented
610 time_estimate_f64 = dnnl_query_time_estimate_f64,
611 /// memory required for scratchpad (bytes)
612 ///
613 /// @sa @ref dev_guide_attributes_scratchpad
614 memory_consumption_s64 = dnnl_query_memory_consumption_s64,
615
616 /// scratchpad engine
617 ///
618 /// engine to be used for creating scratchpad memory
619 scratchpad_engine = dnnl_query_scratchpad_engine,
620
621 /// reorder source engine
622 reorder_src_engine = dnnl_query_reorder_src_engine,
623 /// reorder destination engine
624 reorder_dst_engine = dnnl_query_reorder_dst_engine,
625
626 /// implementation name
627 impl_info_str = dnnl_query_impl_info_str,
628
629 /// propagation kind
630 prop_kind = dnnl_query_prop_kind,
631
632 /// size of cache blob ID in bytes
633 cache_blob_id_size_s64 = dnnl_query_cache_blob_id_size_s64,
634
635 /// cache blob ID (pointer to array)
636 cache_blob_id = dnnl_query_cache_blob_id,
637
638 /// strides
639 strides = dnnl_query_strides,
640 /// dilations
641 dilations = dnnl_query_dilations,
642 /// left padding
643 padding_l = dnnl_query_padding_l,
644 /// right padding
645 padding_r = dnnl_query_padding_r,
646 /// epsilon
647 epsilon_f32 = dnnl_query_epsilon_f32,
648 /// flags
649 flags = dnnl_query_flags,
650 /// algorithm kind
651 alg_kind = dnnl_query_alg_kind,
652 /// alpha
653 alpha_f32 = dnnl_query_alpha_f32,
654 /// beta
655 beta_f32 = dnnl_query_beta_f32,
656 /// axis
657 axis_s32 = dnnl_query_axis_s32,
658 /// LRN parameter local size
659 local_size_s64 = dnnl_query_local_size_s64,
660 /// LRN parameter K
661 k_f32 = dnnl_query_k_f32,
662 /// Reduction parameter P
663 p_f32 = dnnl_query_p_f32,
664 /// Resampling parameter factors
665 factors = dnnl_query_factors,
666 /// RNN parameter cell kind
667 cell_kind = dnnl_query_cell_kind,
668 /// RNN parameter direction
669 direction = dnnl_query_direction,
670 /// RNN parameter activation kind
671 activation_kind = dnnl_query_activation_kind,
672 /// Pooling parameter kernel
673 kernel = dnnl_query_kernel,
674 /// Shuffle parameter group size
675 group_size_s64 = dnnl_query_group_size_s64,
676
677 /// source memory desc
678 src_md = dnnl_query_src_md,
679 /// source gradient (diff) memory desc
680 diff_src_md = dnnl_query_diff_src_md,
681 /// weights memory descriptor desc
682 weights_md = dnnl_query_weights_md,
683 /// weights gradient (diff) memory desc
684 diff_weights_md = dnnl_query_diff_weights_md,
685 /// destination memory desc
686 dst_md = dnnl_query_dst_md,
687 /// destination gradient (diff) memory desc
688 diff_dst_md = dnnl_query_diff_dst_md,
689 /// workspace memory desc
690 workspace_md = dnnl_query_workspace_md,
691 /// scratchpad memory desc
692 scratchpad_md = dnnl_query_scratchpad_md,
693 /// memory desc of an execute argument
694 exec_arg_md = dnnl_query_exec_arg_md,
695
696 /// number of dimensions
697 ndims_s32 = dnnl_query_ndims_s32,
698 /// vector of dimensions
699 dims = dnnl_query_dims,
700 /// data type
701 data_type = dnnl_query_data_type,
702 /// submemory offset
703 submemory_offset_s64 = dnnl_query_submemory_offset_s64,
704 /// vector of padded dimensions
705 padded_dims = dnnl_query_padded_dims,
706 /// vector of padded offsets
707 padded_offsets = dnnl_query_padded_offsets,
708 /// format kind
709 format_kind = dnnl_query_format_kind,
710 /// number of innermost blocks
711 inner_nblks_s32 = dnnl_query_inner_nblks_s32,
712 /// vector of sizes of the innermost blocks
713 inner_blks = dnnl_query_inner_blks,
714 /// vector of logical indices of the blocks
715 inner_idxs = dnnl_query_inner_idxs,
716};
717
718/// Converts query enum value from C++ API to C API type.
719/// @param aquery C++ API query enum value.
720/// @returns Corresponding C API query enum value.
721inline dnnl_query_t convert_to_c(query aquery) {
722 return static_cast<dnnl_query_t>(aquery);
723}
724
725/// @} dnnl_api_primitives_common
726
727/// @} dnnl_api_primitives
728
729/// @addtogroup dnnl_api_memory Memory
730///
731/// A container that describes and stores data. Memory objects can contain
732/// data of various types and formats. There are two levels of abstraction:
733///
734/// 1. **Memory descriptor** -- engine-agnostic logical description of data
735/// (number of dimensions, dimension sizes, and data type), and,
736/// optionally, the information about the physical format of data in
737/// memory. If this information is not known yet, a memory descriptor can
738/// be created with #dnnl::memory::format_tag::any. This allows
739/// compute-intensive primitives to choose the best format for
740/// computation. The user is responsible for reordering the data into the
741/// chosen format when formats do not match.
742///
743/// A memory descriptor can be initialized either by specifying dimensions
744/// and a memory format tag or strides for each of them, or by
745/// manipulating the dnnl_memory_desc_t structure directly.
746///
747/// @warning
748/// The latter approach requires understanding how the physical data
749/// representation is mapped to the structure and is discouraged. This
750/// topic is discussed in @ref dev_guide_understanding_memory_formats.
751///
752/// The user can query the amount of memory required by a memory
753/// descriptor using the #dnnl::memory::desc::get_size() function. The
754/// size of data in general cannot be computed as the product of
755/// dimensions multiplied by the size of the data type. So users are
756/// required to use this function for better code portability.
757///
758/// Two memory descriptors can be compared using the equality and
759/// inequality operators. The comparison is especially useful when
760/// checking whether it is necessary to reorder data from the user's data
761/// format to a primitive's format.
762///
763/// 2. **Memory object** -- an engine-specific object that handles the memory
764/// buffer and its description (a memory descriptor). For the CPU engine or
765/// with USM, the memory buffer handle is simply a pointer to @c void. The
766/// memory buffer can be queried using #dnnl::memory::get_data_handle() and
767/// set using #dnnl::memory::set_data_handle(). The underlying SYCL buffer,
768/// when used, can be queried using #dnnl::sycl_interop::get_buffer and set
769/// using #dnnl::sycl_interop::set_buffer. A memory object can also be
770/// queried for the underlying memory descriptor and for its engine using
771/// #dnnl::memory::get_desc() and dnnl::memory::get_engine().
772///
773/// Along with ordinary memory descriptors with all dimensions being positive,
774/// the library supports *zero-volume* memory descriptors with one or more
775/// dimensions set to zero. This is used to support the NumPy\* convention.
776/// If a zero-volume memory is passed to a primitive, the primitive typically
777/// does not perform any computations with this memory. For example:
778///
779/// - A concatenation primitive would ignore all memory object with zeroes in
780/// the concat dimension / axis.
781///
782/// - A forward convolution with a source memory object with zero in the
783/// minibatch dimension would always produce a destination memory object
784/// with a zero in the minibatch dimension and perform no computations.
785///
786/// - However, a forward convolution with a zero in one of the weights
787/// dimensions is ill-defined and is considered to be an error by the
788/// library because there is no clear definition of what the output values
789/// should be.
790///
791/// Memory buffer of a zero-volume memory is never accessed.
792///
793/// @{
794
795/// Memory object.
796///
797/// A memory object encapsulates a handle to a memory buffer allocated on a
798/// specific engine, tensor dimensions, data type, and memory format, which is
799/// the way tensor indices map to offsets in linear memory space. Memory
800/// objects are passed to primitives during execution.
801struct memory : public handle<dnnl_memory_t> {
802 using handle::handle;
803
804 /// Integer type for representing dimension sizes and indices.
805 typedef dnnl_dim_t dim;
806 /// Vector of dimensions. Implementations are free to force a limit on the
807 /// vector's length.
808 typedef std::vector<dim> dims;
809
810 /// Helper function that validates that an `std::vector` of dimensions can
811 /// be safely converted to the C API array ::dnnl_dims_t. Throws if
812 /// validation fails.
813 ///
814 /// @param v Vector of dimensions.
815 /// @param min_size Minimum expected size of the vector.
816 template <typename T>
817 static void validate_dims(const std::vector<T> &v, int min_size = 0) {
818 validate_container_size(
819 v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
820 }
821
822 /// Data type specification.
823 enum class data_type {
824 /// Undefined data type (used for empty memory descriptors).
825 undef = dnnl_data_type_undef,
826 /// [16-bit/half-precision floating point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format).
827 f16 = dnnl_f16,
828 /// non-standard
829 /// [16-bit floating point with 7-bit mantissa](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format).
830 bf16 = dnnl_bf16,
831 /// [32-bit/single-precision floating point](https://en.wikipedia.org/wiki/Single-precision_floating-point_format).
832 f32 = dnnl_f32,
833 //// [64-bit/double-precision floating point](https://en.wikipedia.org/wiki/Double-precision_floating-point_format).
834 f64 = dnnl_f64,
835 /// 32-bit signed integer.
836 s32 = dnnl_s32,
837 /// 8-bit signed integer.
838 s8 = dnnl_s8,
839 /// 8-bit unsigned integer.
840 u8 = dnnl_u8,
841 };
842
843 /// Returns size of data type in bytes.
844 /// @returns The number of bytes occupied by data type.
845 static size_t data_type_size(data_type adata_type) {
846 return dnnl_data_type_size(convert_to_c(adata_type));
847 }
848
849 /// Memory format kind
850 enum class format_kind {
851 /// Undefined memory format kind, used for empty memory descriptors.
852 undef = dnnl_format_kind_undef,
853 /// A special format kind that indicates that the actual format will be
854 /// selected by a primitive automatically.
855 any = dnnl_format_kind_any,
856 /// A tensor in a generic format described by the stride and blocking
857 /// values in each dimension.
858 blocked = dnnl_blocked,
859 /// A special format kind that indicates that tensor format is opaque.
860 opaque = dnnl_format_kind_opaque,
861 };
862
863 /// Memory format tag specification.
864 ///
865 /// Memory format tags can be further divided into two categories:
866 ///
867 /// - Domain-agnostic names, i.e. names that do not depend on the tensor
868 /// usage in the specific primitive. These names use letters from `a`
869 /// to `f` to denote logical dimensions and form the order in which the
870 /// dimensions are laid in memory. For example,
871 /// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the
872 /// second logical dimension (denoted as `b`) is the innermost, i.e.
873 /// has stride = 1, and the first logical dimension (`a`) is laid out in
874 /// memory with stride equal to the size of the second dimension. On the
875 /// other hand, #dnnl::memory::format_tag::ba is the transposed version
876 /// of the same tensor: the outermost dimension (`a`) becomes the
877 /// innermost one.
878 ///
879 /// - Domain-specific names, i.e. names that make sense only in the
880 /// context of a certain domain, such as CNN. These names are
881 /// aliases to the corresponding domain-agnostic tags and used mostly
882 /// for convenience. For example, #dnnl::memory::format_tag::nc
883 /// is used to denote 2D CNN activations tensor memory format, where
884 /// the channels dimension is the innermost one and the batch dimension
885 /// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is
886 /// an alias for #dnnl::memory::format_tag::ab, because for
887 /// CNN primitives the logical dimensions of activations tensors come
888 /// in order: batch, channels, spatial. In other words, batch
889 /// corresponds to the first logical dimension (`a`), and channels
890 /// correspond to the second one (`b`).
891 ///
892 /// The following domain-specific notation applies to memory format tags:
893 /// - @c 'n' denotes the mini-batch dimension
894 /// - @c 'c' denotes a channels dimension
895 /// - When there are multiple channel dimensions (for example,
896 /// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions
897 /// of input and output channels
898 /// - @c 'g' denotes a groups dimension for convolution weights
899 /// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
900 /// respectively
901 ///
902 /// See @ref dnnl_format_tag_t for a detailed description.
903 enum class format_tag {
904 /// Undefined memory format tag
905 undef = dnnl_format_tag_undef,
906 /// Placeholder memory format tag. Used to instruct the primitive to
907 /// select a format automatically.
908 any = dnnl_format_tag_any,
909
910 /// plain 1D tensor
911 a = dnnl_a,
912
913 /// plain 2D tensor
914 ab = dnnl_ab,
915 /// permuted 2D tensor
916 ba = dnnl_ba,
917
918 /// plain 3D tensor
919 abc = dnnl_abc,
920 /// permuted 3D tensor
921 acb = dnnl_acb,
922 /// permuted 3D tensor
923 bac = dnnl_bac,
924 /// permuted 3D tensor
925 bca = dnnl_bca,
926 /// permuted 3D tensor
927 cba = dnnl_cba,
928
929 /// plain 4D tensor
930 abcd = dnnl_abcd,
931 /// permuted 4D tensor
932 abdc = dnnl_abdc,
933 /// permuted 4D tensor
934 acbd = dnnl_acbd,
935 /// permuted 4D tensor
936 acdb = dnnl_acdb,
937 /// permuted 4D tensor
938 adbc = dnnl_adbc,
939 /// permuted 4D tensor
940 bacd = dnnl_bacd,
941 /// permuted 4D tensor
942 bcda = dnnl_bcda,
943 /// permuted 4D tensor
944 cdba = dnnl_cdba,
945 /// permuted 4D tensor
946 dcab = dnnl_dcab,
947
948 /// plain 5D tensor
949 abcde = dnnl_abcde,
950 /// permuted 5D tensor
951 abdec = dnnl_abdec,
952 /// permuted 5D tensor
953 acbde = dnnl_acbde,
954 /// permuted 5D tensor
955 acdeb = dnnl_acdeb,
956 /// permuted 5D tensor
957 bacde = dnnl_bacde,
958 /// permuted 5D tensor
959 bcdea = dnnl_bcdea,
960 /// permuted 5D tensor
961 cdeba = dnnl_cdeba,
962 /// permuted 5D tensor
963 decab = dnnl_decab,
964 /// permuted 5D tensor
965 abced = dnnl_abced,
966
967 /// plain 6D tensor
968 abcdef = dnnl_abcdef,
969 /// permuted 6D tensor
970 abdfce = dnnl_abdfce,
971 /// permuted 6D tensor
972 acbdef = dnnl_acbdef,
973 /// permuted 6D tensor
974 abdefc = dnnl_abdefc,
975 /// permuted 6D tensor
976 defcab = dnnl_defcab,
977 /// permuted 6D tensor
978 abcdfe = dnnl_abcdfe,
979
980 /// plain 7D tensor
981 abcdefg = dnnl_abcdefg,
982 /// permuted 7D tensor
983 abcdegf = dnnl_abcdegf,
984
985 /// plain 8D tensor
986 abcdefgh = dnnl_abcdefgh,
987 /// permuted 8D tensor
988 abcdefhg = dnnl_abcdefhg,
989
990 /// plain 9D tensor
991 abcdefghi = dnnl_abcdefghi,
992 /// permuted 9D tensor
993 abcdefgih = dnnl_abcdefgih,
994
995 /// plain 10D tensor
996 abcdefghij = dnnl_abcdefghij,
997 /// permuted 10D tensor
998 abcdefghji = dnnl_abcdefghji,
999
1000 /// plain 11D tensor
1001 abcdefghijk = dnnl_abcdefghijk,
1002 /// permuted 11D tensor
1003 abcdefghikj = dnnl_abcdefghikj,
1004
1005 /// plain 12D tensor
1006 abcdefghijkl = dnnl_abcdefghijkl,
1007 /// permuted 12D tensor
1008 abcdefghijlk = dnnl_abcdefghijlk,
1009
1010 /// 1D tensor; an alias for #dnnl::memory::format_tag::a
1011 x = a,
1012 /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab
1013 nc = ab,
1014 /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba
1015 cn = ba,
1016 /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab
1017 tn = ab,
1018 /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba
1019 nt = ba,
1020 /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc
1021 ncw = abc,
1022 /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb
1023 nwc = acb,
1024 /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd
1025 nchw = abcd,
1026 /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb
1027 nhwc = acdb,
1028 /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda
1029 chwn = bcda,
1030 /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde
1031 ncdhw = abcde,
1032 /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb
1033 ndhwc = acdeb,
1034
1035 /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab
1036 oi = ab,
1037 /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba
1038 io = ba,
1039 /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc
1040 oiw = abc,
1041 /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb
1042 owi = acb,
1043 /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba
1044 wio = cba,
1045 /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca
1046 iwo = bca,
1047 /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd
1048 oihw = abcd,
1049 /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba
1050 hwio = cdba,
1051 /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb
1052 ohwi = acdb,
1053 /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda
1054 ihwo = bcda,
1055 /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd
1056 iohw = bacd,
1057 /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde
1058 oidhw = abcde,
1059 /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba
1060 dhwio = cdeba,
1061 /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb
1062 odhwi = acdeb,
1063 /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacde
1064 iodhw = bacde,
1065 /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea
1066 idhwo = bcdea,
1067
1068 /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd
1069 goiw = abcd,
1070 /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdc
1071 gowi = abdc,
1072 /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab
1073 wigo = dcab,
1074 /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdec
1075 gohwi = abdec,
1076 /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde
1077 goihw = abcde,
1078 /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab
1079 hwigo = decab,
1080 /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde
1081 giohw = acbde,
1082 /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
1083 goidhw = abcdef,
1084 /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
1085 giodhw = acbdef,
1086 /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdefc
1087 godhwi = abdefc,
1088 /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab
1089 dhwigo = defcab,
1090
1091 /// 3D RNN data tensor in the format (seq_length, batch, input
1092 /// channels); an alias for #dnnl::memory::format_tag::abc.
1093 tnc = abc,
1094 /// 3D RNN data tensor in the format (batch, seq_length, input
1095 /// channels); an alias for #dnnl::memory::format_tag::bac.
1096 ntc = bac,
1097 /// 4D RNN states tensor in the format (num_layers, num_directions,
1098 /// batch, state channels); an alias for #dnnl::memory::format_tag::abcd.
1099 ldnc = abcd,
1100 /// 5D RNN weights tensor in the format (num_layers, num_directions,
1101 /// input_channels, num_gates, output_channels);
1102 /// an alias for #dnnl::memory::format_tag::abcde.
1103 ///
1104 /// - For LSTM cells, the gates order is input, forget, candidate
1105 /// and output gate.
1106 /// - For GRU cells, the gates order is update, reset and output gate.
1107 ldigo = abcde,
1108 /// 5D RNN weights tensor in the format (num_layers, num_directions,
1109 /// num_gates, output_channels, input_channels);
1110 /// an alias for #dnnl::memory::format_tag::abdec.
1111 ///
1112 /// - For LSTM cells, the gates order is input, forget, candidate
1113 /// and output gate.
1114 /// - For GRU cells, the gates order is update, reset and output gate.
1115 ldgoi = abdec,
1116 /// 4D LSTM projection tensor in the format (num_layers, num_directions,
1117 /// num_channels_in_hidden_state, num_channels_in_recurrent_projection);
1118 /// an alias for #dnnl::memory::format_tag::abcd.
1119 ldio = abcd,
1120 /// 4D LSTM projection tensor in the format (num_layers, num_directions,
1121 /// num_channels_in_recurrent_projection, num_channels_in_hidden_state);
1122 /// an alias for #dnnl::memory::format_tag::abdc.
1123 ldoi = abdc,
1124 /// 4D RNN bias tensor in the format (num_layers, num_directions,
1125 /// num_gates, output_channels);
1126 /// an alias for #dnnl::memory::format_tag::abcd.
1127 ///
1128 /// - For LSTM cells, the gates order is input, forget, candidate
1129 /// and output gate.
1130 /// - For GRU cells, the gates order is update, reset and output gate.
1131 ldgo = abcd,
1132
1133 // Opaque blocked formats
1134
1135 AB16b16a = dnnl_AB16b16a,
1136 AB16b32a = dnnl_AB16b32a,
1137 AB16b64a = dnnl_AB16b64a,
1138 AB8b16a2b = dnnl_AB8b16a2b,
1139 AB8b32a2b = dnnl_AB8b32a2b,
1140 AB8b64a2b = dnnl_AB8b64a2b,
1141 AB4b16a4b = dnnl_AB4b16a4b,
1142 AB4b32a4b = dnnl_AB4b32a4b,
1143 AB4b64a4b = dnnl_AB4b64a4b,
1144 AB16b16a4b = dnnl_AB16b16a4b,
1145 AB16b32a4b = dnnl_AB16b32a4b,
1146 AB16b48a4b = dnnl_AB16b48a4b,
1147 AB16b64a4b = dnnl_AB16b64a4b,
1148 AB16b16a2b = dnnl_AB16b16a2b,
1149 AB16b32a2b = dnnl_AB16b32a2b,
1150 AB16b48a2b = dnnl_AB16b48a2b,
1151 AB16b64a2b = dnnl_AB16b64a2b,
1152 Abc16a = dnnl_Abc16a,
1153 ABc16a16b = dnnl_ABc16a16b,
1154 ABc4a4b = dnnl_ABc4a4b,
1155 aBc16b = dnnl_aBc16b,
1156 aBc32b = dnnl_aBc32b,
1157 ABc16b16a = dnnl_ABc16b16a,
1158 ABc16b32a = dnnl_ABc16b32a,
1159 ABc16b64a = dnnl_ABc16b64a,
1160 Abc4a = dnnl_Abc4a,
1161 aBc4b = dnnl_aBc4b,
1162 ABc4b16a4b = dnnl_ABc4b16a4b,
1163 ABc4b32a4b = dnnl_ABc4b32a4b,
1164 ABc4b64a4b = dnnl_ABc4b64a4b,
1165 ABc2b8a4b = dnnl_ABc2b8a4b,
1166 ABc16a16b2a = dnnl_ABc16a16b2a,
1167 ABc16b16a4b = dnnl_ABc16b16a4b,
1168 ABc16b32a4b = dnnl_ABc16b32a4b,
1169 ABc16b48a4b = dnnl_ABc16b48a4b,
1170 ABc16b64a4b = dnnl_ABc16b64a4b,
1171 ABc16b16a2b = dnnl_ABc16b16a2b,
1172 ABc16b32a2b = dnnl_ABc16b32a2b,
1173 ABc16b48a2b = dnnl_ABc16b48a2b,
1174 ABc16b64a2b = dnnl_ABc16b64a2b,
1175 ABc4b4a = dnnl_ABc4b4a,
1176 ABc8a16b2a = dnnl_ABc8a16b2a,
1177 ABc8a8b = dnnl_ABc8a8b,
1178 ABc8a4b = dnnl_ABc8a4b,
1179 aBc8b = dnnl_aBc8b,
1180 ABc8b16a2b = dnnl_ABc8b16a2b,
1181 ABc8b32a2b = dnnl_ABc8b32a2b,
1182 ABc8b64a2b = dnnl_ABc8b64a2b,
1183 ABc8b8a = dnnl_ABc8b8a,
1184 Abcd8a = dnnl_Abcd8a,
1185 Abcd16a = dnnl_Abcd16a,
1186 Abcd32a = dnnl_Abcd32a,
1187 ABcd16a16b = dnnl_ABcd16a16b,
1188 aBcd16b = dnnl_aBcd16b,
1189 aBcd32b = dnnl_aBcd32b,
1190 ABcd16b16a = dnnl_ABcd16b16a,
1191 ABcd16b32a = dnnl_ABcd16b32a,
1192 ABcd16b64a = dnnl_ABcd16b64a,
1193 aBCd16b16c = dnnl_aBCd16b16c,
1194 aBCd16c16b = dnnl_aBCd16c16b,
1195 Abcd4a = dnnl_Abcd4a,
1196 aBcd4b = dnnl_aBcd4b,
1197 ABcd4b16a4b = dnnl_ABcd4b16a4b,
1198 ABcd4b32a4b = dnnl_ABcd4b32a4b,
1199 ABcd4b64a4b = dnnl_ABcd4b64a4b,
1200 ABcd2b8a4b = dnnl_ABcd2b8a4b,
1201 ABcd4b4a = dnnl_ABcd4b4a,
1202 ABcd4a4b = dnnl_ABcd4a4b,
1203 aBCd4c16b4c = dnnl_aBCd4c16b4c,
1204 aBCd2c8b4c = dnnl_aBCd2c8b4c,
1205 ABcd16a16b2a = dnnl_ABcd16a16b2a,
1206 ABcd16b16a4b = dnnl_ABcd16b16a4b,
1207 ABcd16b32a4b = dnnl_ABcd16b32a4b,
1208 ABcd16b48a4b = dnnl_ABcd16b48a4b,
1209 ABcd16b64a4b = dnnl_ABcd16b64a4b,
1210 ABcd16b16a2b = dnnl_ABcd16b16a2b,
1211 ABcd16b32a2b = dnnl_ABcd16b32a2b,
1212 ABcd16b48a2b = dnnl_ABcd16b48a2b,
1213 ABcd16b64a2b = dnnl_ABcd16b64a2b,
1214 aBCd16b16c2b = dnnl_aBCd16b16c2b,
1215 aBCd16c16b4c = dnnl_aBCd16c16b4c,
1216 aBCd16c16b2c = dnnl_aBCd16c16b2c,
1217 aBCd4c4b = dnnl_aBCd4c4b,
1218 aBCd4b4c = dnnl_aBCd4b4c,
1219 ABcd8a16b2a = dnnl_ABcd8a16b2a,
1220 ABcd8a8b = dnnl_ABcd8a8b,
1221 ABcd8a4b = dnnl_ABcd8a4b,
1222 ABcd8a2b = dnnl_ABcd8a2b,
1223 /// 4D tensor blocked by 2nd dimension with block size 8
1224 aBcd8b = dnnl_aBcd8b,
1225 ABcd8b16a2b = dnnl_ABcd8b16a2b,
1226 ABcd8b32a2b = dnnl_ABcd8b32a2b,
1227 ABcd8b64a2b = dnnl_ABcd8b64a2b,
1228 aBCd8b16c2b = dnnl_aBCd8b16c2b,
1229 /// 4D tensor blocked by 1st and 2nd dimension with block size 8
1230 ABcd8b8a = dnnl_ABcd8b8a,
1231 aBCd8b8c = dnnl_aBCd8b8c,
1232 aBCd8b4c = dnnl_aBCd8b4c,
1233 aBCd8c16b2c = dnnl_aBCd8c16b2c,
1234 aBCd8c8b = dnnl_aBCd8c8b,
1235 Abcde16a = dnnl_Abcde16a,
1236 Abcde32a = dnnl_Abcde32a,
1237 ABcde16a16b = dnnl_ABcde16a16b,
1238 aBcde16b = dnnl_aBcde16b,
1239 aBcde32b = dnnl_aBcde32b,
1240 ABcde16b16a = dnnl_ABcde16b16a,
1241 ABcde16b32a = dnnl_ABcde16b32a,
1242 ABcde16b64a = dnnl_ABcde16b64a,
1243 aBCde16b16c = dnnl_aBCde16b16c,
1244 aBCde16c16b = dnnl_aBCde16c16b,
1245 aBCde2c8b4c = dnnl_aBCde2c8b4c,
1246 Abcde4a = dnnl_Abcde4a,
1247 aBcde4b = dnnl_aBcde4b,
1248 ABcde4b4a = dnnl_ABcde4b4a,
1249 ABcde4a4b = dnnl_ABcde4a4b,
1250 aBCde4b4c = dnnl_aBCde4b4c,
1251 aBCde4c16b4c = dnnl_aBCde4c16b4c,
1252 aBCde16b16c2b = dnnl_aBCde16b16c2b,
1253 aBCde16c16b4c = dnnl_aBCde16c16b4c,
1254 aBCde16c16b2c = dnnl_aBCde16c16b2c,
1255 aBCdef16c16b2c = dnnl_aBCdef16c16b2c,
1256 aBCde4c4b = dnnl_aBCde4c4b,
1257 Abcde8a = dnnl_Abcde8a,
1258 ABcde8a8b = dnnl_ABcde8a8b,
1259 ABcde8a4b = dnnl_ABcde8a4b,
1260 aBcde8b = dnnl_aBcde8b,
1261 ABcde8b16a2b = dnnl_ABcde8b16a2b,
1262 ABcde8b32a2b = dnnl_ABcde8b32a2b,
1263 ABcde8b64a2b = dnnl_ABcde8b64a2b,
1264 ABcde4b16a4b = dnnl_ABcde4b16a4b,
1265 ABcde4b32a4b = dnnl_ABcde4b32a4b,
1266 ABcde4b64a4b = dnnl_ABcde4b64a4b,
1267 ABcde16b16a4b = dnnl_ABcde16b16a4b,
1268 ABcde16b32a4b = dnnl_ABcde16b32a4b,
1269 ABcde16b48a4b = dnnl_ABcde16b48a4b,
1270 ABcde16b64a4b = dnnl_ABcde16b64a4b,
1271 ABcde16b16a2b = dnnl_ABcde16b16a2b,
1272 ABcde16b32a2b = dnnl_ABcde16b32a2b,
1273 ABcde16b48a2b = dnnl_ABcde16b48a2b,
1274 ABcde16b64a2b = dnnl_ABcde16b64a2b,
1275 ABcde2b8a4b = dnnl_ABcde2b8a4b,
1276 aBCde8b16c2b = dnnl_aBCde8b16c2b,
1277 ABcde8b8a = dnnl_ABcde8b8a,
1278 aBCde8b8c = dnnl_aBCde8b8c,
1279 aBCde8b4c = dnnl_aBCde8b4c,
1280 ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
1281 ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
1282 aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
1283 aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
1284 aBCde8c16b2c = dnnl_aBCde8c16b2c,
1285 aBCde8c8b = dnnl_aBCde8c8b,
1286 aBcdef16b = dnnl_aBcdef16b,
1287 aBCdef16b16c = dnnl_aBCdef16b16c,
1288 aBCdef16c16b = dnnl_aBCdef16c16b,
1289 aBcdef4b = dnnl_aBcdef4b,
1290 aBCdef2c8b4c = dnnl_aBCdef2c8b4c,
1291 aBCdef4c4b = dnnl_aBCdef4c4b,
1292 aBCdef4b4c = dnnl_aBCdef4b4c,
1293 aBCdef8b8c = dnnl_aBCdef8b8c,
1294 aBCdef8b4c = dnnl_aBCdef8b4c,
1295 aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
1296 aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
1297 aBCdef8c8b = dnnl_aBCdef8c8b,
1298 aBdc16b = dnnl_aBdc16b,
1299 aBdc4b = dnnl_aBdc4b,
1300 aBdc8b = dnnl_aBdc8b,
1301 aBdec16b = dnnl_aBdec16b,
1302 aBdec4b = dnnl_aBdec4b,
1303 aBdec8b = dnnl_aBdec8b,
1304 aBdefc16b = dnnl_aBdefc16b,
1305 aCBdef16c16b = dnnl_aCBdef16c16b,
1306 aCBdef16b16c = dnnl_aCBdef16b16c,
1307 aBdefc4b = dnnl_aBdefc4b,
1308 aBdefc8b = dnnl_aBdefc8b,
1309 Acb16a = dnnl_Acb16a,
1310 Acb4a = dnnl_Acb4a,
1311 Acb8a = dnnl_Acb8a,
1312 aCBd16b16c = dnnl_aCBd16b16c,
1313 aCBd16c16b = dnnl_aCBd16c16b,
1314 aCBde16b16c = dnnl_aCBde16b16c,
1315 aCBde16c16b = dnnl_aCBde16c16b,
1316 Acdb16a = dnnl_Acdb16a,
1317 Acdb4a = dnnl_Acdb4a,
1318 Acdb8a = dnnl_Acdb8a,
1319 Acdeb16a = dnnl_Acdeb16a,
1320 Acdeb4a = dnnl_Acdeb4a,
1321 Acdeb8a = dnnl_Acdeb8a,
1322 BAc16a16b = dnnl_BAc16a16b,
1323 BAc16b16a = dnnl_BAc16b16a,
1324 BAcd16a16b = dnnl_BAcd16a16b,
1325 BAcd16b16a = dnnl_BAcd16b16a,
1326 ABcd32a32b = dnnl_ABcd32a32b,
1327 BAcde16b16a = dnnl_BAcde16b16a,
1328 BAcde16a16b = dnnl_BAcde16a16b,
1329 aBdec32b = dnnl_aBdec32b,
1330 Abcdef16a = dnnl_Abcdef16a,
1331 Abcdef32a = dnnl_Abcdef32a,
1332 Acdb32a = dnnl_Acdb32a,
1333 aBCd2b4c2b = dnnl_aBCd2b4c2b,
1334 aBCde2b4c2b = dnnl_aBCde2b4c2b,
1335 aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
1336 aBCd2c4b2c = dnnl_aBCd2c4b2c,
1337 aBCde2c4b2c = dnnl_aBCde2c4b2c,
1338 aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
1339 aBCd4b8c2b = dnnl_aBCd4b8c2b,
1340 aBCde4b8c2b = dnnl_aBCde4b8c2b,
1341 aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
1342 aBCd4c8b2c = dnnl_aBCd4c8b2c,
1343 aBCde4c8b2c = dnnl_aBCde4c8b2c,
1344 aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
1345 AB32a32b8a4b = dnnl_AB32a32b8a4b,
1346 AB32a32b8a2b = dnnl_AB32a32b8a2b,
1347 AB8a4b = dnnl_AB8a4b,
1348 AB8a2b = dnnl_AB8a2b,
1349 abDc32d = dnnl_abDc32d,
1350 abDC32d4c = dnnl_abDC32d4c,
1351 abCd32c = dnnl_abCd32c,
1352 abdEc32e = dnnl_abdEc32e,
1353 abdEC32e2c = dnnl_abdEC32e2c,
1354 abdEC32e4c = dnnl_abdEC32e4c,
1355 abdCe32c = dnnl_abdCe32c,
1356 abdCE32c2e = dnnl_abdCE32c2e,
1357 aBCdef16c16b4c = dnnl_aBCdef16c16b4c,
1358 aBdC16b4c = dnnl_aBdC16b4c,
1359 aBdeC16b4c = dnnl_aBdeC16b4c,
1360 AcB16a4b = dnnl_AcB16a4b,
1361 AcdB16a2b = dnnl_AcdB16a2b,
1362 aBdefC16b4c = dnnl_aBdefC16b4c,
1363 AcdeB16a4b = dnnl_AcdeB16a4b,
1364
1365 Acb32a = dnnl_Acb32a,
1366 AcB32a2b = dnnl_AcB32a2b,
1367 AcB32a4b = dnnl_AcB32a4b,
1368 Acb48a = dnnl_Acb48a,
1369 AcB48a2b = dnnl_AcB48a2b,
1370 AcB48a4b = dnnl_AcB48a4b,
1371 Acb64a = dnnl_Acb64a,
1372 AcB64a2b = dnnl_AcB64a2b,
1373 AcB64a4b = dnnl_AcB64a4b,
1374 cBa2b = dnnl_cBa2b,
1375 cBa4b = dnnl_cBa4b,
1376 aBdc32b = dnnl_aBdc32b,
1377 aBdC32b2c = dnnl_aBdC32b2c,
1378 aBdC32b4c = dnnl_aBdC32b4c,
1379 aBdc48b = dnnl_aBdc48b,
1380 aBdC48b2c = dnnl_aBdC48b2c,
1381 aBdC48b4c = dnnl_aBdC48b4c,
1382 aBdc64b = dnnl_aBdc64b,
1383 aBdC64b2c = dnnl_aBdC64b2c,
1384 aBdC64b4c = dnnl_aBdC64b4c,
1385 adcb = dnnl_adcb,
1386 adCb2c = dnnl_adCb2c,
1387 adCb4c = dnnl_adCb4c,
1388 AcdB32a2b = dnnl_AcdB32a2b,
1389 AcdB32a4b = dnnl_AcdB32a4b,
1390 Acdb48a = dnnl_Acdb48a,
1391 AcdB48a2b = dnnl_AcdB48a2b,
1392 AcdB48a4b = dnnl_AcdB48a4b,
1393 Acdb64a = dnnl_Acdb64a,
1394 AcdB64a2b = dnnl_AcdB64a2b,
1395 AcdB64a4b = dnnl_AcdB64a4b,
1396 cdBa2b = dnnl_cdBa2b,
1397 cdBa4b = dnnl_cdBa4b,
1398 aBdeC32b2c = dnnl_aBdeC32b2c,
1399 aBdeC32b4c = dnnl_aBdeC32b4c,
1400 aBdec48b = dnnl_aBdec48b,
1401 aBdeC48b2c = dnnl_aBdeC48b2c,
1402 aBdeC48b4c = dnnl_aBdeC48b4c,
1403 aBdec64b = dnnl_aBdec64b,
1404 aBdeC64b2c = dnnl_aBdeC64b2c,
1405 aBdeC64b4c = dnnl_aBdeC64b4c,
1406 adecb = dnnl_adecb,
1407 adeCb2c = dnnl_adeCb2c,
1408 adeCb4c = dnnl_adeCb4c,
1409 Acdeb32a = dnnl_Acdeb32a,
1410 AcdeB32a2b = dnnl_AcdeB32a2b,
1411 AcdeB32a4b = dnnl_AcdeB32a4b,
1412 Acdeb48a = dnnl_Acdeb48a,
1413 AcdeB48a2b = dnnl_AcdeB48a2b,
1414 AcdeB48a4b = dnnl_AcdeB48a4b,
1415 Acdeb64a = dnnl_Acdeb64a,
1416 AcdeB64a2b = dnnl_AcdeB64a2b,
1417 AcdeB64a4b = dnnl_AcdeB64a4b,
1418 cdeBa2b = dnnl_cdeBa2b,
1419 cdeBa4b = dnnl_cdeBa4b,
1420 aBdefc32b = dnnl_aBdefc32b,
1421 aBdefC32b2c = dnnl_aBdefC32b2c,
1422 aBdefC32b4c = dnnl_aBdefC32b4c,
1423 aBdefc48b = dnnl_aBdefc48b,
1424 aBdefC48b2c = dnnl_aBdefC48b2c,
1425 aBdefC48b4c = dnnl_aBdefC48b4c,
1426 aBdefc64b = dnnl_aBdefc64b,
1427 aBdefC64b2c = dnnl_aBdefC64b2c,
1428 aBdefC64b4c = dnnl_aBdefC64b4c,
1429 adefcb = dnnl_adefcb,
1430 adefCb2c = dnnl_adefCb2c,
1431 adefCb4c = dnnl_adefCb4c,
1432 ABc32a32b = dnnl_ABc32a32b,
1433 BAc8a16b2a = dnnl_BAc8a16b2a,
1434 BAcd8a16b2a = dnnl_BAcd8a16b2a,
1435 ABcde8a16b2a = dnnl_ABcde8a16b2a,
1436 aCBd8b16c2b = dnnl_aCBd8b16c2b,
1437 BAcde8a16b2a = dnnl_BAcde8a16b2a,
1438 aCBde8b16c2b = dnnl_aCBde8b16c2b,
1439 ABcde32a32b = dnnl_ABcde32a32b,
1440 ABc4a8b8a4b = dnnl_ABc4a8b8a4b,
1441 ABcde4a8b8a4b = dnnl_ABcde4a8b8a4b,
1442 BAc4b8a8b4a = dnnl_BAc4b8a8b4a,
1443 BAcd4b8a8b4a = dnnl_BAcd4b8a8b4a,
1444 BAcde4b8a8b4a = dnnl_BAcde4b8a8b4a,
1445 aBCd4b8c8b4c = dnnl_aBCd4b8c8b4c,
1446 aBCdef4b8c8b4c = dnnl_aBCdef4b8c8b4c,
1447 aBCdef8b16c2b = dnnl_aBCdef8b16c2b,
1448 aCBdef8b16c2b = dnnl_aCBdef8b16c2b,
1449 aBdC16b2c = dnnl_aBdC16b2c,
1450 aBdeC16b2c = dnnl_aBdeC16b2c,
1451 aBdefC16b2c = dnnl_aBdefC16b2c,
1452 aBedc16b = dnnl_aBedc16b,
1453 AcB16a2b = dnnl_AcB16a2b,
1454 AcdB16a4b = dnnl_AcdB16a4b,
1455 AcdeB16a2b = dnnl_AcdeB16a2b,
1456 Adcb16a = dnnl_Adcb16a,
1457 aCBd4c8b8c4b = dnnl_aCBd4c8b8c4b,
1458 aCBde4c8b8c4b = dnnl_aCBde4c8b8c4b,
1459 aCBdef4c8b8c4b = dnnl_aCBdef4c8b8c4b,
1460 ABc32a16b = dnnl_ABc32a16b,
1461 ABcd16a32b = dnnl_ABcd16a32b,
1462 ABcd32a16b = dnnl_ABcd32a16b,
1463 ABcde32a16b = dnnl_ABcde32a16b,
1464 AB48a16b = dnnl_AB48a16b,
1465 AB48a32b = dnnl_AB48a32b,
1466 ABc40a16b = dnnl_ABc40a16b,
1467 ABc40a32b = dnnl_ABc40a32b,
1468 aBC48b16c = dnnl_aBC48b16c,
1469 aBC48b32c = dnnl_aBC48b32c,
1470 ABcd40a16b = dnnl_ABcd40a16b,
1471 ABcd40a32b = dnnl_ABcd40a32b,
1472 BA16a16b = dnnl_BA16a16b,
1473 BA16a32b = dnnl_BA16a32b,
1474 BA16a48b = dnnl_BA16a48b,
1475 BA16a64b = dnnl_BA16a64b,
1476 BA16a16b2a = dnnl_BA16a16b2a,
1477 BA16a32b2a = dnnl_BA16a32b2a,
1478 BA16a48b2a = dnnl_BA16a48b2a,
1479 BA16a64b2a = dnnl_BA16a64b2a,
1480 BA16a16b4a = dnnl_BA16a16b4a,
1481 BA16a32b4a = dnnl_BA16a32b4a,
1482 BA16a48b4a = dnnl_BA16a48b4a,
1483 BA16a64b4a = dnnl_BA16a64b4a,
1484 decbA16a = dnnl_decbA16a,
1485 decbA8a = dnnl_decbA8a,
1486 aCB16b16c = dnnl_aCB16b16c,
1487 aCB16b32c = dnnl_aCB16b32c,
1488 aCB16b48c = dnnl_aCB16b48c,
1489 aCB16b64c = dnnl_aCB16b64c,
1490 aCB16b16c2b = dnnl_aCB16b16c2b,
1491 aCB16b32c2b = dnnl_aCB16b32c2b,
1492 aCB16b48c2b = dnnl_aCB16b48c2b,
1493 aCB16b64c2b = dnnl_aCB16b64c2b,
1494 aCB16b16c4b = dnnl_aCB16b16c4b,
1495 aCB16b32c4b = dnnl_aCB16b32c4b,
1496 aCB16b48c4b = dnnl_aCB16b48c4b,
1497 aCB16b64c4b = dnnl_aCB16b64c4b,
1498
1499 format_tag_last = dnnl_format_tag_last,
1500
1501 nCdhw16c = dnnl_nCdhw16c,
1502 nCdhw4c = dnnl_nCdhw4c,
1503 nCdhw8c = dnnl_nCdhw8c,
1504 nChw16c = dnnl_nChw16c,
1505 nChw4c = dnnl_nChw4c,
1506 nChw8c = dnnl_nChw8c,
1507 nCw16c = dnnl_nCw16c,
1508 nCw4c = dnnl_nCw4c,
1509 nCw8c = dnnl_nCw8c,
1510 NCw16n16c = dnnl_NCw16n16c,
1511 NChw16n16c = dnnl_NChw16n16c,
1512 NCdhw16n16c = dnnl_NCdhw16n16c,
1513 NCdhw32n32c = dnnl_NCdhw32n32c,
1514 NChw32n32c = dnnl_NChw32n32c,
1515 IOhw16i16o = dnnl_IOhw16i16o,
1516 OI16i16o = dnnl_OI16i16o,
1517 OI16i32o = dnnl_OI16i32o,
1518 OI16i64o = dnnl_OI16i64o,
1519 OI8i16o2i = dnnl_OI8i16o2i,
1520 OI8i32o2i = dnnl_OI8i32o2i,
1521 OI8i64o2i = dnnl_OI8i64o2i,
1522 OI4i16o4i = dnnl_OI4i16o4i,
1523 OI4i32o4i = dnnl_OI4i32o4i,
1524 OI4i64o4i = dnnl_OI4i64o4i,
1525 Ohwi32o = dnnl_Ohwi32o,
1526 IOdhw16i16o = dnnl_IOdhw16i16o,
1527 gIOhw16i16o = dnnl_gIOhw16i16o,
1528 gOhwi32o = dnnl_gOhwi32o,
1529 Goidhw16g = dnnl_Goidhw16g,
1530 IOw16o16i = dnnl_IOw16o16i,
1531 OIw16i16o = dnnl_OIw16i16o,
1532 OIw16i32o = dnnl_OIw16i32o,
1533 OIw16i64o = dnnl_OIw16i64o,
1534 IOw16i16o = dnnl_IOw16i16o,
1535 gIOw16i16o = dnnl_gIOw16i16o,
1536 OIw16o16i = dnnl_OIw16o16i,
1537 Oiw16o = dnnl_Oiw16o,
1538 OIw4i16o4i = dnnl_OIw4i16o4i,
1539 OIw4i32o4i = dnnl_OIw4i32o4i,
1540 OIw4i64o4i = dnnl_OIw4i64o4i,
1541 OIw2i8o4i = dnnl_OIw2i8o4i,
1542 OIw4i4o = dnnl_OIw4i4o,
1543 OIw4o4i = dnnl_OIw4o4i,
1544 Oiw4o = dnnl_Oiw4o,
1545 OIw8i16o2i = dnnl_OIw8i16o2i,
1546 OIw8i32o2i = dnnl_OIw8i32o2i,
1547 OIw8i64o2i = dnnl_OIw8i64o2i,
1548 OIw8i8o = dnnl_OIw8i8o,
1549 OIw8o16i2o = dnnl_OIw8o16i2o,
1550 OIw8o8i = dnnl_OIw8o8i,
1551 OIw8o4i = dnnl_OIw8o4i,
1552 OIw16i16o4i = dnnl_OIw16i16o4i,
1553 OIw16i32o4i = dnnl_OIw16i32o4i,
1554 OIw16i48o4i = dnnl_OIw16i48o4i,
1555 OIw16i64o4i = dnnl_OIw16i64o4i,
1556 OIw16i16o2i = dnnl_OIw16i16o2i,
1557 OIw16i32o2i = dnnl_OIw16i32o2i,
1558 OIw16i48o2i = dnnl_OIw16i48o2i,
1559 OIw16i64o2i = dnnl_OIw16i64o2i,
1560 OIw16o16i2o = dnnl_OIw16o16i2o,
1561 Owi16o = dnnl_Owi16o,
1562 OwI16o2i = dnnl_OwI16o2i,
1563 Iwo16i = dnnl_Iwo16i,
1564 IwO16i2o = dnnl_IwO16i2o,
1565 IwO16i4o = dnnl_IwO16i4o,
1566 Owi4o = dnnl_Owi4o,
1567 Owi8o = dnnl_Owi8o,
1568 IOhw16o16i = dnnl_IOhw16o16i,
1569 Ohwi16o = dnnl_Ohwi16o,
1570 OhwI16o2i = dnnl_OhwI16o2i,
1571 Ihwo16i = dnnl_Ihwo16i,
1572 IhwO16i2o = dnnl_IhwO16i2o,
1573 IhwO16i4o = dnnl_IhwO16i4o,
1574 Ohwi4o = dnnl_Ohwi4o,
1575 Ohwi8o = dnnl_Ohwi8o,
1576 OIhw16i16o = dnnl_OIhw16i16o,
1577 OIhw16i32o = dnnl_OIhw16i32o,
1578 OIhw16i64o = dnnl_OIhw16i64o,
1579 OIhw16o16i = dnnl_OIhw16o16i,
1580 Oihw16o = dnnl_Oihw16o,
1581 OIhw4i16o4i = dnnl_OIhw4i16o4i,
1582 OIhw4i32o4i = dnnl_OIhw4i32o4i,
1583 OIhw4i64o4i = dnnl_OIhw4i64o4i,
1584 OIhw4i4o = dnnl_OIhw4i4o,
1585 OIhw4o4i = dnnl_OIhw4o4i,
1586 Oihw4o = dnnl_Oihw4o,
1587 OIhw8i16o2i = dnnl_OIhw8i16o2i,
1588 OIhw8i32o2i = dnnl_OIhw8i32o2i,
1589 OIhw8i64o2i = dnnl_OIhw8i64o2i,
1590 OIhw8i8o = dnnl_OIhw8i8o,
1591 OIhw8o16i2o = dnnl_OIhw8o16i2o,
1592 OIhw8o8i = dnnl_OIhw8o8i,
1593 OIhw8o4i = dnnl_OIhw8o4i,
1594 OIhw2i8o4i = dnnl_OIhw2i8o4i,
1595 IOdhw16o16i = dnnl_IOdhw16o16i,
1596 Odhwi16o = dnnl_Odhwi16o,
1597 OdhwI16o2i = dnnl_OdhwI16o2i,
1598 Idhwo16i = dnnl_Idhwo16i,
1599 IdhwO16i2o = dnnl_IdhwO16i2o,
1600 IdhwO16i4o = dnnl_IdhwO16i4o,
1601 Odhwi4o = dnnl_Odhwi4o,
1602 Odhwi8o = dnnl_Odhwi8o,
1603 OIdhw16i16o = dnnl_OIdhw16i16o,
1604 OIdhw16i32o = dnnl_OIdhw16i32o,
1605 OIdhw16i64o = dnnl_OIdhw16i64o,
1606 OIdhw16o16i = dnnl_OIdhw16o16i,
1607 OIdhw16o16i2o = dnnl_OIdhw16o16i2o,
1608 Oidhw16o = dnnl_Oidhw16o,
1609 OIdhw4i4o = dnnl_OIdhw4i4o,
1610 OIdhw4o4i = dnnl_OIdhw4o4i,
1611 Oidhw4o = dnnl_Oidhw4o,
1612 OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
1613 OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
1614 OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
1615 OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
1616 OIdhw16i16o4i = dnnl_OIdhw16i16o4i,
1617 OIdhw16i32o4i = dnnl_OIdhw16i32o4i,
1618 OIdhw16i48o4i = dnnl_OIdhw16i48o4i,
1619 OIdhw16i64o4i = dnnl_OIdhw16i64o4i,
1620 OIdhw16i16o2i = dnnl_OIdhw16i16o2i,
1621 OIdhw16i32o2i = dnnl_OIdhw16i32o2i,
1622 OIdhw16i48o2i = dnnl_OIdhw16i48o2i,
1623 OIdhw16i64o2i = dnnl_OIdhw16i64o2i,
1624 OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
1625 OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
1626 OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
1627 OIdhw8i8o = dnnl_OIdhw8i8o,
1628 OIdhw8o8i = dnnl_OIdhw8o8i,
1629 OIdhw8o4i = dnnl_OIdhw8o4i,
1630 gIOw16o16i = dnnl_gIOw16o16i,
1631 gOIw16i16o = dnnl_gOIw16i16o,
1632 gOIw16o16i = dnnl_gOIw16o16i,
1633 gOiw16o = dnnl_gOiw16o,
1634 gOIw4i16o4i = dnnl_gOIw4i16o4i,
1635 gOIw2i8o4i = dnnl_gOIw2i8o4i,
1636 gOIw4i4o = dnnl_gOIw4i4o,
1637 gOIw4o4i = dnnl_gOIw4o4i,
1638 gOiw4o = dnnl_gOiw4o,
1639 gOIw8i16o2i = dnnl_gOIw8i16o2i,
1640 gOIw8i8o = dnnl_gOIw8i8o,
1641 gOIw8o16i2o = dnnl_gOIw8o16i2o,
1642 gOIw8o8i = dnnl_gOIw8o8i,
1643 gOIw8o4i = dnnl_gOIw8o4i,
1644 gOIw16i16o4i = dnnl_gOIw16i16o4i,
1645 gOIw16i16o2i = dnnl_gOIw16i16o2i,
1646 gOIw16o16i2o = dnnl_gOIw16o16i2o,
1647 gOwi16o = dnnl_gOwi16o,
1648 gOwI16o2i = dnnl_gOwI16o2i,
1649 gIwo16i = dnnl_gIwo16i,
1650 gIwO16i2o = dnnl_gIwO16i2o,
1651 gIwO16i4o = dnnl_gIwO16i4o,
1652 gOwi4o = dnnl_gOwi4o,
1653 gOwi8o = dnnl_gOwi8o,
1654 Goiw8g = dnnl_Goiw8g,
1655 Goiw16g = dnnl_Goiw16g,
1656 gIOhw16o16i = dnnl_gIOhw16o16i,
1657 gOhwi16o = dnnl_gOhwi16o,
1658 gOhwI16o2i = dnnl_gOhwI16o2i,
1659 gIhwo16i = dnnl_gIhwo16i,
1660 gIhwO16i2o = dnnl_gIhwO16i2o,
1661 gIhwO16i4o = dnnl_gIhwO16i4o,
1662 gOhwi4o = dnnl_gOhwi4o,
1663 gOhwi8o = dnnl_gOhwi8o,
1664 Goihw16g = dnnl_Goihw16g,
1665 gOIhw16i16o = dnnl_gOIhw16i16o,
1666 gOIhw16o16i = dnnl_gOIhw16o16i,
1667 gOihw16o = dnnl_gOihw16o,
1668 gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
1669 gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
1670 gOIhw4i4o = dnnl_gOIhw4i4o,
1671 gOIhw4o4i = dnnl_gOIhw4o4i,
1672 gOihw4o = dnnl_gOihw4o,
1673 Goihw8g = dnnl_Goihw8g,
1674 gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
1675 gOIhw8i8o = dnnl_gOIhw8i8o,
1676 gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
1677 OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
1678 OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
1679 OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
1680 OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
1681 gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
1682 gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
1683 gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
1684 gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
1685 OIhw16i16o4i = dnnl_OIhw16i16o4i,
1686 OIhw16i32o4i = dnnl_OIhw16i32o4i,
1687 OIhw16i48o4i = dnnl_OIhw16i48o4i,
1688 OIhw16i64o4i = dnnl_OIhw16i64o4i,
1689 OIhw16i16o2i = dnnl_OIhw16i16o2i,
1690 OIhw16i32o2i = dnnl_OIhw16i32o2i,
1691 OIhw16i48o2i = dnnl_OIhw16i48o2i,
1692 OIhw16i64o2i = dnnl_OIhw16i64o2i,
1693 OIhw16o16i2o = dnnl_OIhw16o16i2o,
1694 gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
1695 gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
1696 gOIhw16o16i2o = dnnl_gOIhw16o16i2o,
1697 gOIhw8o8i = dnnl_gOIhw8o8i,
1698 gOIhw8o4i = dnnl_gOIhw8o4i,
1699 gIOdhw16i16o = dnnl_gIOdhw16i16o,
1700 gIOdhw16o16i = dnnl_gIOdhw16o16i,
1701 gOdhwi16o = dnnl_gOdhwi16o,
1702 gOdhwI16o2i = dnnl_gOdhwI16o2i,
1703 gIdhwo16i = dnnl_gIdhwo16i,
1704 gIdhwO16i2o = dnnl_gIdhwO16i2o,
1705 gIdhwO16i4o = dnnl_gIdhwO16i4o,
1706 gOdhwi4o = dnnl_gOdhwi4o,
1707 gOdhwi8o = dnnl_gOdhwi8o,
1708 gOIdhw16i16o = dnnl_gOIdhw16i16o,
1709 gOIdhw16o16i = dnnl_gOIdhw16o16i,
1710 gOIdhw16o16i2o = dnnl_gOIdhw16o16i2o,
1711 gOidhw16o = dnnl_gOidhw16o,
1712 gOIdhw4i4o = dnnl_gOIdhw4i4o,
1713 gOIdhw4o4i = dnnl_gOIdhw4o4i,
1714 gOidhw4o = dnnl_gOidhw4o,
1715 gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
1716 gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
1717 gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i,
1718 gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i,
1719 gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
1720 gOIdhw8i8o = dnnl_gOIdhw8i8o,
1721 gOIdhw8o8i = dnnl_gOIdhw8o8i,
1722 gOIdhw8o4i = dnnl_gOIdhw8o4i,
1723 gOIw2i4o2i = dnnl_gOIw2i4o2i,
1724 gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
1725 gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
1726 gOIw2o4i2o = dnnl_gOIw2o4i2o,
1727 gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
1728 gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
1729 gOIw4i8o2i = dnnl_gOIw4i8o2i,
1730 gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
1731 gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
1732 gOIw4o8i2o = dnnl_gOIw4o8i2o,
1733 gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
1734 gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
1735 ldOi32o = abDc32d,
1736 ldOI32o4i = abDC32d4c,
1737 ldgOi32o = abdEc32e,
1738 ldgOI32o2i = abdEC32e2c,
1739 ldgOI32o4i = abdEC32e4c,
1740 OwI16o4i = dnnl_OwI16o4i,
1741 OhwI16o4i = dnnl_OhwI16o4i,
1742 gOwI16o4i = dnnl_gOwI16o4i,
1743 gOhwI16o4i = dnnl_gOhwI16o4i,
1744 OdhwI16o4i = dnnl_OdhwI16o4i,
1745 gOdhwI16o4i = dnnl_gOdhwI16o4i,
1746
1747 Owi32o = dnnl_Owi32o,
1748 OwI32o2i = dnnl_OwI32o2i,
1749 OwI32o4i = dnnl_OwI32o4i,
1750 Owi48o = dnnl_Owi48o,
1751 OwI48o2i = dnnl_OwI48o2i,
1752 OwI48o4i = dnnl_OwI48o4i,
1753 Owi64o = dnnl_Owi64o,
1754 OwI64o2i = dnnl_OwI64o2i,
1755 OwI64o4i = dnnl_OwI64o4i,
1756 Iwo32i = dnnl_Iwo32i,
1757 IwO32i2o = dnnl_IwO32i2o,
1758 IwO32i4o = dnnl_IwO32i4o,
1759 Iwo48i = dnnl_Iwo48i,
1760 IwO48i2o = dnnl_IwO48i2o,
1761 IwO48i4o = dnnl_IwO48i4o,
1762 Iwo64i = dnnl_Iwo64i,
1763 IwO64i2o = dnnl_IwO64i2o,
1764 IwO64i4o = dnnl_IwO64i4o,
1765 wIo2i = dnnl_wIo2i,
1766 wIo4i = dnnl_wIo4i,
1767 gOwi32o = dnnl_gOwi32o,
1768 gOwI32o2i = dnnl_gOwI32o2i,
1769 gOwI32o4i = dnnl_gOwI32o4i,
1770 gOwi48o = dnnl_gOwi48o,
1771 gOwI48o2i = dnnl_gOwI48o2i,
1772 gOwI48o4i = dnnl_gOwI48o4i,
1773 gOwi64o = dnnl_gOwi64o,
1774 gOwI64o2i = dnnl_gOwI64o2i,
1775 gOwI64o4i = dnnl_gOwI64o4i,
1776 gIwo32i = dnnl_gIwo32i,
1777 gIwO32i2o = dnnl_gIwO32i2o,
1778 gIwO32i4o = dnnl_gIwO32i4o,
1779 gIwo48i = dnnl_gIwo48i,
1780 gIwO48i2o = dnnl_gIwO48i2o,
1781 gIwO48i4o = dnnl_gIwO48i4o,
1782 gIwo64i = dnnl_gIwo64i,
1783 gIwO64i2o = dnnl_gIwO64i2o,
1784 gIwO64i4o = dnnl_gIwO64i4o,
1785 gwio = dnnl_gwio,
1786 gwIo2i = dnnl_gwIo2i,
1787 gwIo4i = dnnl_gwIo4i,
1788 OhwI32o = dnnl_OhwI32o,
1789 OhwI32o2i = dnnl_OhwI32o2i,
1790 OhwI32o4i = dnnl_OhwI32o4i,
1791 Ohwi48o = dnnl_Ohwi48o,
1792 OhwI48o2i = dnnl_OhwI48o2i,
1793 OhwI48o4i = dnnl_OhwI48o4i,
1794 Ohwi64o = dnnl_Ohwi64o,
1795 OhwI64o2i = dnnl_OhwI64o2i,
1796 OhwI64o4i = dnnl_OhwI64o4i,
1797 Ihwo32i = dnnl_Ihwo32i,
1798 IhwO32i2o = dnnl_IhwO32i2o,
1799 IhwO32i4o = dnnl_IhwO32i4o,
1800 Ihwo48i = dnnl_Ihwo48i,
1801 IhwO48i2o = dnnl_IhwO48i2o,
1802 IhwO48i4o = dnnl_IhwO48i4o,
1803 Ihwo64i = dnnl_Ihwo64i,
1804 IhwO64i2o = dnnl_IhwO64i2o,
1805 IhwO64i4o = dnnl_IhwO64i4o,
1806 hwIo2i = dnnl_hwIo2i,
1807 hwIo4i = dnnl_hwIo4i,
1808 gOhwI32o = dnnl_gOhwI32o,
1809 gOhwI32o2i = dnnl_gOhwI32o2i,
1810 gOhwI32o4i = dnnl_gOhwI32o4i,
1811 gOhwi48o = dnnl_gOhwi48o,
1812 gOhwI48o2i = dnnl_gOhwI48o2i,
1813 gOhwI48o4i = dnnl_gOhwI48o4i,
1814 gOhwi64o = dnnl_gOhwi64o,
1815 gOhwI64o2i = dnnl_gOhwI64o2i,
1816 gOhwI64o4i = dnnl_gOhwI64o4i,
1817 gIhwo32i = dnnl_gIhwo32i,
1818 gIhwO32i2o = dnnl_gIhwO32i2o,
1819 gIhwO32i4o = dnnl_gIhwO32i4o,
1820 gIhwo48i = dnnl_gIhwo48i,
1821 gIhwO48i2o = dnnl_gIhwO48i2o,
1822 gIhwO48i4o = dnnl_gIhwO48i4o,
1823 gIhwo64i = dnnl_gIhwo64i,
1824 gIhwO64i2o = dnnl_gIhwO64i2o,
1825 gIhwO64i4o = dnnl_gIhwO64i4o,
1826 ghwio = dnnl_ghwio,
1827 ghwIo2i = dnnl_ghwIo2i,
1828 ghwIo4i = dnnl_ghwIo4i,
1829 Odhwi32o = dnnl_Odhwi32o,
1830 OdhwI32o2i = dnnl_OdhwI32o2i,
1831 OdhwI32o4i = dnnl_OdhwI32o4i,
1832 Odhwi48o = dnnl_Odhwi48o,
1833 OdhwI48o2i = dnnl_OdhwI48o2i,
1834 OdhwI48o4i = dnnl_OdhwI48o4i,
1835 Odhwi64o = dnnl_Odhwi64o,
1836 OdhwI64o2i = dnnl_OdhwI64o2i,
1837 OdhwI64o4i = dnnl_OdhwI64o4i,
1838 Idhwo32i = dnnl_Idhwo32i,
1839 IdhwO32i2o = dnnl_IdhwO32i2o,
1840 IdhwO32i4o = dnnl_IdhwO32i4o,
1841 Idhwo48i = dnnl_Idhwo48i,
1842 IdhwO48i2o = dnnl_IdhwO48i2o,
1843 IdhwO48i4o = dnnl_IdhwO48i4o,
1844 Idhwo64i = dnnl_Idhwo64i,
1845 IdhwO64i2o = dnnl_IdhwO64i2o,
1846 IdhwO64i4o = dnnl_IdhwO64i4o,
1847 dhwIo2i = dnnl_dhwIo2i,
1848 dhwIo4i = dnnl_dhwIo4i,
1849 gOdhwi32o = dnnl_gOdhwi32o,
1850 gOdhwI32o2i = dnnl_gOdhwI32o2i,
1851 gOdhwI32o4i = dnnl_gOdhwI32o4i,
1852 gOdhwi48o = dnnl_gOdhwi48o,
1853 gOdhwI48o2i = dnnl_gOdhwI48o2i,
1854 gOdhwI48o4i = dnnl_gOdhwI48o4i,
1855 gOdhwi64o = dnnl_gOdhwi64o,
1856 gOdhwI64o2i = dnnl_gOdhwI64o2i,
1857 gOdhwI64o4i = dnnl_gOdhwI64o4i,
1858 gIdhwo32i = dnnl_gIdhwo32i,
1859 gIdhwO32i2o = dnnl_gIdhwO32i2o,
1860 gIdhwO32i4o = dnnl_gIdhwO32i4o,
1861 gIdhwo48i = dnnl_gIdhwo48i,
1862 gIdhwO48i2o = dnnl_gIdhwO48i2o,
1863 gIdhwO48i4o = dnnl_gIdhwO48i4o,
1864 gIdhwo64i = dnnl_gIdhwo64i,
1865 gIdhwO64i2o = dnnl_gIdhwO64i2o,
1866 gIdhwO64i4o = dnnl_gIdhwO64i4o,
1867 gdhwio = dnnl_gdhwio,
1868 gdhwIo2i = dnnl_gdhwIo2i,
1869 gdhwIo4i = dnnl_gdhwIo4i,
1870 ldIo32i = dnnl_ldIo32i,
1871 ldgIo32i = dnnl_ldgIo32i,
1872 ldgIO32i2o = dnnl_ldgIO32i2o,
1873 nCdhw32c = dnnl_nCdhw32c,
1874 nChw32c = dnnl_nChw32c,
1875 nCw32c = dnnl_nCw32c,
1876 NCw32n16c = dnnl_NCw32n16c,
1877 NChw32n16c = dnnl_NChw32n16c,
1878 NCdhw32n16c = dnnl_NCdhw32n16c,
1879 NCw32n32c = dnnl_NCw32n32c,
1880 OI16i16o4i = dnnl_OI16i16o4i,
1881 IOw8o16i2o = dnnl_IOw8o16i2o,
1882 IOhw8o16i2o = dnnl_IOhw8o16i2o,
1883 Owhi16o = dnnl_Owhi16o,
1884 OIdhw8o16i2o = dnnl_OIdhw8o16i2o,
1885 IOdhw8o16i2o = dnnl_IOdhw8o16i2o,
1886 Goiw4g = dnnl_Goiw4g,
1887 gIOw8o16i2o = dnnl_gIOw8o16i2o,
1888 Goiw32g = dnnl_Goiw32g,
1889 Goihw4g = dnnl_Goihw4g,
1890 gIOhw8o16i2o = dnnl_gIOhw8o16i2o,
1891 Goihw32g = dnnl_Goihw32g,
1892 gOwhi16o = dnnl_gOwhi16o,
1893 IOw4i8o8i4o = dnnl_IOw4i8o8i4o,
1894 IOhw4i8o8i4o = dnnl_IOhw4i8o8i4o,
1895 IOdhw4i8o8i4o = dnnl_IOdhw4i8o8i4o,
1896 gIOw4i8o8i4o = dnnl_gIOw4i8o8i4o,
1897 gIOhw4i8o8i4o = dnnl_gIOhw4i8o8i4o,
1898 gIOdhw4i8o8i4o = dnnl_gIOdhw4i8o8i4o,
1899 gOIdhw8o16i2o = dnnl_gOIdhw8o16i2o,
1900 gIOdhw8o16i2o = dnnl_gIOdhw8o16i2o,
1901 Goidhw32g = dnnl_Goidhw32g,
1902 OI16i32o4i = dnnl_OI16i32o4i,
1903 OI16i48o4i = dnnl_OI16i48o4i,
1904 OI16i64o4i = dnnl_OI16i64o4i,
1905 OI16i16o2i = dnnl_OI16i16o2i,
1906 OI16i32o2i = dnnl_OI16i32o2i,
1907 OI16i48o2i = dnnl_OI16i48o2i,
1908 OI16i64o2i = dnnl_OI16i64o2i,
1909 aBdeC16c16b4c = dnnl_aBdeC16c16b4c,
1910 AcB16b16a2b = dnnl_AcB16b16a2b,
1911 aBdC16c16b2c = dnnl_aBdC16c16b2c,
1912 AcB16b16a4b = dnnl_AcB16b16a4b,
1913 aBdC16c16b4c = dnnl_aBdC16c16b4c,
1914 AcdB16b16a2b = dnnl_AcdB16b16a2b,
1915 aBdefC16c16b4c = dnnl_aBdefC16c16b4c,
1916 AcdeB16b16a4b = dnnl_AcdeB16b16a4b,
1917 AcB16b32a2b = dnnl_AcB16b32a2b,
1918 AcB16b32a4b = dnnl_AcB16b32a4b,
1919 AcB16b48a2b = dnnl_AcB16b48a2b,
1920 AcB16b48a4b = dnnl_AcB16b48a4b,
1921 AcB16b64a2b = dnnl_AcB16b64a2b,
1922 AcB16b64a4b = dnnl_AcB16b64a4b,
1923 aBdC16c32b2c = dnnl_aBdC16c32b2c,
1924 aBdC16c32b4c = dnnl_aBdC16c32b4c,
1925 aBdC16c48b2c = dnnl_aBdC16c48b2c,
1926 aBdC16c48b4c = dnnl_aBdC16c48b4c,
1927 aBdC16c64b2c = dnnl_aBdC16c64b2c,
1928 aBdC16c64b4c = dnnl_aBdC16c64b4c,
1929 AcdB16b32a2b = dnnl_AcdB16b32a2b,
1930 AcdB16b32a4b = dnnl_AcdB16b32a4b,
1931 AcdB16b48a2b = dnnl_AcdB16b48a2b,
1932 AcdB16b48a4b = dnnl_AcdB16b48a4b,
1933 AcdB16b64a2b = dnnl_AcdB16b64a2b,
1934 AcdB16b64a4b = dnnl_AcdB16b64a4b,
1935 aBdeC16c32b2c = dnnl_aBdeC16c32b2c,
1936 aBdeC16c32b4c = dnnl_aBdeC16c32b4c,
1937 aBdeC16c48b2c = dnnl_aBdeC16c48b2c,
1938 aBdeC16c48b4c = dnnl_aBdeC16c48b4c,
1939 aBdeC16c64b2c = dnnl_aBdeC16c64b2c,
1940 aBdeC16c64b4c = dnnl_aBdeC16c64b4c,
1941 AcdeB16b32a2b = dnnl_AcdeB16b32a2b,
1942 AcdeB16b32a4b = dnnl_AcdeB16b32a4b,
1943 AcdeB16b48a2b = dnnl_AcdeB16b48a2b,
1944 AcdeB16b48a4b = dnnl_AcdeB16b48a4b,
1945 AcdeB16b64a2b = dnnl_AcdeB16b64a2b,
1946 AcdeB16b64a4b = dnnl_AcdeB16b64a4b,
1947 aBdefC16c32b2c = dnnl_aBdefC16c32b2c,
1948 aBdefC16c32b4c = dnnl_aBdefC16c32b4c,
1949 aBdefC16c48b2c = dnnl_aBdefC16c48b2c,
1950 aBdefC16c48b4c = dnnl_aBdefC16c48b4c,
1951 aBdefC16c64b2c = dnnl_aBdefC16c64b2c,
1952 aBdefC16c64b4c = dnnl_aBdefC16c64b4c,
1953 OwI16i16o2i = dnnl_OwI16i16o2i,
1954 gOwI16i16o2i = dnnl_gOwI16i16o2i,
1955 OhwI16i16o2i = dnnl_OhwI16i16o2i,
1956 gOhwI16i16o2i = dnnl_gOhwI16i16o2i,
1957 OdhwI16i16o2i = dnnl_OdhwI16i16o2i,
1958 gOdhwI16i16o2i = dnnl_gOdhwI16i16o2i,
1959 OwI16i16o4i = dnnl_OwI16i16o4i,
1960 gOwI16i16o4i = dnnl_gOwI16i16o4i,
1961 OhwI16i16o4i = dnnl_OhwI16i16o4i,
1962 gOhwI16i16o4i = dnnl_gOhwI16i16o4i,
1963 OdhwI16i16o4i = dnnl_OdhwI16i16o4i,
1964 gOdhwI16i16o4i = dnnl_gOdhwI16i16o4i,
1965 OwI16i32o2i = dnnl_OwI16i32o2i,
1966 OwI16i32o4i = dnnl_OwI16i32o4i,
1967 OwI16i48o2i = dnnl_OwI16i48o2i,
1968 OwI16i48o4i = dnnl_OwI16i48o4i,
1969 OwI16i64o2i = dnnl_OwI16i64o2i,
1970 OwI16i64o4i = dnnl_OwI16i64o4i,
1971 gOwI16i32o2i = dnnl_gOwI16i32o2i,
1972 gOwI16i32o4i = dnnl_gOwI16i32o4i,
1973 gOwI16i48o2i = dnnl_gOwI16i48o2i,
1974 gOwI16i48o4i = dnnl_gOwI16i48o4i,
1975 gOwI16i64o2i = dnnl_gOwI16i64o2i,
1976 gOwI16i64o4i = dnnl_gOwI16i64o4i,
1977 OhwI16i32o2i = dnnl_OhwI16i32o2i,
1978 OhwI16i32o4i = dnnl_OhwI16i32o4i,
1979 OhwI16i48o2i = dnnl_OhwI16i48o2i,
1980 OhwI16i48o4i = dnnl_OhwI16i48o4i,
1981 OhwI16i64o2i = dnnl_OhwI16i64o2i,
1982 OhwI16i64o4i = dnnl_OhwI16i64o4i,
1983 gOhwI16i32o2i = dnnl_gOhwI16i32o2i,
1984 gOhwI16i32o4i = dnnl_gOhwI16i32o4i,
1985 gOhwI16i48o2i = dnnl_gOhwI16i48o2i,
1986 gOhwI16i48o4i = dnnl_gOhwI16i48o4i,
1987 gOhwI16i64o2i = dnnl_gOhwI16i64o2i,
1988 gOhwI16i64o4i = dnnl_gOhwI16i64o4i,
1989 OdhwI16i32o2i = dnnl_OdhwI16i32o2i,
1990 OdhwI16i32o4i = dnnl_OdhwI16i32o4i,
1991 OdhwI16i48o2i = dnnl_OdhwI16i48o2i,
1992 OdhwI16i48o4i = dnnl_OdhwI16i48o4i,
1993 OdhwI16i64o2i = dnnl_OdhwI16i64o2i,
1994 OdhwI16i64o4i = dnnl_OdhwI16i64o4i,
1995 IdhwO16o32i2o = dnnl_IdhwO16o32i2o,
1996 IdhwO16o32i4o = dnnl_IdhwO16o32i4o,
1997 IdhwO16o48i2o = dnnl_IdhwO16o48i2o,
1998 IdhwO16o48i4o = dnnl_IdhwO16o48i4o,
1999 IdhwO16o64i2o = dnnl_IdhwO16o64i2o,
2000 IdhwO16o64i4o = dnnl_IdhwO16o64i4o,
2001 gOdhwI16i32o2i = dnnl_gOdhwI16i32o2i,
2002 gOdhwI16i32o4i = dnnl_gOdhwI16i32o4i,
2003 gOdhwI16i48o2i = dnnl_gOdhwI16i48o2i,
2004 gOdhwI16i48o4i = dnnl_gOdhwI16i48o4i,
2005 gOdhwI16i64o2i = dnnl_gOdhwI16i64o2i,
2006 gOdhwI16i64o4i = dnnl_gOdhwI16i64o4i,
2007 gIdhwO16o32i2o = dnnl_gIdhwO16o32i2o,
2008 gIdhwO16o32i4o = dnnl_gIdhwO16o32i4o,
2009 gIdhwO16o48i2o = dnnl_gIdhwO16o48i2o,
2010 gIdhwO16o48i4o = dnnl_gIdhwO16o48i4o,
2011 gIdhwO16o64i2o = dnnl_gIdhwO16o64i2o,
2012 gIdhwO16o64i4o = dnnl_gIdhwO16o64i4o,
2013 IwO16o16i2o = dnnl_IwO16o16i2o,
2014 IwO16o16i4o = dnnl_IwO16o16i4o,
2015 IhwO16o16i2o = dnnl_IhwO16o16i2o,
2016 IhwO16o16i4o = dnnl_IhwO16o16i4o,
2017 IdhwO16o16i2o = dnnl_IdhwO16o16i2o,
2018 IdhwO16o16i4o = dnnl_IdhwO16o16i4o,
2019 gIwO16o16i2o = dnnl_gIwO16o16i2o,
2020 gIwO16o16i4o = dnnl_gIwO16o16i4o,
2021 gIhwO16o16i2o = dnnl_gIhwO16o16i2o,
2022 gIhwO16o16i4o = dnnl_gIhwO16o16i4o,
2023 gIdhwO16o16i2o = dnnl_gIdhwO16o16i2o,
2024 gIdhwO16o16i4o = dnnl_gIdhwO16o16i4o,
2025 IwO16o32i2o = dnnl_IwO16o32i2o,
2026 IwO16o32i4o = dnnl_IwO16o32i4o,
2027 IwO16o48i2o = dnnl_IwO16o48i2o,
2028 IwO16o48i4o = dnnl_IwO16o48i4o,
2029 IwO16o64i2o = dnnl_IwO16o64i2o,
2030 IwO16o64i4o = dnnl_IwO16o64i4o,
2031 gIwO16o32i2o = dnnl_gIwO16o32i2o,
2032 gIwO16o32i4o = dnnl_gIwO16o32i4o,
2033 gIwO16o48i2o = dnnl_gIwO16o48i2o,
2034 gIwO16o48i4o = dnnl_gIwO16o48i4o,
2035 gIwO16o64i2o = dnnl_gIwO16o64i2o,
2036 gIwO16o64i4o = dnnl_gIwO16o64i4o,
2037 IhwO16o32i2o = dnnl_IhwO16o32i2o,
2038 IhwO16o32i4o = dnnl_IhwO16o32i4o,
2039 IhwO16o48i2o = dnnl_IhwO16o48i2o,
2040 IhwO16o48i4o = dnnl_IhwO16o48i4o,
2041 IhwO16o64i2o = dnnl_IhwO16o64i2o,
2042 IhwO16o64i4o = dnnl_IhwO16o64i4o,
2043 gIhwO16o32i2o = dnnl_gIhwO16o32i2o,
2044 gIhwO16o32i4o = dnnl_gIhwO16o32i4o,
2045 gIhwO16o48i2o = dnnl_gIhwO16o48i2o,
2046 gIhwO16o48i4o = dnnl_gIhwO16o48i4o,
2047 gIhwO16o64i2o = dnnl_gIhwO16o64i2o,
2048 gIhwO16o64i4o = dnnl_gIhwO16o64i4o,
2049 aBdeC16c16b2c = dnnl_aBdeC16c16b2c,
2050 aBdefC16c16b2c = dnnl_aBdefC16c16b2c,
2051 AcdB16b16a4b = dnnl_AcdB16b16a4b,
2052 AcdeB16b16a2b = dnnl_AcdeB16b16a2b,
2053 hwioG16g = dnnl_hwioG16g,
2054 hwioG8g = dnnl_hwioG8g,
2055 ABc4a2b = dnnl_ABc4a2b,
2056 ABc8a2b = dnnl_ABc8a2b,
2057 ABcd4a2b = dnnl_ABcd4a2b,
2058 ABcde4a2b = dnnl_ABcde4a2b,
2059 ABcde8a2b = dnnl_ABcde8a2b,
2060 ABcd4a8b8a2b = dnnl_ABcd4a8b8a2b,
2061 NCdhw40n32c = dnnl_NCdhw40n32c,
2062 NChw40n32c = dnnl_NChw40n32c,
2063 NCw40n32c = dnnl_NCw40n32c,
2064 OIdhw4o8i8o2i = dnnl_OIdhw4o8i8o2i,
2065 OIhw4o8i8o2i = dnnl_OIhw4o8i8o2i,
2066 OIw4o8i8o2i = dnnl_OIw4o8i8o2i,
2067 gOIdhw4o8i8o2i = dnnl_gOIdhw4o8i8o2i,
2068 gOIhw4o8i8o2i = dnnl_gOIhw4o8i8o2i,
2069 gOIw4o8i8o2i = dnnl_gOIw4o8i8o2i,
2070 IOdhw4i8o8i2o = dnnl_IOdhw4i8o8i2o,
2071 IOhw4i8o8i2o = dnnl_IOhw4i8o8i2o,
2072 IOw4i8o8i2o = dnnl_IOw4i8o8i2o,
2073 gIOdhw4i8o8i2o = dnnl_gIOdhw4i8o8i2o,
2074 gIOhw4i8o8i2o = dnnl_gIOhw4i8o8i2o,
2075 gIOw4i8o8i2o = dnnl_gIOw4i8o8i2o,
2076 aBCd8b2c = dnnl_aBCd8b2c,
2077 ABcde40a16b = dnnl_ABcde40a16b,
2078 ABcde40a32b = dnnl_ABcde40a32b,
2079 aBCde8b2c = dnnl_aBCde8b2c,
2080 ABcde4a8b8a2b = dnnl_ABcde4a8b8a2b,
2081 ABc4a8b8a2b = dnnl_ABc4a8b8a2b,
2082 aBCdef4b8c8b2c = dnnl_aBCdef4b8c8b2c,
2083 aBCde4b8c8b2c = dnnl_aBCde4b8c8b2c,
2084 aBCd4b8c8b2c = dnnl_aBCd4b8c8b2c,
2085 BAcde4b8a8b2a = dnnl_BAcde4b8a8b2a,
2086 BAcd4b8a8b2a = dnnl_BAcd4b8a8b2a,
2087 BAc4b8a8b2a = dnnl_BAc4b8a8b2a,
2088 aCBdef4c8b8c2b = dnnl_aCBdef4c8b8c2b,
2089 aCBde4c8b8c2b = dnnl_aCBde4c8b8c2b,
2090 aCBd4c8b8c2b = dnnl_aCBd4c8b8c2b,
2091 aBCdef8b2c = dnnl_aBCdef8b2c,
2092 AB32a16b = dnnl_AB32a16b,
2093 AB32a32b = dnnl_AB32a32b,
2094 BA4b8a8b2a = dnnl_BA4b8a8b2a,
2095 BA4b8a8b4a = dnnl_BA4b8a8b4a,
2096 aBC32b16c = dnnl_aBC32b16c,
2097 aBC32b32c = dnnl_aBC32b32c,
2098 aCB4c8b8c2b = dnnl_aCB4c8b8c2b,
2099 aCB4c8b8c4b = dnnl_aCB4c8b8c4b,
2100 ABc2b8a16b4a = dnnl_ABc2b8a16b4a,
2101 ABcd2b8a16b4a = dnnl_ABcd2b8a16b4a,
2102 ABcde2b8a16b4a = dnnl_ABcde2b8a16b4a,
2103 ABc2a8b16a4b = dnnl_ABc2a8b16a4b,
2104 ABc2a8b16a2b = dnnl_ABc2a8b16a2b,
2105 ABc2b32a8b = dnnl_ABc2b32a8b,
2106 ABcd2a8b16a4b = dnnl_ABcd2a8b16a4b,
2107 ABcd2a8b16a2b = dnnl_ABcd2a8b16a2b,
2108 aCBd2c8b16c2b = dnnl_aCBd2c8b16c2b,
2109 ABcd2b32a8b = dnnl_ABcd2b32a8b,
2110 aBCd2c8b16c2b = dnnl_aBCd2c8b16c2b,
2111 ABcde2a8b16a4b = dnnl_ABcde2a8b16a4b,
2112 ABcde2a8b16a2b = dnnl_ABcde2a8b16a2b,
2113 aCBde2c8b16c2b = dnnl_aCBde2c8b16c2b,
2114 ABcde2b32a8b = dnnl_ABcde2b32a8b,
2115 aBC2b8c16b2c = dnnl_aBC2b8c16b2c,
2116 aBCd2b8c16b2c = dnnl_aBCd2b8c16b2c,
2117 aBCde2b8c16b2c = dnnl_aBCde2b8c16b2c,
2118 aBCdef2b8c16b2c = dnnl_aBCdef2b8c16b2c,
2119 BAcde2b8a16b4a = dnnl_BAcde2b8a16b4a,
2120 BAcd2b8a16b4a = dnnl_BAcd2b8a16b4a,
2121 BAc2b8a16b4a = dnnl_BAc2b8a16b4a,
2122 BAcde2b8a16b2a = dnnl_BAcde2b8a16b2a,
2123 BAcd2b8a16b2a = dnnl_BAcd2b8a16b2a,
2124 BAc2b8a16b2a = dnnl_BAc2b8a16b2a,
2125 aBCde2c8b16c2b = dnnl_aBCde2c8b16c2b,
2126 aBCdef2c8b16c2b = dnnl_aBCdef2c8b16c2b,
2127 aCBdef2c8b16c2b = dnnl_aCBdef2c8b16c2b,
2128 aBCd2b8c16b4c = dnnl_aBCd2b8c16b4c,
2129 aBCde2b8c16b4c = dnnl_aBCde2b8c16b4c,
2130 NCdhw40n16c = dnnl_NCdhw40n16c,
2131 NCw40n16c = dnnl_NCw40n16c,
2132 NChw40n16c = dnnl_NChw40n16c,
2133 NCw2c32n8c = dnnl_NCw2c32n8c,
2134 NChw2c32n8c = dnnl_NChw2c32n8c,
2135 NCdhw2c32n8c = dnnl_NCdhw2c32n8c,
2136 OIw2i8o16i4o = dnnl_OIw2i8o16i4o,
2137 OIhw2i8o16i4o = dnnl_OIhw2i8o16i4o,
2138 OIdhw2i8o16i4o = dnnl_OIdhw2i8o16i4o,
2139 OIw2o8i16o4i = dnnl_OIw2o8i16o4i,
2140 OIw2o8i16o2i = dnnl_OIw2o8i16o2i,
2141 IOw2i8o16i4o = dnnl_IOw2i8o16i4o,
2142 IOw2i8o16i2o = dnnl_IOw2i8o16i2o,
2143 OIhw2o8i16o4i = dnnl_OIhw2o8i16o4i,
2144 OIhw2o8i16o2i = dnnl_OIhw2o8i16o2i,
2145 IOhw2i8o16i4o = dnnl_IOhw2i8o16i4o,
2146 IOhw2i8o16i2o = dnnl_IOhw2i8o16i2o,
2147 OIdhw2o8i16o4i = dnnl_OIdhw2o8i16o4i,
2148 OIdhw2o8i16o2i = dnnl_OIdhw2o8i16o2i,
2149 IOdhw2i8o16i4o = dnnl_IOdhw2i8o16i4o,
2150 IOdhw2i8o16i2o = dnnl_IOdhw2i8o16i2o,
2151 gOIw2o8i16o2i = dnnl_gOIw2o8i16o2i,
2152 gIOw2i8o16i2o = dnnl_gIOw2i8o16i2o,
2153 gIOhw2i8o16i2o = dnnl_gIOhw2i8o16i2o,
2154 gIOdhw2i8o16i2o = dnnl_gIOdhw2i8o16i2o,
2155 gOIhw2o8i16o2i = dnnl_gOIhw2o8i16o2i,
2156 gOIdhw2o8i16o2i = dnnl_gOIdhw2o8i16o2i,
2157 gOIw2o8i16o4i = dnnl_gOIw2o8i16o4i,
2158 gOIhw2o8i16o4i = dnnl_gOIhw2o8i16o4i,
2159 BA4b8a16b2a = dnnl_BA4b8a16b2a,
2160 BA4b8a16b4a = dnnl_BA4b8a16b4a,
2161 aCB4c8b16c2b = dnnl_aCB4c8b16c2b,
2162 aCB4c8b16c4b = dnnl_aCB4c8b16c4b,
2163 aCB16c2b = dnnl_aCB16c2b,
2164 aCB16c4b = dnnl_aCB16c4b,
2165 BA16b2a = dnnl_BA16b2a,
2166 BA16b4a = dnnl_BA16b4a,
2167 aBC16b16c = dnnl_aBC16b16c,
2168 aBC16b32c = dnnl_aBC16b32c,
2169 AB16a16b = dnnl_AB16a16b,
2170 AB16a32b = dnnl_AB16a32b,
2171 ABcde16a16b2a = dnnl_ABcde16a16b2a,
2172 aBCdef16b16c2b = dnnl_aBCdef16b16c2b,
2173 Acedb16a = dnnl_Acedb16a,
2174 aBdfec16b = dnnl_aBdfec16b,
2175 Odwhi16o = dnnl_Odwhi16o,
2176 gOdwhi16o = dnnl_gOdwhi16o,
2177 abdEC64e2c = dnnl_abdEC64e2c,
2178 abdEC64e4c = dnnl_abdEC64e4c,
2179 ldgOI64o2i = abdEC64e2c,
2180 ldgOI64o4i = abdEC64e4c,
2181 abCd4c = dnnl_abCd4c,
2182 abCde4c = dnnl_abCde4c,
2183 abCdef4c = dnnl_abCdef4c,
2184 abCde32c = dnnl_abCde32c,
2185 abCdef32c = dnnl_abCdef32c,
2186 aCdefB16b32c2b = dnnl_aCdefB16b32c2b,
2187 aCdefB16b32c4b = dnnl_aCdefB16b32c4b,
2188 aCdefB16b48c2b = dnnl_aCdefB16b48c2b,
2189 aCdefB16b48c4b = dnnl_aCdefB16b48c4b,
2190 aCdefB16b64c2b = dnnl_aCdefB16b64c2b,
2191 aCdefB16b64c4b = dnnl_aCdefB16b64c4b,
2192 BcdeA16a32b2a = dnnl_BcdeA16a32b2a,
2193 BcdeA16a32b4a = dnnl_BcdeA16a32b4a,
2194 BcdeA16a48b2a = dnnl_BcdeA16a48b2a,
2195 BcdeA16a48b4a = dnnl_BcdeA16a48b4a,
2196 BcdeA16a64b2a = dnnl_BcdeA16a64b2a,
2197 BcdeA16a64b4a = dnnl_BcdeA16a64b4a,
2198 aCdefb32c = dnnl_aCdefb32c,
2199 aCdefB32c2b = dnnl_aCdefB32c2b,
2200 aCdefB32c4b = dnnl_aCdefB32c4b,
2201 aCdefb48c = dnnl_aCdefb48c,
2202 aCdefB48c2b = dnnl_aCdefB48c2b,
2203 aCdefB48c4b = dnnl_aCdefB48c4b,
2204 aCdefb64c = dnnl_aCdefb64c,
2205 aCdefB64c2b = dnnl_aCdefB64c2b,
2206 aCdefB64c4b = dnnl_aCdefB64c4b,
2207 Bcdea32b = dnnl_Bcdea32b,
2208 BcdeA32b2a = dnnl_BcdeA32b2a,
2209 BcdeA32b4a = dnnl_BcdeA32b4a,
2210 Bcdea48b = dnnl_Bcdea48b,
2211 BcdeA48b2a = dnnl_BcdeA48b2a,
2212 BcdeA48b4a = dnnl_BcdeA48b4a,
2213 Bcdea64b = dnnl_Bcdea64b,
2214 BcdeA64b2a = dnnl_BcdeA64b2a,
2215 BcdeA64b4a = dnnl_BcdeA64b4a,
2216 Bca32b = dnnl_Bca32b,
2217 BcA32b2a = dnnl_BcA32b2a,
2218 BcA32b4a = dnnl_BcA32b4a,
2219 Bca48b = dnnl_Bca48b,
2220 BcA48b2a = dnnl_BcA48b2a,
2221 BcA48b4a = dnnl_BcA48b4a,
2222 Bca64b = dnnl_Bca64b,
2223 BcA64b2a = dnnl_BcA64b2a,
2224 BcA64b4a = dnnl_BcA64b4a,
2225 aCdb32c = dnnl_aCdb32c,
2226 aCdB32c2b = dnnl_aCdB32c2b,
2227 aCdB32c4b = dnnl_aCdB32c4b,
2228 aCdb48c = dnnl_aCdb48c,
2229 aCdB48c2b = dnnl_aCdB48c2b,
2230 aCdB48c4b = dnnl_aCdB48c4b,
2231 aCdb64c = dnnl_aCdb64c,
2232 aCdB64c2b = dnnl_aCdB64c2b,
2233 aCdB64c4b = dnnl_aCdB64c4b,
2234 BcA16a16b2a = dnnl_BcA16a16b2a,
2235 BcA16a16b4a = dnnl_BcA16a16b4a,
2236 BcdA16a16b2a = dnnl_BcdA16a16b2a,
2237 BcdA16a16b4a = dnnl_BcdA16a16b4a,
2238 BcdeA16a16b2a = dnnl_BcdeA16a16b2a,
2239 BcdeA16a16b4a = dnnl_BcdeA16a16b4a,
2240 aCdB16b16c2b = dnnl_aCdB16b16c2b,
2241 aCdB16b16c4b = dnnl_aCdB16b16c4b,
2242 aCdeB16b16c2b = dnnl_aCdeB16b16c2b,
2243 aCdeB16b16c4b = dnnl_aCdeB16b16c4b,
2244 aCdefB16b16c2b = dnnl_aCdefB16b16c2b,
2245 aCdefB16b16c4b = dnnl_aCdefB16b16c4b,
2246 BcA16a32b2a = dnnl_BcA16a32b2a,
2247 BcA16a32b4a = dnnl_BcA16a32b4a,
2248 BcA16a48b2a = dnnl_BcA16a48b2a,
2249 BcA16a48b4a = dnnl_BcA16a48b4a,
2250 BcA16a64b2a = dnnl_BcA16a64b2a,
2251 BcA16a64b4a = dnnl_BcA16a64b4a,
2252 aCdB16b32c2b = dnnl_aCdB16b32c2b,
2253 aCdB16b32c4b = dnnl_aCdB16b32c4b,
2254 aCdB16b48c2b = dnnl_aCdB16b48c2b,
2255 aCdB16b48c4b = dnnl_aCdB16b48c4b,
2256 aCdB16b64c2b = dnnl_aCdB16b64c2b,
2257 aCdB16b64c4b = dnnl_aCdB16b64c4b,
2258 BcdA16a32b2a = dnnl_BcdA16a32b2a,
2259 BcdA16a32b4a = dnnl_BcdA16a32b4a,
2260 BcdA16a48b2a = dnnl_BcdA16a48b2a,
2261 BcdA16a48b4a = dnnl_BcdA16a48b4a,
2262 BcdA16a64b2a = dnnl_BcdA16a64b2a,
2263 BcdA16a64b4a = dnnl_BcdA16a64b4a,
2264 aCdeB16b32c2b = dnnl_aCdeB16b32c2b,
2265 aCdeB16b32c4b = dnnl_aCdeB16b32c4b,
2266 aCdeB16b48c2b = dnnl_aCdeB16b48c2b,
2267 aCdeB16b48c4b = dnnl_aCdeB16b48c4b,
2268 aCdeB16b64c2b = dnnl_aCdeB16b64c2b,
2269 aCdeB16b64c4b = dnnl_aCdeB16b64c4b,
2270 Bca16b = dnnl_Bca16b,
2271 BcA16b2a = dnnl_BcA16b2a,
2272 BcA16b4a = dnnl_BcA16b4a,
2273 Bcda16b = dnnl_Bcda16b,
2274 BcdA16b2a = dnnl_BcdA16b2a,
2275 BcdA16b4a = dnnl_BcdA16b4a,
2276 Bcdea16b = dnnl_Bcdea16b,
2277 BcdeA16b2a = dnnl_BcdeA16b2a,
2278 BcdeA16b4a = dnnl_BcdeA16b4a,
2279 aCdb16c = dnnl_aCdb16c,
2280 aCdB16c2b = dnnl_aCdB16c2b,
2281 aCdB16c4b = dnnl_aCdB16c4b,
2282 aCdeb16c = dnnl_aCdeb16c,
2283 aCdeB16c2b = dnnl_aCdeB16c2b,
2284 aCdeB16c4b = dnnl_aCdeB16c4b,
2285 aCdefb16c = dnnl_aCdefb16c,
2286 aCdefB16c2b = dnnl_aCdefB16c2b,
2287 aCdefB16c4b = dnnl_aCdefB16c4b,
2288 Bcda32b = dnnl_Bcda32b,
2289 BcdA32b2a = dnnl_BcdA32b2a,
2290 BcdA32b4a = dnnl_BcdA32b4a,
2291 Bcda48b = dnnl_Bcda48b,
2292 BcdA48b2a = dnnl_BcdA48b2a,
2293 BcdA48b4a = dnnl_BcdA48b4a,
2294 Bcda64b = dnnl_Bcda64b,
2295 BcdA64b2a = dnnl_BcdA64b2a,
2296 BcdA64b4a = dnnl_BcdA64b4a,
2297 aCdeb32c = dnnl_aCdeb32c,
2298 aCdeB32c2b = dnnl_aCdeB32c2b,
2299 aCdeB32c4b = dnnl_aCdeB32c4b,
2300 aCdeb48c = dnnl_aCdeb48c,
2301 aCdeB48c2b = dnnl_aCdeB48c2b,
2302 aCdeB48c4b = dnnl_aCdeB48c4b,
2303 aCdeb64c = dnnl_aCdeb64c,
2304 aCdeB64c2b = dnnl_aCdeB64c2b,
2305 aCdeB64c4b = dnnl_aCdeB64c4b,
2306 NChw16n32c = dnnl_NChw16n32c,
2307 goIw4i = dnnl_goIw4i,
2308 goIw32i = dnnl_goIw32i,
2309 goIhw4i = dnnl_goIhw4i,
2310 goIhw32i = dnnl_goIhw32i,
2311 goIdhw4i = dnnl_goIdhw4i,
2312 goIdhw32i = dnnl_goIdhw32i,
2313 cab = dnnl_cab,
2314 cdab = dnnl_cdab,
2315 cdeab = dnnl_cdeab,
2316 woi = dnnl_woi,
2317 hwoi = dnnl_hwoi,
2318 dhwoi = dnnl_dhwoi,
2319 };
2320
2321 /// A memory descriptor.
2322 struct desc : public handle<dnnl_memory_desc_t> {
2323 using handle<dnnl_memory_desc_t>::handle;
2324
2325 friend struct memory;
2326
2327 /// Constructs a zero (empty) memory descriptor. Such a memory
2328 /// descriptor can be used to indicate absence of an argument.
2329 desc() {
2330 dnnl_memory_desc_t zero_md = nullptr;
2331 error::wrap_c_api(
2332 dnnl_memory_desc_create_with_tag(&zero_md, 0, nullptr,
2333 dnnl_data_type_undef, dnnl_format_tag_undef),
2334 "could not create a zero memory descriptor");
2335 reset(zero_md);
2336 }
2337
2338 /// Constructs a memory descriptor.
2339 ///
2340 /// @note
2341 /// The logical order of dimensions corresponds to the `abc...`
2342 /// format tag, and the physical meaning of the dimensions depends
2343 /// both on the primitive that would operate on this memory and
2344 /// the operation context.
2345 ///
2346 /// @param adims Tensor dimensions.
2347 /// @param adata_type Data precision/type.
2348 /// @param aformat_tag Memory format tag.
2349 /// @param allow_empty A flag signifying whether construction is
2350 /// allowed to fail without throwing an exception. In this case a
2351 /// zero memory descriptor will be constructed. This flag is
2352 /// optional and defaults to false.
2353 desc(const dims &adims, data_type adata_type, format_tag aformat_tag,
2354 bool allow_empty = false) {
2355 validate_dims(adims);
2356 dnnl_memory_desc_t md = nullptr;
2357 dnnl_status_t status = dnnl_memory_desc_create_with_tag(&md,
2358 (int)adims.size(), adims.data(), convert_to_c(adata_type),
2359 convert_to_c(aformat_tag));
2360 if (!allow_empty)
2361 error::wrap_c_api(status,
2362 "could not construct a memory descriptor using a "
2363 "format tag");
2364 reset(md);
2365 }
2366
2367 /// Constructs a memory descriptor by strides.
2368 ///
2369 /// @note
2370 /// The logical order of dimensions corresponds to the `abc...`
2371 /// format tag, and the physical meaning of the dimensions depends
2372 /// both on the primitive that would operate on this memory and
2373 /// the operation context.
2374 ///
2375 /// @param adims Tensor dimensions.
2376 /// @param adata_type Data precision/type.
2377 /// @param strides Strides for each dimension.
2378 /// @param allow_empty A flag signifying whether construction is
2379 /// allowed to fail without throwing an exception. In this case a
2380 /// zero memory descriptor will be constructed. This flag is
2381 /// optional and defaults to false.
2382 desc(const dims &adims, data_type adata_type, const dims &strides,
2383 bool allow_empty = false) {
2384 validate_dims(adims);
2385 if (!strides.empty()) validate_dims(strides, (int)adims.size());
2386 dnnl_memory_desc_t md = nullptr;
2387 dnnl_status_t status = dnnl_memory_desc_create_with_strides(&md,
2388 (int)adims.size(), adims.data(), convert_to_c(adata_type),
2389 strides.empty() ? nullptr : &strides[0]);
2390 if (!allow_empty)
2391 error::wrap_c_api(status,
2392 "could not construct a memory descriptor using "
2393 "strides");
2394 reset(md);
2395 }
2396
2397 /// Construct a memory descriptor from a C API ::dnnl_memory_desc_t
2398 /// handle. The resulting handle is not weak and the C handle will be
2399 /// destroyed during the destruction of the C++ object.
2400 ///
2401 /// @param md The C API memory descriptor.
2402 desc(dnnl_memory_desc_t md) : handle<dnnl_memory_desc_t>(md) {}
2403
2404 /// Constructs a memory descriptor for a region inside an area
2405 /// described by this memory descriptor.
2406 //
2407 /// @param adims Sizes of the region.
2408 /// @param offsets Offsets to the region from the encompassing
2409 /// memory object in each dimension.
2410 /// @param allow_empty A flag signifying whether construction is
2411 /// allowed to fail without throwing an exception. In this case a
2412 /// zero memory descriptor will be returned. This flag is optional
2413 /// and defaults to false.
2414 /// @returns A memory descriptor for the region.
2415 desc submemory_desc(const dims &adims, const dims &offsets,
2416 bool allow_empty = false) const {
2417 validate_dims(adims, get_ndims());
2418 validate_dims(offsets, get_ndims());
2419 dnnl_memory_desc_t sub_md = nullptr;
2420 dnnl_status_t status = dnnl_memory_desc_create_submemory(
2421 &sub_md, get(), adims.data(), offsets.data());
2422 if (!allow_empty)
2423 error::wrap_c_api(status, "could not construct a sub-memory");
2424 return desc(sub_md);
2425 }
2426
2427 /// Constructs a memory descriptor by reshaping an existing one. The
2428 /// new memory descriptor inherits the data type. This operation is
2429 /// valid only for memory descriptors that have format_kind set to
2430 /// #dnnl::memory::format_kind::blocked or
2431 /// #dnnl::memory::format_kind::any.
2432 ///
2433 /// The operation ensures that the transformation of the physical memory
2434 /// format corresponds to the transformation of the logical dimensions.
2435 /// If such transformation is impossible, the function either throws an
2436 /// exception (default) or returns a zero memory descriptor depending on
2437 /// the `allow_empty` flag.
2438 ///
2439 /// The reshape operation can be described as a combination of the
2440 /// following basic operations:
2441 /// 1. Add a dimension of size `1`. This is always possible.
2442 /// 2. Remove a dimension of size `1`. This is possible only if the
2443 /// dimension has no padding (i.e.
2444 /// `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
2445 /// 3. Split a dimension into multiple ones. This is possible only if
2446 /// the product of all tensor dimensions stays constant and the
2447 /// dimension being split does not have padding (i.e.
2448 /// `padded_dims[dim] = dims[dim]`).
2449 /// 4. Join multiple consecutive dimensions into a single one. As in
2450 /// the cases above, this requires that the dimensions do not have
2451 /// padding and that the memory format is such that in physical
2452 /// memory these dimensions are dense and have the same order as
2453 /// their logical counterparts. This also assumes that these
2454 /// dimensions are not blocked.
2455 /// - Here, 'dense' means:
2456 /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
2457 /// - And 'same order' means:
2458 /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
2459 ///
2460 /// @warning
2461 /// Some combinations of physical memory layout and/or offsets or
2462 /// dimensions may result in a failure to make a reshape.
2463 ///
2464 /// @param adims New dimensions. The product of dimensions must
2465 /// remain constant.
2466 /// @param allow_empty A flag signifying whether construction is
2467 /// allowed to fail without throwing an exception. In this case a
2468 /// zero memory descriptor will be returned. This flag is optional
2469 /// and defaults to false.
2470 /// @returns A new memory descriptor with new dimensions.
2471 desc reshape(const dims &adims, bool allow_empty = false) const {
2472 if (get_ndims()) validate_dims(adims, 1);
2473 dnnl_memory_desc_t out_md = nullptr;
2474 dnnl_status_t status = dnnl_memory_desc_reshape(
2475 &out_md, get(), (int)adims.size(), adims.data());
2476 if (!allow_empty)
2477 error::wrap_c_api(
2478 status, "could not reshape a memory descriptor");
2479 return desc(out_md);
2480 }
2481
2482 /// Constructs a memory descriptor by permuting axes in an existing
2483 /// one.
2484 ///
2485 /// The physical memory layout representation is adjusted accordingly
2486 /// to maintain the consistency between the logical and physical parts
2487 /// of the memory descriptor. The new memory descriptor inherits the
2488 /// data type.
2489 ///
2490 /// The new memory descriptor inherits the data type. This operation is
2491 /// valid only for memory descriptors that have format_kind set to
2492 /// #dnnl::memory::format_kind::blocked or
2493 /// #dnnl::memory::format_kind::any.
2494 ///
2495 /// The logical axes will be permuted in the following manner:
2496 /// @code
2497 /// for (i = 0; i < get_ndims(); i++)
2498 /// new_desc.dims()[permutation[i]] = dims()[i];
2499 /// @endcode
2500 ///
2501 /// Example:
2502 /// @code
2503 /// std::vector<int> permutation = {1, 0}; // swap the first and
2504 /// // the second axes
2505 /// dnnl::memory::desc in_md(
2506 /// {2, 3}, data_type, memory::format_tag::ab);
2507 /// dnnl::memory::desc expect_out_md(
2508 /// {3, 2}, data_type, memory::format_tag::ba);
2509 ///
2510 /// assert(in_md.permute_axes(permutation) == expect_out_md);
2511 /// @endcode
2512 ///
2513 /// @param permutation Axes permutation.
2514 /// @param allow_empty A flag signifying whether construction is
2515 /// allowed to fail without throwing an exception. In this case a
2516 /// zero memory descriptor will be returned. This flag is optional
2517 /// and defaults to false.
2518 /// @returns A new memory descriptor with new dimensions.
2519 desc permute_axes(const std::vector<int> &permutation,
2520 bool allow_empty = false) const {
2521 validate_dims(permutation, get_ndims());
2522 dnnl_memory_desc_t out_md = nullptr;
2523 dnnl_status_t status = dnnl_memory_desc_permute_axes(
2524 &out_md, get(), permutation.data());
2525 if (!allow_empty)
2526 error::wrap_c_api(status,
2527 "could not permute axes of a memory descriptor");
2528 return desc(out_md);
2529 }
2530
2531 /// Returns a number of dimensions of the memory descriptor.
2532 ///
2533 /// @returns A number of dimensions.
2534 int get_ndims() const { return query_s32(query::ndims_s32); }
2535
2536 /// Returns padded dimensions of the memory descriptor.
2537 ///
2538 /// @returns A copy of the padded dimensions vector.
2539 memory::dims get_padded_dims() const {
2540 return query_dims(query::padded_dims);
2541 }
2542
2543 /// Returns padded offsets of the memory descriptor.
2544 ///
2545 /// @returns A copy of the padded offsets vector.
2546 memory::dims get_padded_offsets() const {
2547 return query_dims(query::padded_offsets);
2548 }
2549
2550 /// Returns a submemory offset of the memory descriptor.
2551 ///
2552 /// @returns A submemory offset.
2553 memory::dim get_submemory_offset() const {
2554 dnnl_dim_t submemory_offset;
2555 dnnl_status_t status = dnnl_memory_desc_query(
2556 get(), dnnl_query_submemory_offset_s64, &submemory_offset);
2557 return status == dnnl_success ? submemory_offset : 0;
2558 }
2559
2560 /// Returns strides of the memory descriptor.
2561 ///
2562 /// @note
2563 /// This API is only applicable to memory descriptors with format
2564 /// kind #dnnl_blocked.
2565 ///
2566 /// @returns A copy of the strides vector.
2567 /// @returns An empty #dnnl::memory::dims if the memory descriptor
2568 /// does not have strides.
2569 memory::dims get_strides() const { return query_dims(query::strides); }
2570
2571 /// Returns a number of inner blocks of the memory descriptor.
2572 ///
2573 /// @note
2574 /// This API is only applicable to memory descriptors with format
2575 /// kind #dnnl_blocked.
2576 ///
2577 /// @returns A number of inner blocks.
2578 int get_inner_nblks() const {
2579 return query_s32(query::inner_nblks_s32);
2580 }
2581
2582 /// Returns inner blocks of the memory descriptor.
2583 ///
2584 /// @note
2585 /// This API is only applicable to memory descriptors with format
2586 /// kind #dnnl_blocked.
2587 ///
2588 /// @returns A copy of the inner blocks vector.
2589 /// @returns An empty #dnnl::memory::dims if the memory descriptor
2590 /// does not have inner blocks.
2591 memory::dims get_inner_blks() const {
2592 return query_dims(query::inner_blks);
2593 }
2594
2595 /// Returns inner indices of the memory descriptor.
2596 ///
2597 /// @note
2598 /// This API is only applicable to memory descriptors with format
2599 /// kind #dnnl_blocked.
2600 ///
2601 /// @returns A copy of the inner indices vector.
2602 /// @returns An empty #dnnl::memory::dims if the memory descriptor
2603 /// does not have inner indices.
2604 memory::dims get_inner_idxs() const {
2605 return query_dims(query::inner_idxs);
2606 }
2607
2608 /// Returns the format kind of the memory descriptor.
2609 ///
2610 /// @returns the format kind.
2611 memory::format_kind get_format_kind() const {
2612 dnnl_format_kind_t format_kind;
2613 dnnl_status_t status = dnnl_memory_desc_query(
2614 get(), dnnl_query_format_kind, &format_kind);
2615 return status == dnnl_success
2616 ? static_cast<dnnl::memory::format_kind>(format_kind)
2617 : dnnl::memory::format_kind::undef;
2618 }
2619
2620 /// Returns the data type of the memory descriptor.
2621 ///
2622 /// @returns The data type.
2623 memory::data_type get_data_type() const {
2624 dnnl_data_type_t data_type;
2625 dnnl_status_t status = dnnl_memory_desc_query(
2626 get(), dnnl_query_data_type, &data_type);
2627 return status == dnnl_success
2628 ? static_cast<dnnl::memory::data_type>(data_type)
2629 : dnnl::memory::data_type::undef;
2630 }
2631
2632 /// Returns dimensions of the memory descriptor.
2633 ///
2634 /// Potentially expensive due to the data copy involved.
2635 /// @returns A copy of the dimensions vector.
2636 memory::dims get_dims() const { return query_dims(query::dims); }
2637
2638 /// Returns size of the memory descriptor in bytes.
2639 /// @returns The number of bytes required to allocate a memory buffer
2640 /// for the memory object described by this memory descriptor
2641 /// including the padding area.
2642 size_t get_size() const { return dnnl_memory_desc_get_size(get()); }
2643
2644 /// Checks whether the memory descriptor is zero (empty).
2645 /// @returns @c true if the memory descriptor describes an empty
2646 /// memory and @c false otherwise.
2647 bool is_zero() const { return get_ndims() == 0; }
2648
2649 /// An equality operator.
2650 /// @param other Another memory descriptor.
2651 /// @returns Whether this and the other memory descriptors have
2652 /// the same format tag, dimensions, strides, blocking, etc.
2653 bool operator==(const desc &other) const {
2654 return dnnl_memory_desc_equal(get(), other.get()) != 0;
2655 }
2656
2657 /// An inequality operator.
2658 /// @param other Another memory descriptor.
2659 /// @returns Whether this and the other memory descriptors describe
2660 /// different memory.
2661 bool operator!=(const desc &other) const { return !operator==(other); }
2662
2663 private:
2664 int query_s32(query what) const {
2665 int res;
2666 dnnl_status_t status = dnnl_memory_desc_query(
2667 get(), dnnl::convert_to_c(what), &res);
2668 return status == dnnl_success ? res : 0;
2669 }
2670
2671 memory::dims query_dims(query what) const {
2672 dnnl_dims_t *c_dims;
2673 dnnl_status_t status = dnnl_memory_desc_query(
2674 get(), dnnl::convert_to_c(what), &c_dims);
2675
2676 const int ndims
2677 = (what == query::inner_idxs || what == query::inner_blks)
2678 ? get_inner_nblks()
2679 : get_ndims();
2680
2681 return status == dnnl_success
2682 ? memory::dims(*c_dims, *c_dims + ndims)
2683 : memory::dims {};
2684 }
2685 };
2686
2687 /// Default constructor.
2688 ///
2689 /// Constructs an empty memory object, which can be used to indicate
2690 /// absence of a parameter.
2691 memory() = default;
2692
2693 /// Constructs a memory object.
2694 ///
2695 /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
2696 /// object will have the underlying buffer set. In this case, the buffer
2697 /// will be initialized as if #dnnl::memory::set_data_handle() had been
2698 /// called.
2699 ///
2700 /// @sa memory::set_data_handle()
2701 ///
2702 /// @param md Memory descriptor.
2703 /// @param aengine Engine to store the data on.
2704 /// @param handle Handle of the memory buffer to use.
2705 /// - A pointer to the user-allocated buffer. In this case the library
2706 /// doesn't own the buffer.
2707 /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
2708 /// allocate the buffer for the memory object. In this case the
2709 /// library owns the buffer.
2710 /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying
2711 /// buffer.
2712 memory(const desc &md, const engine &aengine, void *handle) {
2713 dnnl_memory_t result;
2714 error::wrap_c_api(
2715 dnnl_memory_create(&result, md.get(), aengine.get(), handle),
2716 "could not create a memory object");
2717 reset(result);
2718 }
2719
2720 /// Constructs a memory object.
2721 ///
2722 /// The underlying buffer for the memory will be allocated by the library.
2723 ///
2724 /// @param md Memory descriptor.
2725 /// @param aengine Engine to store the data on.
2726 memory(const desc &md, const engine &aengine)
2727 : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
2728
2729 /// Returns the associated memory descriptor.
2730 desc get_desc() const {
2731 const_dnnl_memory_desc_t cdesc;
2732 error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
2733 "could not get a memory descriptor from a memory object");
2734 dnnl_memory_desc_t cloned_md = nullptr;
2735 error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
2736 "could not clone a memory descriptor");
2737 return desc(cloned_md);
2738 }
2739
2740 /// Returns the associated engine.
2741 engine get_engine() const {
2742 dnnl_engine_t c_engine;
2743 error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
2744 "could not get an engine from a memory object");
2745 return engine(c_engine, true);
2746 }
2747
2748 /// Returns the underlying memory buffer.
2749 ///
2750 /// On the CPU engine, or when using USM, this is a pointer to the
2751 /// allocated memory.
2752 void *get_data_handle() const {
2753 void *handle;
2754 error::wrap_c_api(dnnl_memory_get_data_handle(get(), &handle),
2755 "could not get a native handle from a memory object");
2756 return handle;
2757 }
2758
2759 /// Sets the underlying memory buffer.
2760 ///
2761 /// @param handle Memory buffer to use. On the CPU engine or when USM is
2762 /// used, the memory buffer is a pointer to the actual data. For OpenCL
2763 /// it is a cl_mem. It must have at least
2764 /// #dnnl::memory::desc::get_size() bytes allocated.
2765 void set_data_handle(void *handle) const {
2766 error::wrap_c_api(dnnl_memory_set_data_handle(get(), handle),
2767 "could not set native handle of a memory object");
2768 }
2769
2770 /// Maps a memory object and returns a host-side pointer to a memory
2771 /// buffer with a copy of its contents.
2772 ///
2773 /// Mapping enables read/write directly from/to the memory contents for
2774 /// engines that do not support direct memory access.
2775 ///
2776 /// Mapping is an exclusive operation - a memory object cannot be used in
2777 /// other operations until it is unmapped via #dnnl::memory::unmap_data()
2778 /// call.
2779 ///
2780 /// @note
2781 /// Any primitives working with the memory should be completed before
2782 /// the memory is mapped. Use #dnnl::stream::wait() to synchronize the
2783 /// corresponding execution stream.
2784 ///
2785 /// @note
2786 /// The map_data and unmap_data functions are provided mainly for
2787 /// debug and testing purposes and their performance may be suboptimal.
2788 ///
2789 /// @tparam T Data type to return a pointer to.
2790 /// @returns Pointer to the mapped memory.
2791 template <typename T = void>
2792 T *map_data() const {
2793 void *mapped_ptr;
2794 error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
2795 "could not map memory object data");
2796 return static_cast<T *>(mapped_ptr);
2797 }
2798
2799 /// Unmaps a memory object and writes back any changes made to the
2800 /// previously mapped memory buffer.
2801 ///
2802 /// @note
2803 /// The map_data and unmap_data functions are provided mainly for
2804 /// debug and testing purposes and their performance may be
2805 /// suboptimal.
2806 ///
2807 /// @param mapped_ptr A pointer previously returned by
2808 /// #dnnl::memory::map_data().
2809 void unmap_data(void *mapped_ptr) const {
2810 error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr),
2811 "could not unmap memory object data");
2812 }
2813
2814 static dnnl_data_type_t convert_to_c(data_type adata_type) {
2815 return static_cast<dnnl_data_type_t>(adata_type);
2816 }
2817 static dnnl_format_tag_t convert_to_c(format_tag format) {
2818 return static_cast<dnnl_format_tag_t>(format);
2819 }
2820};
2821
2822inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
2823 return a == memory::convert_to_c(b);
2824}
2825inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
2826 return !(a == b);
2827}
2828inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
2829 return b == a;
2830}
2831inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
2832 return !(a == b);
2833}
2834
2835inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
2836 return a == memory::convert_to_c(b);
2837}
2838inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
2839 return !(a == b);
2840}
2841inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
2842 return b == a;
2843}
2844inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
2845 return !(a == b);
2846}
2847
2848/// @} dnnl_api_memory
2849
2850/// @addtogroup dnnl_api_primitives
2851/// @{
2852/// @addtogroup dnnl_api_attributes Attributes
2853///
2854/// A container for parameters that extend primitives behavior.
2855///
2856/// @{
2857
2858/// @cond DO_NOT_DOCUMENT_THIS
2859template <>
2860struct handle_traits<dnnl_post_ops_t> {
2861 static dnnl_status_t destructor(dnnl_post_ops_t p) {
2862 return dnnl_post_ops_destroy(p);
2863 }
2864};
2865/// @endcond
2866
2867/// Post-ops.
2868///
2869/// Post-ops are computations executed after the main primitive computations
2870/// and are attached to the primitive via primitive attributes.
2871///
2872/// @sa @ref dev_guide_attributes_post_ops
2873///
2874struct post_ops : public handle<dnnl_post_ops_t> {
2875 using handle<dnnl_post_ops_t>::handle;
2876
2877 /// Constructs an empty sequence of post-ops.
2878 post_ops() {
2879 dnnl_post_ops_t result;
2880 error::wrap_c_api(
2881 dnnl_post_ops_create(&result), "could not create post-ops");
2882 reset(result);
2883 }
2884
2885 /// Creates post-ops primitive attribute from a C API ::dnnl_post_ops_t
2886 /// handle. The resulting handle is not weak and the C handle will be
2887 /// destroyed during the destruction of the C++ object.
2888 ///
2889 /// @param post_ops The C API post-ops primitive attribute.
2890 post_ops(dnnl_post_ops_t post_ops) : handle<dnnl_post_ops_t>(post_ops) {}
2891
2892 /// Returns the number of post-ops entries.
2893 int len() const { return dnnl_post_ops_len(get()); }
2894
2895 /// Returns the primitive kind of post-op at entry with a certain index.
2896 /// @param index Index of the post-op to return the kind for.
2897 /// @returns Primitive kind of the post-op at the specified index.
2898 primitive::kind kind(int index) const {
2899 error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments,
2900 "post-ops index is out of range");
2901 return static_cast<primitive::kind>(
2902 dnnl_post_ops_get_kind(get(), index));
2903 }
2904
2905 /// Appends an accumulation (sum) post-op. Prior to accumulating the
2906 /// result, the previous value will be will be reduced by zero point
2907 /// @p zero_point and multiplied by a scaling factor @p scale.
2908 ///
2909 /// The kind of this post-op is #dnnl::primitive::kind::sum.
2910 ///
2911 /// This feature may improve performance for cases like dequantize the
2912 /// asymmetrically quantized sum's src1 tensor to f32 domain before
2913 /// performing the sum operation by subtracting @p zero_point before the
2914 /// scaling.
2915 ///
2916 /// In the simplest case when the accumulation is the only post-op,
2917 /// the computations will be `dst[:] := scale * (dst[:] - zero_point) +
2918 /// op(...)` instead of `dst[:] := op(...)`.
2919 ///
2920 /// If @p data_type is specified, the original dst tensor will be
2921 /// reinterpreted as a tensor with the provided data type. Because it is a
2922 /// reinterpretation, data_type and dst data type should have the same size.
2923 /// As a result, computations will be `dst[:] <- scale *
2924 /// (as_data_type(dst[:]) - zero_point) + op(...)` instead of
2925 /// `dst[:] <- op(...)`.
2926 ///
2927 /// @note
2928 /// This post-op executes in-place and does not change the
2929 /// destination layout.
2930 ///
2931 /// @param scale Scaling factor.
2932 /// @param zero_point Zero point.
2933 /// @param data_type Data type.
2934 void append_sum(float scale = 1.f, int32_t zero_point = 0,
2935 memory::data_type data_type = memory::data_type::undef) {
2936 error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale, zero_point,
2937 memory::convert_to_c(data_type)),
2938 "could not append a sum post-op");
2939 }
2940
2941 /// Returns the parameters of an accumulation (sum) post-op.
2942 ///
2943 /// @param index Index of the sum post-op.
2944 /// @param scale Scaling factor of the sum post-op.
2945 void get_params_sum(int index, float &scale) const {
2946 error::wrap_c_api(dnnl_post_ops_get_params_sum(
2947 get(), index, &scale, nullptr, nullptr),
2948 "could not get parameters of a sum post-op");
2949 }
2950
2951 /// Returns the parameters of an accumulation (sum) post-op.
2952 ///
2953 /// @param index Index of the sum post-op.
2954 /// @param scale Scaling factor of the sum post-op.
2955 /// @param data_type Data type of the sum post-op.
2956 void get_params_sum(
2957 int index, float &scale, memory::data_type &data_type) const {
2958 dnnl_data_type_t c_data_type;
2959 error::wrap_c_api(dnnl_post_ops_get_params_sum(
2960 get(), index, &scale, nullptr, &c_data_type),
2961 "could not get parameters of a sum post-op");
2962 data_type = static_cast<memory::data_type>(c_data_type);
2963 }
2964
2965 /// Returns the parameters of an accumulation (sum) post-op.
2966 ///
2967 /// @param index Index of the sum post-op.
2968 /// @param scale Scaling factor of the sum post-op.
2969 /// @param zero_point Single scalar int32_t value of zeropoint.
2970 /// @param data_type Data type of the sum post-op.
2971 void get_params_sum(int index, float &scale, int32_t &zero_point,
2972 memory::data_type &data_type) const {
2973 dnnl_data_type_t c_data_type;
2974 error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale,
2975 &zero_point, &c_data_type),
2976 "could not get parameters of a sum post-op");
2977 data_type = static_cast<memory::data_type>(c_data_type);
2978 }
2979
2980 /// Appends an elementwise post-op.
2981 ///
2982 /// The kind of this post-op is #dnnl::primitive::kind::eltwise.
2983 ///
2984 /// In the simplest case when the elementwise is the only post-op, the
2985 /// computations would be `dst[:] := eltwise_op (op(...))` instead
2986 /// of `dst[:] <- op(...)`, where eltwise_op is configured with the given
2987 /// parameters.
2988 ///
2989 /// @param aalgorithm Elementwise algorithm.
2990 /// @param alpha Alpha parameter for the elementwise algorithm.
2991 /// @param beta Beta parameter for the elementwise algorithm.
2992 void append_eltwise(algorithm aalgorithm, float alpha, float beta) {
2993 error::wrap_c_api(dnnl_post_ops_append_eltwise(
2994 get(), convert_to_c(aalgorithm), alpha, beta),
2995 "could not append an elementwise post-op");
2996 }
2997
2998 /// Returns parameters of an elementwise post-op.
2999 ///
3000 /// @param index Index of the post-op.
3001 /// @param aalgorithm Output elementwise algorithm kind.
3002 /// @param alpha Output alpha parameter for the elementwise algorithm.
3003 /// @param beta Output beta parameter for the elementwise algorithm.
3004 void get_params_eltwise(
3005 int index, algorithm &aalgorithm, float &alpha, float &beta) const {
3006 dnnl_alg_kind_t c_alg;
3007 error::wrap_c_api(dnnl_post_ops_get_params_eltwise(
3008 get(), index, &c_alg, &alpha, &beta),
3009 "could not get parameters of an elementwise post-op");
3010 aalgorithm = static_cast<dnnl::algorithm>(c_alg);
3011 }
3012
3013 /// Appends a depthwise post-op convolution.
3014 ///
3015 /// This post-op can only be fused with a 2D 1x1 convolution (convolution
3016 /// with weights spatial dimension equal to 1 i.e., kh=kw=1).
3017 ///
3018 /// The kind of this post-op is #dnnl_convolution.
3019 ///
3020 /// The number of outputs for primitive remain same as before. The output
3021 /// spatial size can be derived as below:
3022 ///
3023 /// output_height = ceil(output_height_1x1_convolution, stride)
3024 /// output_width = ceil(output_width_1x1_convolution, stride)
3025 ///
3026 /// See @ref dev_guide_attributes_post_ops_depthwise and
3027 /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
3028 ///
3029 /// @param weights_data_type Weights data type of depthwise post-op
3030 /// @param bias_data_type Bias data type of depthwise post-op
3031 /// @param dst_data_type Output data type of depthwise post-op
3032 /// @param kernel_size Size of kernel of depthwise post-op
3033 /// @param stride_size Size of stride of depthwise post-op
3034 /// @param padding_l_size Size of left and top paddings of depthwise post-op
3035 void append_dw(memory::data_type weights_data_type,
3036 memory::data_type bias_data_type, memory::data_type dst_data_type,
3037 memory::dim kernel_size, memory::dim stride_size,
3038 memory::dim padding_l_size) {
3039
3040 error::wrap_c_api(dnnl_post_ops_append_dw(get(),
3041 memory::convert_to_c(weights_data_type),
3042 memory::convert_to_c(bias_data_type),
3043 memory::convert_to_c(dst_data_type),
3044 kernel_size, stride_size, padding_l_size),
3045 "could not append depthwise post-op");
3046 }
3047
3048 /// Returns the parameters of an depthwise post-op.
3049 ///
3050 /// @param index Index of the elementwise post-op.
3051 /// @param weights_data_type Weights data type of depthwise post-op
3052 /// @param bias_data_type Bias data type of depthwise post-op
3053 /// @param dst_data_type Output data type of depthwise post-op
3054 /// @param kernel_size Size of kernel of depthwise post-op
3055 /// @param stride_size Size of stride of depthwise post-op
3056 /// @param padding_l_size Size of left and top paddings of depthwise post-op
3057 void get_params_dw(int index, memory::data_type &weights_data_type,
3058 memory::data_type &bias_data_type, memory::data_type &dst_data_type,
3059 memory::dim &kernel_size, memory::dim &stride_size,
3060 memory::dim &padding_l_size) const {
3061
3062 dnnl_data_type_t c_weights_data_type;
3063 dnnl_data_type_t c_bias_data_type;
3064 dnnl_data_type_t c_dst_data_type;
3065 dnnl_dim_t c_kernel_size;
3066 dnnl_dim_t c_stride_size;
3067 dnnl_dim_t c_padding_l_size;
3068 error::wrap_c_api(
3069 dnnl_post_ops_get_params_dw(get(), index, &c_weights_data_type,
3070 &c_bias_data_type, &c_dst_data_type, &c_kernel_size,
3071 &c_stride_size, &c_padding_l_size),
3072 "could not get parameters of depthwise post-op");
3073
3074 weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
3075 bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
3076 dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
3077 kernel_size = c_kernel_size;
3078 stride_size = c_stride_size;
3079 padding_l_size = c_padding_l_size;
3080 }
3081
3082 /// Appends a binary post-op.
3083 ///
3084 /// The kind of this post operation is #dnnl_binary.
3085 ///
3086 /// In the simplest case when the binary is the only post operation, the
3087 /// computations would be:
3088 ///
3089 /// dst[:] <- binary_op (dst[:], another_input[:])
3090 ///
3091 /// where binary_op is configured with the given parameters. binary_op
3092 /// supports broadcast semantics for a second operand.
3093 ///
3094 /// @param aalgorithm Binary algorithm for the post-op.
3095 /// @param src1_desc Memory descriptor of a second operand.
3096 void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) {
3097 error::wrap_c_api(dnnl_post_ops_append_binary(get(),
3098 convert_to_c(aalgorithm), src1_desc.get()),
3099 "could not append a binary post-op");
3100 }
3101
3102 /// Returns the parameters of a binary post-op.
3103 ///
3104 /// @param index Index of the binary post-op.
3105 /// @param aalgorithm Output binary algorithm kind.
3106 /// @param src1_desc Output memory descriptor of a second operand.
3107 void get_params_binary(
3108 int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
3109 dnnl_alg_kind_t c_alg;
3110 const_dnnl_memory_desc_t cdesc;
3111 error::wrap_c_api(
3112 dnnl_post_ops_get_params_binary(get(), index, &c_alg, &cdesc),
3113 "could not get parameters of a binary post-op");
3114 aalgorithm = static_cast<dnnl::algorithm>(c_alg);
3115 dnnl_memory_desc_t cloned_md = nullptr;
3116 error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
3117 "could not clone a memory descriptor");
3118 src1_desc = memory::desc(cloned_md);
3119 }
3120
3121 /// Appends a prelu forward post-op.
3122 ///
3123 /// The kind of this post-op is #dnnl::primitive::kind::prelu.
3124 ///
3125 /// The post-op can be defined as:
3126 ///
3127 /// dst[:] <- prelu(dst[:], weights[:])
3128 /// prelu:
3129 /// dst[:] <- dst[:] if dst[:] > 0
3130 /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
3131 ///
3132 ///
3133 /// Example usage:
3134 /// @code
3135 /// int mb = 32, oc = 32,
3136 /// oh = 14, ow = 14; // convolution output params
3137 /// // unique weights per output channel
3138 /// vector<float> weights = { ... };
3139 /// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
3140 ///
3141 /// // construct a convolution descriptor
3142 /// dnnl::convolution::desc conv_d;
3143 ///
3144 /// dnnl::primitive_attr attr;
3145 /// attr.append_prelu(1 << oc_dim);
3146 ///
3147 /// dnnl::primitive_desc conv_pd(conv_d, attr, engine);
3148 /// memory prelu_weights({{1}, dt::f32, {1}}, eng, weights.data());
3149 ///
3150 /// std::unordered_map<int, memory> conv_args;
3151 ///
3152 /// conv_args.insert(
3153 /// {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_WEIGHTS, prelu_weights})
3154 /// @endcode
3155 ///
3156 /// @note
3157 /// The order of dimensions does not depend on how elements are laid
3158 /// out in memory. For example:
3159 /// - for a 2D CNN activations tensor the order is always (n, c)
3160 /// - for a 4D CNN activations tensor the order is always (n, c, h, w)
3161 /// - for a 5D CNN weights tensor the order is always
3162 /// (g, oc, ic, kh, kw)
3163 ///
3164 /// Prelu weights tensor is passed in runtime execution phase. Prelu
3165 /// weights tensor data type is implicitly assumed as f32 using plain
3166 /// layout (a, ab, acb, acdb, acdeb).
3167 ///
3168 /// @param mask Defines the correspondence between the output tensor
3169 /// dimensions and the prelu weights tensor. The set i-th bit indicates
3170 /// that a dedicated weights value is used for each index along that
3171 /// dimension. Set the mask to 0 to use a common weights value
3172 /// for the whole output tensor.
3173 void append_prelu(int mask) {
3174 error::wrap_c_api(dnnl_post_ops_append_prelu(get(), mask),
3175 "could not append a prelu post-op");
3176 }
3177
3178 /// Returns the parameters of a prelu post-op.
3179 ///
3180 /// @param index Index of the prelu post-op.
3181 /// @param mask Weights mask of prelu post-op.
3182 void get_params_prelu(int index, int &mask) const {
3183 error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask),
3184 "could not get parameters of a binary post-op");
3185 }
3186};
3187
3188/// @cond DO_NOT_DOCUMENT_THIS
3189template <>
3190struct handle_traits<dnnl_primitive_attr_t> {
3191 static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
3192 return dnnl_primitive_attr_destroy(p);
3193 }
3194};
3195/// @endcond
3196
3197/// Primitive attributes.
3198///
3199/// @sa @ref dev_guide_attributes
3200struct primitive_attr : public handle<dnnl_primitive_attr_t> {
3201 using handle<dnnl_primitive_attr_t>::handle;
3202
3203 /// Constructs default (empty) primitive attributes.
3204 primitive_attr() {
3205 dnnl_primitive_attr_t result;
3206 error::wrap_c_api(dnnl_primitive_attr_create(&result),
3207 "could not create primitive attribute");
3208 reset(result);
3209 }
3210
3211 /// Creates primitive attributes from a C API ::dnnl_primitive_attr_t
3212 /// handle. The resulting handle is not weak and the C handle will be
3213 /// destroyed during the destruction of the C++ object.
3214 ///
3215 /// @param attr The C API primitive attributes.
3216 primitive_attr(dnnl_primitive_attr_t attr)
3217 : handle<dnnl_primitive_attr_t>(attr) {}
3218
3219 /// Returns the fpmath mode
3220 fpmath_mode get_fpmath_mode() const {
3221 dnnl_fpmath_mode_t result;
3222 error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode(get(), &result),
3223 "could not get fpmath mode primitive attribute");
3224 return fpmath_mode(result);
3225 }
3226
3227 /// Sets fpmath mode.
3228 ///
3229 /// @param mode Specified fpmath mode.
3230 void set_fpmath_mode(fpmath_mode mode) {
3231 error::wrap_c_api(dnnl_primitive_attr_set_fpmath_mode(
3232 get(), dnnl::convert_to_c(mode)),
3233 "could not set fpmath mode primitive attribute");
3234 }
3235
3236 /// Returns the scratchpad mode.
3237 scratchpad_mode get_scratchpad_mode() const {
3238 dnnl_scratchpad_mode_t result;
3239 error::wrap_c_api(
3240 dnnl_primitive_attr_get_scratchpad_mode(get(), &result),
3241 "could not get scratchpad mode primitive attribute");
3242 return scratchpad_mode(result);
3243 }
3244
3245 /// Sets scratchpad mode.
3246 ///
3247 /// @param mode Specified scratchpad mode.
3248 void set_scratchpad_mode(scratchpad_mode mode) {
3249 error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode(
3250 get(), dnnl::convert_to_c(mode)),
3251 "could not set scratchpad mode primitive attribute");
3252 }
3253
3254 /// Sets scaling factors for primitive operations for a given memory
3255 /// argument. The scaling factors must be passed at execution time
3256 /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
3257 ///
3258 /// @sa dnnl_primitive_attr_set_scales_mask
3259 ///
3260 /// @param arg Parameter argument index as passed to the
3261 /// primitive::execute() call.
3262 /// @param mask Scaling factors correspondence mask that defines the
3263 /// correspondence between the tensor dimensions and the @p scales
3264 /// vector. The set i-th bit indicates that a dedicated scaling factor
3265 /// is used for each index along that dimension. Set the mask to 0 to
3266 /// use a common scaling factor for the whole output tensor.
3267 void set_scales_mask(int arg, int mask) {
3268 error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
3269 "could not set scales primitive attribute");
3270 }
3271
3272 /// Sets zero points for primitive operations for a given memory argument.
3273 /// The zero points must be passed at execution time as an argument with
3274 /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
3275 ///
3276 /// @sa dnnl_primitive_attr_set_zero_points_mask
3277 ///
3278 /// @param arg Parameter argument index as passed to the
3279 /// primitive::execute() call.
3280 /// @param mask Zero point correspondence mask that defines the
3281 /// correspondence between the tensor dimensions and the @p
3282 /// zero_points vector. The set i-th bit indicates that a dedicated
3283 /// zero point is used for each index along that dimension. Set the
3284 /// mask to 0 to use a common zero point for the whole output tensor.
3285 void set_zero_points_mask(int arg, int mask) {
3286 error::wrap_c_api(
3287 dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask),
3288 "could not set zero points primitive attribute");
3289 }
3290
3291 /// Returns post-ops previously set via set_post_ops().
3292 ///
3293 /// @returns Post-ops.
3294 const post_ops get_post_ops() const {
3295 const_dnnl_post_ops_t const_c_post_ops;
3296 error::wrap_c_api(
3297 dnnl_primitive_attr_get_post_ops(get(), &const_c_post_ops),
3298 "could not get post-ops primitive attribute");
3299 dnnl_post_ops_t c_post_ops;
3300 error::wrap_c_api(dnnl_post_ops_clone(&c_post_ops, const_c_post_ops),
3301 "could not clone post-ops primitive attribute");
3302 return post_ops(c_post_ops);
3303 }
3304
3305 /// Sets post-ops.
3306 ///
3307 /// @note
3308 /// There is no way to check whether the post-ops would be supported
3309 /// by the target primitive. Any error will be reported
3310 /// by the respective primitive descriptor constructor.
3311 ///
3312 /// @param ops Post-ops object to copy post-ops from.
3313 void set_post_ops(const post_ops ops) {
3314 error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()),
3315 "could not set post-ops primitive attribute");
3316 }
3317
3318 /// Sets quantization scale and shift parameters for RNN data tensors.
3319 ///
3320 /// For performance reasons, the low-precision configuration of the RNN
3321 /// primitives expect input activations to have the unsigned 8-bit integer
3322 /// data type. The scale and shift parameters are used to quantize
3323 /// floating-point data to unsigned integer and must be passed to the RNN
3324 /// primitive using attributes.
3325 ///
3326 /// The quantization formula is `scale * data + shift`.
3327 ///
3328 /// Example usage:
3329 /// @code
3330 /// // RNN parameters
3331 /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
3332 /// // Activations quantization parameters
3333 /// float scale = 63.f, shift = 64.f;
3334 ///
3335 /// primitive_attr attr;
3336 ///
3337 /// // Set scale and shift for int8 quantization of activation
3338 /// attr.set_rnn_data_qparams(scale, shift);
3339 ///
3340 /// // Create an RNN primitive descriptor.
3341 /// vanilla_rnn_forward::primitive_desc rnn_d(
3342 /// engine, /* arguments */, attr);
3343 /// @endcode
3344 ///
3345 /// @note
3346 /// Quantization scale and shift are common for src_layer, src_iter,
3347 /// dst_iter, and dst_layer.
3348 ///
3349 /// @param scale The value to scale the data by.
3350 /// @param shift The value to shift the data by.
3351 void set_rnn_data_qparams(float scale, float shift) {
3352 error::wrap_c_api(
3353 dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift),
3354 "could not set RNN data quantization parameters primitive "
3355 "attribute");
3356 }
3357
3358 /// Returns the quantization scale and shift parameters for RNN data
3359 /// tensors.
3360 ///
3361 /// @note
3362 /// Quantization scale and shift are common for src_layer, src_iter,
3363 /// dst_iter, and dst_layer.
3364 ///
3365 /// @param scale The value to scale the data by.
3366 /// @param shift The value to shift the data by.
3367 void get_rnn_data_qparams(float &scale, float &shift) {
3368 float c_scale, c_shift;
3369 error::wrap_c_api(dnnl_primitive_attr_get_rnn_data_qparams(
3370 get(), &c_scale, &c_shift),
3371 "could not set RNN data quantization parameters primitive "
3372 "attribute");
3373 scale = c_scale;
3374 shift = c_shift;
3375 }
3376
3377 /// Sets quantization scaling factors for RNN weights tensors. The
3378 /// low-precision configuration of the RNN primitives expect input weights
3379 /// to use the signed 8-bit integer data type. The scaling factors are
3380 /// used to quantize floating-point data to signed integer and must be
3381 /// passed to RNN primitives using attributes.
3382 ///
3383 /// @note
3384 /// The dimension order is always native and does not depend on the
3385 /// actual layout used. For example, five-dimensional weights always
3386 /// have (l, d, i, g, o) logical dimension ordering.
3387 ///
3388 /// @note
3389 /// Quantization scales are common for weights_layer and
3390 /// weights_iteration
3391 ///
3392 /// @param mask Scaling factors correspondence mask that defines the
3393 /// correspondence between the output tensor dimensions and the @p
3394 /// scales vector. The set i-th bit indicates that a dedicated scaling
3395 /// factor should be used each index along that dimension. Set the
3396 /// mask to 0 to use a common scaling factor for the whole output
3397 /// tensor.
3398 /// @param scales Constant vector of output scaling factors. The following
3399 /// equality must hold:
3400 /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
3401 /// Violations can only be detected when the attributes are used to
3402 /// create a primitive descriptor.
3403 void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
3404 error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(),
3405 (int)scales.size(), mask, scales.data()),
3406 "could not set RNN weights quantization parameters primitive "
3407 "attribute");
3408 }
3409
3410 /// Returns the quantization scaling factors for RNN projection weights
3411 /// tensors.
3412 ///
3413 /// @note
3414 /// The dimension order is always native and does not depend on the
3415 /// actual layout used. For example, five-dimensional weights always
3416 /// have (l, d, i, g, o) logical dimension ordering.
3417 ///
3418 /// @param mask Scaling factors correspondence mask that defines the
3419 /// correspondence between the output tensor dimensions and the @p
3420 /// scales vector. The set i-th bit indicates that a dedicated scaling
3421 /// factor should be used each index along that dimension. Set the
3422 /// mask to 0 to use a common scaling factor for the whole output
3423 /// tensor.
3424 /// @param scales Constant vector of output scaling factors. The following
3425 /// equality must hold:
3426 /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
3427 /// Violations can only be detected when the attributes are used to
3428 /// create a primitive descriptor.
3429 void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) {
3430 dnnl_dim_t count;
3431 int c_mask;
3432 const float *c_scales;
3433 error::wrap_c_api(dnnl_primitive_attr_get_rnn_weights_qparams(
3434 get(), &count, &c_mask, &c_scales),
3435 "could not get primitive RNN weights quantization "
3436 "parameters attributes");
3437 scales.resize(count);
3438
3439 mask = c_mask;
3440 for (dnnl_dim_t c = 0; c < count; c++)
3441 scales[c] = c_scales[c];
3442 }
3443
3444 /// Sets quantization scaling factors for RNN projection weights tensors.
3445 // The low-precision configuration of the RNN primitives expect input
3446 // weights to use the signed 8-bit integer data type. The scaling factors
3447 // are used to quantize floating-point data to signed integer and must be
3448 /// passed to RNN primitives using attributes.
3449 ///
3450 /// @note
3451 /// The dimension order is always native and does not depend on the
3452 /// actual layout used. For example, five-dimensional weights always
3453 /// have (l, d, i, g, o) logical dimension ordering.
3454 ///
3455 /// @note
3456 /// Quantization scales are common for weights_layer and
3457 /// weights_iteration
3458 ///
3459 /// @param mask Scaling factors correspondence mask that defines the
3460 /// correspondence between the output tensor dimensions and the @p
3461 /// scales vector. The set i-th bit indicates that a dedicated scaling
3462 /// factor should be used each index along that dimension. Set the
3463 /// mask to 0 to use a common scaling factor for the whole output
3464 /// tensor.
3465 /// @param scales Constant vector of output scaling factors. The following
3466 /// equality must hold:
3467 /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
3468 /// Violations can only be detected when the attributes are used to
3469 /// create a primitive descriptor.
3470 void set_rnn_weights_projection_qparams(
3471 int mask, const std::vector<float> &scales) {
3472 error::wrap_c_api(
3473 dnnl_primitive_attr_set_rnn_weights_projection_qparams(
3474 get(), (int)scales.size(), mask, scales.data()),
3475 "could not set primitive RNN weights projection quantization "
3476 "parameters attributes");
3477 }
3478
3479 /// Returns the quantization scaling factors for RNN projection weights
3480 /// tensors.
3481 ///
3482 /// @note
3483 /// The dimension order is always native and does not depend on the
3484 /// actual layout used. For example, five-dimensional weights always
3485 /// have (l, d, i, g, o) logical dimension ordering.
3486 ///
3487 /// @param mask Scaling factors correspondence mask that defines the
3488 /// correspondence between the output tensor dimensions and the @p
3489 /// scales vector. The set i-th bit indicates that a dedicated scaling
3490 /// factor should be used each index along that dimension. Set the
3491 /// mask to 0 to use a common scaling factor for the whole output
3492 /// tensor.
3493 /// @param scales Constant vector of output scaling factors. The following
3494 /// equality must hold:
3495 /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
3496 /// Violations can only be detected when the attributes are used to
3497 /// create a primitive descriptor.
3498 void get_rnn_weights_projection_qparams(
3499 int &mask, std::vector<float> &scales) {
3500 dnnl_dim_t count;
3501 int c_mask;
3502 const float *c_scales;
3503 error::wrap_c_api(
3504 dnnl_primitive_attr_get_rnn_weights_projection_qparams(
3505 get(), &count, &c_mask, &c_scales),
3506 "could not get primitive RNN weights projection quantization "
3507 "parameters attributes");
3508 scales.resize(count);
3509
3510 mask = c_mask;
3511 for (dnnl_dim_t c = 0; c < count; c++)
3512 scales[c] = c_scales[c];
3513 }
3514};
3515
3516/// @} dnnl_api_attributes
3517
3518/// @addtogroup dnnl_api_primitives_common
3519/// @{
3520
3521/// Base class for all primitive descriptors.
3522struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
3523 using handle<dnnl_primitive_desc_t>::handle;
3524
3525 /// Default constructor. Produces an empty object.
3526 primitive_desc_base() = default;
3527
3528 /// Returns the engine of the primitive descriptor.
3529 /// @returns The engine of the primitive descriptor.
3530 engine get_engine() const { return query_engine(query::engine); }
3531
3532 /// Returns implementation name.
3533 /// @returns The implementation name.
3534 const char *impl_info_str() const {
3535 const char *res;
3536 error::wrap_c_api(dnnl_primitive_desc_query(
3537 get(), dnnl_query_impl_info_str, 0, &res),
3538 "could not retrieve implementation info string from a "
3539 "primitive descriptor");
3540 return res;
3541 }
3542
3543 /// Returns a memory::dim value (same as int64_t).
3544 /// @param what The value to query.
3545 /// @returns The result of the query.
3546 memory::dim query_s64(query what) const {
3547 memory::dim res;
3548 dnnl_status_t status = dnnl_primitive_desc_query(
3549 get(), dnnl::convert_to_c(what), 0, &res);
3550 return status == dnnl_success ? res : 0;
3551 }
3552
3553 /// Returns strides.
3554 /// @returns Strides.
3555 /// @returns An empty #dnnl::memory::dims if the primitive does not have
3556 /// a strides parameter.
3557 memory::dims get_strides() const { return query_dims(query::strides); }
3558
3559 /// Returns dilations.
3560 /// @returns Dilations.
3561 /// @returns An empty #dnnl::memory::dims if the primitive does not have
3562 /// a dilations parameter.
3563 memory::dims get_dilations() const { return query_dims(query::dilations); }
3564
3565 /// Returns a left padding.
3566 /// @returns A left padding.
3567 /// @returns An empty #dnnl::memory::dims if the primitive does not have
3568 /// a left padding parameter.
3569 memory::dims get_padding_l() const { return query_dims(query::padding_l); }
3570
3571 /// Returns a right padding.
3572 /// @returns A right padding.
3573 /// @returns An empty #dnnl::memory::dims if the primitive does not have
3574 /// a right padding parameter.
3575 memory::dims get_padding_r() const { return query_dims(query::padding_r); }
3576
3577 /// Returns an epsilon.
3578 /// @returns An epsilon.
3579 /// @returns Zero if the primitive does not have an epsilon parameter.
3580 float get_epsilon() const { return query_f32(query::epsilon_f32); }
3581
3582 /// Returns flags.
3583 /// @tparam T Flags enumeration type.
3584 /// @returns Flags.
3585 /// @returns Zero if the primitive does not have a flags parameter.
3586 template <typename T = unsigned>
3587 T get_flags() const {
3588 unsigned res;
3589 dnnl_status_t status
3590 = dnnl_primitive_desc_query(get(), dnnl_query_flags, 0, &res);
3591 return static_cast<T>(status == dnnl_success ? res : 0x0U);
3592 }
3593
3594 /// Returns an algorithm kind.
3595 /// @returns An algorithm kind.
3596 /// @returns #dnnl::algorithm::undef if the primitive does not have an
3597 /// algorithm parameter.
3598 dnnl::algorithm get_algorithm() const { return query_alg(query::alg_kind); }
3599
3600 /// Returns an alpha.
3601 /// @returns An alpha.
3602 /// @returns Zero if the primitive does not have an alpha parameter.
3603 float get_alpha() const { return query_f32(query::alpha_f32); }
3604
3605 /// Returns a beta.
3606 /// @returns A beta.
3607 /// @returns Zero if the primitive does not have a beta parameter.
3608 float get_beta() const { return query_f32(query::beta_f32); }
3609
3610 /// Returns an axis.
3611 /// @returns An axis.
3612 /// @returns A negative number if the primitive does not have an axis
3613 /// parameter.
3614 int get_axis() const {
3615 int res;
3616 dnnl_status_t status = dnnl_primitive_desc_query(
3617 get(), dnnl_query_axis_s32, 0, &res);
3618 return status == dnnl_success ? res : -1;
3619 }
3620
3621 /// Returns an LRN local size parameter.
3622 /// @returns An LRN local size parameter.
3623 /// @returns Zero if the primitive does not have an LRN local size
3624 /// parameter.
3625 memory::dim get_local_size() const {
3626 return query_s64(query::local_size_s64);
3627 }
3628
3629 /// Returns an LRN K parameter.
3630 /// @returns An LRN K parameter.
3631 /// @returns Zero if the primitive does not have an LRN K parameter.
3632 float get_k() const { return query_f32(query::k_f32); }
3633
3634 /// Returns a reduction P parameter.
3635 /// @returns A reduction P parameter.
3636 /// @returns Zero if the primitive does not have a reduction P parameter.
3637 float get_p() const { return query_f32(query::p_f32); }
3638
3639 /// Returns a resampling factors parameters.
3640 /// @returns A vector of factors.
3641 /// @returns An empty vector if the primitive does not have a resampling
3642 /// factors parameter.
3643 std::vector<float> get_factors() const {
3644 float *factors;
3645 dnnl_status_t status = dnnl_primitive_desc_query(
3646 get(), dnnl_query_factors, 0, &factors);
3647
3648 const bool is_backward = get_prop_kind() != prop_kind::forward_training
3649 && get_prop_kind() != prop_kind::forward_inference;
3650 const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
3651 is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
3652
3653 int ndims;
3654 error::wrap_c_api(
3655 dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
3656 "could not query ndims from a memory descriptor");
3657
3658 return status == dnnl_success
3659 ? std::vector<float>(factors, factors + (ndims - 2))
3660 : std::vector<float> {};
3661 }
3662
3663 /// Returns an RNN cell kind parameter.
3664 /// @returns An RNN cell kind parameter.
3665 /// @returns #dnnl::algorithm::undef if the primitive does not have an
3666 /// RNN cell kind parameter.
3667 dnnl::algorithm get_cell_kind() const {
3668 return query_alg(query::cell_kind);
3669 }
3670
3671 /// Returns an RNN direction parameter.
3672 /// @returns An RNN direction parameter.
3673 /// @returns #dnnl::rnn_direction::undef if the primitive does not have
3674 /// an RNN direction parameter.
3675 dnnl::rnn_direction get_direction() const {
3676 dnnl_rnn_direction_t direction;
3677 dnnl_status_t status = dnnl_primitive_desc_query(
3678 get(), dnnl_query_direction, 0, &direction);
3679 return status == dnnl_success
3680 ? static_cast<dnnl::rnn_direction>(direction)
3681 : dnnl::rnn_direction::undef;
3682 }
3683
3684 /// Returns an RNN activation kind parameter.
3685 /// @returns An RNN activation kind parameter.
3686 /// @returns #dnnl::algorithm::undef if the primitive does not have an
3687 /// RNN activation kind parameter.
3688 dnnl::algorithm get_activation_kind() const {
3689 return query_alg(query::activation_kind);
3690 }
3691
3692 /// Returns a pooling kernel parameter.
3693 /// @returns A pooling kernel parameter.
3694 /// @returns An empty #dnnl::memory::dims if the primitive does not have
3695 /// a pooling kernel parameter.
3696 memory::dims get_kernel() const { return query_dims(query::kernel); }
3697
3698 /// Returns a shuffle group size parameter.
3699 /// @returns A shuffle group size parameter.
3700 /// @returns Zero if the primitive does not have a shuffle group size
3701 /// parameter.
3702 memory::dim get_group_size() const {
3703 return query_s64(query::group_size_s64);
3704 }
3705
3706 /// Returns a propagation kind.
3707 /// @returns A propagation kind.
3708 /// @returns #dnnl::prop_kind::undef if the primitive does not have
3709 /// a propagation parameter.
3710 dnnl::prop_kind get_prop_kind() const {
3711 dnnl_prop_kind_t prop_kind;
3712 dnnl_status_t status = dnnl_primitive_desc_query(
3713 get(), dnnl_query_prop_kind, 0, &prop_kind);
3714 return status == dnnl_success ? static_cast<dnnl::prop_kind>(prop_kind)
3715 : dnnl::prop_kind::undef;
3716 }
3717
3718 /// Returns a memory descriptor.
3719 ///
3720 /// @note
3721 /// There are also convenience methods
3722 /// #dnnl::primitive_desc_base::src_desc(),
3723 /// #dnnl::primitive_desc_base::dst_desc(), and others.
3724 ///
3725 /// @param what The kind of parameter to query; can be
3726 /// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
3727 /// @param idx Index of the parameter. For example, convolution bias can
3728 /// be queried with what = #dnnl::query::weights_md and idx = 1.
3729 /// @returns The requested memory descriptor.
3730 /// @returns A zero memory descriptor if the primitive does not have a
3731 /// parameter of the specified kind or index.
3732 memory::desc query_md(query what, int idx = 0) const {
3733 std::vector<query> valid_q {query::src_md, query::diff_src_md,
3734 query::weights_md, query::diff_weights_md, query::dst_md,
3735 query::diff_dst_md, query::workspace_md, query::scratchpad_md,
3736 query::exec_arg_md};
3737 if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
3738 [=](query q) { return what == q; }))
3739 DNNL_THROW_ERROR(dnnl_invalid_arguments,
3740 "memory descriptor query is invalid");
3741
3742 const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md(
3743 get(), dnnl::convert_to_c(what), idx);
3744 if (!cdesc) return memory::desc();
3745
3746 dnnl_memory_desc_t cloned_md = nullptr;
3747 error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
3748 "could not clone a memory descriptor");
3749
3750 return memory::desc(cloned_md);
3751 }
3752
3753 /// Returns a source memory descriptor.
3754 /// @param idx Source index.
3755 /// @returns Source memory descriptor.
3756 /// @returns A zero memory descriptor if the primitive does not have a
3757 /// source parameter with index @p idx.
3758 memory::desc src_desc(int idx) const {
3759 return query_md(query::src_md, idx);
3760 }
3761
3762 /// Returns a destination memory descriptor.
3763 /// @param idx Destination index.
3764 /// @returns Destination memory descriptor.
3765 /// @returns A zero memory descriptor if the primitive does not have a
3766 /// destination parameter with index @p idx.
3767 memory::desc dst_desc(int idx) const {
3768 return query_md(query::dst_md, idx);
3769 }
3770
3771 /// Returns a weights memory descriptor.
3772 /// @param idx Weights index.
3773 /// @returns Weights memory descriptor.
3774 /// @returns A zero memory descriptor if the primitive does not have a
3775 /// weights parameter with index @p idx.
3776 memory::desc weights_desc(int idx) const {
3777 return query_md(query::weights_md, idx);
3778 }
3779
3780 /// Returns a diff source memory descriptor.
3781 /// @param idx Diff source index.
3782 /// @returns Diff source memory descriptor.
3783 /// @returns A zero memory descriptor if the primitive does not have a
3784 /// diff source parameter with index @p idx.
3785 memory::desc diff_src_desc(int idx) const {
3786 return query_md(query::diff_src_md, idx);
3787 }
3788
3789 /// Returns a diff destination memory descriptor.
3790 /// @param idx Diff destination index.
3791 /// @returns Diff destination memory descriptor.
3792 /// @returns A zero memory descriptor if the primitive does not have a
3793 /// diff destination parameter with index @p idx.
3794 memory::desc diff_dst_desc(int idx) const {
3795 return query_md(query::diff_dst_md, idx);
3796 }
3797
3798 /// Returns a diff weights memory descriptor.
3799 /// @param idx Diff weights index.
3800 /// @returns Diff weights memory descriptor.
3801 /// @returns A zero memory descriptor if the primitive does not have a
3802 /// diff weights parameter with index @p idx.
3803 memory::desc diff_weights_desc(int idx) const {
3804 return query_md(query::diff_weights_md, idx);
3805 }
3806
3807 // Separate versions without the index argument for documentation
3808 // purposes.
3809
3810 /// Returns a source memory descriptor.
3811 /// @returns Source memory descriptor.
3812 /// @returns A zero memory descriptor if the primitive does not have a
3813 /// source parameter.
3814 memory::desc src_desc() const { return src_desc(0); }
3815
3816 /// Returns a destination memory descriptor.
3817 /// @returns Destination memory descriptor.
3818 /// @returns A zero memory descriptor if the primitive does not have a
3819 /// destination parameter.
3820 memory::desc dst_desc() const { return dst_desc(0); }
3821
3822 /// Returns a weights memory descriptor.
3823 /// @returns Weights memory descriptor.
3824 /// @returns A zero memory descriptor if the primitive does not have a
3825 /// weights parameter.
3826 memory::desc weights_desc() const { return weights_desc(0); }
3827
3828 /// Returns a diff source memory descriptor.
3829 /// @returns Diff source memory descriptor.
3830 /// @returns A zero memory descriptor if the primitive does not have a
3831 /// diff source memory with.
3832 memory::desc diff_src_desc() const { return diff_src_desc(0); }
3833
3834 /// Returns a diff destination memory descriptor.
3835 /// @returns Diff destination memory descriptor.
3836 /// @returns A zero memory descriptor if the primitive does not have a
3837 /// diff destination parameter.
3838 memory::desc diff_dst_desc() const { return diff_dst_desc(0); }
3839
3840 /// Returns a diff weights memory descriptor.
3841 /// @returns Diff weights memory descriptor.
3842 /// @returns A zero memory descriptor if the primitive does not have a
3843 /// diff weights parameter.
3844 memory::desc diff_weights_desc() const { return diff_weights_desc(0); }
3845
3846 /// Returns the workspace memory descriptor.
3847 /// @returns Workspace memory descriptor.
3848 /// @returns A zero memory descriptor if the primitive does not require
3849 /// workspace parameter.
3850 memory::desc workspace_desc() const {
3851 return query_md(query::workspace_md, 0);
3852 }
3853
3854 /// Returns the scratchpad memory descriptor.
3855 /// @returns scratchpad memory descriptor.
3856 /// @returns A zero memory descriptor if the primitive does not require
3857 /// scratchpad parameter.
3858 /// @sa @ref dev_guide_attributes_scratchpad
3859 memory::desc scratchpad_desc() const {
3860 return query_md(query::scratchpad_md, 0);
3861 }
3862
3863 /// Returns the engine on which the scratchpad memory is located.
3864 /// @returns The engine on which the scratchpad memory is located.
3865 engine scratchpad_engine() const {
3866 dnnl_engine_t c_engine;
3867 error::wrap_c_api(dnnl_primitive_desc_query(get(),
3868 dnnl::convert_to_c(query::scratchpad_engine),
3869 0, &c_engine),
3870 "could not retrieve scratchpad engine from a primitive "
3871 "descriptor");
3872 return engine(c_engine, true);
3873 }
3874
3875 /// Returns the primitive attributes.
3876 /// @returns The primitive attributes.
3877 primitive_attr get_primitive_attr() const {
3878 const_dnnl_primitive_attr_t const_c_attr;
3879 error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr),
3880 "could not get attributes from a primitive descriptor");
3881 dnnl_primitive_attr_t c_attr;
3882 error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
3883 "could not clone primitive attributes");
3884 return primitive_attr(c_attr);
3885 }
3886
3887 /// Returns the kind of the primitive descriptor.
3888 /// @returns The kind of the primitive descriptor.
3889 dnnl::primitive::kind get_kind() const {
3890 dnnl_primitive_kind_t kind;
3891 error::wrap_c_api(dnnl_primitive_desc_query(get(),
3892 dnnl_query_primitive_kind, 0, (void *)&kind),
3893 "could not get primitive kind from a primitive descriptor");
3894 return static_cast<dnnl::primitive::kind>(kind);
3895 }
3896
3897 /// Returns the cache blob ID of the primitive descriptor.
3898 /// @returns The cache blob ID of the primitive descriptor.
3899 std::vector<uint8_t> get_cache_blob_id() const {
3900 dnnl_dim_t count;
3901 const uint8_t *c_id;
3902 error::wrap_c_api(
3903 dnnl_primitive_desc_query(get(),
3904 dnnl::convert_to_c(query::cache_blob_id_size_s64), 0,
3905 (void *)&count),
3906 "could not get size of cache blob ID from a primitive "
3907 "descriptor");
3908 error::wrap_c_api(dnnl_primitive_desc_query(get(),
3909 dnnl::convert_to_c(query::cache_blob_id), 0,
3910 (void **)&c_id),
3911 "could not get cache blob ID from a primitive descriptor");
3912 std::vector<uint8_t> id(c_id, c_id + count);
3913 return id;
3914 }
3915
3916protected:
3917 /// Returns a float value.
3918 /// @param what The value to query.
3919 /// @returns The result of the query.
3920 /// @returns Zero if the primitive doesn't support the query.
3921 float query_f32(query what) const {
3922 float res;
3923 dnnl_status_t status = dnnl_primitive_desc_query(
3924 get(), dnnl::convert_to_c(what), 0, &res);
3925 return status == dnnl_success ? res : 0.0f;
3926 }
3927
3928 /// Returns an #dnnl::algorithm value.
3929 /// @param what The value to query.
3930 /// @returns The result of the query.
3931 /// @returns #dnnl::algorithm::undef if the primitive doesn't support
3932 /// the query.
3933 algorithm query_alg(query what) const {
3934 dnnl_alg_kind_t res;
3935 dnnl_status_t status = dnnl_primitive_desc_query(
3936 get(), dnnl::convert_to_c(what), 0, &res);
3937 return status == dnnl_success ? static_cast<dnnl::algorithm>(res)
3938 : algorithm::undef;
3939 }
3940
3941 /// Returns a memory::dims value.
3942 /// @param what The value to query.
3943 /// @returns The result of the query.
3944 /// @returns An empty #dnnl::memory::dims if the primitive doesn't support
3945 /// the query.
3946 memory::dims query_dims(query what) const {
3947 const bool is_backward = get_prop_kind() != prop_kind::forward_training
3948 && get_prop_kind() != prop_kind::forward_inference;
3949 const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
3950 is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
3951
3952 int nspatial_dims = 0;
3953 if (md) {
3954 int ndims;
3955 error::wrap_c_api(
3956 dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
3957 "could not query ndims from a memory descriptor");
3958 nspatial_dims = ndims - 2;
3959 }
3960
3961 dnnl_dims_t *c_dims;
3962 dnnl_status_t status = dnnl_primitive_desc_query(
3963 get(), dnnl::convert_to_c(what), 0, &c_dims);
3964 return status == dnnl_success
3965 ? memory::dims(*c_dims, *c_dims + nspatial_dims)
3966 : memory::dims {};
3967 }
3968
3969 /// Returns an #dnnl::engine value.
3970 /// @param what The value to query.
3971 /// @returns The result of the query.
3972 /// @returns A weak handle to the engine that the primitive descriptor was
3973 /// created with.
3974 engine query_engine(query what) const {
3975 dnnl_engine_t c_engine;
3976 error::wrap_c_api(dnnl_primitive_desc_query(get(),
3977 dnnl::convert_to_c(what), 0, &c_engine),
3978 "could not get an engine from a primitive_desc");
3979 return engine(c_engine, true);
3980 }
3981
3982 /// Resets the value of the handle to a clone of a C API primitive
3983 /// descriptor.
3984 /// @param pd A C API primitive descriptor to clone.
3985 void reset_with_clone(const_dnnl_primitive_desc_t pd) {
3986 dnnl_primitive_desc_t new_pd;
3987 error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd),
3988 "could not clone a primitive descriptor");
3989 reset(new_pd);
3990 }
3991
3992 /// Constructs a primitive descriptor base object from a clone of a C API
3993 /// primitive descriptor after verifying that it is what the caller
3994 /// expects.
3995 ///
3996 /// @note
3997 /// The @p prim_kind should map to a primitive that does not have
3998 /// different values of propagation kind (e.g. #dnnl::binary).
3999 /// @note
4000 /// Primitive descriptor base constructed this way does not support
4001 /// next_impl() (will throw).
4002 ///
4003 /// @param pd C API primitive descriptor to clone.
4004 /// @param prim_kind Expected primitive kind.
4005 primitive_desc_base(
4006 dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
4007 : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
4008
4009 /// Constructs a primitive descriptor base object from a clone of a C API
4010 /// primitive descriptor after verifying that it is what the caller
4011 /// expects.
4012 ///
4013 /// @note
4014 /// Primitive descriptor base constructed this way does not support
4015 /// next_impl() (will throw).
4016 ///
4017 /// @param pd C API primitive descriptor to clone.
4018 /// @param prim_kind Expected primitive kind.
4019 /// @param aprop_kind Expected propagation kind.
4020 primitive_desc_base(dnnl_primitive_desc_t pd,
4021 dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
4022 : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {}
4023
4024 /// Constructs a primitive descriptor base object from a clone of a C API
4025 /// primitive descriptor after verifying that it is what the caller
4026 /// expects.
4027 ///
4028 /// @note
4029 /// Primitive descriptor base constructed this way does not support
4030 /// next_impl() (will throw).
4031 ///
4032 /// @param pd C API primitive descriptor to clone.
4033 /// @param prim_kind Expected primitive kind.
4034 /// @param prop_kind1 Expected propagation kind (option 1).
4035 /// @param prop_kind2 Expected propagation kind (option 2). This value is
4036 /// checked if the check with @p prop_kind1 fails.
4037 primitive_desc_base(dnnl_primitive_desc_t pd,
4038 dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
4039 dnnl::prop_kind prop_kind2) {
4040 // It is OK to pass an empty primitive descriptor
4041 if (pd == nullptr) return;
4042
4043 dnnl_status_t rc;
4044
4045 dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
4046 dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
4047 dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
4048
4049 // Check that primitive kind matches
4050 dnnl_primitive_kind_t pd_kind;
4051 rc = dnnl_primitive_desc_query(
4052 pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
4053 error::wrap_c_api(
4054 rc, "could not get primitive kind from a primitive descriptor");
4055 if (pd_kind != c_prim_kind)
4056 DNNL_THROW_ERROR(dnnl_invalid_arguments,
4057 "primitive descriptor operation kind mismatch");
4058
4059 // Check that propagation kind matches
4060 dnnl_prop_kind_t pd_prop_kind;
4061 rc = dnnl_primitive_desc_query(
4062 pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
4063
4064 // Something went wrong
4065 if (rc != dnnl_success && rc != dnnl_unimplemented)
4066 DNNL_THROW_ERROR(dnnl_invalid_arguments,
4067 "could not get propagation kind from the primitive "
4068 "descriptor");
4069
4070 // Everything is fine
4071 if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
4072 || (rc == dnnl_success
4073 && (pd_prop_kind == c_prop_kind1
4074 || pd_prop_kind == c_prop_kind2))) {
4075 reset_with_clone(pd);
4076 return;
4077 }
4078
4079 // We could get the propagation kind but there is a mismatch
4080 DNNL_THROW_ERROR(dnnl_invalid_arguments,
4081 "primitive descriptor propagation kind mismatch");
4082 }
4083
4084 /// Returns a constant reference to a static instance of default constructed
4085 /// primitive attributes
4086 static const primitive_attr &default_attr() {
4087 static const primitive_attr attr;
4088 return attr;
4089 }
4090
4091 const_dnnl_memory_desc_t optional_arg(const memory::desc *md) {
4092 return md ? md->get() : nullptr;
4093 }
4094
4095 const dnnl_dim_t *optional_arg(const memory::dims *dims) {
4096 return dims ? dims->data() : nullptr;
4097 }
4098
4099 const float *optional_arg(const std::vector<float> *arg) {
4100 return arg ? arg->data() : nullptr;
4101 }
4102
4103 using base = primitive_desc_base;
4104};
4105
4106/// @} dnnl_api_primitives_common
4107
4108/// @addtogroup dnnl_api_reorder Reorder
4109///
4110/// A primitive to copy data between two memory objects. This primitive is
4111/// typically used to change the way the data is laid out in memory.
4112///
4113/// @sa @ref dev_guide_reorder in developer guide
4114///
4115/// @{
4116
4117/// Reorder primitive.
4118struct reorder : public primitive {
4119 /// Primitive descriptor for a reorder primitive.
4120 struct primitive_desc : public primitive_desc_base {
4121 using primitive_desc_base::primitive_desc_base;
4122
4123 /// Default constructor. Produces an empty object.
4124 primitive_desc() = default;
4125
4126 /// Constructs a primitive descriptor for reorder primitive.
4127 ///
4128 /// @note
4129 /// If @p allow_empty is true, the constructor does not throw if a
4130 /// primitive descriptor cannot be created.
4131 ///
4132 /// @param src_engine Engine on which the source memory object will be
4133 /// located.
4134 /// @param src_md Source memory descriptor.
4135 /// @param dst_engine Engine on which the destination memory object
4136 /// will be located.
4137 /// @param dst_md Destination memory descriptor.
4138 /// @param attr Primitive attributes to use. Attributes are optional
4139 /// and default to empty attributes.
4140 /// @param allow_empty A flag signifying whether construction is allowed
4141 /// to fail without throwing an exception. In this case an empty
4142 /// object will be produced. This flag is optional and defaults to
4143 /// false.
4144 primitive_desc(const engine &src_engine, const memory::desc &src_md,
4145 const engine &dst_engine, const memory::desc &dst_md,
4146 const primitive_attr &attr = default_attr(),
4147 bool allow_empty = false) {
4148 dnnl_primitive_desc_t result;
4149 dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
4150 src_md.get(), src_engine.get(), dst_md.get(),
4151 dst_engine.get(), attr.get());
4152 if (!allow_empty)
4153 error::wrap_c_api(status,
4154 "could not create a primitive descriptor for a reorder "
4155 "primitive");
4156 reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
4157 }
4158
4159 /// Constructs a primitive descriptor for reorder primitive.
4160 ///
4161 /// @param src Source memory object. It is used to obtain the source
4162 /// memory descriptor and engine.
4163 /// @param dst Destination memory object. It is used to obtain the
4164 /// destination memory descriptor and engine.
4165 /// @param attr Primitive attributes to use. Attributes are optional
4166 /// and default to empty attributes.
4167 /// @param allow_empty A flag signifying whether construction is allowed
4168 /// to fail without throwing an exception. In this case an empty
4169 /// object will be produced. This flag is optional and defaults to
4170 /// false.
4171 primitive_desc(const memory &src, const memory &dst,
4172 const primitive_attr &attr = default_attr(),
4173 bool allow_empty = false) {
4174 dnnl_primitive_desc_t result;
4175 auto src_md = src.get_desc();
4176 auto dst_md = dst.get_desc();
4177 dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
4178 src_md.get(), src.get_engine().get(), dst_md.get(),
4179 dst.get_engine().get(), attr.get());
4180 if (!allow_empty)
4181 error::wrap_c_api(status,
4182 "could not create a primitive descriptor for a reorder "
4183 "primitive");
4184 reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
4185 }
4186
4187 /// Constructs a primitive descriptor for reorder primitive from a C
4188 /// API primitive descriptor which must have a matching kind.
4189 ///
4190 /// @param pd C API primitive descriptor for reorder primitive.
4191 primitive_desc(dnnl_primitive_desc_t pd)
4192 : primitive_desc_base(pd, dnnl::primitive::kind::reorder) {}
4193
4194 /// Returns the engine on which the source memory is allocated.
4195 /// @returns The engine on which the source memory is allocated.
4196 engine get_src_engine() const {
4197 return query_engine(dnnl::query::reorder_src_engine);
4198 }
4199
4200 /// Returns the engine on which the destination memory is allocated.
4201 /// @returns The engine on which the destination memory is allocated.
4202 engine get_dst_engine() const {
4203 return query_engine(dnnl::query::reorder_dst_engine);
4204 }
4205
4206 /// @copydoc dnnl::primitive_desc_base::src_desc()const
4207 memory::desc src_desc() const { return base::src_desc(0); }
4208
4209 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
4210 memory::desc dst_desc() const { return base::dst_desc(0); }
4211 };
4212
4213 /// Default constructor. Produces an empty object.
4214 reorder() = default;
4215
4216 /// Constructs a reorder primitive.
4217 /// @param pd Primitive descriptor for reorder primitive.
4218 reorder(const primitive_desc &pd) : primitive(pd.get()) {}
4219
4220 /// Constructs a reorder primitive from a cache blob.
4221 /// @param pd Primitive descriptor for reorder primitive.
4222 /// @param cache_blob Cache blob.
4223 reorder(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
4224 : primitive(pd.get(), cache_blob) {}
4225
4226 /// Constructs a reorder primitive that would reorder data between memory
4227 /// objects having the same memory descriptors as memory objects @p src and
4228 /// @p dst.
4229 ///
4230 /// @param src Source memory object.
4231 /// @param dst Destination memory object.
4232 /// @param attr Primitive attributes to use (optional).
4233 reorder(const memory &src, const memory &dst,
4234 const primitive_attr &attr = primitive_attr())
4235 : primitive(primitive_desc(src, dst, attr).get()) {}
4236
4237 using primitive::execute;
4238
4239 /// Executes the reorder primitive.
4240 ///
4241 /// @param astream Stream object. The stream must belong to the same engine
4242 /// as the primitive.
4243 /// @param src Source memory object.
4244 /// @param dst Destination memory object.
4245 void execute(const stream &astream, memory &src, memory &dst) const {
4246 primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
4247 }
4248};
4249
4250/// @} dnnl_api_reorder
4251
4252/// @addtogroup dnnl_api_concat Concat
4253///
4254/// A primitive to concatenate data by arbitrary dimension.
4255///
4256/// @sa @ref dev_guide_concat in developer guide
4257///
4258/// @{
4259
4260/// @cond DO_NOT_DOCUMENT_THIS
4261inline std::vector<const_dnnl_memory_desc_t> convert_to_c(
4262 const std::vector<memory::desc> &mds) {
4263 std::vector<const_dnnl_memory_desc_t> c_mds;
4264 c_mds.reserve(mds.size());
4265 for (const auto &md : mds)
4266 c_mds.push_back(md.get());
4267 return c_mds;
4268}
4269/// @endcond
4270
4271/// Tensor concatenation (concat) primitive.
4272struct concat : public primitive {
4273 /// Primitive descriptor for a concat primitive.
4274 struct primitive_desc : public primitive_desc_base {
4275 using primitive_desc_base::primitive_desc_base;
4276
4277 /// Default constructor. Produces an empty object.
4278 primitive_desc() = default;
4279
4280 /// Constructs a primitive descriptor for an out-of-place concatenation
4281 /// primitive.
4282 ///
4283 /// @param aengine Engine to perform the operation on.
4284 /// @param dst Destination memory descriptor.
4285 /// @param concat_dimension Source tensors will be concatenated over
4286 /// dimension with this index. Note that order of dimensions does
4287 /// not depend on memory format.
4288 /// @param srcs Vector of source memory descriptors.
4289 /// @param attr Primitive attributes to use. Attributes are optional
4290 /// and default to empty attributes.
4291 primitive_desc(const engine &aengine, const memory::desc &dst,
4292 int concat_dimension, const std::vector<memory::desc> &srcs,
4293 const primitive_attr &attr = default_attr()) {
4294 auto c_srcs = convert_to_c(srcs);
4295
4296 dnnl_primitive_desc_t result;
4297 error::wrap_c_api(
4298 dnnl_concat_primitive_desc_create(&result, aengine.get(),
4299 dst.get(), (int)c_srcs.size(), concat_dimension,
4300 c_srcs.data(), attr.get()),
4301 "could not create a primitive descriptor for a concat "
4302 "primitive");
4303 reset(result);
4304 }
4305
4306 /// Constructs a primitive descriptor for an out-of-place concatenation
4307 /// primitive.
4308 ///
4309 /// This version derives the destination memory descriptor
4310 /// automatically.
4311 ///
4312 /// @param aengine Engine to perform the operation on.
4313 /// @param concat_dimension Source tensors will be concatenated over
4314 /// dimension with this index. Note that order of dimensions does
4315 /// not depend on memory format.
4316 /// @param srcs Vector of source memory descriptors.
4317 /// @param attr Primitive attributes to use. Attributes are optional
4318 /// and default to empty attributes.
4319 primitive_desc(const engine &aengine, int concat_dimension,
4320 const std::vector<memory::desc> &srcs,
4321 const primitive_attr &attr = default_attr()) {
4322 auto c_api_srcs = convert_to_c(srcs);
4323
4324 dnnl_primitive_desc_t result;
4325 error::wrap_c_api(
4326 dnnl_concat_primitive_desc_create(&result, aengine.get(),
4327 nullptr, (int)c_api_srcs.size(), concat_dimension,
4328 c_api_srcs.data(), attr.get()),
4329 "could not create a primitive descriptor for a concat "
4330 "primitive");
4331 reset(result);
4332 }
4333
4334 /// Constructs a primitive descriptor for concat primitive from a C
4335 /// API primitive descriptor which must have a matching kind.
4336 ///
4337 /// @param pd C API primitive descriptor for concat primitive.
4338 primitive_desc(dnnl_primitive_desc_t pd)
4339 : primitive_desc_base(pd, dnnl::primitive::kind::concat) {}
4340
4341 /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
4342 memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
4343
4344 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
4345 memory::desc dst_desc() const { return base::dst_desc(0); }
4346 };
4347
4348 /// Default constructor. Produces an empty object.
4349 concat() = default;
4350
4351 /// Constructs a concatenation primitive.
4352 /// @param pd Primitive descriptor for concatenation primitive.
4353 concat(const primitive_desc &pd) : primitive(pd.get()) {}
4354
4355 /// Constructs a concatenation primitive from a cache blob.
4356 /// @param pd Primitive descriptor for concatenation primitive.
4357 /// @param cache_blob Cache blob.
4358 concat(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
4359 : primitive(pd.get(), cache_blob) {}
4360};
4361
4362/// @} dnnl_api_concat
4363
4364/// @addtogroup dnnl_api_sum Sum
4365///
4366/// A primitive to sum multiple tensors.
4367///
4368/// @sa @ref dev_guide_sum in developer guide
4369///
4370/// @{
4371
4372/// Out-of-place summation (sum) primitive.
4373struct sum : public primitive {
4374 /// Primitive descriptor for a sum primitive.
4375 struct primitive_desc : public primitive_desc_base {
4376 using primitive_desc_base::primitive_desc_base;
4377
4378 /// Default constructor. Produces an empty object.
4379 primitive_desc() = default;
4380
4381 /// Constructs a primitive descriptor for a sum primitive.
4382 ///
4383 /// @param aengine Engine to perform the operation on.
4384 /// @param dst Destination memory descriptor.
4385 /// @param scales Vector of scales to multiply data in each source
4386 /// memory by.
4387 /// @param srcs Vector of source memory descriptors.
4388 /// @param attr Primitive attributes to use. Attributes are optional
4389 /// and default to empty attributes.
4390 primitive_desc(const engine &aengine, const memory::desc &dst,
4391 const std::vector<float> &scales,
4392 const std::vector<memory::desc> &srcs,
4393 const primitive_attr &attr = default_attr()) {
4394 validate_container_size(scales,
4395 "counts of scales and sources are not equal",
4396 (int)srcs.size(), (int)srcs.size());
4397
4398 auto c_api_srcs = convert_to_c(srcs);
4399
4400 dnnl_primitive_desc_t result;
4401 error::wrap_c_api(
4402 dnnl_sum_primitive_desc_create(&result, aengine.get(),
4403 dst.get(), (int)c_api_srcs.size(), scales.data(),
4404 c_api_srcs.data(), attr.get()),
4405 "could not create a primitive descriptor for a sum "
4406 "primitive");
4407 reset(result);
4408 }
4409
4410 /// Constructs a primitive descriptor for a sum primitive.
4411 ///
4412 /// This version derives the destination memory descriptor
4413 /// automatically.
4414 ///
4415 /// @param aengine Engine on which to perform the operation.
4416 /// @param scales Vector of scales by which to multiply data in each
4417 /// source memory object.
4418 /// @param srcs Vector of source memory descriptors.
4419 /// @param attr Primitive attributes to use. Attributes are optional
4420 /// and default to empty attributes.
4421 primitive_desc(const engine &aengine, const std::vector<float> &scales,
4422 const std::vector<memory::desc> &srcs,
4423 const primitive_attr &attr = default_attr()) {
4424 validate_container_size(scales,
4425 "counts of scales and sources are not equal",
4426 (int)srcs.size(), (int)srcs.size());
4427
4428 auto c_api_srcs = convert_to_c(srcs);
4429 dnnl_primitive_desc_t result;
4430 error::wrap_c_api(
4431 dnnl_sum_primitive_desc_create(&result, aengine.get(),
4432 nullptr, (int)c_api_srcs.size(), scales.data(),
4433 c_api_srcs.data(), attr.get()),
4434 "could not create a primitive descriptor for a sum "
4435 "primitive");
4436 reset(result);
4437 }
4438
4439 /// Constructs a primitive descriptor for sum primitive from a C API
4440 /// primitive descriptor which must have a matching kind.
4441 ///
4442 /// @param pd C API primitive descriptor for reorder primitive.
4443 primitive_desc(dnnl_primitive_desc_t pd)
4444 : primitive_desc_base(pd, dnnl::primitive::kind::sum) {}
4445
4446 /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
4447 memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
4448
4449 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
4450 memory::desc dst_desc() const { return base::dst_desc(0); }
4451 };
4452
4453 /// Default constructor. Produces an empty object.
4454 sum() = default;
4455
4456 /// Constructs a sum primitive.
4457 /// @param pd Primitive descriptor for sum primitive.
4458 sum(const primitive_desc &pd) : primitive(pd.get()) {}
4459
4460 /// Constructs a sum primitive from a cache blob.
4461 /// @param pd Primitive descriptor for sum primitive.
4462 /// @param cache_blob Cache blob.
4463 sum(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
4464 : primitive(pd.get(), cache_blob) {}
4465};
4466
4467/// @} dnnl_api_sum
4468
4469/// @addtogroup dnnl_api_primitives_common
4470/// @{
4471
4472/// A base class for descriptors of all primitives that support iteration
4473/// over multiple implementations.
4474struct primitive_desc : public primitive_desc_base {
4475 using primitive_desc_base::primitive_desc_base;
4476
4477 primitive_desc() = default;
4478
4479 /// Changes the primitive descriptor to point to the next available
4480 /// implementation.
4481 ///
4482 /// @returns @c true on success and @c false if the last available
4483 /// implementation has already been reached. In the latter case, the
4484 /// primitive descriptor itself is kept unchanged.
4485 bool next_impl() {
4486 dnnl_status_t status = dnnl_primitive_desc_next_impl(get());
4487 if (status == dnnl_last_impl_reached) return false;
4488 error::wrap_c_api(status, "last available implementation is reached");
4489 return true;
4490 }
4491};
4492
4493/// @} dnnl_api_primitives_common
4494
4495/// @addtogroup dnnl_api_convolution Convolution
4496///
4497/// A primitive to perform 1D, 2D or 3D convolution. Supported variants are
4498/// forward propagation, backward propagation, and weights gradient with or
4499/// without bias.
4500///
4501/// @sa @ref dev_guide_convolution in developer guide
4502///
4503/// @{
4504
4505/// Convolution forward propagation primitive.
4506struct convolution_forward : public primitive {
4507 /// Primitive descriptor for a convolution forward propagation primitive.
4508 struct primitive_desc : public dnnl::primitive_desc {
4509 /// Default constructor. Produces an empty object.
4510 primitive_desc() = default;
4511
4512 /// Constructs a primitive descriptor for a convolution forward
4513 /// propagation primitive with bias.
4514 ///
4515 /// @note
4516 /// All the memory descriptors may be initialized with the
4517 /// #dnnl::memory::format_tag::any value of @p format_tag.
4518 ///
4519 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
4520 /// for spatial dimensions only and hence must have the same number of
4521 /// elements as there are spatial dimensions. The order of values is
4522 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
4523 /// and 2D tensors), and width.
4524 ///
4525 /// @param aengine Engine to use.
4526 /// @param aprop_kind Propagation kind. Possible values are
4527 /// #dnnl::prop_kind::forward_training, and
4528 /// #dnnl::prop_kind::forward_inference.
4529 /// @param aalgorithm Convolution algorithm. Possible values are
4530 /// #dnnl::algorithm::convolution_direct,
4531 /// #dnnl::algorithm::convolution_winograd, and
4532 /// #dnnl::algorithm::convolution_auto.
4533 /// @param src_desc Source memory descriptor.
4534 /// @param weights_desc Weights memory descriptor.
4535 /// @param bias_desc Bias memory descriptor. Passing zero memory
4536 /// descriptor disables the bias term.
4537 /// @param dst_desc Destination memory descriptor.
4538 /// @param strides Strides for each spatial dimension.
4539 /// @param padding_l Vector of padding values for low indices for each
4540 /// spatial dimension `([[front,] top,] left)`.
4541 /// @param padding_r Vector of padding values for high indices for
4542 /// each spatial dimension `([[back,] bottom,] right)`.
4543 /// @param attr Primitive attributes to use. Attributes are optional
4544 /// and default to empty attributes.
4545 /// @param allow_empty A flag signifying whether construction is
4546 /// allowed to fail without throwing an exception. In this case an
4547 /// empty object will be produced. This flag is optional and
4548 /// defaults to false.
4549 primitive_desc(const engine &aengine, prop_kind aprop_kind,
4550 algorithm aalgorithm, const memory::desc &src_desc,
4551 const memory::desc &weights_desc, const memory::desc &bias_desc,
4552 const memory::desc &dst_desc, const memory::dims &strides,
4553 const memory::dims &padding_l, const memory::dims &padding_r,
4554 const primitive_attr &attr = default_attr(),
4555 bool allow_empty = false)
4556 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
4557 weights_desc, &bias_desc, dst_desc, strides, nullptr,
4558 padding_l, padding_r, attr, allow_empty) {}
4559
4560 /// Constructs a primitive descriptor for a convolution forward
4561 /// propagation primitive without bias.
4562 ///
4563 /// @note
4564 /// All the memory descriptors may be initialized with the
4565 /// #dnnl::memory::format_tag::any value of @p format_tag.
4566 ///
4567 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
4568 /// for spatial dimensions only and hence must have the same number of
4569 /// elements as there are spatial dimensions. The order of values is
4570 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
4571 /// and 2D tensors), and width.
4572 ///
4573 /// @param aengine Engine to use.
4574 /// @param aprop_kind Propagation kind. Possible values are
4575 /// #dnnl::prop_kind::forward_training, and
4576 /// #dnnl::prop_kind::forward_inference.
4577 /// @param aalgorithm Convolution algorithm. Possible values are
4578 /// #dnnl::algorithm::convolution_direct,
4579 /// #dnnl::algorithm::convolution_winograd, and
4580 /// #dnnl::algorithm::convolution_auto.
4581 /// @param src_desc Source memory descriptor.
4582 /// @param weights_desc Weights memory descriptor.
4583 /// @param dst_desc Destination memory descriptor.
4584 /// @param strides Strides for each spatial dimension.
4585 /// @param padding_l Vector of padding values for low indices for each
4586 /// spatial dimension `([[front,] top,] left)`.
4587 /// @param padding_r Vector of padding values for high indices for
4588 /// each spatial dimension `([[back,] bottom,] right)`.
4589 /// @param attr Primitive attributes to use. Attributes are optional
4590 /// and default to empty attributes.
4591 /// @param allow_empty A flag signifying whether construction is
4592 /// allowed to fail without throwing an exception. In this case an
4593 /// empty object will be produced. This flag is optional and
4594 /// defaults to false.
4595 primitive_desc(const engine &aengine, prop_kind aprop_kind,
4596 algorithm aalgorithm, const memory::desc &src_desc,
4597 const memory::desc &weights_desc, const memory::desc &dst_desc,
4598 const memory::dims &strides, const memory::dims &padding_l,
4599 const memory::dims &padding_r,
4600 const primitive_attr &attr = default_attr(),
4601 bool allow_empty = false)
4602 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
4603 weights_desc, nullptr, dst_desc, strides, nullptr,
4604 padding_l, padding_r, attr, allow_empty) {}
4605
4606 /// Constructs a primitive descriptor for a convolution forward
4607 /// propagation primitive with bias.
4608 ///
4609 /// @note
4610 /// All the memory descriptors may be initialized with the
4611 /// #dnnl::memory::format_tag::any value of @p format_tag.
4612 ///
4613 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
4614 /// contain values for spatial dimensions only and hence must have the
4615 /// same number of elements as there are spatial dimensions. The order
4616 /// of values is the same as in the tensor: depth (for 3D tensors),
4617 /// height (for 3D and 2D tensors), and width.
4618 ///
4619 /// @param aengine Engine to use.
4620 /// @param aprop_kind Propagation kind. Possible values are
4621 /// #dnnl::prop_kind::forward_training, and
4622 /// #dnnl::prop_kind::forward_inference.
4623 /// @param aalgorithm Convolution algorithm. Possible values are
4624 /// #dnnl::algorithm::convolution_direct,
4625 /// #dnnl::algorithm::convolution_winograd, and
4626 /// #dnnl::algorithm::convolution_auto.
4627 /// @param src_desc Source memory descriptor.
4628 /// @param weights_desc Weights memory descriptor.
4629 /// @param bias_desc Bias memory descriptor. Passing zero memory
4630 /// descriptor disables the bias term.
4631 /// @param dst_desc Destination memory descriptor.
4632 /// @param strides Strides for each spatial dimension.
4633 /// @param dilates Dilations for each spatial dimension. A zero value
4634 /// means no dilation in the corresponding dimension.
4635 /// @param padding_l Vector of padding values for low indices for each
4636 /// spatial dimension `([[front,] top,] left)`.
4637 /// @param padding_r Vector of padding values for high indices for
4638 /// each spatial dimension `([[back,] bottom,] right)`.
4639 /// @param attr Primitive attributes to use. Attributes are optional
4640 /// and default to empty attributes.
4641 /// @param allow_empty A flag signifying whether construction is
4642 /// allowed to fail without throwing an exception. In this case an
4643 /// empty object will be produced. This flag is optional and
4644 /// defaults to false.
4645 primitive_desc(const engine &aengine, prop_kind aprop_kind,
4646 algorithm aalgorithm, const memory::desc &src_desc,
4647 const memory::desc &weights_desc, const memory::desc &bias_desc,
4648 const memory::desc &dst_desc, const memory::dims &strides,
4649 const memory::dims &dilates, const memory::dims &padding_l,
4650 const memory::dims &padding_r,
4651 const primitive_attr &attr = default_attr(),
4652 bool allow_empty = false)
4653 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
4654 weights_desc, &bias_desc, dst_desc, strides, &dilates,
4655 padding_l, padding_r, attr, allow_empty) {}
4656
4657 /// Constructs a primitive descriptor for a convolution forward
4658 /// propagation primitive without bias.
4659 ///
4660 /// @note
4661 /// All the memory descriptors may be initialized with the
4662 /// #dnnl::memory::format_tag::any value of @p format_tag.
4663 ///
4664 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
4665 /// contain values for spatial dimensions only and hence must have the
4666 /// same number of elements as there are spatial dimensions. The order
4667 /// of values is the same as in the tensor: depth (for 3D tensors),
4668 /// height (for 3D and 2D tensors), and width.
4669 ///
4670 /// @param aengine Engine to use.
4671 /// @param aprop_kind Propagation kind. Possible values are
4672 /// #dnnl::prop_kind::forward_training, and
4673 /// #dnnl::prop_kind::forward_inference.
4674 /// @param aalgorithm Convolution algorithm. Possible values are
4675 /// #dnnl::algorithm::convolution_direct,
4676 /// #dnnl::algorithm::convolution_winograd, and
4677 /// #dnnl::algorithm::convolution_auto.
4678 /// @param src_desc Source memory descriptor.
4679 /// @param weights_desc Weights memory descriptor.
4680 /// @param dst_desc Destination memory descriptor.
4681 /// @param strides Strides for each spatial dimension.
4682 /// @param dilates Dilations for each spatial dimension. A zero value
4683 /// means no dilation in the corresponding dimension.
4684 /// @param padding_l Vector of padding values for low indices for each
4685 /// spatial dimension `([[front,] top,] left)`.
4686 /// @param padding_r Vector of padding values for high indices for
4687 /// each spatial dimension `([[back,] bottom,] right)`.
4688 /// @param attr Primitive attributes to use. Attributes are optional
4689 /// and default to empty attributes.
4690 /// @param allow_empty A flag signifying whether construction is
4691 /// allowed to fail without throwing an exception. In this case an
4692 /// empty object will be produced. This flag is optional and
4693 /// defaults to false.
4694 primitive_desc(const engine &aengine, prop_kind aprop_kind,
4695 algorithm aalgorithm, const memory::desc &src_desc,
4696 const memory::desc &weights_desc, const memory::desc &dst_desc,
4697 const memory::dims &strides, const memory::dims &dilates,
4698 const memory::dims &padding_l, const memory::dims &padding_r,
4699 const primitive_attr &attr = default_attr(),
4700 bool allow_empty = false)
4701 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
4702 weights_desc, nullptr, dst_desc, strides, &dilates,
4703 padding_l, padding_r, attr, allow_empty) {}
4704
4705 /// Constructs a primitive descriptor for a convolution forward
4706 /// propagation primitive from a C API primitive descriptor that must
4707 /// have a matching kind.
4708 ///
4709 /// @param pd C API primitive descriptor for a convolution forward
4710 /// propagation primitive.
4711 primitive_desc(dnnl_primitive_desc_t pd)
4712 : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4713 dnnl::prop_kind::forward_training,
4714 dnnl::prop_kind::forward_inference) {}
4715
4716 /// @copydoc dnnl::primitive_desc_base::src_desc()const
4717 memory::desc src_desc() const { return base::src_desc(0); }
4718
4719 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
4720 memory::desc weights_desc() const { return base::weights_desc(0); }
4721
4722 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
4723 memory::desc dst_desc() const { return base::dst_desc(0); }
4724
4725 /// Returns the bias memory descriptor.
4726 /// @returns The bias memory descriptor.
4727 /// @returns A zero memory descriptor of the primitive does not have a
4728 /// bias parameter.
4729 memory::desc bias_desc() const { return base::weights_desc(1); }
4730
4731 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
4732 algorithm get_algorithm() const { return base::get_algorithm(); }
4733
4734 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
4735 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
4736
4737 /// @copydoc dnnl::primitive_desc_base::get_strides()const
4738 memory::dims get_strides() const { return base::get_strides(); }
4739
4740 /// @copydoc dnnl::primitive_desc_base::get_dilations()const
4741 memory::dims get_dilations() const { return base::get_dilations(); }
4742
4743 /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
4744 memory::dims get_padding_l() const { return base::get_padding_l(); }
4745
4746 /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
4747 memory::dims get_padding_r() const { return base::get_padding_r(); }
4748
4749 private:
4750 primitive_desc(const engine &aengine, prop_kind aprop_kind,
4751 algorithm aalgorithm, const memory::desc &src_desc,
4752 const memory::desc &weights_desc, const memory::desc *bias_desc,
4753 const memory::desc &dst_desc, const memory::dims &strides,
4754 const memory::dims *dilates, const memory::dims &padding_l,
4755 const memory::dims &padding_r, const primitive_attr &attr,
4756 bool allow_empty) {
4757
4758 memory::validate_dims(strides, src_desc.get_ndims() - 2);
4759 memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
4760 memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
4761
4762 if (dilates)
4763 memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
4764
4765 dnnl_primitive_desc_t pd = nullptr;
4766 dnnl_status_t status
4767 = dnnl_convolution_forward_primitive_desc_create(&pd,
4768 aengine.get(), dnnl::convert_to_c(aprop_kind),
4769 convert_to_c(aalgorithm), src_desc.get(),
4770 weights_desc.get(), optional_arg(bias_desc),
4771 dst_desc.get(), &strides[0], optional_arg(dilates),
4772 &padding_l[0], &padding_r[0], attr.get());
4773 if (!allow_empty)
4774 error::wrap_c_api(status,
4775 "could not create a primitive descriptor for a "
4776 "convolution forward propagation primitive");
4777 reset(pd);
4778 }
4779 };
4780
4781 /// Default constructor. Produces an empty object.
4782 convolution_forward() = default;
4783
4784 /// Constructs a convolution forward propagation primitive.
4785 /// @param pd Primitive descriptor for a convolution forward propagation
4786 /// primitive.
4787 convolution_forward(const primitive_desc &pd) : primitive(pd) {}
4788
4789 /// Constructs a convolution forward propagation primitive from a cache
4790 /// blob.
4791 /// @param pd Primitive descriptor for a convolution forward propagation
4792 /// primitive.
4793 /// @param cache_blob Cache blob.
4794 convolution_forward(
4795 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
4796 : primitive(pd, cache_blob) {}
4797};
4798
4799/// Convolution backward propagation primitive.
4800struct convolution_backward_data : public primitive {
4801 /// Primitive descriptor for a convolution backward propagation primitive.
4802 struct primitive_desc : public dnnl::primitive_desc {
4803 /// Default constructor. Produces an empty object.
4804 primitive_desc() = default;
4805
4806 /// Constructs a primitive descriptor for a convolution backward
4807 /// propagation primitive.
4808 ///
4809 /// @note
4810 /// All the memory descriptors may be initialized with the
4811 /// #dnnl::memory::format_tag::any value of @p format_tag.
4812 ///
4813 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
4814 /// for spatial dimensions only and hence must have the same number of
4815 /// elements as there are spatial dimensions. The order of values is
4816 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
4817 /// and 2D tensors), and width.
4818 ///
4819 /// @param aengine Engine to use.
4820 /// @param aalgorithm Convolution algorithm. Possible values are
4821 /// #dnnl::algorithm::convolution_direct,
4822 /// #dnnl::algorithm::convolution_winograd, and
4823 /// #dnnl::algorithm::convolution_auto.
4824 /// @param diff_src_desc Diff source memory descriptor.
4825 /// @param weights_desc Weights memory descriptor.
4826 /// @param diff_dst_desc Diff destination memory descriptor.
4827 /// @param strides Strides for each spatial dimension.
4828 /// @param padding_l Vector of padding values for low indices for each
4829 /// spatial dimension `([[front,] top,] left)`.
4830 /// @param padding_r Vector of padding values for high indices for
4831 /// each spatial dimension `([[back,] bottom,] right)`.
4832 /// @param hint_fwd_pd Primitive descriptor for a convolution
4833 /// forward propagation primitive. It is used as a hint for
4834 /// deciding which memory format to use.
4835 /// @param attr Primitive attributes to use. Attributes are optional
4836 /// and default to empty attributes.
4837 /// @param allow_empty A flag signifying whether construction is
4838 /// allowed to fail without throwing an exception. In this case an
4839 /// empty object will be produced. This flag is optional and
4840 /// defaults to false.
4841 primitive_desc(const engine &aengine, algorithm aalgorithm,
4842 const memory::desc &diff_src_desc,
4843 const memory::desc &weights_desc,
4844 const memory::desc &diff_dst_desc, const memory::dims &strides,
4845 const memory::dims &padding_l, const memory::dims &padding_r,
4846 const convolution_forward::primitive_desc &hint_fwd_pd,
4847 const primitive_attr &attr = default_attr(),
4848 bool allow_empty = false)
4849 : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
4850 diff_dst_desc, strides, nullptr, padding_l, padding_r,
4851 hint_fwd_pd, attr, allow_empty) {}
4852
4853 /// Constructs a primitive descriptor for a convolution backward
4854 /// propagation primitive.
4855 ///
4856 /// @note
4857 /// All the memory descriptors may be initialized with the
4858 /// #dnnl::memory::format_tag::any value of @p format_tag.
4859 ///
4860 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
4861 /// contain values for spatial dimensions only and hence must have the
4862 /// same number of elements as there are spatial dimensions. The order
4863 /// of values is the same as in the tensor: depth (for 3D tensors),
4864 /// height (for 3D and 2D tensors), and width.
4865 ///
4866 /// @param aengine Engine to use.
4867 /// @param aalgorithm Convolution algorithm. Possible values are
4868 /// #dnnl::algorithm::convolution_direct,
4869 /// #dnnl::algorithm::convolution_winograd, and
4870 /// #dnnl::algorithm::convolution_auto.
4871 /// @param diff_src_desc Diff source memory descriptor.
4872 /// @param weights_desc Weights memory descriptor.
4873 /// @param diff_dst_desc Diff destination memory descriptor.
4874 /// @param strides Strides for each spatial dimension.
4875 /// @param dilates Dilations for each spatial dimension. A zero value
4876 /// means no dilation in the corresponding dimension.
4877 /// @param padding_l Vector of padding values for low indices for each
4878 /// spatial dimension `([[front,] top,] left)`.
4879 /// @param padding_r Vector of padding values for high indices for
4880 /// each spatial dimension `([[back,] bottom,] right)`.
4881 /// @param hint_fwd_pd Primitive descriptor for a convolution
4882 /// forward propagation primitive. It is used as a hint for
4883 /// deciding which memory format to use.
4884 /// @param attr Primitive attributes to use. Attributes are optional
4885 /// and default to empty attributes.
4886 /// @param allow_empty A flag signifying whether construction is
4887 /// allowed to fail without throwing an exception. In this case an
4888 /// empty object will be produced. This flag is optional and
4889 /// defaults to false.
4890 primitive_desc(const engine &aengine, algorithm aalgorithm,
4891 const memory::desc &diff_src_desc,
4892 const memory::desc &weights_desc,
4893 const memory::desc &diff_dst_desc, const memory::dims &strides,
4894 const memory::dims &dilates, const memory::dims &padding_l,
4895 const memory::dims &padding_r,
4896 const convolution_forward::primitive_desc &hint_fwd_pd,
4897 const primitive_attr &attr = default_attr(),
4898 bool allow_empty = false)
4899 : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
4900 diff_dst_desc, strides, &dilates, padding_l, padding_r,
4901 hint_fwd_pd, attr, allow_empty) {}
4902
4903 /// Constructs a primitive descriptor for a convolution backward
4904 /// propagation primitive from a C API primitive descriptor that must
4905 /// have a matching kind.
4906 ///
4907 /// @param pd C API primitive descriptor for a convolution backward
4908 /// propagation primitive.
4909 primitive_desc(dnnl_primitive_desc_t pd)
4910 : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4911 dnnl::prop_kind::backward_data) {}
4912
4913 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
4914 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
4915
4916 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
4917 memory::desc weights_desc() const { return base::weights_desc(0); }
4918
4919 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
4920 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
4921
4922 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
4923 algorithm get_algorithm() const { return base::get_algorithm(); }
4924
4925 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
4926 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
4927
4928 /// @copydoc dnnl::primitive_desc_base::get_strides()const
4929 memory::dims get_strides() const { return base::get_strides(); }
4930
4931 /// @copydoc dnnl::primitive_desc_base::get_dilations()const
4932 memory::dims get_dilations() const { return base::get_dilations(); }
4933
4934 /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
4935 memory::dims get_padding_l() const { return base::get_padding_l(); }
4936
4937 /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
4938 memory::dims get_padding_r() const { return base::get_padding_r(); }
4939
4940 private:
4941 primitive_desc(const engine &aengine, algorithm aalgorithm,
4942 const memory::desc &diff_src_desc,
4943 const memory::desc &weights_desc,
4944 const memory::desc &diff_dst_desc, const memory::dims &strides,
4945 const memory::dims *dilates, const memory::dims &padding_l,
4946 const memory::dims &padding_r,
4947 const convolution_forward::primitive_desc &hint_fwd_pd,
4948 const primitive_attr &attr, bool allow_empty) {
4949
4950 memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
4951 memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
4952 memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
4953
4954 if (dilates)
4955 memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
4956
4957 dnnl_primitive_desc_t pd = nullptr;
4958 dnnl_status_t status
4959 = dnnl_convolution_backward_data_primitive_desc_create(&pd,
4960 aengine.get(), convert_to_c(aalgorithm),
4961 diff_src_desc.get(), weights_desc.get(),
4962 diff_dst_desc.get(), &strides[0],
4963 optional_arg(dilates), &padding_l[0], &padding_r[0],
4964 hint_fwd_pd.get(), attr.get());
4965 if (!allow_empty)
4966 error::wrap_c_api(status,
4967 "could not create a primitive descriptor for a "
4968 "convolution backward propagation primitive");
4969 reset(pd);
4970 }
4971 };
4972
4973 /// Default constructor. Produces an empty object.
4974 convolution_backward_data() = default;
4975
4976 /// Constructs a convolution backward propagation primitive.
4977 /// @param pd Primitive descriptor for a convolution backward propagation
4978 /// primitive.
4979 convolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
4980
4981 /// Constructs a convolution backward propagation primitive from a cache
4982 /// blob.
4983 /// @param pd Primitive descriptor for a convolution backward propagation
4984 /// primitive.
4985 /// @param cache_blob Cache blob.
4986 convolution_backward_data(
4987 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
4988 : primitive(pd, cache_blob) {}
4989};
4990
4991/// Convolution weights gradient primitive.
4992struct convolution_backward_weights : public primitive {
4993 /// Primitive descriptor for a convolution weights gradient primitive.
4994 struct primitive_desc : public dnnl::primitive_desc {
4995 /// Default constructor. Produces an empty object.
4996 primitive_desc() = default;
4997
4998 /// Constructs a primitive descriptor for a convolution weights gradient
4999 /// primitive with bias.
5000 ///
5001 /// @note
5002 /// All the memory descriptors may be initialized with the
5003 /// #dnnl::memory::format_tag::any value of @p format_tag.
5004 ///
5005 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
5006 /// for spatial dimensions only and hence must have the same number of
5007 /// elements as there are spatial dimensions. The order of values is
5008 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
5009 /// and 2D tensors), and width.
5010 ///
5011 /// @param aengine Engine to use.
5012 /// @param aalgorithm Convolution algorithm. Possible values are
5013 /// #dnnl::algorithm::convolution_direct,
5014 /// #dnnl::algorithm::convolution_winograd, and
5015 /// #dnnl::algorithm::convolution_auto.
5016 /// @param src_desc Source memory descriptor.
5017 /// @param diff_weights_desc Diff weights memory descriptor.
5018 /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
5019 /// memory descriptor disables the bias term.
5020 /// @param diff_dst_desc Diff destination memory descriptor.
5021 /// @param strides Strides for each spatial dimension.
5022 /// @param padding_l Vector of padding values for low indices for each
5023 /// spatial dimension `([[front,] top,] left)`.
5024 /// @param padding_r Vector of padding values for high indices for
5025 /// each spatial dimension `([[back,] bottom,] right)`.
5026 /// @param hint_fwd_pd Primitive descriptor for a convolution
5027 /// forward propagation primitive. It is used as a hint for
5028 /// deciding which memory format to use.
5029 /// @param attr Primitive attributes to use. Attributes are optional
5030 /// and default to empty attributes.
5031 /// @param allow_empty A flag signifying whether construction is
5032 /// allowed to fail without throwing an exception. In this case an
5033 /// empty object will be produced. This flag is optional and
5034 /// defaults to false.
5035 primitive_desc(const engine &aengine, algorithm aalgorithm,
5036 const memory::desc &src_desc,
5037 const memory::desc &diff_weights_desc,
5038 const memory::desc &diff_bias_desc,
5039 const memory::desc &diff_dst_desc, const memory::dims &strides,
5040 const memory::dims &padding_l, const memory::dims &padding_r,
5041 const convolution_forward::primitive_desc &hint_fwd_pd,
5042 const primitive_attr &attr = default_attr(),
5043 bool allow_empty = false)
5044 : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
5045 &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
5046 padding_r, hint_fwd_pd, attr, allow_empty) {}
5047
5048 /// Constructs a primitive descriptor for a convolution weights gradient
5049 /// primitive without bias.
5050 ///
5051 /// @note
5052 /// All the memory descriptors may be initialized with the
5053 /// #dnnl::memory::format_tag::any value of @p format_tag.
5054 ///
5055 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
5056 /// for spatial dimensions only and hence must have the same number of
5057 /// elements as there are spatial dimensions. The order of values is
5058 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
5059 /// and 2D tensors), and width.
5060 ///
5061 /// @param aengine Engine to use.
5062 /// @param aalgorithm Convolution algorithm. Possible values are
5063 /// #dnnl::algorithm::convolution_direct,
5064 /// #dnnl::algorithm::convolution_winograd, and
5065 /// #dnnl::algorithm::convolution_auto.
5066 /// @param src_desc Source memory descriptor.
5067 /// @param diff_weights_desc Diff weights memory descriptor.
5068 /// @param diff_dst_desc Diff destination memory descriptor.
5069 /// @param strides Strides for each spatial dimension.
5070 /// @param padding_l Vector of padding values for low indices for each
5071 /// spatial dimension `([[front,] top,] left)`.
5072 /// @param padding_r Vector of padding values for high indices for
5073 /// each spatial dimension `([[back,] bottom,] right)`.
5074 /// @param hint_fwd_pd Primitive descriptor for a convolution
5075 /// forward propagation primitive. It is used as a hint for
5076 /// deciding which memory format to use.
5077 /// @param attr Primitive attributes to use. Attributes are optional
5078 /// and default to empty attributes.
5079 /// @param allow_empty A flag signifying whether construction is
5080 /// allowed to fail without throwing an exception. In this case an
5081 /// empty object will be produced. This flag is optional and
5082 /// defaults to false.
5083 primitive_desc(const engine &aengine, algorithm aalgorithm,
5084 const memory::desc &src_desc,
5085 const memory::desc &diff_weights_desc,
5086 const memory::desc &diff_dst_desc, const memory::dims &strides,
5087 const memory::dims &padding_l, const memory::dims &padding_r,
5088 const convolution_forward::primitive_desc &hint_fwd_pd,
5089 const primitive_attr &attr = default_attr(),
5090 bool allow_empty = false)
5091 : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
5092 nullptr, diff_dst_desc, strides, nullptr, padding_l,
5093 padding_r, hint_fwd_pd, attr, allow_empty) {}
5094
5095 /// Constructs a primitive descriptor for a convolution weights
5096 /// gradient primitive with bias.
5097 ///
5098 /// @note
5099 /// All the memory descriptors may be initialized with the
5100 /// #dnnl::memory::format_tag::any value of @p format_tag.
5101 ///
5102 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
5103 /// contain values for spatial dimensions only and hence must have the
5104 /// same number of elements as there are spatial dimensions. The order
5105 /// of values is the same as in the tensor: depth (for 3D tensors),
5106 /// height (for 3D and 2D tensors), and width.
5107 ///
5108 /// @param aengine Engine to use.
5109 /// @param aalgorithm Convolution algorithm. Possible values are
5110 /// #dnnl::algorithm::convolution_direct,
5111 /// #dnnl::algorithm::convolution_winograd, and
5112 /// #dnnl::algorithm::convolution_auto.
5113 /// @param src_desc Source memory descriptor.
5114 /// @param diff_weights_desc Diff weights memory descriptor.
5115 /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
5116 /// memory descriptor disables the bias term.
5117 /// @param diff_dst_desc Diff destination memory descriptor.
5118 /// @param strides Strides for each spatial dimension.
5119 /// @param dilates Dilations for each spatial dimension. A zero value
5120 /// means no dilation in the corresponding dimension.
5121 /// @param padding_l Vector of padding values for low indices for each
5122 /// spatial dimension `([[front,] top,] left)`.
5123 /// @param padding_r Vector of padding values for high indices for
5124 /// each spatial dimension `([[back,] bottom,] right)`.
5125 /// @param hint_fwd_pd Primitive descriptor for a convolution
5126 /// forward propagation primitive. It is used as a hint for
5127 /// deciding which memory format to use.
5128 /// @param attr Primitive attributes to use. Attributes are optional
5129 /// and default to empty attributes.
5130 /// @param allow_empty A flag signifying whether construction is
5131 /// allowed to fail without throwing an exception. In this case an
5132 /// empty object will be produced. This flag is optional and
5133 /// defaults to false.
5134 primitive_desc(const engine &aengine, algorithm aalgorithm,
5135 const memory::desc &src_desc,
5136 const memory::desc &diff_weights_desc,
5137 const memory::desc &diff_bias_desc,
5138 const memory::desc &diff_dst_desc, const memory::dims &strides,
5139 const memory::dims &dilates, const memory::dims &padding_l,
5140 const memory::dims &padding_r,
5141 const convolution_forward::primitive_desc &hint_fwd_pd,
5142 const primitive_attr &attr = default_attr(),
5143 bool allow_empty = false)
5144 : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
5145 &diff_bias_desc, diff_dst_desc, strides, &dilates,
5146 padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
5147
5148 /// Constructs a primitive descriptor for a convolution weights
5149 /// gradient primitive without bias.
5150 ///
5151 /// @note
5152 /// All the memory descriptors may be initialized with the
5153 /// #dnnl::memory::format_tag::any value of @p format_tag.
5154 ///
5155 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
5156 /// contain values for spatial dimensions only and hence must have the
5157 /// same number of elements as there are spatial dimensions. The order
5158 /// of values is the same as in the tensor: depth (for 3D tensors),
5159 /// height (for 3D and 2D tensors), and width.
5160 ///
5161 /// @param aengine Engine to use.
5162 /// @param aalgorithm Convolution algorithm. Possible values are
5163 /// #dnnl::algorithm::convolution_direct,
5164 /// #dnnl::algorithm::convolution_winograd, and
5165 /// #dnnl::algorithm::convolution_auto.
5166 /// @param src_desc Source memory descriptor.
5167 /// @param diff_weights_desc Diff weights memory descriptor.
5168 /// @param diff_dst_desc Diff destination memory descriptor.
5169 /// @param strides Strides for each spatial dimension.
5170 /// @param dilates Dilations for each spatial dimension. A zero value
5171 /// means no dilation in the corresponding dimension.
5172 /// @param padding_l Vector of padding values for low indices for each
5173 /// spatial dimension `([[front,] top,] left)`.
5174 /// @param padding_r Vector of padding values for high indices for
5175 /// each spatial dimension `([[back,] bottom,] right)`.
5176 /// @param hint_fwd_pd Primitive descriptor for a convolution
5177 /// forward propagation primitive. It is used as a hint for
5178 /// deciding which memory format to use.
5179 /// @param attr Primitive attributes to use. Attributes are optional
5180 /// and default to empty attributes.
5181 /// @param allow_empty A flag signifying whether construction is
5182 /// allowed to fail without throwing an exception. In this case an
5183 /// empty object will be produced. This flag is optional and
5184 /// defaults to false.
5185 primitive_desc(const engine &aengine, algorithm aalgorithm,
5186 const memory::desc &src_desc,
5187 const memory::desc &diff_weights_desc,
5188 const memory::desc &diff_dst_desc, const memory::dims &strides,
5189 const memory::dims &dilates, const memory::dims &padding_l,
5190 const memory::dims &padding_r,
5191 const convolution_forward::primitive_desc &hint_fwd_pd,
5192 const primitive_attr &attr = default_attr(),
5193 bool allow_empty = false)
5194 : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
5195 nullptr, diff_dst_desc, strides, &dilates, padding_l,
5196 padding_r, hint_fwd_pd, attr, allow_empty) {}
5197
5198 /// Constructs a primitive descriptor for a convolution weights gradient
5199 /// primitive from a C API primitive descriptor that must have a
5200 /// matching kind.
5201 ///
5202 /// @param pd C API primitive descriptor for a convolution weights
5203 /// gradient primitive.
5204 primitive_desc(dnnl_primitive_desc_t pd)
5205 : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
5206 dnnl::prop_kind::backward_weights) {}
5207
5208 /// @copydoc dnnl::primitive_desc_base::src_desc()const
5209 memory::desc src_desc() const { return base::src_desc(0); }
5210
5211 /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
5212 memory::desc diff_weights_desc() const {
5213 return base::diff_weights_desc(0);
5214 }
5215
5216 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
5217 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
5218
5219 /// Returns the diff bias memory descriptor.
5220 /// @returns The diff bias memory descriptor.
5221 /// @returns A zero memory descriptor of the primitive does not have a
5222 /// diff bias parameter.
5223 memory::desc diff_bias_desc() const {
5224 return base::diff_weights_desc(1);
5225 }
5226
5227 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
5228 algorithm get_algorithm() const { return base::get_algorithm(); }
5229
5230 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
5231 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
5232
5233 /// @copydoc dnnl::primitive_desc_base::get_strides()const
5234 memory::dims get_strides() const { return base::get_strides(); }
5235
5236 /// @copydoc dnnl::primitive_desc_base::get_dilations()const
5237 memory::dims get_dilations() const { return base::get_dilations(); }
5238
5239 /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
5240 memory::dims get_padding_l() const { return base::get_padding_l(); }
5241
5242 /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
5243 memory::dims get_padding_r() const { return base::get_padding_r(); }
5244
5245 private:
5246 primitive_desc(const engine &aengine, algorithm aalgorithm,
5247 const memory::desc &src_desc,
5248 const memory::desc &diff_weights_desc,
5249 const memory::desc *diff_bias_desc,
5250 const memory::desc &diff_dst_desc, const memory::dims &strides,
5251 const memory::dims *dilates, const memory::dims &padding_l,
5252 const memory::dims &padding_r,
5253 const convolution_forward::primitive_desc &hint_fwd_pd,
5254 const primitive_attr &attr, bool allow_empty) {
5255
5256 memory::validate_dims(strides, src_desc.get_ndims() - 2);
5257 memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
5258 memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
5259
5260 if (dilates)
5261 memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
5262
5263 dnnl_primitive_desc_t pd = nullptr;
5264 dnnl_status_t status
5265 = dnnl_convolution_backward_weights_primitive_desc_create(
5266 &pd, aengine.get(), convert_to_c(aalgorithm),
5267 src_desc.get(), diff_weights_desc.get(),
5268 optional_arg(diff_bias_desc), diff_dst_desc.get(),
5269 &strides[0], optional_arg(dilates), &padding_l[0],
5270 &padding_r[0], hint_fwd_pd.get(), attr.get());
5271 if (!allow_empty)
5272 error::wrap_c_api(status,
5273 "could not create a primitive descriptor for a "
5274 "convolution weights update primitive");
5275 reset(pd);
5276 }
5277 };
5278
5279 /// Default constructor. Produces an empty object.
5280 convolution_backward_weights() = default;
5281
5282 /// Constructs a convolution weights gradient primitive.
5283 /// @param pd Primitive descriptor for a convolution weights gradient
5284 /// primitive.
5285 convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
5286
5287 /// Constructs a convolution weights gradient primitive from a cache blob.
5288 /// @param pd Primitive descriptor for a convolution weights gradient
5289 /// primitive.
5290 /// @param cache_blob Cache blob.
5291 convolution_backward_weights(
5292 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
5293 : primitive(pd, cache_blob) {}
5294};
5295
5296/// @} dnnl_api_convolution
5297//
5298/// @addtogroup dnnl_api_deconvolution Deconvolution
5299///
5300/// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are
5301/// forward propagation, backward propagation, and weights gradient with or
5302/// without bias.
5303///
5304/// @{
5305
5306/// Deconvolution forward propagation primitive.
5307struct deconvolution_forward : public primitive {
5308 /// Primitive descriptor for a deconvolution forward propagation primitive.
5309 struct primitive_desc : public dnnl::primitive_desc {
5310 /// Default constructor. Produces an empty object.
5311 primitive_desc() = default;
5312
5313 /// Constructs a primitive descriptor for a deconvolution forward
5314 /// propagation primitive with bias.
5315 ///
5316 /// @note
5317 /// All the memory descriptors may be initialized with the
5318 /// #dnnl::memory::format_tag::any value of @p format_tag.
5319 ///
5320 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
5321 /// for spatial dimensions only and hence must have the same number of
5322 /// elements as there are spatial dimensions. The order of values is
5323 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
5324 /// and 2D tensors), and width.
5325 ///
5326 /// @param aengine Engine to use.
5327 /// @param aprop_kind Propagation kind. Possible values are
5328 /// #dnnl::prop_kind::forward_training, and
5329 /// #dnnl::prop_kind::forward_inference.
5330 /// @param aalgorithm Deconvolution algorithm:
5331 /// #dnnl::algorithm::deconvolution_direct, and
5332 /// #dnnl::algorithm::deconvolution_winograd.
5333 /// @param src_desc Source memory descriptor.
5334 /// @param weights_desc Weights memory descriptor.
5335 /// @param bias_desc Bias memory descriptor. Passing zero memory
5336 /// descriptor disables the bias term.
5337 /// @param dst_desc Destination memory descriptor.
5338 /// @param strides Vector of strides for spatial dimension.
5339 /// @param padding_l Vector of padding values for low indices for each
5340 /// spatial dimension `([[front,] top,] left)`.
5341 /// @param padding_r Vector of padding values for high indices for
5342 /// each spatial dimension `([[back,] bottom,] right)`.
5343 /// @param attr Primitive attributes to use. Attributes are optional
5344 /// and default to empty attributes.
5345 /// @param allow_empty A flag signifying whether construction is
5346 /// allowed to fail without throwing an exception. In this case an
5347 /// empty object will be produced. This flag is optional and
5348 /// defaults to false.
5349 primitive_desc(const engine &aengine, prop_kind aprop_kind,
5350 algorithm aalgorithm, const memory::desc &src_desc,
5351 const memory::desc &weights_desc, const memory::desc &bias_desc,
5352 const memory::desc &dst_desc, const memory::dims &strides,
5353 const memory::dims &padding_l, const memory::dims &padding_r,
5354 const primitive_attr &attr = default_attr(),
5355 bool allow_empty = false)
5356 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
5357 weights_desc, &bias_desc, dst_desc, strides, nullptr,
5358 padding_l, padding_r, attr, allow_empty) {}
5359
5360 /// Constructs a primitive descriptor for a deconvolution forward
5361 /// propagation primitive without bias.
5362 ///
5363 /// @note
5364 /// All the memory descriptors may be initialized with the
5365 /// #dnnl::memory::format_tag::any value of @p format_tag.
5366 ///
5367 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
5368 /// for spatial dimensions only and hence must have the same number of
5369 /// elements as there are spatial dimensions. The order of values is
5370 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
5371 /// and 2D tensors), and width.
5372 ///
5373 /// @param aengine Engine to use.
5374 /// @param aprop_kind Propagation kind. Possible values are
5375 /// #dnnl::prop_kind::forward_training, and
5376 /// #dnnl::prop_kind::forward_inference.
5377 /// @param aalgorithm Deconvolution algorithm:
5378 /// #dnnl::algorithm::deconvolution_direct, and
5379 /// #dnnl::algorithm::deconvolution_winograd.
5380 /// @param src_desc Source memory descriptor.
5381 /// @param weights_desc Weights memory descriptor.
5382 /// @param dst_desc Destination memory descriptor.
5383 /// @param strides Vector of strides for spatial dimension.
5384 /// @param padding_l Vector of padding values for low indices for each
5385 /// spatial dimension `([[front,] top,] left)`.
5386 /// @param padding_r Vector of padding values for high indices for
5387 /// each spatial dimension `([[back,] bottom,] right)`.
5388 /// @param attr Primitive attributes to use. Attributes are optional
5389 /// and default to empty attributes.
5390 /// @param allow_empty A flag signifying whether construction is
5391 /// allowed to fail without throwing an exception. In this case an
5392 /// empty object will be produced. This flag is optional and
5393 /// defaults to false.
5394 primitive_desc(const engine &aengine, prop_kind aprop_kind,
5395 algorithm aalgorithm, const memory::desc &src_desc,
5396 const memory::desc &weights_desc, const memory::desc &dst_desc,
5397 const memory::dims &strides, const memory::dims &padding_l,
5398 const memory::dims &padding_r,
5399 const primitive_attr &attr = default_attr(),
5400 bool allow_empty = false)
5401 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
5402 weights_desc, nullptr, dst_desc, strides, nullptr,
5403 padding_l, padding_r, attr, allow_empty) {}
5404
5405 /// Constructs a primitive descriptor for a deconvolution forward
5406 /// propagation primitive with bias.
5407 ///
5408 /// @note
5409 /// All the memory descriptors may be initialized with the
5410 /// #dnnl::memory::format_tag::any value of @p format_tag.
5411 ///
5412 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
5413 /// contain values for spatial dimensions only and hence must have the
5414 /// same number of elements as there are spatial dimensions. The order
5415 /// of values is the same as in the tensor: depth (for 3D tensors),
5416 /// height (for 3D and 2D tensors), and width.
5417 ///
5418 /// @param aengine Engine to use.
5419 /// @param aprop_kind Propagation kind. Possible values are
5420 /// #dnnl::prop_kind::forward_training, and
5421 /// #dnnl::prop_kind::forward_inference.
5422 /// @param aalgorithm Deconvolution algorithm:
5423 /// #dnnl::algorithm::deconvolution_direct, and
5424 /// #dnnl::algorithm::deconvolution_winograd.
5425 /// @param src_desc Source memory descriptor.
5426 /// @param weights_desc Weights memory descriptor.
5427 /// @param bias_desc Bias memory descriptor. Passing zero memory
5428 /// descriptor disables the bias term.
5429 /// @param dst_desc Destination memory descriptor.
5430 /// @param strides Vector of strides for spatial dimension.
5431 /// @param dilates Dilations for each spatial dimension. A zero value
5432 /// means no dilation in the corresponding dimension.
5433 /// @param padding_l Vector of padding values for low indices for each
5434 /// spatial dimension `([[front,] top,] left)`.
5435 /// @param padding_r Vector of padding values for high indices for
5436 /// each spatial dimension `([[back,] bottom,] right)`.
5437 /// @param attr Primitive attributes to use. Attributes are optional
5438 /// and default to empty attributes.
5439 /// @param allow_empty A flag signifying whether construction is
5440 /// allowed to fail without throwing an exception. In this case an
5441 /// empty object will be produced. This flag is optional and
5442 /// defaults to false.
5443 primitive_desc(const engine &aengine, prop_kind aprop_kind,
5444 algorithm aalgorithm, const memory::desc &src_desc,
5445 const memory::desc &weights_desc, const memory::desc &bias_desc,
5446 const memory::desc &dst_desc, const memory::dims &strides,
5447 const memory::dims &dilates, const memory::dims &padding_l,
5448 const memory::dims &padding_r,
5449 const primitive_attr &attr = default_attr(),
5450 bool allow_empty = false)
5451 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
5452 weights_desc, &bias_desc, dst_desc, strides, &dilates,
5453 padding_l, padding_r, attr, allow_empty) {}
5454
5455 /// Constructs a primitive descriptor for a deconvolution forward
5456 /// propagation primitive without bias.
5457 ///
5458 /// @note
5459 /// All the memory descriptors may be initialized with the
5460 /// #dnnl::memory::format_tag::any value of @p format_tag.
5461 ///
5462 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
5463 /// contain values for spatial dimensions only and hence must have the
5464 /// same number of elements as there are spatial dimensions. The order
5465 /// of values is the same as in the tensor: depth (for 3D tensors),
5466 /// height (for 3D and 2D tensors), and width.
5467 ///
5468 /// @param aengine Engine to use.
5469 /// @param aprop_kind Propagation kind. Possible values are
5470 /// #dnnl::prop_kind::forward_training, and
5471 /// #dnnl::prop_kind::forward_inference.
5472 /// @param aalgorithm Deconvolution algorithm:
5473 /// #dnnl::algorithm::deconvolution_direct, and
5474 /// #dnnl::algorithm::deconvolution_winograd.
5475 /// @param src_desc Source memory descriptor.
5476 /// @param weights_desc Weights memory descriptor.
5477 /// @param dst_desc Destination memory descriptor.
5478 /// @param strides Vector of strides for spatial dimension.
5479 /// @param dilates Dilations for each spatial dimension. A zero value
5480 /// means no dilation in the corresponding dimension.
5481 /// @param padding_l Vector of padding values for low indices for each
5482 /// spatial dimension `([[front,] top,] left)`.
5483 /// @param padding_r Vector of padding values for high indices for
5484 /// each spatial dimension `([[back,] bottom,] right)`.
5485 /// @param attr Primitive attributes to use. Attributes are optional
5486 /// and default to empty attributes.
5487 /// @param allow_empty A flag signifying whether construction is
5488 /// allowed to fail without throwing an exception. In this case an
5489 /// empty object will be produced. This flag is optional and
5490 /// defaults to false.
5491 primitive_desc(const engine &aengine, prop_kind aprop_kind,
5492 algorithm aalgorithm, const memory::desc &src_desc,
5493 const memory::desc &weights_desc, const memory::desc &dst_desc,
5494 const memory::dims &strides, const memory::dims &dilates,
5495 const memory::dims &padding_l, const memory::dims &padding_r,
5496 const primitive_attr &attr = default_attr(),
5497 bool allow_empty = false)
5498 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
5499 weights_desc, nullptr, dst_desc, strides, &dilates,
5500 padding_l, padding_r, attr, allow_empty) {}
5501
5502 /// Constructs a primitive descriptor for a deconvolution forward
5503 /// propagation primitive from a C API primitive descriptor that must
5504 /// have a matching kind.
5505 ///
5506 /// @param pd C API primitive descriptor for a deconvolution forward
5507 /// propagation primitive.
5508 primitive_desc(dnnl_primitive_desc_t pd)
5509 : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5510 dnnl::prop_kind::forward_training,
5511 dnnl::prop_kind::forward_inference) {}
5512
5513 /// @copydoc dnnl::primitive_desc_base::src_desc()const
5514 memory::desc src_desc() const { return base::src_desc(0); }
5515
5516 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
5517 memory::desc weights_desc() const { return base::weights_desc(0); }
5518
5519 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
5520 memory::desc dst_desc() const { return base::dst_desc(0); }
5521
5522 /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
5523 memory::desc bias_desc() const { return base::weights_desc(1); }
5524
5525 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
5526 algorithm get_algorithm() const { return base::get_algorithm(); }
5527
5528 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
5529 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
5530
5531 /// @copydoc dnnl::primitive_desc_base::get_strides()const
5532 memory::dims get_strides() const { return base::get_strides(); }
5533
5534 /// @copydoc dnnl::primitive_desc_base::get_dilations()const
5535 memory::dims get_dilations() const { return base::get_dilations(); }
5536
5537 /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
5538 memory::dims get_padding_l() const { return base::get_padding_l(); }
5539
5540 /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
5541 memory::dims get_padding_r() const { return base::get_padding_r(); }
5542
5543 private:
5544 primitive_desc(const engine &aengine, prop_kind aprop_kind,
5545 algorithm aalgorithm, const memory::desc &src_desc,
5546 const memory::desc &weights_desc, const memory::desc *bias_desc,
5547 const memory::desc &dst_desc, const memory::dims &strides,
5548 const memory::dims *dilates, const memory::dims &padding_l,
5549 const memory::dims &padding_r, const primitive_attr &attr,
5550 bool allow_empty) {
5551
5552 memory::validate_dims(strides, src_desc.get_ndims() - 2);
5553 memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
5554 memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
5555
5556 if (dilates)
5557 memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
5558
5559 dnnl_primitive_desc_t pd = nullptr;
5560 dnnl_status_t status
5561 = dnnl_deconvolution_forward_primitive_desc_create(&pd,
5562 aengine.get(), dnnl::convert_to_c(aprop_kind),
5563 convert_to_c(aalgorithm), src_desc.get(),
5564 weights_desc.get(), optional_arg(bias_desc),
5565 dst_desc.get(), &strides[0], optional_arg(dilates),
5566 &padding_l[0], &padding_r[0], attr.get());
5567 if (!allow_empty)
5568 error::wrap_c_api(status,
5569 "could not create a primitive descriptor for a "
5570 "deconvolution forward propagation primitive");
5571 reset(pd);
5572 }
5573 };
5574
5575 /// Default constructor. Produces an empty object.
5576 deconvolution_forward() = default;
5577
5578 /// Constructs a deconvolution forward propagation primitive.
5579 /// @param pd Primitive descriptor for a deconvolution forward propagation
5580 /// primitive.
5581 deconvolution_forward(const primitive_desc &pd) : primitive(pd) {}
5582
5583 /// Constructs a deconvolution forward propagation primitive from a cache
5584 /// blob.
5585 /// @param pd Primitive descriptor for a deconvolution forward propagation
5586 /// primitive.
5587 /// @param cache_blob Cache blob.
5588 deconvolution_forward(
5589 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
5590 : primitive(pd, cache_blob) {}
5591};
5592
5593/// Deconvolution backward propagation primitive.
5594struct deconvolution_backward_data : public primitive {
5595 /// Primitive descriptor for a deconvolution backward propagation primitive.
5596 struct primitive_desc : public dnnl::primitive_desc {
5597 /// Default constructor. Produces an empty object.
5598 primitive_desc() = default;
5599
5600 /// Constructs a primitive descriptor for a deconvolution backward
5601 /// propagation primitive.
5602 ///
5603 /// @note
5604 /// All the memory descriptors may be initialized with the
5605 /// #dnnl::memory::format_tag::any value of @p format_tag.
5606 ///
5607 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
5608 /// for spatial dimensions only and hence must have the same number of
5609 /// elements as there are spatial dimensions. The order of values is
5610 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
5611 /// and 2D tensors), and width.
5612 ///
5613 /// @param aengine Engine to use.
5614 /// @param aalgorithm Deconvolution algorithm
5615 /// (#dnnl::algorithm::convolution_direct,
5616 /// #dnnl::algorithm::convolution_winograd).
5617 /// @param diff_src_desc Diff source memory descriptor.
5618 /// @param weights_desc Weights memory descriptor.
5619 /// @param diff_dst_desc Diff destination memory descriptor.
5620 /// @param strides Strides for each spatial dimension.
5621 /// @param padding_l Vector of padding values for low indices for each
5622 /// spatial dimension `([[front,] top,] left)`.
5623 /// @param padding_r Vector of padding values for high indices for
5624 /// each spatial dimension `([[back,] bottom,] right)`.
5625 /// @param hint_fwd_pd Primitive descriptor for a deconvolution
5626 /// forward propagation primitive. It is used as a hint for
5627 /// deciding which memory format to use.
5628 /// @param attr Primitive attributes to use. Attributes are optional
5629 /// and default to empty attributes.
5630 /// @param allow_empty A flag signifying whether construction is
5631 /// allowed to fail without throwing an exception. In this case an
5632 /// empty object will be produced. This flag is optional and
5633 /// defaults to false.
5634 primitive_desc(const engine &aengine, algorithm aalgorithm,
5635 const memory::desc &diff_src_desc,
5636 const memory::desc &weights_desc,
5637 const memory::desc &diff_dst_desc, const memory::dims &strides,
5638 const memory::dims &padding_l, const memory::dims &padding_r,
5639 const deconvolution_forward::primitive_desc &hint_fwd_pd,
5640 const primitive_attr &attr = default_attr(),
5641 bool allow_empty = false)
5642 : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
5643 diff_dst_desc, strides, nullptr, padding_l, padding_r,
5644 hint_fwd_pd, attr, allow_empty) {}
5645
5646 /// Constructs a primitive descriptor for a deconvolution backward
5647 /// propagation primitive.
5648 ///
5649 /// @note
5650 /// All the memory descriptors may be initialized with the
5651 /// #dnnl::memory::format_tag::any value of @p format_tag.
5652 ///
5653 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
5654 /// contain values for spatial dimensions only and hence must have the
5655 /// same number of elements as there are spatial dimensions. The order
5656 /// of values is the same as in the tensor: depth (for 3D tensors),
5657 /// height (for 3D and 2D tensors), and width.
5658 ///
5659 /// @param aengine Engine to use.
5660 /// @param aalgorithm Deconvolution algorithm
5661 /// (#dnnl::algorithm::convolution_direct,
5662 /// #dnnl::algorithm::convolution_winograd).
5663 /// @param diff_src_desc Diff source memory descriptor.
5664 /// @param weights_desc Weights memory descriptor.
5665 /// @param diff_dst_desc Diff destination memory descriptor.
5666 /// @param strides Strides for each spatial dimension.
5667 /// @param dilates Dilations for each spatial dimension. A zero value
5668 /// means no dilation in the corresponding dimension.
5669 /// @param padding_l Vector of padding values for low indices for each
5670 /// spatial dimension `([[front,] top,] left)`.
5671 /// @param padding_r Vector of padding values for high indices for
5672 /// each spatial dimension `([[back,] bottom,] right)`.
5673 /// @param hint_fwd_pd Primitive descriptor for a deconvolution
5674 /// forward propagation primitive. It is used as a hint for
5675 /// deciding which memory format to use.
5676 /// @param attr Primitive attributes to use. Attributes are optional
5677 /// and default to empty attributes.
5678 /// @param allow_empty A flag signifying whether construction is
5679 /// allowed to fail without throwing an exception. In this case an
5680 /// empty object will be produced. This flag is optional and
5681 /// defaults to false.
5682 primitive_desc(const engine &aengine, algorithm aalgorithm,
5683 const memory::desc &diff_src_desc,
5684 const memory::desc &weights_desc,
5685 const memory::desc &diff_dst_desc, const memory::dims &strides,
5686 const memory::dims &dilates, const memory::dims &padding_l,
5687 const memory::dims &padding_r,
5688 const deconvolution_forward::primitive_desc &hint_fwd_pd,
5689 const primitive_attr &attr = default_attr(),
5690 bool allow_empty = false)
5691 : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
5692 diff_dst_desc, strides, &dilates, padding_l, padding_r,
5693 hint_fwd_pd, attr, allow_empty) {}
5694
5695 /// Constructs a primitive descriptor for a deconvolution backward
5696 /// propagation primitive from a C API primitive descriptor that must
5697 /// have a matching kind.
5698 ///
5699 /// @param pd C API primitive descriptor for a deconvolution backward
5700 /// propagation primitive.
5701 primitive_desc(dnnl_primitive_desc_t pd)
5702 : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5703 dnnl::prop_kind::backward_data) {}
5704
5705 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
5706 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
5707
5708 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
5709 memory::desc weights_desc() const { return base::weights_desc(0); }
5710
5711 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
5712 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
5713
5714 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
5715 algorithm get_algorithm() const { return base::get_algorithm(); }
5716
5717 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
5718 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
5719
5720 /// @copydoc dnnl::primitive_desc_base::get_strides()const
5721 memory::dims get_strides() const { return base::get_strides(); }
5722
5723 /// @copydoc dnnl::primitive_desc_base::get_dilations()const
5724 memory::dims get_dilations() const { return base::get_dilations(); }
5725
5726 /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
5727 memory::dims get_padding_l() const { return base::get_padding_l(); }
5728
5729 /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
5730 memory::dims get_padding_r() const { return base::get_padding_r(); }
5731
5732 private:
5733 primitive_desc(const engine &aengine, algorithm aalgorithm,
5734 const memory::desc &diff_src_desc,
5735 const memory::desc &weights_desc,
5736 const memory::desc &diff_dst_desc, const memory::dims &strides,
5737 const memory::dims *dilates, const memory::dims &padding_l,
5738 const memory::dims &padding_r,
5739 const deconvolution_forward::primitive_desc &hint_fwd_pd,
5740 const primitive_attr &attr, bool allow_empty) {
5741
5742 memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
5743 memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
5744 memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
5745
5746 if (dilates)
5747 memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
5748
5749 dnnl_primitive_desc_t pd = nullptr;
5750 dnnl_status_t status
5751 = dnnl_deconvolution_backward_data_primitive_desc_create(
5752 &pd, aengine.get(), convert_to_c(aalgorithm),
5753 diff_src_desc.get(), weights_desc.get(),
5754 diff_dst_desc.get(), &strides[0],
5755 optional_arg(dilates), &padding_l[0], &padding_r[0],
5756 hint_fwd_pd.get(), attr.get());
5757 if (!allow_empty)
5758 error::wrap_c_api(status,
5759 "could not create a primitive descriptor for a "
5760 "deconvolution backward propagation primitive");
5761 reset(pd);
5762 }
5763 };
5764
5765 /// Default constructor. Produces an empty object.
5766 deconvolution_backward_data() = default;
5767
5768 /// Constructs a deconvolution backward propagation primitive.
5769 /// @param pd Primitive descriptor for a deconvolution backward propagation
5770 /// primitive.
5771 deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
5772
5773 /// Constructs a deconvolution backward propagation primitive from a cache
5774 /// blob.
5775 /// @param pd Primitive descriptor for a deconvolution backward propagation
5776 /// primitive.
5777 /// @param cache_blob Cache blob.
5778 deconvolution_backward_data(
5779 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
5780 : primitive(pd, cache_blob) {}
5781};
5782
5783/// Deconvolution weights gradient primitive.
5784struct deconvolution_backward_weights : public primitive {
5785 /// Primitive descriptor for a deconvolution weights gradient primitive.
5786 struct primitive_desc : public dnnl::primitive_desc {
5787 /// Default constructor. Produces an empty object.
5788 primitive_desc() = default;
5789
5790 /// Constructs a primitive descriptor for a deconvolution weights
5791 /// gradient primitive with bias.
5792 ///
5793 /// @note
5794 /// All the memory descriptors may be initialized with the
5795 /// #dnnl::memory::format_tag::any value of @p format_tag.
5796 ///
5797 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
5798 /// for spatial dimensions only and hence must have the same number of
5799 /// elements as there are spatial dimensions. The order of values is
5800 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
5801 /// and 2D tensors), and width.
5802 ///
5803 /// @param aengine Engine to use.
5804 /// @param aalgorithm Deconvolution algorithm. Possible values are
5805 /// #dnnl::algorithm::deconvolution_direct, and
5806 /// #dnnl::algorithm::deconvolution_winograd.
5807 /// @param src_desc Source memory descriptor.
5808 /// @param diff_weights_desc Diff weights memory descriptor.
5809 /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
5810 /// memory descriptor disables the bias term.
5811 /// @param diff_dst_desc Diff destination memory descriptor.
5812 /// @param strides Strides for each spatial dimension.
5813 /// @param padding_l Vector of padding values for low indices for each
5814 /// spatial dimension `([[front,] top,] left)`.
5815 /// @param padding_r Vector of padding values for high indices for
5816 /// each spatial dimension `([[back,] bottom,] right)`.
5817 /// @param hint_fwd_pd Primitive descriptor for a deconvolution
5818 /// forward propagation primitive. It is used as a hint for
5819 /// deciding which memory format to use.
5820 /// @param attr Primitive attributes to use. Attributes are optional
5821 /// and default to empty attributes.
5822 /// @param allow_empty A flag signifying whether construction is
5823 /// allowed to fail without throwing an exception. In this case an
5824 /// empty object will be produced. This flag is optional and
5825 /// defaults to false.
5826 primitive_desc(const engine &aengine, algorithm aalgorithm,
5827 const memory::desc &src_desc,
5828 const memory::desc &diff_weights_desc,
5829 const memory::desc &diff_bias_desc,
5830 const memory::desc &diff_dst_desc, const memory::dims &strides,
5831 const memory::dims &padding_l, const memory::dims &padding_r,
5832 const deconvolution_forward::primitive_desc &hint_fwd_pd,
5833 const primitive_attr &attr = default_attr(),
5834 bool allow_empty = false)
5835 : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
5836 &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
5837 padding_r, hint_fwd_pd, attr, allow_empty) {}
5838
5839 /// Constructs a primitive descriptor for a deconvolution weights
5840 /// gradient primitive without bias.
5841 ///
5842 /// @note
5843 /// All the memory descriptors may be initialized with the
5844 /// #dnnl::memory::format_tag::any value of @p format_tag.
5845 ///
5846 /// Arrays @p strides, @p padding_l, and @p padding_r contain values
5847 /// for spatial dimensions only and hence must have the same number of
5848 /// elements as there are spatial dimensions. The order of values is
5849 /// the same as in the tensor: depth (for 3D tensors), height (for 3D
5850 /// and 2D tensors), and width.
5851 ///
5852 /// @param aengine Engine to use.
5853 /// @param aalgorithm Deconvolution algorithm. Possible values are
5854 /// #dnnl::algorithm::deconvolution_direct, and
5855 /// #dnnl::algorithm::deconvolution_winograd.
5856 /// @param src_desc Source memory descriptor.
5857 /// @param diff_weights_desc Diff weights memory descriptor.
5858 /// @param diff_dst_desc Diff destination memory descriptor.
5859 /// @param strides Strides for each spatial dimension.
5860 /// @param padding_l Vector of padding values for low indices for each
5861 /// spatial dimension `([[front,] top,] left)`.
5862 /// @param padding_r Vector of padding values for high indices for
5863 /// each spatial dimension `([[back,] bottom,] right)`.
5864 /// @param hint_fwd_pd Primitive descriptor for a deconvolution
5865 /// forward propagation primitive. It is used as a hint for
5866 /// deciding which memory format to use.
5867 /// @param attr Primitive attributes to use. Attributes are optional
5868 /// and default to empty attributes.
5869 /// @param allow_empty A flag signifying whether construction is
5870 /// allowed to fail without throwing an exception. In this case an
5871 /// empty object will be produced. This flag is optional and
5872 /// defaults to false.
5873 primitive_desc(const engine &aengine, algorithm aalgorithm,
5874 const memory::desc &src_desc,
5875 const memory::desc &diff_weights_desc,
5876 const memory::desc &diff_dst_desc, const memory::dims &strides,
5877 const memory::dims &padding_l, const memory::dims &padding_r,
5878 const deconvolution_forward::primitive_desc &hint_fwd_pd,
5879 const primitive_attr &attr = default_attr(),
5880 bool allow_empty = false)
5881 : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
5882 nullptr, diff_dst_desc, strides, nullptr, padding_l,
5883 padding_r, hint_fwd_pd, attr, allow_empty) {}
5884
5885 /// Constructs a primitive descriptor for a deconvolution weights
5886 /// gradient primitive with bias.
5887 ///
5888 /// @note
5889 /// All the memory descriptors may be initialized with the
5890 /// #dnnl::memory::format_tag::any value of @p format_tag.
5891 ///
5892 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
5893 /// contain values for spatial dimensions only and hence must have the
5894 /// same number of elements as there are spatial dimensions. The order
5895 /// of values is the same as in the tensor: depth (for 3D tensors),
5896 /// height (for 3D and 2D tensors), and width.
5897 ///
5898 /// @param aengine Engine to use.
5899 /// @param aalgorithm Deconvolution algorithm. Possible values are
5900 /// #dnnl::algorithm::deconvolution_direct, and
5901 /// #dnnl::algorithm::deconvolution_winograd.
5902 /// @param src_desc Source memory descriptor.
5903 /// @param diff_weights_desc Diff weights memory descriptor.
5904 /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
5905 /// memory descriptor disables the bias term.
5906 /// @param diff_dst_desc Diff destination memory descriptor.
5907 /// @param strides Strides for each spatial dimension.
5908 /// @param dilates Dilations for each spatial dimension. A zero value
5909 /// means no dilation in the corresponding dimension.
5910 /// @param padding_l Vector of padding values for low indices for each
5911 /// spatial dimension `([[front,] top,] left)`.
5912 /// @param padding_r Vector of padding values for high indices for
5913 /// each spatial dimension `([[back,] bottom,] right)`.
5914 /// @param hint_fwd_pd Primitive descriptor for a deconvolution
5915 /// forward propagation primitive. It is used as a hint for
5916 /// deciding which memory format to use.
5917 /// @param attr Primitive attributes to use. Attributes are optional
5918 /// and default to empty attributes.
5919 /// @param allow_empty A flag signifying whether construction is
5920 /// allowed to fail without throwing an exception. In this case an
5921 /// empty object will be produced. This flag is optional and
5922 /// defaults to false.
5923 primitive_desc(const engine &aengine, algorithm aalgorithm,
5924 const memory::desc &src_desc,
5925 const memory::desc &diff_weights_desc,
5926 const memory::desc &diff_bias_desc,
5927 const memory::desc &diff_dst_desc, const memory::dims &strides,
5928 const memory::dims &dilates, const memory::dims &padding_l,
5929 const memory::dims &padding_r,
5930 const deconvolution_forward::primitive_desc &hint_fwd_pd,
5931 const primitive_attr &attr = default_attr(),
5932 bool allow_empty = false)
5933 : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
5934 &diff_bias_desc, diff_dst_desc, strides, &dilates,
5935 padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
5936
5937 /// Constructs a primitive descriptor for a deconvolution weights
5938 /// gradient primitive without bias.
5939 ///
5940 /// @note
5941 /// All the memory descriptors may be initialized with the
5942 /// #dnnl::memory::format_tag::any value of @p format_tag.
5943 ///
5944 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
5945 /// contain values for spatial dimensions only and hence must have the
5946 /// same number of elements as there are spatial dimensions. The order
5947 /// of values is the same as in the tensor: depth (for 3D tensors),
5948 /// height (for 3D and 2D tensors), and width.
5949 ///
5950 /// @param aengine Engine to use.
5951 /// @param aalgorithm Deconvolution algorithm. Possible values are
5952 /// #dnnl::algorithm::deconvolution_direct, and
5953 /// #dnnl::algorithm::deconvolution_winograd.
5954 /// @param src_desc Source memory descriptor.
5955 /// @param diff_weights_desc Diff weights memory descriptor.
5956 /// @param diff_dst_desc Diff destination memory descriptor.
5957 /// @param strides Strides for each spatial dimension.
5958 /// @param dilates Dilations for each spatial dimension. A zero value
5959 /// means no dilation in the corresponding dimension.
5960 /// @param padding_l Vector of padding values for low indices for each
5961 /// spatial dimension `([[front,] top,] left)`.
5962 /// @param padding_r Vector of padding values for high indices for
5963 /// each spatial dimension `([[back,] bottom,] right)`.
5964 /// @param hint_fwd_pd Primitive descriptor for a deconvolution
5965 /// forward propagation primitive. It is used as a hint for
5966 /// deciding which memory format to use.
5967 /// @param attr Primitive attributes to use. Attributes are optional
5968 /// and default to empty attributes.
5969 /// @param allow_empty A flag signifying whether construction is
5970 /// allowed to fail without throwing an exception. In this case an
5971 /// empty object will be produced. This flag is optional and
5972 /// defaults to false.
5973 primitive_desc(const engine &aengine, algorithm aalgorithm,
5974 const memory::desc &src_desc,
5975 const memory::desc &diff_weights_desc,
5976 const memory::desc &diff_dst_desc, const memory::dims &strides,
5977 const memory::dims &dilates, const memory::dims &padding_l,
5978 const memory::dims &padding_r,
5979 const deconvolution_forward::primitive_desc &hint_fwd_pd,
5980 const primitive_attr &attr = default_attr(),
5981 bool allow_empty = false)
5982 : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
5983 nullptr, diff_dst_desc, strides, &dilates, padding_l,
5984 padding_r, hint_fwd_pd, attr, allow_empty) {}
5985
5986 /// Constructs a primitive descriptor for a deconvolution weights
5987 /// gradient primitive from a C API primitive descriptor that must
5988 /// have a matching kind.
5989 ///
5990 /// @param pd C API primitive descriptor for a deconvolution weights
5991 /// gradient primitive.
5992 primitive_desc(dnnl_primitive_desc_t pd)
5993 : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5994 dnnl::prop_kind::backward_weights) {}
5995
5996 /// @copydoc dnnl::primitive_desc_base::src_desc()const
5997 memory::desc src_desc() const { return base::src_desc(0); }
5998
5999 /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
6000 memory::desc diff_weights_desc() const {
6001 return base::diff_weights_desc(0);
6002 }
6003
6004 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
6005 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
6006
6007 /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
6008 memory::desc diff_bias_desc() const {
6009 return base::diff_weights_desc(1);
6010 }
6011
6012 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
6013 algorithm get_algorithm() const { return base::get_algorithm(); }
6014
6015 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
6016 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
6017
6018 /// @copydoc dnnl::primitive_desc_base::get_strides()const
6019 memory::dims get_strides() const { return base::get_strides(); }
6020
6021 /// @copydoc dnnl::primitive_desc_base::get_dilations()const
6022 memory::dims get_dilations() const { return base::get_dilations(); }
6023
6024 /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
6025 memory::dims get_padding_l() const { return base::get_padding_l(); }
6026
6027 /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
6028 memory::dims get_padding_r() const { return base::get_padding_r(); }
6029
6030 private:
6031 primitive_desc(const engine &aengine, algorithm aalgorithm,
6032 const memory::desc &src_desc,
6033 const memory::desc &diff_weights_desc,
6034 const memory::desc *diff_bias_desc,
6035 const memory::desc &diff_dst_desc, const memory::dims &strides,
6036 const memory::dims *dilates, const memory::dims &padding_l,
6037 const memory::dims &padding_r,
6038 const deconvolution_forward::primitive_desc &hint_fwd_pd,
6039 const primitive_attr &attr, bool allow_empty) {
6040
6041 memory::validate_dims(strides, src_desc.get_ndims() - 2);
6042 memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
6043 memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
6044
6045 if (dilates)
6046 memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
6047
6048 dnnl_primitive_desc_t pd = nullptr;
6049 dnnl_status_t status
6050 = dnnl_deconvolution_backward_weights_primitive_desc_create(
6051 &pd, aengine.get(), convert_to_c(aalgorithm),
6052 src_desc.get(), diff_weights_desc.get(),
6053 optional_arg(diff_bias_desc), diff_dst_desc.get(),
6054 &strides[0], optional_arg(dilates), &padding_l[0],
6055 &padding_r[0], hint_fwd_pd.get(), attr.get());
6056 if (!allow_empty)
6057 error::wrap_c_api(status,
6058 "could not create a primitive descriptor for a "
6059 "deconvolution weights update primitive");
6060 reset(pd);
6061 }
6062 };
6063
6064 /// Default constructor. Produces an empty object.
6065 deconvolution_backward_weights() = default;
6066
6067 /// Constructs a deconvolution weights gradient primitive.
6068 /// @param pd Primitive descriptor for a deconvolution weights gradient
6069 /// primitive.
6070 deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
6071
6072 /// Constructs a deconvolution weights gradient primitive from a cache
6073 /// blob.
6074 /// @param pd Primitive descriptor for a deconvolution weights gradient
6075 /// primitive.
6076 /// @param cache_blob Cache blob.
6077 deconvolution_backward_weights(
6078 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
6079 : primitive(pd, cache_blob) {}
6080};
6081
6082/// @} dnnl_api_deconvolution
6083
6084/// @addtogroup dnnl_api_lrn LRN
6085///
6086/// A primitive to perform local response normalization (LRN) across or within
6087/// channels.
6088///
6089/// @sa @ref dev_guide_lrn in developer guide
6090///
6091/// @{
6092
6093/// Local response normalization (LRN) forward propagation primitive.
6094struct lrn_forward : public primitive {
6095 /// Primitive descriptor for an LRN forward propagation primitive.
6096 struct primitive_desc : public dnnl::primitive_desc {
6097 /// Default constructor. Produces an empty object.
6098 primitive_desc() = default;
6099
6100 /// Constructs a primitive descriptor for an LRN forward propagation
6101 /// primitive.
6102 ///
6103 /// @param aengine Engine to use.
6104 /// @param aprop_kind Propagation kind. Possible values are
6105 /// #dnnl::prop_kind::forward_training, and
6106 /// #dnnl::prop_kind::forward_inference.
6107 /// @param aalgorithm LRN algorithm kind: either
6108 /// #dnnl::algorithm::lrn_across_channels, or
6109 /// #dnnl::algorithm::lrn_within_channel.
6110 /// @param src_desc Source memory descriptor.
6111 /// @param dst_desc Destination memory descriptor.
6112 /// @param local_size Regularization local size.
6113 /// @param alpha The alpha regularization parameter.
6114 /// @param beta The beta regularization parameter.
6115 /// @param k The k regularization parameter.
6116 /// @param attr Primitive attributes to use. Attributes are optional
6117 /// and default to empty attributes.
6118 /// @param allow_empty A flag signifying whether construction is
6119 /// allowed to fail without throwing an exception. In this case an
6120 /// empty object will be produced. This flag is optional and
6121 /// defaults to false.
6122 primitive_desc(const engine &aengine, prop_kind aprop_kind,
6123 algorithm aalgorithm, const memory::desc &src_desc,
6124 const memory::desc &dst_desc, memory::dim local_size,
6125 float alpha, float beta, float k,
6126 const primitive_attr &attr = default_attr(),
6127 bool allow_empty = false) {
6128
6129 dnnl_primitive_desc_t pd = nullptr;
6130 dnnl_status_t status = dnnl_lrn_forward_primitive_desc_create(&pd,
6131 aengine.get(), dnnl::convert_to_c(aprop_kind),
6132 convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
6133 local_size, alpha, beta, k, attr.get());
6134
6135 if (!allow_empty)
6136 error::wrap_c_api(status,
6137 "could not create a primitive descriptor for a lrn "
6138 "forward propagation primitive");
6139 reset(pd);
6140 }
6141
6142 /// Constructs a primitive descriptor for an LRN forward propagation
6143 /// primitive from a C API primitive descriptor that must have a
6144 /// matching kind.
6145 ///
6146 /// @param pd C API primitive descriptor for an LRN forward
6147 /// propagation primitive.
6148 primitive_desc(dnnl_primitive_desc_t pd)
6149 : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
6150 dnnl::prop_kind::forward_training,
6151 dnnl::prop_kind::forward_inference) {}
6152
6153 /// @copydoc dnnl::primitive_desc_base::src_desc()const
6154 memory::desc src_desc() const { return base::src_desc(0); }
6155
6156 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
6157 memory::desc dst_desc() const { return base::dst_desc(0); }
6158
6159 /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
6160 memory::desc workspace_desc() const { return base::workspace_desc(); }
6161
6162 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
6163 algorithm get_algorithm() const { return base::get_algorithm(); }
6164
6165 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
6166 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
6167
6168 /// @copydoc dnnl::primitive_desc_base::get_alpha()const
6169 float get_alpha() const { return base::get_alpha(); }
6170
6171 /// @copydoc dnnl::primitive_desc_base::get_beta()const
6172 float get_beta() const { return base::get_beta(); }
6173
6174 /// @copydoc dnnl::primitive_desc_base::get_local_size()const
6175 memory::dim get_local_size() const { return base::get_local_size(); }
6176
6177 /// @copydoc dnnl::primitive_desc_base::get_k()const
6178 float get_k() const { return base::get_k(); }
6179 };
6180
6181 /// Default constructor. Produces an empty object.
6182 lrn_forward() = default;
6183
6184 /// Constructs an LRN forward propagation primitive.
6185 /// @param pd Primitive descriptor for an LRN forward propagation
6186 /// primitive.
6187 lrn_forward(const primitive_desc &pd) : primitive(pd) {}
6188
6189 /// Constructs an LRN forward propagation primitive from a cache blob.
6190 /// @param pd Primitive descriptor for an LRN forward propagation
6191 /// primitive.
6192 /// @param cache_blob Cache blob.
6193 lrn_forward(
6194 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
6195 : primitive(pd, cache_blob) {}
6196};
6197
6198/// Local response normalization (LRN) backward propagation primitive.
6199struct lrn_backward : public primitive {
6200 /// Primitive descriptor for an LRN backward propagation primitive.
6201 struct primitive_desc : public dnnl::primitive_desc {
6202 /// Default constructor. Produces an empty object.
6203 primitive_desc() = default;
6204
6205 /// Constructs a primitive descriptor for an LRN backward propagation
6206 /// primitive.
6207 ///
6208 /// @param aengine Engine to use.
6209 /// @param aalgorithm LRN algorithm kind: either
6210 /// #dnnl::algorithm::lrn_across_channels, or
6211 /// #dnnl::algorithm::lrn_within_channel.
6212 /// @param diff_src_desc Diff source memory descriptor.
6213 /// @param diff_dst_desc Diff destination memory descriptor.
6214 /// @param src_desc Source memory descriptor.
6215 /// @param local_size Regularization local size.
6216 /// @param alpha The alpha regularization parameter.
6217 /// @param beta The beta regularization parameter.
6218 /// @param k The k regularization parameter.
6219 /// @param hint_fwd_pd Primitive descriptor for an LRN forward
6220 /// propagation primitive. It is used as a hint for deciding which
6221 /// memory format to use.
6222 /// @param attr Primitive attributes to use. Attributes are optional
6223 /// and default to empty attributes.
6224 /// @param allow_empty A flag signifying whether construction is
6225 /// allowed to fail without throwing an exception. In this case an
6226 /// empty object will be produced. This flag is optional and
6227 /// defaults to false.
6228 primitive_desc(const engine &aengine, algorithm aalgorithm,
6229 const memory::desc &diff_src_desc,
6230 const memory::desc &diff_dst_desc, const memory::desc &src_desc,
6231 memory::dim local_size, float alpha, float beta, float k,
6232 const lrn_forward::primitive_desc &hint_fwd_pd,
6233 const primitive_attr &attr = default_attr(),
6234 bool allow_empty = false) {
6235
6236 dnnl_primitive_desc_t pd = nullptr;
6237 dnnl_status_t status = dnnl_lrn_backward_primitive_desc_create(&pd,
6238 aengine.get(), convert_to_c(aalgorithm),
6239 diff_src_desc.get(), diff_dst_desc.get(), src_desc.get(),
6240 local_size, alpha, beta, k, hint_fwd_pd.get(), attr.get());
6241
6242 if (!allow_empty)
6243 error::wrap_c_api(status,
6244 "could not create a primitive descriptor for a lrn "
6245 "backward propagation primitive");
6246 reset(pd);
6247 }
6248
6249 /// Constructs a primitive descriptor for an LRN backward propagation
6250 /// primitive from a C API primitive descriptor that must have a
6251 /// matching kind.
6252 ///
6253 /// @param pd C API primitive descriptor for an LRN backward
6254 /// propagation primitive.
6255 primitive_desc(dnnl_primitive_desc_t pd)
6256 : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
6257 dnnl::prop_kind::backward_data) {}
6258
6259 /// @copydoc dnnl::primitive_desc_base::src_desc()const
6260 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
6261
6262 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
6263 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
6264
6265 /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
6266 memory::desc workspace_desc() const { return base::workspace_desc(); }
6267
6268 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
6269 algorithm get_algorithm() const { return base::get_algorithm(); }
6270
6271 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
6272 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
6273
6274 /// @copydoc dnnl::primitive_desc_base::get_alpha()const
6275 float get_alpha() const { return base::get_alpha(); }
6276
6277 /// @copydoc dnnl::primitive_desc_base::get_beta()const
6278 float get_beta() const { return base::get_beta(); }
6279
6280 /// @copydoc dnnl::primitive_desc_base::get_local_size()const
6281 memory::dim get_local_size() const { return base::get_local_size(); }
6282
6283 /// @copydoc dnnl::primitive_desc_base::get_k()const
6284 float get_k() const { return base::get_k(); }
6285 };
6286
6287 /// Default constructor. Produces an empty object.
6288 lrn_backward() = default;
6289
6290 /// Constructs an LRN backward propagation primitive.
6291 /// @param pd Primitive descriptor for an LRN backward propagation
6292 /// primitive.
6293 lrn_backward(const primitive_desc &pd) : primitive(pd) {}
6294
6295 /// Constructs an LRN backward propagation primitive from a cache blob.
6296 /// @param pd Primitive descriptor for an LRN backward propagation
6297 /// primitive.
6298 /// @param cache_blob Cache blob.
6299 lrn_backward(
6300 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
6301 : primitive(pd, cache_blob) {}
6302};
6303
6304/// @} dnnl_api_lrn
6305
6306/// @addtogroup dnnl_api_eltwise Eltwise
6307///
6308/// A primitive to perform elementwise operations such as the
6309/// rectifier linear unit (ReLU).
6310///
6311/// Both forward and backward propagation primitives support in-place
6312/// operation; that is, src and dst can refer to the same memory for forward
6313/// propagation, and diff_dst and diff_src can refer to the same memory for
6314/// backward propagation.
6315///
6316/// @warning
6317/// Because the original source data is required for backward propagation,
6318/// in-place forward propagation is not generally supported in the
6319/// training mode. However, for algorithms supporting destination as input
6320/// memory, dst can be used for the backward propagation, which makes it
6321/// possible to get performance benefit even in the training mode.
6322///
6323/// @sa @ref dev_guide_eltwise in developer guide
6324///
6325/// @{
6326
6327/// Elementwise unary operation forward propagation primitive.
6328struct eltwise_forward : public primitive {
6329 /// Primitive descriptor for an elementwise forward propagation primitive.
6330 struct primitive_desc : public dnnl::primitive_desc {
6331 /// Default constructor. Produces an empty object.
6332 primitive_desc() = default;
6333
6334 /// Constructs a primitive descriptor for an elementwise forward
6335 /// propagation primitive.
6336 ///
6337 /// @param aengine Engine to use.
6338 /// @param aprop_kind Propagation kind. Possible values are
6339 /// #dnnl::prop_kind::forward_training, and
6340 /// #dnnl::prop_kind::forward_inference.
6341 /// @param aalgorithm Elementwise algorithm kind.
6342 /// @param src_desc Source memory descriptor.
6343 /// @param dst_desc Destination memory descriptor.
6344 /// @param attr Primitive attributes to use. Attributes are optional
6345 /// and default to empty attributes.
6346 /// @param allow_empty A flag signifying whether construction is
6347 /// allowed to fail without throwing an exception. In this case an
6348 /// empty object will be produced. This flag is optional and
6349 /// defaults to false.
6350 primitive_desc(const engine &aengine, prop_kind aprop_kind,
6351 algorithm aalgorithm, const memory::desc &src_desc,
6352 const memory::desc &dst_desc,
6353 const primitive_attr &attr = default_attr(),
6354 bool allow_empty = false)
6355 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
6356 dst_desc, nullptr, nullptr, attr, allow_empty) {}
6357
6358 /// Constructs a primitive descriptor for an elementwise forward
6359 /// propagation primitive with an alpha parameter.
6360 ///
6361 /// @param aengine Engine to use.
6362 /// @param aprop_kind Propagation kind. Possible values are
6363 /// #dnnl::prop_kind::forward_training, and
6364 /// #dnnl::prop_kind::forward_inference.
6365 /// @param aalgorithm Elementwise algorithm kind.
6366 /// @param src_desc Source memory descriptor.
6367 /// @param dst_desc Destination memory descriptor.
6368 /// @param alpha The alpha parameter for the elementwise operation.
6369 /// Specific meaning depends on the algorithm.
6370 /// @param attr Primitive attributes to use. Attributes are optional
6371 /// and default to empty attributes.
6372 /// @param allow_empty A flag signifying whether construction is
6373 /// allowed to fail without throwing an exception. In this case an
6374 /// empty object will be produced. This flag is optional and
6375 /// defaults to false.
6376 primitive_desc(const engine &aengine, prop_kind aprop_kind,
6377 algorithm aalgorithm, const memory::desc &src_desc,
6378 const memory::desc &dst_desc, float alpha,
6379 const primitive_attr &attr = default_attr(),
6380 bool allow_empty = false)
6381 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
6382 dst_desc, &alpha, nullptr, attr, allow_empty) {}
6383
6384 /// Constructs a primitive descriptor for an elementwise forward
6385 /// propagation primitive with an alpha and beta parameters.
6386 ///
6387 /// @param aengine Engine to use.
6388 /// @param aprop_kind Propagation kind. Possible values are
6389 /// #dnnl::prop_kind::forward_training, and
6390 /// #dnnl::prop_kind::forward_inference.
6391 /// @param aalgorithm Elementwise algorithm kind.
6392 /// @param src_desc Source memory descriptor.
6393 /// @param dst_desc Destination memory descriptor.
6394 /// @param alpha The alpha parameter for the elementwise operation.
6395 /// Specific meaning depends on the algorithm.
6396 /// @param beta The beta parameter for the elementwise operation.
6397 /// Specific meaning depends on the algorithm.
6398 /// @param attr Primitive attributes to use. Attributes are optional
6399 /// and default to empty attributes.
6400 /// @param allow_empty A flag signifying whether construction is
6401 /// allowed to fail without throwing an exception. In this case an
6402 /// empty object will be produced. This flag is optional and
6403 /// defaults to false.
6404 primitive_desc(const engine &aengine, prop_kind aprop_kind,
6405 algorithm aalgorithm, const memory::desc &src_desc,
6406 const memory::desc &dst_desc, float alpha, float beta,
6407 const primitive_attr &attr = default_attr(),
6408 bool allow_empty = false)
6409 : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
6410 dst_desc, &alpha, &beta, attr, allow_empty) {}
6411
6412 /// Constructs a primitive descriptor for an eltwise forward
6413 /// propagation primitive from a C API primitive descriptor that must
6414 /// have a matching kind.
6415 ///
6416 /// @param pd C API primitive descriptor for an eltwise forward
6417 /// propagation primitive.
6418 primitive_desc(dnnl_primitive_desc_t pd)
6419 : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
6420 dnnl::prop_kind::forward_training,
6421 dnnl::prop_kind::forward_inference) {}
6422
6423 /// @copydoc dnnl::primitive_desc_base::src_desc()const
6424 memory::desc src_desc() const { return base::src_desc(0); }
6425
6426 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
6427 memory::desc dst_desc() const { return base::dst_desc(0); }
6428
6429 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
6430 dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
6431
6432 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
6433 dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
6434
6435 /// @copydoc dnnl::primitive_desc_base::get_alpha()const
6436 float get_alpha() const { return base::get_alpha(); }
6437
6438 /// @copydoc dnnl::primitive_desc_base::get_beta()const
6439 float get_beta() const { return base::get_beta(); }
6440
6441 private:
6442 primitive_desc(const engine &aengine, prop_kind aprop_kind,
6443 algorithm aalgorithm, const memory::desc &src_desc,
6444 const memory::desc &dst_desc, const float *alpha,
6445 const float *beta, const primitive_attr &attr,
6446 bool allow_empty) {
6447
6448 dnnl_primitive_desc_t pd = nullptr;
6449 dnnl_status_t status = dnnl_eltwise_forward_primitive_desc_create(
6450 &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
6451 dnnl::convert_to_c(aalgorithm), src_desc.get(),
6452 dst_desc.get(), alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
6453 attr.get());
6454
6455 if (!allow_empty)
6456 error::wrap_c_api(status,
6457 "could not create a primitive descriptor for an "
6458 "eltwise forward propagation primitive");
6459 reset(pd);
6460 }
6461 };
6462
6463 /// Default constructor. Produces an empty object.
6464 eltwise_forward() = default;
6465
6466 /// Constructs an eltwise forward propagation primitive.
6467 /// @param pd Primitive descriptor for an eltwise forward propagation
6468 /// primitive.
6469 eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
6470
6471 /// Constructs an eltwise forward propagation primitive from a cache blob.
6472 /// @param pd Primitive descriptor for an eltwise forward propagation
6473 /// primitive.
6474 /// @param cache_blob Cache blob.
6475 eltwise_forward(
6476 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
6477 : primitive(pd, cache_blob) {}
6478};
6479
6480/// Elementwise unary operation backward propagation primitive.
6481struct eltwise_backward : public primitive {
6482 /// Primitive descriptor for eltwise backward propagation.
6483 struct primitive_desc : public dnnl::primitive_desc {
6484 /// Default constructor. Produces an empty object.
6485 primitive_desc() = default;
6486
6487 /// Constructs a primitive descriptor for an elementwise backward
6488 /// propagation primitive with an alpha parameter.
6489 ///
6490 /// @param aengine Engine to use.
6491 /// @param aalgorithm Elementwise algorithm kind.
6492 /// @param diff_src_desc Diff source memory descriptor.
6493 /// @param diff_dst_desc Diff destination memory descriptor.
6494 /// @param data_desc Destination memory descriptor if one of the
6495 /// "use_dst_for_bwd" algorithms are used (such as
6496 /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
6497 /// otherwise.
6498 /// @param hint_fwd_pd Primitive descriptor for an elementwise
6499 /// forward propagation primitive. It is used as a hint for
6500 /// deciding which memory format to use.
6501 /// @param attr Primitive attributes to use. Attributes are optional
6502 /// and default to empty attributes.
6503 /// @param allow_empty A flag signifying whether construction is
6504 /// allowed to fail without throwing an exception. In this case an
6505 /// empty object will be produced. This flag is optional and
6506 /// defaults to false.
6507 primitive_desc(const engine &aengine, algorithm aalgorithm,
6508 const memory::desc &diff_src_desc,
6509 const memory::desc &diff_dst_desc,
6510 const memory::desc &data_desc,
6511 const eltwise_forward::primitive_desc &hint_fwd_pd,
6512 const primitive_attr &attr = default_attr(),
6513 bool allow_empty = false)
6514 : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
6515 data_desc, nullptr, nullptr, hint_fwd_pd, attr,
6516 allow_empty) {}
6517
6518 /// Constructs a primitive descriptor for an elementwise backward
6519 /// propagation primitive with an alpha parameter.
6520 ///
6521 /// @param aengine Engine to use.
6522 /// @param aalgorithm Elementwise algorithm kind.
6523 /// @param diff_src_desc Diff source memory descriptor.
6524 /// @param diff_dst_desc Diff destination memory descriptor.
6525 /// @param data_desc Destination memory descriptor if one of the
6526 /// "use_dst_for_bwd" algorithms are used (such as
6527 /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
6528 /// otherwise.
6529 /// @param alpha The alpha parameter for the elementwise operation.
6530 /// Specific meaning depends on the algorithm.
6531 /// @param hint_fwd_pd Primitive descriptor for an elementwise
6532 /// forward propagation primitive. It is used as a hint for
6533 /// deciding which memory format to use.
6534 /// @param attr Primitive attributes to use. Attributes are optional
6535 /// and default to empty attributes.
6536 /// @param allow_empty A flag signifying whether construction is
6537 /// allowed to fail without throwing an exception. In this case an
6538 /// empty object will be produced. This flag is optional and
6539 /// defaults to false.
6540 primitive_desc(const engine &aengine, algorithm aalgorithm,
6541 const memory::desc &diff_src_desc,
6542 const memory::desc &diff_dst_desc,
6543 const memory::desc &data_desc, float alpha,
6544 const eltwise_forward::primitive_desc &hint_fwd_pd,
6545 const primitive_attr &attr = default_attr(),
6546 bool allow_empty = false)
6547 : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
6548 data_desc, &alpha, nullptr, hint_fwd_pd, attr,
6549 allow_empty) {}
6550
6551 /// Constructs a primitive descriptor for an elementwise backward
6552 /// propagation primitive with an alpha and beta parameters.
6553 ///
6554 /// @param aengine Engine to use.
6555 /// @param aalgorithm Elementwise algorithm kind.
6556 /// @param diff_src_desc Diff source memory descriptor.
6557 /// @param diff_dst_desc Diff destination memory descriptor.
6558 /// @param data_desc Destination memory descriptor if one of the
6559 /// "use_dst_for_bwd" algorithms are used (such as
6560 /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
6561 /// otherwise.
6562 /// @param alpha The alpha parameter for the elementwise operation.
6563 /// Specific meaning depends on the algorithm.
6564 /// @param beta The beta parameter for the elementwise operation.
6565 /// Specific meaning depends on the algorithm.
6566 /// @param hint_fwd_pd Primitive descriptor for an elementwise
6567 /// forward propagation primitive. It is used as a hint for
6568 /// deciding which memory format to use.
6569 /// @param attr Primitive attributes to use. Attributes are optional
6570 /// and default to empty attributes.
6571 /// @param allow_empty A flag signifying whether construction is
6572 /// allowed to fail without throwing an exception. In this case an
6573 /// empty object will be produced. This flag is optional and
6574 /// defaults to false.
6575 primitive_desc(const engine &aengine, algorithm aalgorithm,
6576 const memory::desc &diff_src_desc,
6577 const memory::desc &diff_dst_desc,
6578 const memory::desc &data_desc, float alpha, float beta,
6579 const eltwise_forward::primitive_desc &hint_fwd_pd,
6580 const primitive_attr &attr = default_attr(),
6581 bool allow_empty = false)
6582 : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
6583 data_desc, &alpha, &beta, hint_fwd_pd, attr, allow_empty) {}
6584
6585 /// Constructs a primitive descriptor for an eltwise backward
6586 /// propagation primitive from a C API primitive descriptor that must
6587 /// have a matching kind.
6588 ///
6589 /// @param pd C API primitive descriptor for an eltwise backward
6590 /// propagation primitive.
6591 primitive_desc(dnnl_primitive_desc_t pd)
6592 : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
6593 dnnl::prop_kind::backward_data) {}
6594
6595 /// @copydoc dnnl::primitive_desc_base::src_desc()const
6596 memory::desc src_desc() const { return base::src_desc(0); }
6597
6598 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
6599 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
6600
6601 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
6602 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
6603
6604 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
6605 dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
6606
6607 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
6608 dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
6609
6610 /// @copydoc dnnl::primitive_desc_base::get_alpha()const
6611 float get_alpha() const { return base::get_alpha(); }
6612
6613 /// @copydoc dnnl::primitive_desc_base::get_beta()const
6614 float get_beta() const { return base::get_beta(); }
6615
6616 private:
6617 primitive_desc(const engine &aengine, algorithm aalgorithm,
6618 const memory::desc &diff_src_desc,
6619 const memory::desc &diff_dst_desc,
6620 const memory::desc &data_desc, const float *alpha,
6621 const float *beta,
6622 const eltwise_forward::primitive_desc &hint_fwd_pd,
6623 const primitive_attr &attr, bool allow_empty) {
6624
6625 dnnl_primitive_desc_t pd = nullptr;
6626 dnnl_status_t status = dnnl_eltwise_backward_primitive_desc_create(
6627 &pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
6628 diff_src_desc.get(), diff_dst_desc.get(), data_desc.get(),
6629 alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
6630 hint_fwd_pd.get(), attr.get());
6631
6632 if (!allow_empty)
6633 error::wrap_c_api(status,
6634 "could not create a primitive descriptor for an "
6635 "eltwise backward propagation primitive");
6636 reset(pd);
6637 }
6638 };
6639
6640 /// Default constructor. Produces an empty object.
6641 eltwise_backward() = default;
6642
6643 /// Constructs an eltwise backward propagation primitive.
6644 /// @param pd Primitive descriptor for an eltwise backward propagation
6645 /// primitive.
6646 eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
6647
6648 /// Constructs an eltwise backward propagation primitive from a cache blob.
6649 /// @param pd Primitive descriptor for an eltwise backward propagation
6650 /// primitive.
6651 /// @param cache_blob Cache blob.
6652 eltwise_backward(
6653 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
6654 : primitive(pd, cache_blob) {}
6655};
6656
6657/// @} dnnl_api_eltwise
6658
6659/// @addtogroup dnnl_api_softmax Softmax
6660///
6661/// A primitive to perform softmax.
6662///
6663/// @sa @ref dev_guide_softmax in developer guide
6664///
6665/// @{
6666
6667/// Softmax forward propagation primitive.
6668struct softmax_forward : public primitive {
6669 /// Primitive descriptor for a softmax forward propagation primitive.
6670 struct primitive_desc : public dnnl::primitive_desc {
6671 /// Default constructor. Produces an empty object.
6672 primitive_desc() = default;
6673
6674 /// Constructs a primitive descriptor for a softmax forward propagation
6675 /// primitive.
6676 ///
6677 /// @param aengine Engine to use.
6678 /// @param aprop_kind Propagation kind. Possible values are
6679 /// #dnnl::prop_kind::forward_training, and
6680 /// #dnnl::prop_kind::forward_inference.
6681 /// @param aalgorithm Softmax algorithm kind: either
6682 /// #dnnl::algorithm::softmax_accurate,
6683 /// or #dnnl::algorithm::softmax_log.
6684 /// @param src_desc Source memory descriptor.
6685 /// @param dst_desc Destination memory descriptor.
6686 /// @param axis Axis over which softmax is computed.
6687 /// @param attr Primitive attributes to use. Attributes are optional
6688 /// and default to empty attributes.
6689 /// @param allow_empty A flag signifying whether construction is
6690 /// allowed to fail without throwing an exception. In this case an
6691 /// empty object will be produced. This flag is optional and
6692 /// defaults to false.
6693 primitive_desc(const engine &aengine, prop_kind aprop_kind,
6694 algorithm aalgorithm, const memory::desc &src_desc,
6695 const memory::desc &dst_desc, int axis,
6696 const primitive_attr &attr = default_attr(),
6697 bool allow_empty = false) {
6698
6699 dnnl_primitive_desc_t pd = nullptr;
6700 dnnl_status_t status = dnnl_softmax_forward_primitive_desc_create(
6701 &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
6702 dnnl::convert_to_c(aalgorithm), src_desc.get(),
6703 dst_desc.get(), axis, attr.get());
6704
6705 if (!allow_empty)
6706 error::wrap_c_api(status,
6707 "could not create a primitive descriptor for a softmax "
6708 "forward propagation primitive");
6709 reset(pd);
6710 }
6711
6712 /// Constructs a primitive descriptor for a softmax forward
6713 /// propagation primitive from a C API primitive descriptor that must
6714 /// have a matching kind.
6715 ///
6716 /// @param pd C API primitive descriptor for a softmax forward
6717 /// propagation primitive.
6718 primitive_desc(dnnl_primitive_desc_t pd)
6719 : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
6720 dnnl::prop_kind::forward_training,
6721 dnnl::prop_kind::forward_inference) {}
6722
6723 /// @copydoc dnnl::primitive_desc_base::src_desc()const
6724 memory::desc src_desc() const { return base::src_desc(0); }
6725
6726 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
6727 memory::desc dst_desc() const { return base::dst_desc(0); }
6728
6729 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
6730 dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
6731
6732 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
6733 dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
6734
6735 /// @copydoc dnnl::primitive_desc_base::get_axis()const
6736 int get_axis() const { return base::get_axis(); }
6737 };
6738
6739 /// Default constructor. Produces an empty object.
6740 softmax_forward() = default;
6741
6742 /// Constructs a softmax forward propagation primitive.
6743 /// @param pd Primitive descriptor for a softmax forward propagation
6744 /// primitive.
6745 softmax_forward(const primitive_desc &pd) : primitive(pd) {}
6746
6747 /// Constructs a softmax forward propagation primitive from a cache blob.
6748 /// @param pd Primitive descriptor for a softmax forward propagation
6749 /// primitive.
6750 /// @param cache_blob Cache blob.
6751 softmax_forward(
6752 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
6753 : primitive(pd, cache_blob) {}
6754};
6755
6756/// Softmax backward propagation primitive.
6757struct softmax_backward : public primitive {
6758 /// Primitive descriptor for a softmax backward propagation primitive.
6759 struct primitive_desc : public dnnl::primitive_desc {
6760 /// Default constructor. Produces an empty object.
6761 primitive_desc() = default;
6762
6763 /// Constructs a primitive descriptor for a softmax backward propagation
6764 /// primitive.
6765 ///
6766 /// @param aengine Engine to use.
6767 /// @param aalgorithm Softmax algorithm kind: either
6768 /// #dnnl::algorithm::softmax_accurate,
6769 /// or #dnnl::algorithm::softmax_log.
6770 /// @param diff_src_desc Diff source memory descriptor.
6771 /// @param diff_dst_desc Diff destination memory descriptor.
6772 /// @param dst_desc Destination memory descriptor.
6773 /// @param axis Axis over which softmax is computed.
6774 /// @param hint_fwd_pd Primitive descriptor for a softmax
6775 /// forward propagation primitive. It is used as a hint for
6776 /// deciding which memory format to use.
6777 /// @param attr Primitive attributes to use. Attributes are optional
6778 /// and default to empty attributes.
6779 /// @param allow_empty A flag signifying whether construction is
6780 /// allowed to fail without throwing an exception. In this case an
6781 /// empty object will be produced. This flag is optional and
6782 /// defaults to false.
6783 primitive_desc(const engine &aengine, algorithm aalgorithm,
6784 const memory::desc &diff_src_desc,
6785 const memory::desc &diff_dst_desc, const memory::desc &dst_desc,
6786 int axis, const softmax_forward::primitive_desc &hint_fwd_pd,
6787 const primitive_attr &attr = default_attr(),
6788 bool allow_empty = false) {
6789
6790 dnnl_primitive_desc_t pd = nullptr;
6791 dnnl_status_t status = dnnl_softmax_backward_primitive_desc_create(
6792 &pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
6793 diff_src_desc.get(), diff_dst_desc.get(), dst_desc.get(),
6794 axis, hint_fwd_pd.get(), attr.get());
6795
6796 if (!allow_empty)
6797 error::wrap_c_api(status,
6798 "could not create a primitive descriptor for a softmax "
6799 "backward propagation primitive");
6800 reset(pd);
6801 }
6802
6803 /// Constructs a primitive descriptor for a softmax backward
6804 /// propagation primitive from a C API primitive descriptor that must
6805 /// have a matching kind.
6806 ///
6807 /// @param pd C API primitive descriptor for a softmax backward
6808 /// propagation primitive.
6809 primitive_desc(dnnl_primitive_desc_t pd)
6810 : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
6811 dnnl::prop_kind::backward_data) {}
6812
6813 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
6814 memory::desc dst_desc() const { return base::dst_desc(0); }
6815
6816 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
6817 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
6818
6819 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
6820 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
6821
6822 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
6823 dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
6824
6825 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
6826 dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
6827
6828 /// @copydoc dnnl::primitive_desc_base::get_axis()const
6829 int get_axis() const { return base::get_axis(); }
6830 };
6831
6832 /// Default constructor. Produces an empty object.
6833 softmax_backward() = default;
6834
6835 /// Constructs a softmax backward propagation primitive.
6836 /// @param pd Primitive descriptor for a softmax backward propagation
6837 /// primitive.
6838 softmax_backward(const primitive_desc &pd) : primitive(pd) {}
6839
6840 /// Constructs a softmax backward propagation primitive from a cache blob.
6841 /// @param pd Primitive descriptor for a softmax backward propagation
6842 /// primitive.
6843 /// @param cache_blob Cache blob.
6844 softmax_backward(
6845 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
6846 : primitive(pd, cache_blob) {}
6847};
6848
6849/// @} dnnl_api_softmax
6850
6851/// @addtogroup dnnl_api_batch_normalization Batch Normalization
6852///
6853/// A primitive to perform batch normalization.
6854///
6855/// Both forward and backward propagation primitives support in-place
6856/// operation; that is, src and dst can refer to the same memory for forward
6857/// propagation, and diff_dst and diff_src can refer to the same memory for
6858/// backward propagation.
6859///
6860/// The batch normalization primitives computations can be controlled by
6861/// specifying different @ref dnnl::normalization_flags values. For example,
6862/// batch normalization forward propagation can be configured to either
6863/// compute the mean and variance or take them as arguments. It can either
6864/// perform scaling and shifting using gamma and beta parameters or not.
6865/// Optionally, it can also perform a fused ReLU, which in case of training
6866/// would also require a workspace.
6867///
6868/// @sa @ref dev_guide_batch_normalization in developer guide
6869///
6870/// @{
6871
6872/// Batch normalization forward propagation primitive.
6873struct batch_normalization_forward : public primitive {
6874 /// Primitive descriptor for a batch normalization forward propagation
6875 /// primitive.
6876 struct primitive_desc : public dnnl::primitive_desc {
6877 /// Default constructor. Produces an empty object.
6878 primitive_desc() = default;
6879
6880 /// Constructs a primitive descriptor for a batch normalization forward
6881 /// propagation primitive.
6882 ///
6883 /// @note
6884 /// In-place operation is supported: the dst can refer to the same
6885 /// memory as the src.
6886 ///
6887 /// @param aengine Engine to use.
6888 /// @param aprop_kind Propagation kind. Possible values are
6889 /// #dnnl::prop_kind::forward_training and
6890 /// #dnnl::prop_kind::forward_inference.
6891 /// @param src_desc Source memory descriptor.
6892 /// @param dst_desc Destination memory descriptor.
6893 /// @param epsilon Batch normalization epsilon parameter.
6894 /// @param flags Batch normalization flags (@ref
6895 /// dnnl::normalization_flags).
6896 /// @param attr Primitive attributes to use. Attributes are optional
6897 /// and default to empty attributes.
6898 /// @param allow_empty A flag signifying whether construction is
6899 /// allowed to fail without throwing an exception. In this case an
6900 /// empty object will be produced. This flag is optional and
6901 /// defaults to false.
6902 primitive_desc(const engine &aengine, prop_kind aprop_kind,
6903 const memory::desc &src_desc, const memory::desc &dst_desc,
6904 float epsilon, normalization_flags flags,
6905 const primitive_attr &attr = default_attr(),
6906 bool allow_empty = false) {
6907 dnnl_primitive_desc_t pd = nullptr;
6908 dnnl_status_t status
6909 = dnnl_batch_normalization_forward_primitive_desc_create(
6910 &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
6911 src_desc.get(), dst_desc.get(), epsilon,
6912 convert_to_c(flags), attr.get());
6913
6914 if (!allow_empty)
6915 error::wrap_c_api(status,
6916 "could not create a primitive descriptor for a batch "
6917 "normalization forward propagation primitive");
6918 reset(pd);
6919 }
6920
6921 /// Constructs a primitive descriptor for a batch normalization
6922 /// forward propagation primitive from a C API primitive descriptor
6923 /// that must have a matching kind.
6924 ///
6925 /// @param pd C API primitive descriptor for a batch normalization
6926 /// forward propagation primitive.
6927 primitive_desc(dnnl_primitive_desc_t pd)
6928 : dnnl::primitive_desc(pd,
6929 dnnl::primitive::kind::batch_normalization,
6930 dnnl::prop_kind::forward_training,
6931 dnnl::prop_kind::forward_inference) {}
6932
6933 /// @copydoc dnnl::primitive_desc_base::src_desc()const
6934 memory::desc src_desc() const { return base::src_desc(0); }
6935
6936 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
6937 memory::desc dst_desc() const { return base::dst_desc(0); }
6938
6939 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
6940 memory::desc weights_desc() const { return base::weights_desc(0); }
6941
6942 /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
6943 memory::desc workspace_desc() const { return base::workspace_desc(); }
6944
6945 /// Returns memory descriptor for mean.
6946 /// @returns Memory descriptor for mean.
6947 memory::desc mean_desc() const { return stat_desc(mean); }
6948
6949 /// Returns memory descriptor for variance.
6950 /// @returns Memory descriptor for variance.
6951 memory::desc variance_desc() const { return stat_desc(var); }
6952
6953 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
6954 dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
6955
6956 /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
6957 float get_epsilon() const { return base::get_epsilon(); }
6958
6959 /// Returns normalization flags.
6960 /// @return Normalization flags.
6961 normalization_flags get_flags() const {
6962 return base::get_flags<normalization_flags>();
6963 }
6964
6965 private:
6966 enum {
6967 mean = 1,
6968 var = 2,
6969 };
6970 memory::desc stat_desc(int kind) const {
6971 const bool use_global_stats
6972 = (get_flags() & normalization_flags::use_global_stats)
6973 != normalization_flags::none;
6974 return query_md(
6975 use_global_stats ? query::src_md : query::dst_md, kind);
6976 }
6977 };
6978
6979 /// Default constructor. Produces an empty object.
6980 batch_normalization_forward() = default;
6981
6982 /// Constructs a batch normalization forward propagation primitive.
6983 /// @param pd Primitive descriptor for a batch normalization forward
6984 /// propagation primitive.
6985 batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
6986
6987 /// Constructs a batch normalization forward propagation primitive from
6988 /// a cache blob.
6989 /// @param pd Primitive descriptor for a batch normalization forward
6990 /// propagation primitive.
6991 /// @param cache_blob Cache blob.
6992 batch_normalization_forward(
6993 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
6994 : primitive(pd, cache_blob) {}
6995};
6996
6997/// Batch normalization backward propagation primitive.
6998struct batch_normalization_backward : public primitive {
6999 /// Primitive descriptor for a batch normalization backward propagation
7000 /// primitive.
7001 struct primitive_desc : public dnnl::primitive_desc {
7002 /// Default constructor. Produces an empty object.
7003 primitive_desc() = default;
7004
7005 /// Constructs a primitive descriptor for a batch normalization backward
7006 /// propagation primitive.
7007 ///
7008 /// @param aengine Engine to use.
7009 /// @param aprop_kind Propagation kind. Possible values are
7010 /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
7011 /// (diffs for all parameters are computed in this case).
7012 /// @param diff_src_desc Diff source memory descriptor.
7013 /// @param diff_dst_desc Diff destination memory descriptor.
7014 /// @param src_desc Source memory descriptor.
7015 /// @param epsilon Batch normalization epsilon parameter.
7016 /// @param flags Batch normalization flags (@ref
7017 /// dnnl::normalization_flags).
7018 /// @param hint_fwd_pd Primitive descriptor for a batch normalization
7019 /// forward propagation primitive. It is used as a hint for
7020 /// deciding which memory format to use.
7021 /// @param attr Primitive attributes to use. Attributes are optional
7022 /// and default to empty attributes.
7023 /// @param allow_empty A flag signifying whether construction is
7024 /// allowed to fail without throwing an exception. In this case an
7025 /// empty object will be produced. This flag is optional and
7026 /// defaults to false.
7027 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7028 const memory::desc &diff_src_desc,
7029 const memory::desc &diff_dst_desc, const memory::desc &src_desc,
7030 float epsilon, normalization_flags flags,
7031 const batch_normalization_forward::primitive_desc &hint_fwd_pd,
7032 const primitive_attr &attr = default_attr(),
7033 bool allow_empty = false) {
7034 dnnl_primitive_desc_t pd = nullptr;
7035 dnnl_status_t status
7036 = dnnl_batch_normalization_backward_primitive_desc_create(
7037 &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
7038 diff_src_desc.get(), diff_dst_desc.get(),
7039 src_desc.get(), epsilon, convert_to_c(flags),
7040 hint_fwd_pd.get(), attr.get());
7041
7042 if (!allow_empty)
7043 error::wrap_c_api(status,
7044 "could not create a primitive descriptor for a batch "
7045 "normalization backward propagation primitive");
7046 reset(pd);
7047 }
7048
7049 /// Constructs a primitive descriptor for a batch normalization
7050 /// backward propagation primitive from a C API primitive descriptor
7051 /// that must have a matching kind.
7052 ///
7053 /// @param pd C API primitive descriptor for a batch normalization
7054 /// backward propagation primitive.
7055 primitive_desc(dnnl_primitive_desc_t pd)
7056 : dnnl::primitive_desc(pd,
7057 dnnl::primitive::kind::batch_normalization,
7058 dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
7059 }
7060
7061 /// @copydoc dnnl::primitive_desc_base::src_desc()const
7062 memory::desc src_desc() const { return base::src_desc(0); }
7063
7064 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
7065 memory::desc weights_desc() const { return base::weights_desc(0); }
7066
7067 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
7068 memory::desc dst_desc() const { return base::dst_desc(0); }
7069
7070 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
7071 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
7072
7073 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
7074 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
7075
7076 /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
7077 memory::desc diff_weights_desc() const {
7078 return base::diff_weights_desc(0);
7079 }
7080
7081 /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
7082 memory::desc mean_desc() const { return query_md(query::src_md, 1); }
7083
7084 /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
7085 memory::desc variance_desc() const {
7086 return query_md(query::src_md, 2);
7087 }
7088
7089 /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
7090 memory::desc workspace_desc() const { return base::workspace_desc(); }
7091
7092 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
7093 dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
7094
7095 /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
7096 float get_epsilon() const { return base::get_epsilon(); }
7097
7098 /// Returns normalization flags.
7099 /// @return Normalization flags.
7100 normalization_flags get_flags() const {
7101 return base::get_flags<normalization_flags>();
7102 }
7103 };
7104
7105 /// Default constructor. Produces an empty object.
7106 batch_normalization_backward() = default;
7107
7108 /// Constructs a batch normalization backward propagation primitive.
7109 /// @param pd Primitive descriptor for a batch normalization backward
7110 /// propagation primitive.
7111 batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
7112
7113 /// Constructs a batch normalization backward propagation primitive from
7114 /// a cache blob.
7115 /// @param pd Primitive descriptor for a batch normalization backward
7116 /// propagation primitive.
7117 /// @param cache_blob Cache blob.
7118 batch_normalization_backward(
7119 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
7120 : primitive(pd, cache_blob) {}
7121};
7122
7123/// @} dnnl_api_batch_normalization
7124
7125/// @addtogroup dnnl_api_layer_normalization Layer Normalization
7126///
7127/// A primitive to perform layer normalization. Normalization is performed
7128/// within the last logical dimension of data tensor.
7129///
7130/// Both forward and backward propagation primitives support in-place
7131/// operation; that is, src and dst can refer to the same memory for forward
7132/// propagation, and diff_dst and diff_src can refer to the same memory for
7133/// backward propagation.
7134///
7135/// The layer normalization primitives computations can be controlled by
7136/// specifying different @ref dnnl::normalization_flags values. For example,
7137/// layer normalization forward propagation can be configured to either
7138/// compute the mean and variance or take them as arguments. It can either
7139/// perform scaling and shifting using gamma and beta parameters or not.
7140///
7141/// @sa @ref dev_guide_layer_normalization in developer guide
7142///
7143/// @{
7144
7145/// Layer normalization forward propagation primitive.
7146struct layer_normalization_forward : public primitive {
7147 /// Primitive descriptor for a layer normalization forward propagation
7148 /// primitive.
7149 struct primitive_desc : public dnnl::primitive_desc {
7150 /// Default constructor. Produces an empty object.
7151 primitive_desc() = default;
7152
7153 /// Constructs a primitive descriptor for a layer normalization forward
7154 /// propagation primitive.
7155 ///
7156 /// @param aengine Engine to use.
7157 /// @param aprop_kind Propagation kind. Possible values are
7158 /// #dnnl::prop_kind::forward_training, and
7159 /// #dnnl::prop_kind::forward_inference.
7160 /// @param src_desc Source memory descriptor.
7161 /// @param dst_desc Destination memory descriptor.
7162 /// @param stat_desc Statistics memory descriptors.
7163 /// @param epsilon Layer normalization epsilon parameter.
7164 /// @param flags Layer normalization flags (@ref
7165 /// dnnl::normalization_flags).
7166 /// @param attr Primitive attributes to use. Attributes are optional
7167 /// and default to empty attributes.
7168 /// @param allow_empty A flag signifying whether construction is
7169 /// allowed to fail without throwing an exception. In this case an
7170 /// empty object will be produced. This flag is optional and
7171 /// defaults to false.
7172 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7173 const memory::desc &src_desc, const memory::desc &dst_desc,
7174 const memory::desc &stat_desc, float epsilon,
7175 normalization_flags flags,
7176 const primitive_attr &attr = default_attr(),
7177 bool allow_empty = false)
7178 : primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
7179 &stat_desc, epsilon, flags, attr, allow_empty) {}
7180
7181 /// Constructs a primitive descriptor for a layer normalization forward
7182 /// propagation primitive.
7183 ///
7184 /// @param aengine Engine to use.
7185 /// @param aprop_kind Propagation kind. Possible values are
7186 /// #dnnl::prop_kind::forward_training, and
7187 /// #dnnl::prop_kind::forward_inference.
7188 /// @param src_desc Source memory descriptor.
7189 /// @param dst_desc Destination memory descriptor.
7190 /// @param epsilon Layer normalization epsilon parameter.
7191 /// @param flags Layer normalization flags (@ref
7192 /// dnnl::normalization_flags).
7193 /// @param attr Primitive attributes to use. Attributes are optional
7194 /// and default to empty attributes.
7195 /// @param allow_empty A flag signifying whether construction is
7196 /// allowed to fail without throwing an exception. In this case an
7197 /// empty object will be produced. This flag is optional and
7198 /// defaults to false.
7199 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7200 const memory::desc &src_desc, const memory::desc &dst_desc,
7201 float epsilon, normalization_flags flags,
7202 const primitive_attr &attr = default_attr(),
7203 bool allow_empty = false)
7204 : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
7205 epsilon, flags, attr, allow_empty) {}
7206
7207 /// Constructs a primitive descriptor for a layer normalization
7208 /// forward propagation primitive from a C API primitive descriptor
7209 /// that must have a matching kind.
7210 ///
7211 /// @param pd C API primitive descriptor for a layer normalization
7212 /// forward propagation primitive.
7213 primitive_desc(dnnl_primitive_desc_t pd)
7214 : dnnl::primitive_desc(pd,
7215 dnnl::primitive::kind::layer_normalization,
7216 dnnl::prop_kind::forward_training,
7217 dnnl::prop_kind::forward_inference) {}
7218
7219 /// @copydoc dnnl::primitive_desc_base::src_desc()const
7220 memory::desc src_desc() const { return base::src_desc(0); }
7221
7222 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
7223 memory::desc dst_desc() const { return base::dst_desc(0); }
7224
7225 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
7226 memory::desc weights_desc() const { return base::weights_desc(0); }
7227
7228 /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
7229 memory::desc workspace_desc() const { return base::workspace_desc(); }
7230
7231 /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
7232 memory::desc mean_desc() const { return stat_desc(mean); }
7233
7234 /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
7235 memory::desc variance_desc() const { return stat_desc(var); }
7236
7237 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
7238 dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
7239
7240 /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
7241 float get_epsilon() const { return base::get_epsilon(); }
7242
7243 /// Returns normalization flags.
7244 /// @return Normalization flags.
7245 normalization_flags get_flags() const {
7246 return base::get_flags<normalization_flags>();
7247 }
7248
7249 private:
7250 enum {
7251 mean = 1,
7252 var = 2,
7253 };
7254 memory::desc stat_desc(int kind) const {
7255 const bool use_global_stats
7256 = (get_flags() & normalization_flags::use_global_stats)
7257 != normalization_flags::none;
7258 return query_md(
7259 use_global_stats ? query::src_md : query::dst_md, kind);
7260 }
7261
7262 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7263 const memory::desc &src_desc, const memory::desc &dst_desc,
7264 const memory::desc *stat_desc, float epsilon,
7265 normalization_flags flags, const primitive_attr &attr,
7266 bool allow_empty) {
7267
7268 dnnl_primitive_desc_t pd = nullptr;
7269 dnnl_status_t status
7270 = dnnl_layer_normalization_forward_primitive_desc_create(
7271 &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
7272 src_desc.get(), dst_desc.get(),
7273 optional_arg(stat_desc), epsilon,
7274 convert_to_c(flags), attr.get());
7275
7276 if (!allow_empty)
7277 error::wrap_c_api(status,
7278 "could not create a primitive descriptor for a layer "
7279 "normalization forward propagation primitive");
7280 reset(pd);
7281 }
7282 };
7283
7284 /// Default constructor. Produces an empty object.
7285 layer_normalization_forward() = default;
7286
7287 /// Constructs a layer normalization forward propagation primitive.
7288 /// @param pd Primitive descriptor for a layer normalization forward
7289 /// propagation primitive.
7290 layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
7291
7292 /// Constructs a layer normalization forward propagation primitive from
7293 /// a cache blob.
7294 /// @param pd Primitive descriptor for a layer normalization forward
7295 /// propagation primitive.
7296 /// @param cache_blob Cache blob.
7297 layer_normalization_forward(
7298 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
7299 : primitive(pd, cache_blob) {}
7300};
7301
7302/// Layer normalization backward propagation primitive.
7303struct layer_normalization_backward : public primitive {
7304 /// Primitive descriptor for a layer normalization backward propagation
7305 /// primitive.
7306 struct primitive_desc : public dnnl::primitive_desc {
7307 /// Default constructor. Produces an empty object.
7308 primitive_desc() = default;
7309
7310 /// Constructs a primitive descriptor for a layer normalization backward
7311 /// propagation primitive.
7312 ///
7313 /// @param aengine Engine to use.
7314 /// @param aprop_kind Propagation kind. Possible values are
7315 /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
7316 /// (diffs for all parameters are computed in this case).
7317 /// @param diff_src_desc Diff source memory descriptor.
7318 /// @param diff_dst_desc Diff destination memory descriptor.
7319 /// @param src_desc Source memory descriptor.
7320 /// @param stat_desc Statistics memory descriptors.
7321 /// @param epsilon Layer normalization epsilon parameter.
7322 /// @param flags Layer normalization flags (@ref
7323 /// dnnl::normalization_flags).
7324 /// @param attr Primitive attributes to use. Attributes are optional
7325 /// and default to empty attributes.
7326 /// @param hint_fwd_pd Primitive descriptor for a layer normalization
7327 /// forward propagation primitive. It is used as a hint for
7328 /// deciding which memory format to use.
7329 /// @param allow_empty A flag signifying whether construction is
7330 /// allowed to fail without throwing an exception. In this case an
7331 /// empty object will be produced. This flag is optional and
7332 /// defaults to false.
7333 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7334 const memory::desc &diff_src_desc,
7335 const memory::desc &diff_dst_desc, const memory::desc &src_desc,
7336 const memory::desc &stat_desc, float epsilon,
7337 normalization_flags flags,
7338 const layer_normalization_forward::primitive_desc &hint_fwd_pd,
7339 const primitive_attr &attr = default_attr(),
7340 bool allow_empty = false)
7341 : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
7342 src_desc, &stat_desc, epsilon, flags, hint_fwd_pd, attr,
7343 allow_empty) {}
7344
7345 /// Constructs a primitive descriptor for a layer normalization backward
7346 /// propagation primitive.
7347 ///
7348 /// @param aengine Engine to use.
7349 /// @param aprop_kind Propagation kind. Possible values are
7350 /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
7351 /// (diffs for all parameters are computed in this case).
7352 /// @param diff_src_desc Diff source memory descriptor.
7353 /// @param diff_dst_desc Diff destination memory descriptor.
7354 /// @param src_desc Source memory descriptor.
7355 /// @param epsilon Layer normalization epsilon parameter.
7356 /// @param flags Layer normalization flags (@ref
7357 /// dnnl::normalization_flags).
7358 /// @param attr Primitive attributes to use. Attributes are optional
7359 /// and default to empty attributes.
7360 /// @param hint_fwd_pd Primitive descriptor for a layer normalization
7361 /// forward propagation primitive. It is used as a hint for
7362 /// deciding which memory format to use.
7363 /// @param allow_empty A flag signifying whether construction is
7364 /// allowed to fail without throwing an exception. In this case an
7365 /// empty object will be produced. This flag is optional and
7366 /// defaults to false.
7367 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7368 const memory::desc &diff_src_desc,
7369 const memory::desc &diff_dst_desc, const memory::desc &src_desc,
7370 float epsilon, normalization_flags flags,
7371 const layer_normalization_forward::primitive_desc &hint_fwd_pd,
7372 const primitive_attr &attr = default_attr(),
7373 bool allow_empty = false)
7374 : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
7375 src_desc, nullptr, epsilon, flags, hint_fwd_pd, attr,
7376 allow_empty) {}
7377
7378 /// Constructs a primitive descriptor for a layer normalization
7379 /// backward propagation primitive from a C API primitive descriptor
7380 /// that must have a matching kind.
7381 ///
7382 /// @param pd C API primitive descriptor for a layer normalization
7383 /// backward propagation primitive.
7384 primitive_desc(dnnl_primitive_desc_t pd)
7385 : dnnl::primitive_desc(pd,
7386 dnnl::primitive::kind::layer_normalization,
7387 dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
7388 }
7389
7390 /// @copydoc dnnl::primitive_desc_base::src_desc()const
7391 memory::desc src_desc() const { return base::src_desc(0); }
7392
7393 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
7394 memory::desc weights_desc() const { return base::weights_desc(0); }
7395
7396 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
7397 memory::desc dst_desc() const { return base::dst_desc(0); }
7398
7399 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
7400 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
7401
7402 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
7403 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
7404
7405 /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
7406 memory::desc diff_weights_desc() const {
7407 return base::diff_weights_desc(0);
7408 }
7409
7410 /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
7411 memory::desc mean_desc() const { return query_md(query::src_md, 1); }
7412
7413 /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
7414 memory::desc variance_desc() const {
7415 return query_md(query::src_md, 2);
7416 }
7417
7418 /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
7419 memory::desc workspace_desc() const { return base::workspace_desc(); }
7420
7421 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
7422 dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
7423
7424 /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
7425 float get_epsilon() const { return base::get_epsilon(); }
7426
7427 /// Returns normalization flags.
7428 /// @return Normalization flags.
7429 normalization_flags get_flags() const {
7430 return base::get_flags<normalization_flags>();
7431 }
7432
7433 private:
7434 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7435 const memory::desc &diff_src_desc,
7436 const memory::desc &diff_dst_desc, const memory::desc &src_desc,
7437 const memory::desc *stat_desc, float epsilon,
7438 normalization_flags flags,
7439 const layer_normalization_forward::primitive_desc &hint_fwd_pd,
7440 const primitive_attr &attr, bool allow_empty) {
7441
7442 dnnl_primitive_desc_t pd = nullptr;
7443 dnnl_status_t status
7444 = dnnl_layer_normalization_backward_primitive_desc_create(
7445 &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
7446 diff_src_desc.get(), diff_dst_desc.get(),
7447 src_desc.get(), optional_arg(stat_desc), epsilon,
7448 convert_to_c(flags), hint_fwd_pd.get(), attr.get());
7449
7450 if (!allow_empty)
7451 error::wrap_c_api(status,
7452 "could not create a primitive descriptor for a layer "
7453 "normalization backward propagation primitive");
7454 reset(pd);
7455 }
7456 };
7457
7458 /// Default constructor. Produces an empty object.
7459 layer_normalization_backward() = default;
7460
7461 /// Constructs a layer normalization backward propagation primitive.
7462 /// @param pd Primitive descriptor for a layer normalization backward
7463 /// propagation primitive.
7464 layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
7465
7466 /// Constructs a layer normalization backward propagation primitive from
7467 /// a cache blob.
7468 /// @param pd Primitive descriptor for a layer normalization backward
7469 /// propagation primitive.
7470 /// @param cache_blob Cache blob.
7471 layer_normalization_backward(
7472 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
7473 : primitive(pd, cache_blob) {}
7474};
7475
7476/// @} dnnl_api_layer_normalization
7477
7478/// @addtogroup dnnl_api_inner_product Inner Product
7479///
7480/// A primitive to compute an inner product.
7481///
7482/// @sa @ref dev_guide_inner_product in developer guide
7483///
7484/// @{
7485
7486/// Inner product forward propagation primitive.
7487struct inner_product_forward : public primitive {
7488 /// Primitive descriptor for an inner product forward propagation primitive.
7489 struct primitive_desc : public dnnl::primitive_desc {
7490 /// Default constructor. Produces an empty object.
7491 primitive_desc() = default;
7492
7493 /// Constructs a primitive descriptor for an inner product forward
7494 /// propagation primitive with bias.
7495 ///
7496 /// @note
7497 /// All the memory descriptors may be initialized with the
7498 /// #dnnl::memory::format_tag::any value of @p format_tag.
7499 ///
7500 /// @param aengine Engine to use.
7501 /// @param aprop_kind Propagation kind. Possible values are
7502 /// #dnnl::prop_kind::forward_training, and
7503 /// #dnnl::prop_kind::forward_inference.
7504 /// @param src_desc Memory descriptor for src.
7505 /// @param weights_desc Memory descriptor for weights.
7506 /// @param bias_desc Memory descriptor for bias.
7507 /// @param dst_desc Memory descriptor for dst.
7508 /// @param attr Primitive attributes to use. Attributes are optional
7509 /// and default to empty attributes.
7510 /// @param allow_empty A flag signifying whether construction is
7511 /// allowed to fail without throwing an exception. In this case an
7512 /// empty object will be produced. This flag is optional and
7513 /// defaults to false.
7514 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7515 const memory::desc &src_desc, const memory::desc &weights_desc,
7516 const memory::desc &bias_desc, const memory::desc &dst_desc,
7517 const primitive_attr &attr = default_attr(),
7518 bool allow_empty = false)
7519 : primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
7520 &bias_desc, dst_desc, attr, allow_empty) {}
7521
7522 /// Constructs a primitive descriptor for an inner product forward
7523 /// propagation primitive.
7524 ///
7525 /// @note
7526 /// All the memory descriptors may be initialized with the
7527 /// #dnnl::memory::format_tag::any value of @p format_tag.
7528 ///
7529 /// @param aengine Engine to use.
7530 /// @param aprop_kind Propagation kind. Possible values are
7531 /// #dnnl::prop_kind::forward_training, and
7532 /// #dnnl::prop_kind::forward_inference.
7533 /// @param src_desc Memory descriptor for src.
7534 /// @param weights_desc Memory descriptor for weights.
7535 /// @param dst_desc Memory descriptor for dst.
7536 /// @param attr Primitive attributes to use. Attributes are optional
7537 /// and default to empty attributes.
7538 /// @param allow_empty A flag signifying whether construction is
7539 /// allowed to fail without throwing an exception. In this case an
7540 /// empty object will be produced. This flag is optional and
7541 /// defaults to false.
7542 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7543 const memory::desc &src_desc, const memory::desc &weights_desc,
7544 const memory::desc &dst_desc,
7545 const primitive_attr &attr = default_attr(),
7546 bool allow_empty = false)
7547 : primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
7548 nullptr, dst_desc, attr, allow_empty) {}
7549
7550 /// Constructs a primitive descriptor for an inner product forward
7551 /// propagation primitive from a C API primitive descriptor that must
7552 /// have a matching kind.
7553 ///
7554 /// @param pd C API primitive descriptor for an inner product forward
7555 /// propagation primitive.
7556 primitive_desc(dnnl_primitive_desc_t pd)
7557 : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7558 dnnl::prop_kind::forward_training,
7559 dnnl::prop_kind::forward_inference) {}
7560
7561 /// @copydoc dnnl::primitive_desc_base::src_desc()const
7562 memory::desc src_desc() const { return base::src_desc(0); }
7563
7564 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
7565 memory::desc weights_desc() const { return base::weights_desc(0); }
7566
7567 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
7568 memory::desc dst_desc() const { return base::dst_desc(0); }
7569
7570 /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
7571 memory::desc bias_desc() const { return base::weights_desc(1); }
7572
7573 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
7574 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
7575
7576 private:
7577 primitive_desc(const engine &aengine, prop_kind aprop_kind,
7578 const memory::desc &src_desc, const memory::desc &weights_desc,
7579 const memory::desc *bias_desc, const memory::desc &dst_desc,
7580 const primitive_attr &attr, bool allow_empty) {
7581
7582 dnnl_primitive_desc_t pd = nullptr;
7583 dnnl_status_t status
7584 = dnnl_inner_product_forward_primitive_desc_create(&pd,
7585 aengine.get(), dnnl::convert_to_c(aprop_kind),
7586 src_desc.get(), weights_desc.get(),
7587 optional_arg(bias_desc), dst_desc.get(),
7588 attr.get());
7589
7590 if (!allow_empty)
7591 error::wrap_c_api(status,
7592 "could not create a primitive descriptor for an inner "
7593 "product forward propagation primitive");
7594 reset(pd);
7595 }
7596 };
7597
7598 /// Default constructor. Produces an empty object.
7599 inner_product_forward() = default;
7600
7601 /// Constructs an inner product forward propagation primitive.
7602 /// @param pd Primitive descriptor for an inner product forward
7603 /// propagation primitive.
7604 inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
7605
7606 /// Constructs an inner product forward propagation primitive from
7607 /// a cache blob.
7608 /// @param pd Primitive descriptor for an inner product forward
7609 /// propagation primitive.
7610 /// @param cache_blob Cache blob.
7611 inner_product_forward(
7612 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
7613 : primitive(pd, cache_blob) {}
7614};
7615
7616/// Inner product backward propagation primitive.
7617struct inner_product_backward_data : public primitive {
7618 /// Primitive descriptor for an inner product backward propagation
7619 /// primitive.
7620 struct primitive_desc : public dnnl::primitive_desc {
7621 /// Default constructor. Produces an empty object.
7622 primitive_desc() = default;
7623
7624 /// Constructs a primitive descriptor for an inner product backward
7625 /// propagation primitive.
7626 ///
7627 /// @note
7628 /// All the memory descriptors may be initialized with the
7629 /// #dnnl::memory::format_tag::any value of @p format_tag.
7630 ///
7631 /// @param aengine Engine to use.
7632 /// @param diff_src_desc Memory descriptor for diff src.
7633 /// @param weights_desc Memory descriptor for weights.
7634 /// @param diff_dst_desc Memory descriptor for diff dst.
7635 /// @param hint_fwd_pd Primitive descriptor for an inner product
7636 /// forward propagation primitive. It is used as a hint for
7637 /// deciding which memory format to use.
7638 /// @param attr Primitive attributes to use. Attributes are optional
7639 /// and default to empty attributes.
7640 /// @param allow_empty A flag signifying whether construction is
7641 /// allowed to fail without throwing an exception. In this case an
7642 /// empty object will be produced. This flag is optional and
7643 /// defaults to false.
7644 primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
7645 const memory::desc &weights_desc,
7646 const memory::desc &diff_dst_desc,
7647 const inner_product_forward::primitive_desc &hint_fwd_pd,
7648 const primitive_attr &attr = default_attr(),
7649 bool allow_empty = false) {
7650 dnnl_primitive_desc_t pd = nullptr;
7651 dnnl_status_t status
7652 = dnnl_inner_product_backward_data_primitive_desc_create(
7653 &pd, aengine.get(), diff_src_desc.get(),
7654 weights_desc.get(), diff_dst_desc.get(),
7655 hint_fwd_pd.get(), attr.get());
7656
7657 if (!allow_empty)
7658 error::wrap_c_api(status,
7659 "could not create a primitive descriptor for an inner "
7660 "product backward propagation primitive");
7661 reset(pd);
7662 }
7663
7664 /// Constructs a primitive descriptor for an inner product backward
7665 /// propagation primitive from a C API primitive descriptor that must
7666 /// have a matching kind.
7667 ///
7668 /// @param pd C API primitive descriptor for an inner product backward
7669 /// propagation primitive.
7670 primitive_desc(dnnl_primitive_desc_t pd)
7671 : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7672 dnnl::prop_kind::backward_data) {}
7673
7674 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
7675 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
7676
7677 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
7678 memory::desc weights_desc() const { return base::weights_desc(0); }
7679
7680 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
7681 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
7682
7683 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
7684 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
7685 };
7686
7687 /// Default constructor. Produces an empty object.
7688 inner_product_backward_data() = default;
7689
7690 /// Constructs an inner product backward propagation primitive.
7691 /// @param pd Primitive descriptor for an inner product backward
7692 /// propagation primitive.
7693 inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {}
7694
7695 /// Constructs an inner product backward propagation primitive from
7696 /// a cache blob.
7697 /// @param pd Primitive descriptor for an inner product backward
7698 /// propagation primitive.
7699 /// @param cache_blob Cache blob.
7700 inner_product_backward_data(
7701 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
7702 : primitive(pd, cache_blob) {}
7703};
7704
7705/// Inner product weights gradient primitive.
7706struct inner_product_backward_weights : public primitive {
7707 /// Primitive descriptor for an inner product weights gradient primitive.
7708 struct primitive_desc : public dnnl::primitive_desc {
7709 /// Default constructor. Produces an empty object.
7710 primitive_desc() = default;
7711
7712 /// Constructs a primitive descriptor for an inner product weights
7713 /// update primitive with bias.
7714 ///
7715 /// @note
7716 /// All the memory descriptors may be initialized with the
7717 /// #dnnl::memory::format_tag::any value of @p format_tag.
7718 ///
7719 /// @param aengine Engine to use.
7720 /// @param src_desc Memory descriptor for src.
7721 /// @param diff_weights_desc Memory descriptor for diff weights.
7722 /// @param diff_bias_desc Memory descriptor for diff bias.
7723 /// @param diff_dst_desc Memory descriptor for diff dst.
7724 /// @param hint_fwd_pd Primitive descriptor for an inner product
7725 /// forward propagation primitive. It is used as a hint for
7726 /// deciding which memory format to use.
7727 /// @param attr Primitive attributes to use. Attributes are optional
7728 /// and default to empty attributes.
7729 /// @param allow_empty A flag signifying whether construction is
7730 /// allowed to fail without throwing an exception. In this case an
7731 /// empty object will be produced. This flag is optional and
7732 /// defaults to false.
7733 primitive_desc(const engine &aengine, const memory::desc &src_desc,
7734 const memory::desc &diff_weights_desc,
7735 const memory::desc &diff_bias_desc,
7736 const memory::desc &diff_dst_desc,
7737 const inner_product_forward::primitive_desc &hint_fwd_pd,
7738 const primitive_attr &attr = default_attr(),
7739 bool allow_empty = false)
7740 : primitive_desc(aengine, src_desc, diff_weights_desc,
7741 &diff_bias_desc, diff_dst_desc, hint_fwd_pd, attr,
7742 allow_empty) {}
7743
7744 /// Constructs a primitive descriptor for an inner product weights
7745 /// update primitive.
7746 ///
7747 /// @note
7748 /// All the memory descriptors may be initialized with the
7749 /// #dnnl::memory::format_tag::any value of @p format_tag.
7750 ///
7751 /// @param aengine Engine to use.
7752 /// @param src_desc Memory descriptor for src.
7753 /// @param diff_weights_desc Memory descriptor for diff weights.
7754 /// @param diff_dst_desc Memory descriptor for diff dst.
7755 /// @param attr Primitive attributes to use. Attributes are optional
7756 /// and default to empty attributes.
7757 /// @param hint_fwd_pd Primitive descriptor for an inner product
7758 /// forward propagation primitive. It is used as a hint for
7759 /// deciding which memory format to use.
7760 /// @param allow_empty A flag signifying whether construction is
7761 /// allowed to fail without throwing an exception. In this case an
7762 /// empty object will be produced. This flag is optional and
7763 /// defaults to false.
7764 primitive_desc(const engine &aengine, const memory::desc &src_desc,
7765 const memory::desc &diff_weights_desc,
7766 const memory::desc &diff_dst_desc,
7767 const inner_product_forward::primitive_desc &hint_fwd_pd,
7768 const primitive_attr &attr = default_attr(),
7769 bool allow_empty = false)
7770 : primitive_desc(aengine, src_desc, diff_weights_desc, nullptr,
7771 diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
7772
7773 /// Constructs a primitive descriptor for an inner product weights
7774 /// update primitive from a C API primitive descriptor that must
7775 /// have a matching kind.
7776 ///
7777 /// @param pd C API primitive descriptor for an inner product weights
7778 /// gradient primitive.
7779 primitive_desc(dnnl_primitive_desc_t pd)
7780 : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7781 dnnl::prop_kind::backward_weights) {}
7782
7783 /// @copydoc dnnl::primitive_desc_base::src_desc()const
7784 memory::desc src_desc() const { return base::src_desc(0); }
7785
7786 /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
7787 memory::desc diff_weights_desc() const {
7788 return base::diff_weights_desc(0);
7789 }
7790
7791 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
7792 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
7793
7794 /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
7795 memory::desc diff_bias_desc() const {
7796 return base::diff_weights_desc(1);
7797 }
7798
7799 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
7800 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
7801
7802 private:
7803 primitive_desc(const engine &aengine, const memory::desc &src_desc,
7804 const memory::desc &diff_weights_desc,
7805 const memory::desc *diff_bias_desc,
7806 const memory::desc &diff_dst_desc,
7807 const inner_product_forward::primitive_desc &hint_fwd_pd,
7808 const primitive_attr &attr, bool allow_empty) {
7809
7810 dnnl_primitive_desc_t pd = nullptr;
7811 dnnl_status_t status
7812 = dnnl_inner_product_backward_weights_primitive_desc_create(
7813 &pd, aengine.get(), src_desc.get(),
7814 diff_weights_desc.get(),
7815 optional_arg(diff_bias_desc), diff_dst_desc.get(),
7816 hint_fwd_pd.get(), attr.get());
7817
7818 if (!allow_empty)
7819 error::wrap_c_api(status,
7820 "could not create a primitive descriptor for an inner "
7821 "product weights gradient primitive");
7822 reset(pd);
7823 }
7824 };
7825
7826 /// Default constructor. Produces an empty object.
7827 inner_product_backward_weights() = default;
7828
7829 /// Constructs an inner product weights gradient primitive.
7830 /// @param pd Primitive descriptor for an inner product weights gradient
7831 /// primitive.
7832 inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {}
7833
7834 /// Constructs an inner product weights gradient primitive from a cache
7835 /// blob.
7836 /// @param pd Primitive descriptor for an inner product weights gradient
7837 /// primitive.
7838 /// @param cache_blob Cache blob.
7839 inner_product_backward_weights(
7840 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
7841 : primitive(pd, cache_blob) {}
7842};
7843
7844/// @} dnnl_api_inner_product
7845
7846/// @addtogroup dnnl_api_rnn RNN
7847///
7848/// A primitive to compute recurrent neural network layers.
7849///
7850/// @sa @ref dev_guide_rnn in developer guide
7851///
7852/// @{
7853
7854/// Base class for primitive descriptors for RNN primitives.
7855struct rnn_primitive_desc_base : public primitive_desc {
7856 using primitive_desc::primitive_desc;
7857
7858 /// Default constructor. Produces an empty object.
7859 rnn_primitive_desc_base() = default;
7860
7861 /// Constructs an RNN primitive descriptor base from a C API primitive
7862 /// descriptor while checking that it actually describes the expected
7863 /// primitive by comparing propagation and primitive kinds.
7864 ///
7865 /// @param pd C API primitive descriptor.
7866 /// @param aprop_kind Expected propagation kind.
7867 /// @param cell_kind Expected cell kind.
7868 rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
7869 dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
7870 : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {}
7871
7872 /// Returns source layer memory descriptor.
7873 /// @returns Source layer memory descriptor.
7874 memory::desc src_layer_desc() const {
7875 return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER);
7876 }
7877
7878 /// Returns AUGRU attention memory descriptor.
7879 /// @returns AUGRU attention memory descriptor.
7880 memory::desc augru_attention_desc() const {
7881 return base::query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION);
7882 }
7883
7884 /// Returns source iteration memory descriptor.
7885 /// @returns Source iteration memory descriptor.
7886 /// @returns A zero memory descriptor if the primitive does not have a
7887 /// source iteration parameter.
7888 memory::desc src_iter_desc() const {
7889 return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER);
7890 }
7891
7892 /// Returns source recurrent cell state memory descriptor.
7893 /// @returns Source recurrent cell state memory descriptor.
7894 memory::desc src_iter_c_desc() const {
7895 return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C);
7896 }
7897
7898 /// Returns weights layer memory descriptor.
7899 /// @returns Weights layer memory descriptor.
7900 memory::desc weights_layer_desc() const {
7901 return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER);
7902 }
7903
7904 /// Returns weights iteration memory descriptor.
7905 /// @returns Weights iteration memory descriptor.
7906 memory::desc weights_iter_desc() const {
7907 return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER);
7908 }
7909
7910 /// Returns weights peephole memory descriptor.
7911 /// @returns Weights peephole memory descriptor.
7912 memory::desc weights_peephole_desc() const {
7913 return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE);
7914 }
7915
7916 /// Returns weights projection memory descriptor.
7917 /// @returns Weights projection memory descriptor.
7918 memory::desc weights_projection_desc() const {
7919 return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION);
7920 }
7921
7922 /// Returns bias memory descriptor.
7923 /// @returns Bias memory descriptor.
7924 /// @returns A zero memory descriptor if the primitive does not have a
7925 /// bias parameter.
7926 memory::desc bias_desc() const {
7927 return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS);
7928 }
7929
7930 /// Returns destination layer memory descriptor.
7931 /// @returns Destination layer memory descriptor.
7932 memory::desc dst_layer_desc() const {
7933 return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER);
7934 }
7935
7936 /// Returns destination iteration memory descriptor.
7937 /// @returns Destination iteration memory descriptor.
7938 /// @returns A zero memory descriptor if the primitive does not have a
7939 /// destination iteration parameter.
7940 memory::desc dst_iter_desc() const {
7941 return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER);
7942 }
7943
7944 /// Returns destination recurrent cell state memory descriptor.
7945 /// @returns Destination recurrent cell state memory descriptor.
7946 memory::desc dst_iter_c_desc() const {
7947 return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C);
7948 }
7949
7950 /// Returns diff source layer memory descriptor.
7951 /// @returns Diff source layer memory descriptor.
7952 memory::desc diff_src_layer_desc() const {
7953 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER);
7954 }
7955
7956 /// Returns diff AUGRU attention memory descriptor.
7957 /// @returns Diff AUGRU attention memory descriptor.
7958 memory::desc diff_augru_attention_desc() const {
7959 return base::query_md(
7960 query::exec_arg_md, DNNL_ARG_DIFF_AUGRU_ATTENTION);
7961 }
7962
7963 /// Returns diff source iteration memory descriptor.
7964 /// @returns Diff source iteration memory descriptor.
7965 /// @returns A zero memory descriptor if the primitive does not have a
7966 /// diff source iteration parameter.
7967 memory::desc diff_src_iter_desc() const {
7968 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER);
7969 }
7970
7971 /// Returns diff source recurrent cell state memory descriptor.
7972 /// @returns Diff source recurrent cell state memory descriptor.
7973 memory::desc diff_src_iter_c_desc() const {
7974 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C);
7975 }
7976
7977 /// Returns diff weights layer memory descriptor.
7978 /// @returns Diff weights layer memory descriptor.
7979 memory::desc diff_weights_layer_desc() const {
7980 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER);
7981 }
7982
7983 /// Returns diff weights iteration memory descriptor.
7984 /// @returns Diff weights iteration memory descriptor.
7985 memory::desc diff_weights_iter_desc() const {
7986 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER);
7987 }
7988
7989 /// Returns diff weights peephole memory descriptor.
7990 /// @returns Diff weights peephole memory descriptor.
7991 memory::desc diff_weights_peephole_desc() const {
7992 return base::query_md(
7993 query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
7994 }
7995
7996 /// Returns diff weights projection memory descriptor.
7997 /// @returns Diff weights projection memory descriptor.
7998 memory::desc diff_weights_projection_desc() const {
7999 return base::query_md(
8000 query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
8001 }
8002
8003 /// Returns diff bias memory descriptor.
8004 /// @returns Diff bias memory descriptor.
8005 /// @returns A zero memory descriptor if the primitive does not have a
8006 /// diff bias parameter.
8007 memory::desc diff_bias_desc() const {
8008 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS);
8009 }
8010
8011 /// Returns diff destination layer memory descriptor.
8012 /// @returns Diff destination layer memory descriptor.
8013 memory::desc diff_dst_layer_desc() const {
8014 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER);
8015 }
8016
8017 /// Returns diff destination iteration memory descriptor.
8018 /// @returns Diff destination iteration memory descriptor.
8019 /// @returns A zero memory descriptor if the primitive does not have a
8020 /// diff destination iteration parameter.
8021 memory::desc diff_dst_iter_desc() const {
8022 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER);
8023 }
8024
8025 /// Returns diff destination recurrent cell state memory descriptor.
8026 /// @returns Diff destination recurrent cell state memory descriptor.
8027 memory::desc diff_dst_iter_c_desc() const {
8028 return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C);
8029 }
8030
8031protected:
8032 using rnn_base = rnn_primitive_desc_base;
8033
8034 // (Deliberately not using doxygen comments)
8035 //
8036 // Constructs an RNN primitive descriptor base from a C API primitive
8037 // descriptor while checking that it actually describes the expected
8038 // primitive by comparing propagation and primitive kinds. Caller can
8039 // pass two options propagation kinds. This is typically used to check
8040 // that propagation kind is inference or training forward propagation.
8041 //
8042 // @param pd C API primitive descriptor.
8043 // @param prop_kind1 Expected propagation kind.
8044 // @param prop_kind2 Expected propagation kind.
8045 // @param cell_kind Expected cell kind.
8046 rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
8047 dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
8048 dnnl::algorithm cell_kind) {
8049
8050 dnnl_status_t rc;
8051
8052 dnnl_primitive_kind_t q_primitive_kind;
8053 rc = dnnl_primitive_desc_query(
8054 pd, dnnl_query_primitive_kind, 0, &q_primitive_kind);
8055 error::wrap_c_api(rc,
8056 "could not retrieve a primitive kind from a primitive "
8057 "descriptor for an RNN primitive");
8058
8059 dnnl_prop_kind_t q_prop_kind;
8060 rc = dnnl_primitive_desc_query(
8061 pd, dnnl_query_prop_kind, 0, &q_prop_kind);
8062 error::wrap_c_api(rc,
8063 "could not retrieve a propagation kind from a primitive "
8064 "descriptor for an RNN primitive");
8065
8066 dnnl_alg_kind_t q_cell_kind;
8067 rc = dnnl_primitive_desc_query(
8068 pd, dnnl_query_cell_kind, 0, &q_cell_kind);
8069 error::wrap_c_api(rc,
8070 "could not retrieve a cell kind from a primitive descriptor "
8071 "for an RNN primitive");
8072
8073 dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
8074 dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
8075 dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
8076
8077 bool ok = q_primitive_kind == dnnl_rnn
8078 && (q_prop_kind == c_prop_kind1 || q_prop_kind == c_prop_kind2)
8079 && q_cell_kind == c_cell_kind;
8080
8081 if (!ok)
8082 DNNL_THROW_ERROR(dnnl_invalid_arguments,
8083 "mismatch between expected and provided descriptors for an "
8084 "RNN primitive");
8085
8086 reset_with_clone(pd);
8087 }
8088
8089 // Constructs an RNN forward propagation primitive descriptor base for
8090 // any cell kind.
8091 rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
8092 prop_kind aprop_kind, algorithm activation, rnn_direction direction,
8093 const memory::desc &src_layer_desc,
8094 const memory::desc &src_iter_desc,
8095 const memory::desc *src_iter_c_desc,
8096 const memory::desc *attention_desc,
8097 const memory::desc &weights_layer_desc,
8098 const memory::desc &weights_iter_desc,
8099 const memory::desc *weights_peephole_desc,
8100 const memory::desc *weights_projection_desc,
8101 const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
8102 const memory::desc &dst_iter_desc,
8103 const memory::desc *dst_iter_c_desc, rnn_flags flags, float alpha,
8104 float beta, const primitive_attr &attr, bool allow_empty) {
8105
8106 dnnl_status_t status = dnnl_success;
8107 const char *msg
8108 = "could not create a primitive descriptor for a requested "
8109 "cell kind";
8110
8111 dnnl_primitive_desc_t pd = nullptr;
8112 switch (cell_kind) {
8113 case algorithm::vanilla_rnn:
8114 status = dnnl_vanilla_rnn_forward_primitive_desc_create(&pd,
8115 aengine.get(), dnnl::convert_to_c(aprop_kind),
8116 dnnl::convert_to_c(activation),
8117 dnnl::convert_to_c(direction), src_layer_desc.get(),
8118 src_iter_desc.get(), weights_layer_desc.get(),
8119 weights_iter_desc.get(), bias_desc.get(),
8120 dst_layer_desc.get(), dst_iter_desc.get(),
8121 convert_to_c(flags), alpha, beta, attr.get());
8122 msg = "could not create a primitive descriptor for a vanilla "
8123 "RNN forward propagation primitive";
8124 break;
8125 case algorithm::vanilla_lstm:
8126 status = dnnl_lstm_forward_primitive_desc_create(&pd,
8127 aengine.get(), dnnl::convert_to_c(aprop_kind),
8128 dnnl::convert_to_c(direction), src_layer_desc.get(),
8129 src_iter_desc.get(), optional_arg(src_iter_c_desc),
8130 weights_layer_desc.get(), weights_iter_desc.get(),
8131 optional_arg(weights_peephole_desc),
8132 optional_arg(weights_projection_desc), bias_desc.get(),
8133 dst_layer_desc.get(), dst_iter_desc.get(),
8134 optional_arg(dst_iter_c_desc), convert_to_c(flags),
8135 attr.get());
8136 msg = "could not create a primitive descriptor for an LSTM "
8137 "forward propagation primitive";
8138 break;
8139 case algorithm::vanilla_gru:
8140 status = dnnl_gru_forward_primitive_desc_create(&pd,
8141 aengine.get(), dnnl::convert_to_c(aprop_kind),
8142 dnnl::convert_to_c(direction), src_layer_desc.get(),
8143 src_iter_desc.get(), weights_layer_desc.get(),
8144 weights_iter_desc.get(), bias_desc.get(),
8145 dst_layer_desc.get(), dst_iter_desc.get(),
8146 convert_to_c(flags), attr.get());
8147 msg = "could not create a primitive descriptor for a GRU "
8148 "forward propagation primitive";
8149 break;
8150 case algorithm::lbr_gru:
8151 status = dnnl_lbr_gru_forward_primitive_desc_create(&pd,
8152 aengine.get(), dnnl::convert_to_c(aprop_kind),
8153 dnnl::convert_to_c(direction), src_layer_desc.get(),
8154 src_iter_desc.get(), weights_layer_desc.get(),
8155 weights_iter_desc.get(), bias_desc.get(),
8156 dst_layer_desc.get(), dst_iter_desc.get(),
8157 convert_to_c(flags), attr.get());
8158 msg = "could not create a primitive descriptor for an LBR GRU "
8159 "forward propagation primitive";
8160 break;
8161 case algorithm::vanilla_augru:
8162 status = dnnl_augru_forward_primitive_desc_create(&pd,
8163 aengine.get(), dnnl::convert_to_c(aprop_kind),
8164 dnnl::convert_to_c(direction), src_layer_desc.get(),
8165 src_iter_desc.get(), optional_arg(attention_desc),
8166 weights_layer_desc.get(), weights_iter_desc.get(),
8167 bias_desc.get(), dst_layer_desc.get(),
8168 dst_iter_desc.get(), convert_to_c(flags), attr.get());
8169 msg = "could not create a primitive descriptor for an AUGRU "
8170 "forward propagation primitive";
8171 break;
8172 case algorithm::lbr_augru:
8173 status = dnnl_lbr_augru_forward_primitive_desc_create(&pd,
8174 aengine.get(), dnnl::convert_to_c(aprop_kind),
8175 dnnl::convert_to_c(direction), src_layer_desc.get(),
8176 src_iter_desc.get(), optional_arg(attention_desc),
8177 weights_layer_desc.get(), weights_iter_desc.get(),
8178 bias_desc.get(), dst_layer_desc.get(),
8179 dst_iter_desc.get(), convert_to_c(flags), attr.get());
8180 msg = "could not create a primitive descriptor for an LBR "
8181 "AUGRU forward propagation primitive";
8182 break;
8183 default: status = dnnl_unimplemented;
8184 }
8185
8186 if (!allow_empty) error::wrap_c_api(status, msg);
8187 reset(pd);
8188 }
8189
8190 // Constructs an RNN backward propagation primitive descriptor base for
8191 // any cell kind.
8192 rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
8193 prop_kind aprop_kind, algorithm activation, rnn_direction direction,
8194 const memory::desc &src_layer_desc,
8195 const memory::desc &src_iter_desc,
8196 const memory::desc *src_iter_c_desc,
8197 const memory::desc *attention_desc,
8198 const memory::desc &weights_layer_desc,
8199 const memory::desc &weights_iter_desc,
8200 const memory::desc *weights_peephole_desc,
8201 const memory::desc *weights_projection_desc,
8202 const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
8203 const memory::desc &dst_iter_desc,
8204 const memory::desc *dst_iter_c_desc,
8205 const memory::desc &diff_src_layer_desc,
8206 const memory::desc &diff_src_iter_desc,
8207 const memory::desc *diff_src_iter_c_desc,
8208 const memory::desc *diff_attention_desc,
8209 const memory::desc &diff_weights_layer_desc,
8210 const memory::desc &diff_weights_iter_desc,
8211 const memory::desc *diff_weights_peephole_desc,
8212 const memory::desc *diff_weights_projection_desc,
8213 const memory::desc &diff_bias_desc,
8214 const memory::desc &diff_dst_layer_desc,
8215 const memory::desc &diff_dst_iter_desc,
8216 const memory::desc *diff_dst_iter_c_desc, rnn_flags flags,
8217 float alpha, float beta, const rnn_primitive_desc_base &hint_fwd_pd,
8218 const primitive_attr &attr, bool allow_empty) {
8219
8220 dnnl_status_t status = dnnl_success;
8221 const char *msg = "";
8222
8223 dnnl_primitive_desc_t pd = nullptr;
8224 switch (cell_kind) {
8225 case algorithm::vanilla_rnn:
8226 status = dnnl_vanilla_rnn_backward_primitive_desc_create(&pd,
8227 aengine.get(), dnnl::convert_to_c(aprop_kind),
8228 dnnl::convert_to_c(activation),
8229 dnnl::convert_to_c(direction), src_layer_desc.get(),
8230 src_iter_desc.get(), weights_layer_desc.get(),
8231 weights_iter_desc.get(), bias_desc.get(),
8232 dst_layer_desc.get(), dst_iter_desc.get(),
8233 diff_src_layer_desc.get(), diff_src_iter_desc.get(),
8234 diff_weights_layer_desc.get(),
8235 diff_weights_iter_desc.get(), diff_bias_desc.get(),
8236 diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
8237 convert_to_c(flags), alpha, beta, hint_fwd_pd.get(),
8238 attr.get());
8239 msg = "could not create a primitive descriptor for a vanilla "
8240 "RNN backward propagation primitive";
8241 break;
8242 case algorithm::vanilla_lstm:
8243 status = dnnl_lstm_backward_primitive_desc_create(&pd,
8244 aengine.get(), dnnl::convert_to_c(aprop_kind),
8245 dnnl::convert_to_c(direction), src_layer_desc.get(),
8246 src_iter_desc.get(), optional_arg(src_iter_c_desc),
8247 weights_layer_desc.get(), weights_iter_desc.get(),
8248 optional_arg(weights_peephole_desc),
8249 optional_arg(weights_projection_desc), bias_desc.get(),
8250 dst_layer_desc.get(), dst_iter_desc.get(),
8251 optional_arg(dst_iter_c_desc),
8252 diff_src_layer_desc.get(), diff_src_iter_desc.get(),
8253 optional_arg(diff_src_iter_c_desc),
8254 diff_weights_layer_desc.get(),
8255 diff_weights_iter_desc.get(),
8256 optional_arg(diff_weights_peephole_desc),
8257 optional_arg(diff_weights_projection_desc),
8258 diff_bias_desc.get(), diff_dst_layer_desc.get(),
8259 diff_dst_iter_desc.get(),
8260 optional_arg(diff_dst_iter_c_desc), convert_to_c(flags),
8261 hint_fwd_pd.get(), attr.get());
8262 msg = "could not create a primitive descriptor for an LSTM "
8263 "backward propagation primitive";
8264 break;
8265 case algorithm::vanilla_gru:
8266 status = dnnl_gru_backward_primitive_desc_create(&pd,
8267 aengine.get(), dnnl::convert_to_c(aprop_kind),
8268 dnnl::convert_to_c(direction), src_layer_desc.get(),
8269 src_iter_desc.get(), weights_layer_desc.get(),
8270 weights_iter_desc.get(), bias_desc.get(),
8271 dst_layer_desc.get(), dst_iter_desc.get(),
8272 diff_src_layer_desc.get(), diff_src_iter_desc.get(),
8273 diff_weights_layer_desc.get(),
8274 diff_weights_iter_desc.get(), diff_bias_desc.get(),
8275 diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
8276 convert_to_c(flags), hint_fwd_pd.get(), attr.get());
8277 msg = "could not create a primitive descriptor for a GRU "
8278 "backward propagation primitive";
8279 break;
8280 case algorithm::lbr_gru:
8281 status = dnnl_lbr_gru_backward_primitive_desc_create(&pd,
8282 aengine.get(), dnnl::convert_to_c(aprop_kind),
8283 dnnl::convert_to_c(direction), src_layer_desc.get(),
8284 src_iter_desc.get(), weights_layer_desc.get(),
8285 weights_iter_desc.get(), bias_desc.get(),
8286 dst_layer_desc.get(), dst_iter_desc.get(),
8287 diff_src_layer_desc.get(), diff_src_iter_desc.get(),
8288 diff_weights_layer_desc.get(),
8289 diff_weights_iter_desc.get(), diff_bias_desc.get(),
8290 diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
8291 convert_to_c(flags), hint_fwd_pd.get(), attr.get());
8292 msg = "could not create a primitive descriptor for an LBR GRU "
8293 "backward propagation primitive";
8294 break;
8295 case algorithm::vanilla_augru:
8296 status = dnnl_augru_backward_primitive_desc_create(&pd,
8297 aengine.get(), dnnl::convert_to_c(aprop_kind),
8298 dnnl::convert_to_c(direction), src_layer_desc.get(),
8299 src_iter_desc.get(), optional_arg(attention_desc),
8300 weights_layer_desc.get(), weights_iter_desc.get(),
8301 bias_desc.get(), dst_layer_desc.get(),
8302 dst_iter_desc.get(), diff_src_layer_desc.get(),
8303 diff_src_iter_desc.get(),
8304 optional_arg(diff_attention_desc),
8305 diff_weights_layer_desc.get(),
8306 diff_weights_iter_desc.get(), diff_bias_desc.get(),
8307 diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
8308 convert_to_c(flags), hint_fwd_pd.get(), attr.get());
8309 msg = "could not create a primitive descriptor for an AUGRU "
8310 "backward propagation primitive";
8311 break;
8312 case algorithm::lbr_augru:
8313 status = dnnl_lbr_augru_backward_primitive_desc_create(&pd,
8314 aengine.get(), dnnl::convert_to_c(aprop_kind),
8315 dnnl::convert_to_c(direction), src_layer_desc.get(),
8316 src_iter_desc.get(), optional_arg(attention_desc),
8317 weights_layer_desc.get(), weights_iter_desc.get(),
8318 bias_desc.get(), dst_layer_desc.get(),
8319 dst_iter_desc.get(), diff_src_layer_desc.get(),
8320 diff_src_iter_desc.get(),
8321 optional_arg(diff_attention_desc),
8322 diff_weights_layer_desc.get(),
8323 diff_weights_iter_desc.get(), diff_bias_desc.get(),
8324 diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
8325 convert_to_c(flags), hint_fwd_pd.get(), attr.get());
8326 msg = "could not create a primitive descriptor for an LBR "
8327 "AUGRU backward propagation primitive";
8328 break;
8329 default: status = dnnl_unimplemented;
8330 }
8331 if (!allow_empty) error::wrap_c_api(status, msg);
8332 reset(pd);
8333 }
8334};
8335
8336/// Vanilla RNN forward propagation primitive.
8337struct vanilla_rnn_forward : public primitive {
8338 /// Primitive descriptor for a vanilla RNN forward propagation primitive.
8339 struct primitive_desc : public rnn_primitive_desc_base {
8340 /// Default constructor. Produces an empty object.
8341 primitive_desc() = default;
8342
8343 /// Constructs a primitive descriptor for a vanilla RNN forward
8344 /// propagation primitive.
8345 ///
8346 /// The following arguments may point to a zero memory descriptor:
8347 /// - @p src_iter_desc,
8348 /// - @p bias_desc,
8349 /// - @p dst_iter_desc.
8350 ///
8351 /// This would then indicate that the RNN forward propagation primitive
8352 /// should not use them and should default to zero values instead.
8353 ///
8354 /// @note
8355 /// All memory descriptors except @p src_iter_desc can be
8356 /// initialized with an #dnnl::memory::format_tag::any value of @p
8357 /// format_tag.
8358 ///
8359 /// @param aengine Engine to use.
8360 /// @param aprop_kind Propagation kind. Possible values are
8361 /// #dnnl::prop_kind::forward_training, and
8362 /// #dnnl::prop_kind::forward_inference.
8363 /// @param activation Activation kind. Possible values are
8364 /// #dnnl::algorithm::eltwise_relu,
8365 /// #dnnl::algorithm::eltwise_tanh, or
8366 /// #dnnl::algorithm::eltwise_logistic.
8367 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
8368 /// more info.
8369 /// @param src_layer_desc Memory descriptor for the input vector.
8370 /// @param src_iter_desc Memory descriptor for the input recurrent
8371 /// hidden state vector.
8372 /// @param weights_layer_desc Memory descriptor for the weights
8373 /// applied to the layer input.
8374 /// @param weights_iter_desc Memory descriptor for the weights applied
8375 /// to the recurrent input.
8376 /// @param bias_desc Bias memory descriptor.
8377 /// @param dst_layer_desc Memory descriptor for the output vector.
8378 /// @param dst_iter_desc Memory descriptor for the output recurrent
8379 /// hidden state vector.
8380 /// @param attr Primitive attributes to use. Attributes are optional
8381 /// and default to empty attributes.
8382 /// @param allow_empty A flag signifying whether construction is
8383 /// allowed to fail without throwing an exception. In this case an
8384 /// empty object will be produced. This flag is optional and
8385 /// defaults to false.
8386 primitive_desc(const engine &aengine, prop_kind aprop_kind,
8387 algorithm activation, rnn_direction direction,
8388 const memory::desc &src_layer_desc,
8389 const memory::desc &src_iter_desc,
8390 const memory::desc &weights_layer_desc,
8391 const memory::desc &weights_iter_desc,
8392 const memory::desc &bias_desc,
8393 const memory::desc &dst_layer_desc,
8394 const memory::desc &dst_iter_desc,
8395 const primitive_attr &attr = default_attr(),
8396 bool allow_empty = false)
8397 : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
8398 aprop_kind, activation, direction, src_layer_desc,
8399 src_iter_desc, nullptr, nullptr, weights_layer_desc,
8400 weights_iter_desc, nullptr, nullptr, bias_desc,
8401 dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
8402 0.0f, 0.0f, attr, allow_empty) {}
8403
8404 /// Constructs a primitive descriptor for a vanilla RNN forward
8405 /// propagation primitive with alpha parameter.
8406 ///
8407 /// The following arguments may point to a zero memory descriptor:
8408 /// - @p src_iter_desc,
8409 /// - @p bias_desc,
8410 /// - @p dst_iter_desc.
8411 ///
8412 /// This would then indicate that the RNN forward propagation primitive
8413 /// should not use them and should default to zero values instead.
8414 ///
8415 /// @note
8416 /// All memory descriptors except @p src_iter_desc can be
8417 /// initialized with an #dnnl::memory::format_tag::any value of @p
8418 /// format_tag.
8419 ///
8420 /// @param aengine Engine to use.
8421 /// @param aprop_kind Propagation kind. Possible values are
8422 /// #dnnl::prop_kind::forward_training, and
8423 /// #dnnl::prop_kind::forward_inference.
8424 /// @param activation Activation kind. Possible values are
8425 /// #dnnl::algorithm::eltwise_relu,
8426 /// #dnnl::algorithm::eltwise_tanh, or
8427 /// #dnnl::algorithm::eltwise_logistic.
8428 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
8429 /// more info.
8430 /// @param src_layer_desc Memory descriptor for the input vector.
8431 /// @param src_iter_desc Memory descriptor for the input recurrent
8432 /// hidden state vector.
8433 /// @param weights_layer_desc Memory descriptor for the weights
8434 /// applied to the layer input.
8435 /// @param weights_iter_desc Memory descriptor for the weights applied
8436 /// to the recurrent input.
8437 /// @param bias_desc Bias memory descriptor.
8438 /// @param dst_layer_desc Memory descriptor for the output vector.
8439 /// @param dst_iter_desc Memory descriptor for the output recurrent
8440 /// hidden state vector.
8441 /// @param alpha Negative slope if activation is
8442 /// #dnnl::algorithm::eltwise_relu.
8443 /// @param attr Primitive attributes to use. Attributes are optional
8444 /// and default to empty attributes.
8445 /// @param allow_empty A flag signifying whether construction is
8446 /// allowed to fail without throwing an exception. In this case an
8447 /// empty object will be produced. This flag is optional and
8448 /// defaults to false.
8449 primitive_desc(const engine &aengine, prop_kind aprop_kind,
8450 algorithm activation, rnn_direction direction,
8451 const memory::desc &src_layer_desc,
8452 const memory::desc &src_iter_desc,
8453 const memory::desc &weights_layer_desc,
8454 const memory::desc &weights_iter_desc,
8455 const memory::desc &bias_desc,
8456 const memory::desc &dst_layer_desc,
8457 const memory::desc &dst_iter_desc, float alpha,
8458 const primitive_attr &attr = default_attr(),
8459 bool allow_empty = false)
8460 : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
8461 aprop_kind, activation, direction, src_layer_desc,
8462 src_iter_desc, nullptr, nullptr, weights_layer_desc,
8463 weights_iter_desc, nullptr, nullptr, bias_desc,
8464 dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
8465 alpha, 0.0f, attr, allow_empty) {}
8466
8467 /// Constructs a primitive descriptor for a vanilla RNN forward
8468 /// propagation primitive from a C API primitive descriptor that must
8469 /// have a matching kind.
8470 ///
8471 /// @param pd C API primitive descriptor for a vanilla RNN forward
8472 /// propagation primitive.
8473 primitive_desc(dnnl_primitive_desc_t pd)
8474 : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
8475 dnnl::prop_kind::forward_inference,
8476 dnnl::algorithm::vanilla_rnn) {}
8477
8478 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
8479 memory::desc src_layer_desc() const {
8480 return rnn_base::src_layer_desc();
8481 }
8482
8483 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
8484 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
8485
8486 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
8487 memory::desc weights_layer_desc() const {
8488 return rnn_base::weights_layer_desc();
8489 }
8490
8491 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
8492 memory::desc weights_iter_desc() const {
8493 return rnn_base::weights_iter_desc();
8494 }
8495
8496 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
8497 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
8498
8499 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
8500 memory::desc dst_layer_desc() const {
8501 return rnn_base::dst_layer_desc();
8502 }
8503
8504 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
8505 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
8506
8507 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
8508 memory::desc workspace_desc() const {
8509 return rnn_base::workspace_desc();
8510 }
8511
8512 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
8513 algorithm get_cell_kind() const { return base::get_cell_kind(); }
8514
8515 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
8516 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
8517
8518 /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
8519 algorithm get_activation_kind() const {
8520 return base::get_activation_kind();
8521 }
8522
8523 /// @copydoc dnnl::primitive_desc_base::get_direction()const
8524 rnn_direction get_direction() const { return base::get_direction(); }
8525
8526 /// @copydoc dnnl::primitive_desc_base::get_alpha()const
8527 float get_alpha() const { return base::get_alpha(); }
8528
8529 /// @copydoc dnnl::primitive_desc_base::get_beta()const
8530 float get_beta() const { return base::get_beta(); }
8531 };
8532
8533 /// Default constructor. Produces an empty object.
8534 vanilla_rnn_forward() = default;
8535
8536 /// Constructs a vanilla RNN forward propagation primitive.
8537 /// @param pd Primitive descriptor for a vanilla RNN forward
8538 /// propagation primitive.
8539 vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
8540
8541 /// Constructs a vanilla RNN forward propagation primitive from
8542 /// a cache blob.
8543 /// @param pd Primitive descriptor for a vanilla RNN forward
8544 /// propagation primitive.
8545 /// @param cache_blob Cache blob.
8546 vanilla_rnn_forward(
8547 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
8548 : primitive(pd, cache_blob) {}
8549};
8550
8551/// Vanilla RNN backward propagation primitive.
8552struct vanilla_rnn_backward : public primitive {
8553 /// Primitive descriptor for an RNN backward propagation primitive.
8554 struct primitive_desc : public rnn_primitive_desc_base {
8555 /// Default constructor. Produces an empty object.
8556 primitive_desc() = default;
8557
8558 /// Constructs a primitive descriptor for a vanilla RNN backward
8559 /// propagation primitive.
8560 ///
8561 /// The following arguments may point to a zero memory descriptor:
8562 /// - @p src_iter_desc together with @p diff_src_iter_desc,
8563 /// - @p bias_desc together with @p diff_bias_desc,
8564 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
8565 ///
8566 /// This would then indicate that the RNN backward propagation
8567 /// primitive should not use the respective data and should use zero
8568 /// values instead.
8569 ///
8570 /// @note
8571 /// All the memory descriptors may be initialized with the
8572 /// #dnnl::memory::format_tag::any value of @p format_tag.
8573 ///
8574 /// @param aengine Engine to use.
8575 /// @param aprop_kind Propagation kind. Must be
8576 /// #dnnl::prop_kind::backward.
8577 /// @param activation Activation kind. Possible values are
8578 /// #dnnl::algorithm::eltwise_relu,
8579 /// #dnnl::algorithm::eltwise_tanh, or
8580 /// #dnnl::algorithm::eltwise_logistic.
8581 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
8582 /// more info.
8583 /// @param src_layer_desc Memory descriptor for the input vector.
8584 /// @param src_iter_desc Memory descriptor for the input recurrent
8585 /// hidden state vector.
8586 /// @param weights_layer_desc Memory descriptor for the weights
8587 /// applied to the layer input.
8588 /// @param weights_iter_desc Memory descriptor for the weights applied
8589 /// to the recurrent input.
8590 /// @param bias_desc Bias memory descriptor.
8591 /// @param dst_layer_desc Memory descriptor for the output vector.
8592 /// @param dst_iter_desc Memory descriptor for the output recurrent
8593 /// hidden state vector.
8594 /// @param diff_src_layer_desc Memory descriptor for the diff of input
8595 /// vector.
8596 /// @param diff_src_iter_desc Memory descriptor for the diff of input
8597 /// recurrent hidden state vector.
8598 /// @param diff_weights_layer_desc Memory descriptor for the diff of
8599 /// weights applied to the layer input.
8600 /// @param diff_weights_iter_desc Memory descriptor for the diff of
8601 /// weights applied to the recurrent input.
8602 /// @param diff_bias_desc Diff bias memory descriptor.
8603 /// @param diff_dst_layer_desc Memory descriptor for the diff of
8604 /// output vector.
8605 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
8606 /// recurrent hidden state vector.
8607 /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
8608 /// forward propagation primitive. It is used as a hint for
8609 /// deciding which memory format to use.
8610 /// @param attr Primitive attributes to use. Attributes are optional
8611 /// and default to empty attributes.
8612 /// @param allow_empty A flag signifying whether construction is
8613 /// allowed to fail without throwing an exception. In this case an
8614 /// empty object will be produced. This flag is optional and
8615 /// defaults to false.
8616 primitive_desc(const engine &aengine, prop_kind aprop_kind,
8617 algorithm activation, rnn_direction direction,
8618 const memory::desc &src_layer_desc,
8619 const memory::desc &src_iter_desc,
8620 const memory::desc &weights_layer_desc,
8621 const memory::desc &weights_iter_desc,
8622 const memory::desc &bias_desc,
8623 const memory::desc &dst_layer_desc,
8624 const memory::desc &dst_iter_desc,
8625 const memory::desc &diff_src_layer_desc,
8626 const memory::desc &diff_src_iter_desc,
8627 const memory::desc &diff_weights_layer_desc,
8628 const memory::desc &diff_weights_iter_desc,
8629 const memory::desc &diff_bias_desc,
8630 const memory::desc &diff_dst_layer_desc,
8631 const memory::desc &diff_dst_iter_desc,
8632 const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
8633 const primitive_attr &attr = default_attr(),
8634 bool allow_empty = false)
8635 : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
8636 aprop_kind, activation, direction, src_layer_desc,
8637 src_iter_desc, nullptr, nullptr, weights_layer_desc,
8638 weights_iter_desc, nullptr, nullptr, bias_desc,
8639 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
8640 diff_src_iter_desc, nullptr, nullptr,
8641 diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
8642 nullptr, diff_bias_desc, diff_dst_layer_desc,
8643 diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
8644 hint_fwd_pd, attr, allow_empty) {}
8645
8646 /// Constructs a primitive descriptor for a vanilla RNN backward
8647 /// propagation primitive with an alpha parameter.
8648 ///
8649 /// The following arguments may point to a zero memory descriptor:
8650 /// - @p src_iter_desc together with @p diff_src_iter_desc,
8651 /// - @p bias_desc together with @p diff_bias_desc,
8652 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
8653 ///
8654 /// This would then indicate that the RNN backward propagation
8655 /// primitive should not use the respective data and should use zero
8656 /// values instead.
8657 ///
8658 /// @note
8659 /// All the memory descriptors may be initialized with the
8660 /// #dnnl::memory::format_tag::any value of @p format_tag.
8661 ///
8662 /// @param aengine Engine to use.
8663 /// @param aprop_kind Propagation kind. Must be
8664 /// #dnnl::prop_kind::backward.
8665 /// @param activation Activation kind. Possible values are
8666 /// #dnnl::algorithm::eltwise_relu,
8667 /// #dnnl::algorithm::eltwise_tanh, or
8668 /// #dnnl::algorithm::eltwise_logistic.
8669 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
8670 /// more info.
8671 /// @param src_layer_desc Memory descriptor for the input vector.
8672 /// @param src_iter_desc Memory descriptor for the input recurrent
8673 /// hidden state vector.
8674 /// @param weights_layer_desc Memory descriptor for the weights
8675 /// applied to the layer input.
8676 /// @param weights_iter_desc Memory descriptor for the weights applied
8677 /// to the recurrent input.
8678 /// @param bias_desc Bias memory descriptor.
8679 /// @param dst_layer_desc Memory descriptor for the output vector.
8680 /// @param dst_iter_desc Memory descriptor for the output recurrent
8681 /// hidden state vector.
8682 /// @param diff_src_layer_desc Memory descriptor for the diff of input
8683 /// vector.
8684 /// @param diff_src_iter_desc Memory descriptor for the diff of input
8685 /// recurrent hidden state vector.
8686 /// @param diff_weights_layer_desc Memory descriptor for the diff of
8687 /// weights applied to the layer input.
8688 /// @param diff_weights_iter_desc Memory descriptor for the diff of
8689 /// weights applied to the recurrent input.
8690 /// @param diff_bias_desc Diff bias memory descriptor.
8691 /// @param diff_dst_layer_desc Memory descriptor for the diff of
8692 /// output vector.
8693 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
8694 /// recurrent hidden state vector.
8695 /// @param alpha Negative slope if activation is
8696 /// #dnnl::algorithm::eltwise_relu.
8697 /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
8698 /// forward propagation primitive. It is used as a hint for
8699 /// deciding which memory format to use.
8700 /// @param attr Primitive attributes to use. Attributes are optional
8701 /// and default to empty attributes.
8702 /// @param allow_empty A flag signifying whether construction is
8703 /// allowed to fail without throwing an exception. In this case an
8704 /// empty object will be produced. This flag is optional and
8705 /// defaults to false.
8706 primitive_desc(const engine &aengine, prop_kind aprop_kind,
8707 algorithm activation, rnn_direction direction,
8708 const memory::desc &src_layer_desc,
8709 const memory::desc &src_iter_desc,
8710 const memory::desc &weights_layer_desc,
8711 const memory::desc &weights_iter_desc,
8712 const memory::desc &bias_desc,
8713 const memory::desc &dst_layer_desc,
8714 const memory::desc &dst_iter_desc,
8715 const memory::desc &diff_src_layer_desc,
8716 const memory::desc &diff_src_iter_desc,
8717 const memory::desc &diff_weights_layer_desc,
8718 const memory::desc &diff_weights_iter_desc,
8719 const memory::desc &diff_bias_desc,
8720 const memory::desc &diff_dst_layer_desc,
8721 const memory::desc &diff_dst_iter_desc, float alpha,
8722 const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
8723 const primitive_attr &attr = default_attr(),
8724 bool allow_empty = false)
8725 : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
8726 aprop_kind, activation, direction, src_layer_desc,
8727 src_iter_desc, nullptr, nullptr, weights_layer_desc,
8728 weights_iter_desc, nullptr, nullptr, bias_desc,
8729 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
8730 diff_src_iter_desc, nullptr, nullptr,
8731 diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
8732 nullptr, diff_bias_desc, diff_dst_layer_desc,
8733 diff_dst_iter_desc, nullptr, rnn_flags::undef, alpha, 0.0f,
8734 hint_fwd_pd, attr, allow_empty) {}
8735
8736 /// Constructs a primitive descriptor for a vanilla RNN backward
8737 /// propagation primitive from a C API primitive descriptor that must
8738 /// have a matching kind.
8739 ///
8740 /// @param pd C API primitive descriptor for a vanilla RNN backward
8741 /// propagation primitive.
8742 primitive_desc(dnnl_primitive_desc_t pd)
8743 : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
8744 dnnl::algorithm::vanilla_rnn) {}
8745
8746 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
8747 memory::desc src_layer_desc() const {
8748 return rnn_base::src_layer_desc();
8749 }
8750
8751 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
8752 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
8753
8754 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
8755 memory::desc weights_layer_desc() const {
8756 return rnn_base::weights_layer_desc();
8757 }
8758
8759 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
8760 memory::desc weights_iter_desc() const {
8761 return rnn_base::weights_iter_desc();
8762 }
8763
8764 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
8765 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
8766
8767 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
8768 memory::desc dst_layer_desc() const {
8769 return rnn_base::dst_layer_desc();
8770 }
8771
8772 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
8773 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
8774
8775 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
8776 memory::desc workspace_desc() const {
8777 return rnn_base::workspace_desc();
8778 }
8779
8780 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
8781 memory::desc diff_src_layer_desc() const {
8782 return rnn_base::diff_src_layer_desc();
8783 }
8784
8785 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
8786 memory::desc diff_src_iter_desc() const {
8787 return rnn_base::diff_src_iter_desc();
8788 }
8789
8790 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
8791 memory::desc diff_weights_layer_desc() const {
8792 return rnn_base::diff_weights_layer_desc();
8793 }
8794
8795 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
8796 memory::desc diff_weights_iter_desc() const {
8797 return rnn_base::diff_weights_iter_desc();
8798 }
8799
8800 /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
8801 memory::desc diff_bias_desc() const {
8802 return rnn_base::diff_bias_desc();
8803 }
8804
8805 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
8806 memory::desc diff_dst_layer_desc() const {
8807 return rnn_base::diff_dst_layer_desc();
8808 }
8809
8810 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
8811 memory::desc diff_dst_iter_desc() const {
8812 return rnn_base::diff_dst_iter_desc();
8813 }
8814
8815 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
8816 algorithm get_cell_kind() const { return base::get_cell_kind(); }
8817
8818 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
8819 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
8820
8821 /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
8822 algorithm get_activation_kind() const {
8823 return base::get_activation_kind();
8824 }
8825
8826 /// @copydoc dnnl::primitive_desc_base::get_direction()const
8827 rnn_direction get_direction() const { return base::get_direction(); }
8828
8829 /// @copydoc dnnl::primitive_desc_base::get_alpha()const
8830 float get_alpha() const { return base::get_alpha(); }
8831
8832 /// @copydoc dnnl::primitive_desc_base::get_beta()const
8833 float get_beta() const { return base::get_beta(); }
8834 };
8835
8836 /// Default constructor. Produces an empty object.
8837 vanilla_rnn_backward() = default;
8838
8839 /// Constructs a vanilla RNN backward propagation primitive.
8840 /// @param pd Primitive descriptor for a vanilla RNN backward
8841 /// propagation primitive.
8842 vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
8843
8844 /// Constructs a vanilla RNN backward propagation primitive from
8845 /// a cache blob.
8846 /// @param pd Primitive descriptor for a vanilla RNN backward
8847 /// propagation primitive.
8848 /// @param cache_blob Cache blob.
8849 vanilla_rnn_backward(
8850 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
8851 : primitive(pd, cache_blob) {}
8852};
8853
8854/// LSTM forward propagation primitive.
8855struct lstm_forward : public primitive {
8856 /// Primitive descriptor for an LSTM forward propagation primitive.
8857 struct primitive_desc : public rnn_primitive_desc_base {
8858 /// Default constructor. Produces an empty object.
8859 primitive_desc() = default;
8860
8861 /// Constructs a primitive descriptor for an LSTM (with or without
8862 /// peephole and with or without projection) forward propagation
8863 /// primitive.
8864 ///
8865 /// The following arguments may point to a zero memory descriptor:
8866 /// - @p src_iter_desc together with @p src_iter_c_desc,
8867 /// - @p weights_peephole_desc,
8868 /// - @p bias_desc,
8869 /// - @p dst_iter_desc together with @p dst_iter_c_desc.
8870 ///
8871 /// This would then indicate that the LSTM forward propagation
8872 /// primitive should not use them and should default to zero values
8873 /// instead.
8874 ///
8875 /// The @p weights_projection_desc may point to a zero memory
8876 /// descriptor. This would then indicate that the LSTM doesn't have
8877 /// recurrent projection layer.
8878 ///
8879 /// @note
8880 /// All memory descriptors can be initialized with an
8881 /// #dnnl::memory::format_tag::any value of @p format_tag.
8882 ///
8883 /// @param aengine Engine to use.
8884 /// @param aprop_kind Propagation kind. Possible values are
8885 /// #dnnl::prop_kind::forward_training, and
8886 /// #dnnl::prop_kind::forward_inference.
8887 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
8888 /// more info.
8889 /// @param src_layer_desc Memory descriptor for the input vector.
8890 /// @param src_iter_desc Memory descriptor for the input recurrent
8891 /// hidden state vector.
8892 /// @param src_iter_c_desc Memory descriptor for the input recurrent
8893 /// cell state vector.
8894 /// @param weights_layer_desc Memory descriptor for the weights
8895 /// applied to the layer input.
8896 /// @param weights_iter_desc Memory descriptor for the weights applied
8897 /// to the recurrent input.
8898 /// @param weights_peephole_desc Memory descriptor for the weights
8899 /// applied to the cell states (according to the Peephole LSTM
8900 /// formula).
8901 /// @param weights_projection_desc Memory descriptor for the weights
8902 /// applied to the hidden states to get the recurrent projection
8903 /// (according to the Projection LSTM formula).
8904 /// @param bias_desc Bias memory descriptor.
8905 /// @param dst_layer_desc Memory descriptor for the output vector.
8906 /// @param dst_iter_desc Memory descriptor for the output recurrent
8907 /// hidden state vector.
8908 /// @param dst_iter_c_desc Memory descriptor for the output recurrent
8909 /// cell state vector.
8910 /// @param attr Primitive attributes to use. Attributes are optional
8911 /// and default to empty attributes.
8912 /// @param allow_empty A flag signifying whether construction is
8913 /// allowed to fail without throwing an exception. In this case an
8914 /// empty object will be produced. This flag is optional and
8915 /// defaults to false.
8916 primitive_desc(const engine &aengine, prop_kind aprop_kind,
8917 rnn_direction direction, const memory::desc &src_layer_desc,
8918 const memory::desc &src_iter_desc,
8919 const memory::desc &src_iter_c_desc,
8920 const memory::desc &weights_layer_desc,
8921 const memory::desc &weights_iter_desc,
8922 const memory::desc &weights_peephole_desc,
8923 const memory::desc &weights_projection_desc,
8924 const memory::desc &bias_desc,
8925 const memory::desc &dst_layer_desc,
8926 const memory::desc &dst_iter_desc,
8927 const memory::desc &dst_iter_c_desc,
8928 const primitive_attr &attr = default_attr(),
8929 bool allow_empty = false)
8930 : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
8931 aprop_kind, algorithm::undef, direction, src_layer_desc,
8932 src_iter_desc, &src_iter_c_desc, nullptr,
8933 weights_layer_desc, weights_iter_desc,
8934 &weights_peephole_desc, &weights_projection_desc, bias_desc,
8935 dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
8936 rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
8937
8938 /// Constructs a primitive descriptor for an LSTM (with or without
8939 /// peephole) forward propagation primitive.
8940 ///
8941 /// The following arguments may point to a zero memory descriptor:
8942 /// - @p src_iter_desc together with @p src_iter_c_desc,
8943 /// - @p weights_peephole_desc,
8944 /// - @p bias_desc,
8945 /// - @p dst_iter_desc together with @p dst_iter_c_desc.
8946 ///
8947 /// This would then indicate that the LSTM forward propagation
8948 /// primitive should not use them and should default to zero values
8949 /// instead.
8950 ///
8951 /// @note
8952 /// All memory descriptors can be initialized with an
8953 /// #dnnl::memory::format_tag::any value of @p format_tag.
8954 ///
8955 /// @param aengine Engine to use.
8956 /// @param aprop_kind Propagation kind. Possible values are
8957 /// #dnnl::prop_kind::forward_training, and
8958 /// #dnnl::prop_kind::forward_inference.
8959 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
8960 /// more info.
8961 /// @param src_layer_desc Memory descriptor for the input vector.
8962 /// @param src_iter_desc Memory descriptor for the input recurrent
8963 /// hidden state vector.
8964 /// @param src_iter_c_desc Memory descriptor for the input recurrent
8965 /// cell state vector.
8966 /// @param weights_layer_desc Memory descriptor for the weights
8967 /// applied to the layer input.
8968 /// @param weights_iter_desc Memory descriptor for the weights applied
8969 /// to the recurrent input.
8970 /// @param weights_peephole_desc Memory descriptor for the weights
8971 /// applied to the cell states (according to the Peephole LSTM
8972 /// formula).
8973 /// @param bias_desc Bias memory descriptor.
8974 /// @param dst_layer_desc Memory descriptor for the output vector.
8975 /// @param dst_iter_desc Memory descriptor for the output recurrent
8976 /// hidden state vector.
8977 /// @param dst_iter_c_desc Memory descriptor for the output recurrent
8978 /// cell state vector.
8979 /// @param attr Primitive attributes to use. Attributes are optional
8980 /// and default to empty attributes.
8981 /// @param allow_empty A flag signifying whether construction is
8982 /// allowed to fail without throwing an exception. In this case an
8983 /// empty object will be produced. This flag is optional and
8984 /// defaults to false.
8985 primitive_desc(const engine &aengine, prop_kind aprop_kind,
8986 rnn_direction direction, const memory::desc &src_layer_desc,
8987 const memory::desc &src_iter_desc,
8988 const memory::desc &src_iter_c_desc,
8989 const memory::desc &weights_layer_desc,
8990 const memory::desc &weights_iter_desc,
8991 const memory::desc &weights_peephole_desc,
8992 const memory::desc &bias_desc,
8993 const memory::desc &dst_layer_desc,
8994 const memory::desc &dst_iter_desc,
8995 const memory::desc &dst_iter_c_desc,
8996 const primitive_attr &attr = default_attr(),
8997 bool allow_empty = false)
8998 : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
8999 aprop_kind, algorithm::undef, direction, src_layer_desc,
9000 src_iter_desc, &src_iter_c_desc, nullptr,
9001 weights_layer_desc, weights_iter_desc,
9002 &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
9003 dst_iter_desc, &dst_iter_c_desc, rnn_flags::undef, 0.0f,
9004 0.0f, attr, allow_empty) {}
9005
9006 /// Constructs a primitive descriptor for an LSTM forward propagation
9007 /// primitive.
9008 ///
9009 /// The following arguments may point to a zero memory descriptor:
9010 /// - @p src_iter_desc together with @p src_iter_c_desc,
9011 /// - @p bias_desc,
9012 /// - @p dst_iter_desc together with @p dst_iter_c_desc.
9013 ///
9014 /// This would then indicate that the LSTM forward propagation
9015 /// primitive should not use them and should default to zero values
9016 /// instead.
9017 ///
9018 /// @note
9019 /// All memory descriptors can be initialized with an
9020 /// #dnnl::memory::format_tag::any value of @p format_tag.
9021 ///
9022 /// @param aengine Engine to use.
9023 /// @param aprop_kind Propagation kind. Possible values are
9024 /// #dnnl::prop_kind::forward_training, and
9025 /// #dnnl::prop_kind::forward_inference.
9026 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
9027 /// more info.
9028 /// @param src_layer_desc Memory descriptor for the input vector.
9029 /// @param src_iter_desc Memory descriptor for the input recurrent
9030 /// hidden state vector.
9031 /// @param src_iter_c_desc Memory descriptor for the input recurrent
9032 /// cell state vector.
9033 /// @param weights_layer_desc Memory descriptor for the weights
9034 /// applied to the layer input.
9035 /// @param weights_iter_desc Memory descriptor for the weights applied
9036 /// to the recurrent input.
9037 /// @param bias_desc Bias memory descriptor.
9038 /// @param dst_layer_desc Memory descriptor for the output vector.
9039 /// @param dst_iter_desc Memory descriptor for the output recurrent
9040 /// hidden state vector.
9041 /// @param dst_iter_c_desc Memory descriptor for the output recurrent
9042 /// cell state vector.
9043 /// @param attr Primitive attributes to use. Attributes are optional
9044 /// and default to empty attributes.
9045 /// @param allow_empty A flag signifying whether construction is
9046 /// allowed to fail without throwing an exception. In this case an
9047 /// empty object will be produced. This flag is optional and
9048 /// defaults to false.
9049 primitive_desc(const engine &aengine, prop_kind aprop_kind,
9050 rnn_direction direction, const memory::desc &src_layer_desc,
9051 const memory::desc &src_iter_desc,
9052 const memory::desc &src_iter_c_desc,
9053 const memory::desc &weights_layer_desc,
9054 const memory::desc &weights_iter_desc,
9055 const memory::desc &bias_desc,
9056 const memory::desc &dst_layer_desc,
9057 const memory::desc &dst_iter_desc,
9058 const memory::desc &dst_iter_c_desc,
9059 const primitive_attr &attr = default_attr(),
9060 bool allow_empty = false)
9061 : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
9062 aprop_kind, algorithm::undef, direction, src_layer_desc,
9063 src_iter_desc, &src_iter_c_desc, nullptr,
9064 weights_layer_desc, weights_iter_desc, nullptr, nullptr,
9065 bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
9066 rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
9067
9068 /// Constructs a primitive descriptor for an LSTM forward propagation
9069 /// primitive from a C API primitive descriptor that must have a
9070 /// matching kind.
9071 ///
9072 /// @param pd C API primitive descriptor for an LSTM forward
9073 /// propagation primitive.
9074 primitive_desc(dnnl_primitive_desc_t pd)
9075 : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
9076 dnnl::prop_kind::forward_inference,
9077 dnnl::algorithm::vanilla_lstm) {}
9078
9079 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
9080 memory::desc src_layer_desc() const {
9081 return rnn_base::src_layer_desc();
9082 }
9083
9084 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
9085 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
9086
9087 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
9088 memory::desc src_iter_c_desc() const {
9089 return rnn_base::src_iter_c_desc();
9090 }
9091
9092 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
9093 memory::desc weights_layer_desc() const {
9094 return rnn_base::weights_layer_desc();
9095 }
9096
9097 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
9098 memory::desc weights_iter_desc() const {
9099 return rnn_base::weights_iter_desc();
9100 }
9101
9102 /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
9103 memory::desc weights_peephole_desc() const {
9104 return rnn_base::weights_peephole_desc();
9105 }
9106
9107 /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
9108 memory::desc weights_projection_desc() const {
9109 return rnn_base::weights_projection_desc();
9110 }
9111
9112 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
9113 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
9114
9115 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
9116 memory::desc dst_layer_desc() const {
9117 return rnn_base::dst_layer_desc();
9118 }
9119
9120 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
9121 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
9122
9123 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
9124 memory::desc dst_iter_c_desc() const {
9125 return rnn_base::dst_iter_c_desc();
9126 }
9127
9128 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
9129 memory::desc workspace_desc() const {
9130 return rnn_base::workspace_desc();
9131 }
9132
9133 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
9134 algorithm get_cell_kind() const { return base::get_cell_kind(); }
9135
9136 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
9137 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
9138
9139 /// @copydoc dnnl::primitive_desc_base::get_direction()const
9140 rnn_direction get_direction() const { return base::get_direction(); }
9141 };
9142
9143 /// Default constructor. Produces an empty object.
9144 lstm_forward() = default;
9145
9146 /// Constructs an LSTM forward propagation primitive.
9147 /// @param pd Primitive descriptor for an LSTM forward propagation
9148 /// primitive.
9149 lstm_forward(const primitive_desc &pd) : primitive(pd) {}
9150
9151 /// Constructs an LSTM forward propagation primitive from a cache blob.
9152 /// @param pd Primitive descriptor for an LSTM forward propagation
9153 /// primitive.
9154 /// @param cache_blob Cache blob.
9155 lstm_forward(
9156 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
9157 : primitive(pd, cache_blob) {}
9158};
9159
9160/// LSTM backward propagation primitive.
9161struct lstm_backward : public primitive {
9162 /// Primitive descriptor for an LSTM backward propagation primitive.
9163 struct primitive_desc : public rnn_primitive_desc_base {
9164 /// Default constructor. Produces an empty object.
9165 primitive_desc() = default;
9166
9167 /// Constructs an LSTM (with or without peephole and with or without
9168 /// projection) primitive descriptor for backward propagation
9169 /// using @p prop_kind, @p direction, and memory descriptors.
9170 ///
9171 /// The following arguments may point to a zero memory descriptor:
9172 /// - @p src_iter_desc together with @p src_iter_c_desc,
9173 /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
9174 /// - @p weights_peephole_desc together with
9175 /// @p diff_weights_peephole_desc
9176 /// - @p bias_desc together with @p diff_bias_desc,
9177 /// - @p dst_iter_desc together with @p dst_iter_c_desc,
9178 /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
9179 ///
9180 /// This would then indicate that the LSTM backward propagation
9181 /// primitive should not use them and should default to zero values
9182 /// instead.
9183 ///
9184 /// The @p weights_projection_desc together with @p
9185 /// diff_weights_projection_desc may point to a zero memory descriptor.
9186 /// This would then indicate that the LSTM doesn't have recurrent
9187 /// projection layer.
9188 ///
9189 /// @note
9190 /// All memory descriptors can be initialized with
9191 /// #dnnl::memory::format_tag::any value of @p format_tag.
9192 ///
9193 /// @param aengine Engine to use.
9194 /// @param aprop_kind Propagation kind. Must be
9195 /// #dnnl::prop_kind::backward.
9196 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
9197 /// more info.
9198 /// @param src_layer_desc Memory descriptor for the input vector.
9199 /// @param src_iter_desc Memory descriptor for the input recurrent
9200 /// hidden state vector.
9201 /// @param src_iter_c_desc Memory descriptor for the input recurrent
9202 /// cell state vector.
9203 /// @param weights_layer_desc Memory descriptor for the weights
9204 /// applied to the layer input.
9205 /// @param weights_iter_desc Memory descriptor for the weights applied
9206 /// to the recurrent input.
9207 /// @param weights_peephole_desc Memory descriptor for the weights
9208 /// applied to the cell states (according to the Peephole LSTM
9209 /// formula).
9210 /// @param weights_projection_desc Memory descriptor for the weights
9211 /// applied to the hidden states to get the recurrent projection
9212 /// (according to the Projection LSTM formula).
9213 /// @param bias_desc Bias memory descriptor.
9214 /// @param dst_layer_desc Memory descriptor for the output vector.
9215 /// @param dst_iter_desc Memory descriptor for the output recurrent
9216 /// hidden state vector.
9217 /// @param dst_iter_c_desc Memory descriptor for the output recurrent
9218 /// cell state vector.
9219 /// @param diff_src_layer_desc Memory descriptor for the diff of input
9220 /// vector.
9221 /// @param diff_src_iter_desc Memory descriptor for the diff of input
9222 /// recurrent hidden state vector.
9223 /// @param diff_src_iter_c_desc Memory descriptor for the diff of
9224 /// input recurrent cell state vector.
9225 /// @param diff_weights_layer_desc Memory descriptor for the diff of
9226 /// weights applied to the layer input.
9227 /// @param diff_weights_iter_desc Memory descriptor for the diff of
9228 /// weights applied to the recurrent input.
9229 /// @param diff_weights_peephole_desc Memory descriptor for the diff of
9230 /// weights applied to the cell states (according to the Peephole
9231 /// LSTM formula).
9232 /// @param diff_weights_projection_desc Memory descriptor for the diff
9233 /// of weights applied to the hidden states to get the recurrent
9234 /// projection (according to the Projection LSTM formula).
9235 /// @param diff_bias_desc Diff bias memory descriptor.
9236 /// @param diff_dst_layer_desc Memory descriptor for the diff of
9237 /// output vector.
9238 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
9239 /// recurrent hidden state vector.
9240 /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
9241 /// output recurrent cell state vector.
9242 /// @param hint_fwd_pd Primitive descriptor for an LSTM
9243 /// forward propagation primitive. It is used as a hint for
9244 /// deciding which memory format to use.
9245 /// @param attr Primitive attributes to use. Attributes are optional
9246 /// and default to empty attributes.
9247 /// @param allow_empty A flag signifying whether construction is
9248 /// allowed to fail without throwing an exception. In this case an
9249 /// empty object will be produced. This flag is optional and
9250 /// defaults to false.
9251 primitive_desc(const engine &aengine, prop_kind aprop_kind,
9252 rnn_direction direction, const memory::desc &src_layer_desc,
9253 const memory::desc &src_iter_desc,
9254 const memory::desc &src_iter_c_desc,
9255 const memory::desc &weights_layer_desc,
9256 const memory::desc &weights_iter_desc,
9257 const memory::desc &weights_peephole_desc,
9258 const memory::desc &weights_projection_desc,
9259 const memory::desc &bias_desc,
9260 const memory::desc &dst_layer_desc,
9261 const memory::desc &dst_iter_desc,
9262 const memory::desc &dst_iter_c_desc,
9263 const memory::desc &diff_src_layer_desc,
9264 const memory::desc &diff_src_iter_desc,
9265 const memory::desc &diff_src_iter_c_desc,
9266 const memory::desc &diff_weights_layer_desc,
9267 const memory::desc &diff_weights_iter_desc,
9268 const memory::desc &diff_weights_peephole_desc,
9269 const memory::desc &diff_weights_projection_desc,
9270 const memory::desc &diff_bias_desc,
9271 const memory::desc &diff_dst_layer_desc,
9272 const memory::desc &diff_dst_iter_desc,
9273 const memory::desc &diff_dst_iter_c_desc,
9274 const lstm_forward::primitive_desc &hint_fwd_pd,
9275 const primitive_attr &attr = default_attr(),
9276 bool allow_empty = false)
9277 : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
9278 aprop_kind, algorithm::undef, direction, src_layer_desc,
9279 src_iter_desc, &src_iter_c_desc, nullptr,
9280 weights_layer_desc, weights_iter_desc,
9281 &weights_peephole_desc, &weights_projection_desc, bias_desc,
9282 dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
9283 diff_src_layer_desc, diff_src_iter_desc,
9284 &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
9285 diff_weights_iter_desc, &diff_weights_peephole_desc,
9286 &diff_weights_projection_desc, diff_bias_desc,
9287 diff_dst_layer_desc, diff_dst_iter_desc,
9288 &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
9289 hint_fwd_pd, attr, allow_empty) {}
9290
9291 /// Constructs an LSTM (with or without peephole) primitive descriptor
9292 /// for backward propagation using @p prop_kind, @p direction,
9293 /// and memory descriptors.
9294 ///
9295 /// The following arguments may point to a zero memory descriptor:
9296 /// - @p src_iter_desc together with @p src_iter_c_desc,
9297 /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
9298 /// - @p weights_peephole_desc together with
9299 /// @p diff_weights_peephole_desc
9300 /// - @p bias_desc together with @p diff_bias_desc,
9301 /// - @p dst_iter_desc together with @p dst_iter_c_desc,
9302 /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
9303 ///
9304 /// This would then indicate that the LSTM backward propagation
9305 /// primitive should not use them and should default to zero values
9306 /// instead.
9307 ///
9308 /// @note
9309 /// All memory descriptors may be initialized with
9310 /// #dnnl::memory::format_tag::any value of @p format_tag.
9311 ///
9312 /// @param aengine Engine to use.
9313 /// @param aprop_kind Propagation kind. Must be
9314 /// #dnnl::prop_kind::backward.
9315 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
9316 /// more info.
9317 /// @param src_layer_desc Memory descriptor for the input vector.
9318 /// @param src_iter_desc Memory descriptor for the input recurrent
9319 /// hidden state vector.
9320 /// @param src_iter_c_desc Memory descriptor for the input recurrent
9321 /// cell state vector.
9322 /// @param weights_layer_desc Memory descriptor for the weights
9323 /// applied to the layer input.
9324 /// @param weights_iter_desc Memory descriptor for the weights applied
9325 /// to the recurrent input.
9326 /// @param weights_peephole_desc Memory descriptor for the weights
9327 /// applied to the cell states (according to the Peephole LSTM
9328 /// formula).
9329 /// @param bias_desc Bias memory descriptor.
9330 /// @param dst_layer_desc Memory descriptor for the output vector.
9331 /// @param dst_iter_desc Memory descriptor for the output recurrent
9332 /// hidden state vector.
9333 /// @param dst_iter_c_desc Memory descriptor for the output recurrent
9334 /// cell state vector.
9335 /// @param diff_src_layer_desc Memory descriptor for the diff of input
9336 /// vector.
9337 /// @param diff_src_iter_desc Memory descriptor for the diff of input
9338 /// recurrent hidden state vector.
9339 /// @param diff_src_iter_c_desc Memory descriptor for the diff of
9340 /// input recurrent cell state vector.
9341 /// @param diff_weights_layer_desc Memory descriptor for the diff of
9342 /// weights applied to the layer input.
9343 /// @param diff_weights_iter_desc Memory descriptor for the diff of
9344 /// weights applied to the recurrent input.
9345 /// @param diff_weights_peephole_desc Memory descriptor for the diff of
9346 /// weights applied to the cell states (according to the Peephole
9347 /// LSTM formula).
9348 /// @param diff_bias_desc Diff bias memory descriptor.
9349 /// @param diff_dst_layer_desc Memory descriptor for the diff of
9350 /// output vector.
9351 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
9352 /// recurrent hidden state vector.
9353 /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
9354 /// output recurrent cell state vector.
9355 /// @param hint_fwd_pd Primitive descriptor for an LSTM
9356 /// forward propagation primitive. It is used as a hint for
9357 /// deciding which memory format to use.
9358 /// @param attr Primitive attributes to use. Attributes are optional
9359 /// and default to empty attributes.
9360 /// @param allow_empty A flag signifying whether construction is
9361 /// allowed to fail without throwing an exception. In this case an
9362 /// empty object will be produced. This flag is optional and
9363 /// defaults to false.
9364 primitive_desc(const engine &aengine, prop_kind aprop_kind,
9365 rnn_direction direction, const memory::desc &src_layer_desc,
9366 const memory::desc &src_iter_desc,
9367 const memory::desc &src_iter_c_desc,
9368 const memory::desc &weights_layer_desc,
9369 const memory::desc &weights_iter_desc,
9370 const memory::desc &weights_peephole_desc,
9371 const memory::desc &bias_desc,
9372 const memory::desc &dst_layer_desc,
9373 const memory::desc &dst_iter_desc,
9374 const memory::desc &dst_iter_c_desc,
9375 const memory::desc &diff_src_layer_desc,
9376 const memory::desc &diff_src_iter_desc,
9377 const memory::desc &diff_src_iter_c_desc,
9378 const memory::desc &diff_weights_layer_desc,
9379 const memory::desc &diff_weights_iter_desc,
9380 const memory::desc &diff_weights_peephole_desc,
9381 const memory::desc &diff_bias_desc,
9382 const memory::desc &diff_dst_layer_desc,
9383 const memory::desc &diff_dst_iter_desc,
9384 const memory::desc &diff_dst_iter_c_desc,
9385 const lstm_forward::primitive_desc &hint_fwd_pd,
9386 const primitive_attr &attr = default_attr(),
9387 bool allow_empty = false)
9388 : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
9389 aprop_kind, algorithm::undef, direction, src_layer_desc,
9390 src_iter_desc, &src_iter_c_desc, nullptr,
9391 weights_layer_desc, weights_iter_desc,
9392 &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
9393 dst_iter_desc, &dst_iter_c_desc, diff_src_layer_desc,
9394 diff_src_iter_desc, &diff_src_iter_c_desc, nullptr,
9395 diff_weights_layer_desc, diff_weights_iter_desc,
9396 &diff_weights_peephole_desc, nullptr, diff_bias_desc,
9397 diff_dst_layer_desc, diff_dst_iter_desc,
9398 &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
9399 hint_fwd_pd, attr, allow_empty) {}
9400
9401 /// Constructs an LSTM primitive descriptor for backward propagation
9402 /// using @p prop_kind, @p direction, and memory descriptors.
9403 ///
9404 /// The following arguments may point to a zero memory descriptor:
9405 /// - @p src_iter_desc together with @p src_iter_c_desc,
9406 /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
9407 /// - @p bias_desc together with @p diff_bias_desc,
9408 /// - @p dst_iter_desc together with @p dst_iter_c_desc,
9409 /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
9410 ///
9411 /// This would then indicate that the LSTM backward propagation
9412 /// primitive should not use them and should default to zero values
9413 /// instead.
9414 ///
9415 /// @note
9416 /// All memory descriptors may be initialized with
9417 /// #dnnl::memory::format_tag::any value of @p format_tag.
9418 ///
9419 /// @param aengine Engine to use.
9420 /// @param aprop_kind Propagation kind. Must be
9421 /// #dnnl::prop_kind::backward.
9422 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
9423 /// more info.
9424 /// @param src_layer_desc Memory descriptor for the input vector.
9425 /// @param src_iter_desc Memory descriptor for the input recurrent
9426 /// hidden state vector.
9427 /// @param src_iter_c_desc Memory descriptor for the input recurrent
9428 /// cell state vector.
9429 /// @param weights_layer_desc Memory descriptor for the weights
9430 /// applied to the layer input.
9431 /// @param weights_iter_desc Memory descriptor for the weights applied
9432 /// to the recurrent input.
9433 /// @param bias_desc Bias memory descriptor.
9434 /// @param dst_layer_desc Memory descriptor for the output vector.
9435 /// @param dst_iter_desc Memory descriptor for the output recurrent
9436 /// hidden state vector.
9437 /// @param dst_iter_c_desc Memory descriptor for the output recurrent
9438 /// cell state vector.
9439 /// @param diff_src_layer_desc Memory descriptor for the diff of input
9440 /// vector.
9441 /// @param diff_src_iter_desc Memory descriptor for the diff of input
9442 /// recurrent hidden state vector.
9443 /// @param diff_src_iter_c_desc Memory descriptor for the diff of
9444 /// input recurrent cell state vector.
9445 /// @param diff_weights_layer_desc Memory descriptor for the diff of
9446 /// weights applied to the layer input.
9447 /// @param diff_weights_iter_desc Memory descriptor for the diff of
9448 /// weights applied to the recurrent input.
9449 /// @param diff_bias_desc Diff bias memory descriptor.
9450 /// @param diff_dst_layer_desc Memory descriptor for the diff of
9451 /// output vector.
9452 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
9453 /// recurrent hidden state vector.
9454 /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
9455 /// output recurrent cell state vector.
9456 /// @param hint_fwd_pd Primitive descriptor for a convolution
9457 /// forward propagation primitive. It is used as a hint for
9458 /// deciding which memory format to use.
9459 /// @param attr Primitive attributes to use. Attributes are optional
9460 /// and default to empty attributes.
9461 /// @param allow_empty A flag signifying whether construction is
9462 /// allowed to fail without throwing an exception. In this case an
9463 /// empty object will be produced. This flag is optional and
9464 /// defaults to false.
9465 primitive_desc(const engine &aengine, prop_kind aprop_kind,
9466 rnn_direction direction, const memory::desc &src_layer_desc,
9467 const memory::desc &src_iter_desc,
9468 const memory::desc &src_iter_c_desc,
9469 const memory::desc &weights_layer_desc,
9470 const memory::desc &weights_iter_desc,
9471 const memory::desc &bias_desc,
9472 const memory::desc &dst_layer_desc,
9473 const memory::desc &dst_iter_desc,
9474 const memory::desc &dst_iter_c_desc,
9475 const memory::desc &diff_src_layer_desc,
9476 const memory::desc &diff_src_iter_desc,
9477 const memory::desc &diff_src_iter_c_desc,
9478 const memory::desc &diff_weights_layer_desc,
9479 const memory::desc &diff_weights_iter_desc,
9480 const memory::desc &diff_bias_desc,
9481 const memory::desc &diff_dst_layer_desc,
9482 const memory::desc &diff_dst_iter_desc,
9483 const memory::desc &diff_dst_iter_c_desc,
9484 const lstm_forward::primitive_desc &hint_fwd_pd,
9485 const primitive_attr &attr = default_attr(),
9486 bool allow_empty = false)
9487 : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
9488 aprop_kind, algorithm::undef, direction, src_layer_desc,
9489 src_iter_desc, &src_iter_c_desc, nullptr,
9490 weights_layer_desc, weights_iter_desc, nullptr, nullptr,
9491 bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
9492 diff_src_layer_desc, diff_src_iter_desc,
9493 &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
9494 diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
9495 diff_dst_layer_desc, diff_dst_iter_desc,
9496 &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
9497 hint_fwd_pd, attr, allow_empty) {}
9498
9499 /// Constructs a primitive descriptor for an LSTM backward propagation
9500 /// primitive from a C API primitive descriptor that must have a
9501 /// matching kind.
9502 ///
9503 /// @param pd C API primitive descriptor for an LSTM backward
9504 /// propagation primitive.
9505 primitive_desc(dnnl_primitive_desc_t pd)
9506 : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
9507 dnnl::algorithm::vanilla_lstm) {}
9508
9509 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
9510 memory::desc src_layer_desc() const {
9511 return rnn_base::src_layer_desc();
9512 }
9513
9514 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
9515 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
9516
9517 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
9518 memory::desc src_iter_c_desc() const {
9519 return rnn_base::src_iter_c_desc();
9520 }
9521
9522 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
9523 memory::desc weights_layer_desc() const {
9524 return rnn_base::weights_layer_desc();
9525 }
9526
9527 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
9528 memory::desc weights_iter_desc() const {
9529 return rnn_base::weights_iter_desc();
9530 }
9531
9532 /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
9533 memory::desc weights_peephole_desc() const {
9534 return rnn_base::weights_peephole_desc();
9535 }
9536
9537 /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
9538 memory::desc weights_projection_desc() const {
9539 return rnn_base::weights_projection_desc();
9540 }
9541
9542 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
9543 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
9544
9545 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
9546 memory::desc dst_layer_desc() const {
9547 return rnn_base::dst_layer_desc();
9548 }
9549
9550 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
9551 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
9552
9553 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
9554 memory::desc dst_iter_c_desc() const {
9555 return rnn_base::dst_iter_c_desc();
9556 }
9557
9558 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
9559 memory::desc workspace_desc() const {
9560 return rnn_base::workspace_desc();
9561 }
9562
9563 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
9564 memory::desc diff_src_layer_desc() const {
9565 return rnn_base::diff_src_layer_desc();
9566 }
9567
9568 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
9569 memory::desc diff_src_iter_desc() const {
9570 return rnn_base::diff_src_iter_desc();
9571 }
9572
9573 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const
9574 memory::desc diff_src_iter_c_desc() const {
9575 return rnn_base::diff_src_iter_c_desc();
9576 }
9577
9578 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
9579 memory::desc diff_weights_layer_desc() const {
9580 return rnn_base::diff_weights_layer_desc();
9581 }
9582
9583 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
9584 memory::desc diff_weights_iter_desc() const {
9585 return rnn_base::diff_weights_iter_desc();
9586 }
9587
9588 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const
9589 memory::desc diff_weights_peephole_desc() const {
9590 return rnn_base::diff_weights_peephole_desc();
9591 }
9592
9593 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const
9594 memory::desc diff_weights_projection_desc() const {
9595 return rnn_base::diff_weights_projection_desc();
9596 }
9597
9598 /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
9599 memory::desc diff_bias_desc() const {
9600 return rnn_base::diff_bias_desc();
9601 }
9602
9603 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
9604 memory::desc diff_dst_layer_desc() const {
9605 return rnn_base::diff_dst_layer_desc();
9606 }
9607
9608 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
9609 memory::desc diff_dst_iter_desc() const {
9610 return rnn_base::diff_dst_iter_desc();
9611 }
9612
9613 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const
9614 memory::desc diff_dst_iter_c_desc() const {
9615 return rnn_base::diff_dst_iter_c_desc();
9616 }
9617
9618 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
9619 algorithm get_cell_kind() const { return base::get_cell_kind(); }
9620
9621 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
9622 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
9623
9624 /// @copydoc dnnl::primitive_desc_base::get_direction()const
9625 rnn_direction get_direction() const { return base::get_direction(); }
9626 };
9627
9628 /// Default constructor. Produces an empty object.
9629 lstm_backward() = default;
9630
9631 /// Constructs an LSTM backward propagation primitive.
9632 /// @param pd Primitive descriptor for an LSTM backward propagation
9633 /// primitive.
9634 lstm_backward(const primitive_desc &pd) : primitive(pd) {}
9635
9636 /// Constructs an LSTM backward propagation primitive from a cache blob.
9637 /// @param pd Primitive descriptor for an LSTM backward propagation
9638 /// primitive.
9639 /// @param cache_blob Cache blob.
9640 lstm_backward(
9641 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
9642 : primitive(pd, cache_blob) {}
9643};
9644
9645/// GRU forward propagation primitive.
9646struct gru_forward : public primitive {
9647 /// Primitive descriptor for a GRU forward propagation primitive.
9648 struct primitive_desc : public rnn_primitive_desc_base {
9649 /// Default constructor. Produces an empty object.
9650 primitive_desc() = default;
9651
9652 /// Constructs a primitive descriptor for a GRU forward propagation
9653 /// primitive.
9654 ///
9655 /// The following arguments may point to a zero memory descriptor:
9656 /// - @p src_iter_desc,
9657 /// - @p bias_desc,
9658 /// - @p dst_iter_desc.
9659 ///
9660 /// This would then indicate that the GRU forward propagation primitive
9661 /// should not use them and should default to zero values instead.
9662 ///
9663 /// @note
9664 /// All memory descriptors except @p src_iter_desc may be
9665 /// initialized with an #dnnl::memory::format_tag::any value of @p
9666 /// format_tag.
9667 ///
9668 /// @param aengine Engine to use.
9669 /// @param aprop_kind Propagation kind. Possible values are
9670 /// #dnnl::prop_kind::forward_training, and
9671 /// #dnnl::prop_kind::forward_inference.
9672 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
9673 /// more info.
9674 /// @param src_layer_desc Memory descriptor for the input vector.
9675 /// @param src_iter_desc Memory descriptor for the input recurrent
9676 /// hidden state vector.
9677 /// @param weights_layer_desc Memory descriptor for the weights
9678 /// applied to the layer input.
9679 /// @param weights_iter_desc Memory descriptor for the weights applied
9680 /// to the recurrent input.
9681 /// @param bias_desc Bias memory descriptor.
9682 /// @param dst_layer_desc Memory descriptor for the output vector.
9683 /// @param dst_iter_desc Memory descriptor for the output recurrent
9684 /// hidden state vector.
9685 /// @param attr Primitive attributes to use. Attributes are optional
9686 /// and default to empty attributes.
9687 /// @param allow_empty A flag signifying whether construction is
9688 /// allowed to fail without throwing an exception. In this case an
9689 /// empty object will be produced. This flag is optional and
9690 /// defaults to false.
9691 primitive_desc(const engine &aengine, prop_kind aprop_kind,
9692 rnn_direction direction, const memory::desc &src_layer_desc,
9693 const memory::desc &src_iter_desc,
9694 const memory::desc &weights_layer_desc,
9695 const memory::desc &weights_iter_desc,
9696 const memory::desc &bias_desc,
9697 const memory::desc &dst_layer_desc,
9698 const memory::desc &dst_iter_desc,
9699 const primitive_attr &attr = default_attr(),
9700 bool allow_empty = false)
9701 : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
9702 aprop_kind, algorithm::undef, direction, src_layer_desc,
9703 src_iter_desc, nullptr, nullptr, weights_layer_desc,
9704 weights_iter_desc, nullptr, nullptr, bias_desc,
9705 dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
9706 0.0f, 0.0f, attr, allow_empty) {}
9707
9708 /// Constructs a primitive descriptor for a GRU forward propagation
9709 /// primitive from a C API primitive descriptor that must have a
9710 /// matching kind.
9711 ///
9712 /// @param pd C API primitive descriptor for a GRU forward
9713 /// propagation primitive.
9714 primitive_desc(dnnl_primitive_desc_t pd)
9715 : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
9716 dnnl::prop_kind::forward_inference,
9717 dnnl::algorithm::vanilla_gru) {}
9718
9719 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
9720 memory::desc src_layer_desc() const {
9721 return rnn_base::src_layer_desc();
9722 }
9723
9724 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
9725 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
9726
9727 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
9728 memory::desc weights_layer_desc() const {
9729 return rnn_base::weights_layer_desc();
9730 }
9731
9732 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
9733 memory::desc weights_iter_desc() const {
9734 return rnn_base::weights_iter_desc();
9735 }
9736
9737 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
9738 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
9739
9740 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
9741 memory::desc dst_layer_desc() const {
9742 return rnn_base::dst_layer_desc();
9743 }
9744
9745 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
9746 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
9747
9748 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
9749 memory::desc workspace_desc() const {
9750 return rnn_base::workspace_desc();
9751 }
9752
9753 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
9754 algorithm get_cell_kind() const { return base::get_cell_kind(); }
9755
9756 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
9757 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
9758
9759 /// @copydoc dnnl::primitive_desc_base::get_direction()const
9760 rnn_direction get_direction() const { return base::get_direction(); }
9761 };
9762
9763 /// Default constructor. Produces an empty object.
9764 gru_forward() = default;
9765
9766 /// Constructs a GRU forward propagation primitive.
9767 /// @param pd Primitive descriptor for a GRU forward propagation
9768 /// primitive.
9769 gru_forward(const primitive_desc &pd) : primitive(pd) {}
9770
9771 /// Constructs a GRU forward propagation primitive from a cache blob.
9772 /// @param pd Primitive descriptor for a GRU forward propagation
9773 /// primitive.
9774 /// @param cache_blob Cache blob.
9775 gru_forward(
9776 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
9777 : primitive(pd, cache_blob) {}
9778};
9779
9780/// GRU backward propagation primitive.
9781struct gru_backward : public primitive {
9782 /// Primitive descriptor for a GRU backward propagation primitive.
9783 struct primitive_desc : public rnn_primitive_desc_base {
9784 /// Default constructor. Produces an empty object.
9785 primitive_desc() = default;
9786
9787 /// Constructs a primitive descriptor for a GRU backward propagation
9788 /// primitive.
9789 ///
9790 /// The following arguments may point to a zero memory descriptor:
9791 /// - @p src_iter_desc together with @p diff_src_iter_desc,
9792 /// - @p bias_desc together with @p diff_bias_desc,
9793 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
9794 ///
9795 /// This would then indicate that the GRU backward propagation
9796 /// primitive should not use them and should default to zero values
9797 /// instead.
9798 ///
9799 /// @note
9800 /// All memory descriptors may be initialized with
9801 /// #dnnl::memory::format_tag::any value of @p format_tag.
9802 ///
9803 /// @param aengine Engine to use.
9804 /// @param aprop_kind Propagation kind. Must be
9805 /// #dnnl::prop_kind::backward.
9806 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
9807 /// more info.
9808 /// @param src_layer_desc Memory descriptor for the input vector.
9809 /// @param src_iter_desc Memory descriptor for the input recurrent
9810 /// hidden state vector.
9811 /// @param weights_layer_desc Memory descriptor for the weights
9812 /// applied to the layer input.
9813 /// @param weights_iter_desc Memory descriptor for the weights applied
9814 /// to the recurrent input.
9815 /// @param bias_desc Bias memory descriptor.
9816 /// @param dst_layer_desc Memory descriptor for the output vector.
9817 /// @param dst_iter_desc Memory descriptor for the output recurrent
9818 /// hidden state vector.
9819 /// @param diff_src_layer_desc Memory descriptor for the diff of input
9820 /// vector.
9821 /// @param diff_src_iter_desc Memory descriptor for the diff of input
9822 /// recurrent hidden state vector.
9823 /// @param diff_weights_layer_desc Memory descriptor for the diff of
9824 /// weights applied to the layer input.
9825 /// @param diff_weights_iter_desc Memory descriptor for the diff of
9826 /// weights applied to the recurrent input.
9827 /// @param diff_bias_desc Diff bias memory descriptor.
9828 /// @param diff_dst_layer_desc Memory descriptor for the diff of
9829 /// output vector.
9830 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
9831 /// recurrent hidden state vector.
9832 /// @param hint_fwd_pd Primitive descriptor for a GRU
9833 /// forward propagation primitive. It is used as a hint for
9834 /// deciding which memory format to use.
9835 /// @param attr Primitive attributes to use. Attributes are optional
9836 /// and default to empty attributes.
9837 /// @param allow_empty A flag signifying whether construction is
9838 /// allowed to fail without throwing an exception. In this case an
9839 /// empty object will be produced. This flag is optional and
9840 /// defaults to false.
9841 primitive_desc(const engine &aengine, prop_kind aprop_kind,
9842 rnn_direction direction, const memory::desc &src_layer_desc,
9843 const memory::desc &src_iter_desc,
9844 const memory::desc &weights_layer_desc,
9845 const memory::desc &weights_iter_desc,
9846 const memory::desc &bias_desc,
9847 const memory::desc &dst_layer_desc,
9848 const memory::desc &dst_iter_desc,
9849 const memory::desc &diff_src_layer_desc,
9850 const memory::desc &diff_src_iter_desc,
9851 const memory::desc &diff_weights_layer_desc,
9852 const memory::desc &diff_weights_iter_desc,
9853 const memory::desc &diff_bias_desc,
9854 const memory::desc &diff_dst_layer_desc,
9855 const memory::desc &diff_dst_iter_desc,
9856 const gru_forward::primitive_desc &hint_fwd_pd,
9857 const primitive_attr &attr = default_attr(),
9858 bool allow_empty = false)
9859 : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
9860 aprop_kind, algorithm::undef, direction, src_layer_desc,
9861 src_iter_desc, nullptr, nullptr, weights_layer_desc,
9862 weights_iter_desc, nullptr, nullptr, bias_desc,
9863 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
9864 diff_src_iter_desc, nullptr, nullptr,
9865 diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
9866 nullptr, diff_bias_desc, diff_dst_layer_desc,
9867 diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
9868 hint_fwd_pd, attr, allow_empty) {}
9869
9870 /// Constructs a primitive descriptor for a GRU backward propagation
9871 /// primitive from a C API primitive descriptor that must have a
9872 /// matching kind.
9873 ///
9874 /// @param pd C API primitive descriptor for a GRU backward
9875 /// propagation primitive.
9876 primitive_desc(dnnl_primitive_desc_t pd)
9877 : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
9878 dnnl::algorithm::vanilla_gru) {}
9879
9880 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
9881 memory::desc src_layer_desc() const {
9882 return rnn_base::src_layer_desc();
9883 }
9884
9885 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
9886 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
9887
9888 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
9889 memory::desc weights_layer_desc() const {
9890 return rnn_base::weights_layer_desc();
9891 }
9892
9893 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
9894 memory::desc weights_iter_desc() const {
9895 return rnn_base::weights_iter_desc();
9896 }
9897
9898 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
9899 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
9900
9901 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
9902 memory::desc dst_layer_desc() const {
9903 return rnn_base::dst_layer_desc();
9904 }
9905
9906 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
9907 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
9908
9909 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
9910 memory::desc workspace_desc() const {
9911 return rnn_base::workspace_desc();
9912 }
9913
9914 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
9915 memory::desc diff_src_layer_desc() const {
9916 return rnn_base::diff_src_layer_desc();
9917 }
9918
9919 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
9920 memory::desc diff_src_iter_desc() const {
9921 return rnn_base::diff_src_iter_desc();
9922 }
9923
9924 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
9925 memory::desc diff_weights_layer_desc() const {
9926 return rnn_base::diff_weights_layer_desc();
9927 }
9928
9929 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
9930 memory::desc diff_weights_iter_desc() const {
9931 return rnn_base::diff_weights_iter_desc();
9932 }
9933
9934 /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
9935 memory::desc diff_bias_desc() const {
9936 return rnn_base::diff_bias_desc();
9937 }
9938
9939 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
9940 memory::desc diff_dst_layer_desc() const {
9941 return rnn_base::diff_dst_layer_desc();
9942 }
9943
9944 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
9945 memory::desc diff_dst_iter_desc() const {
9946 return rnn_base::diff_dst_iter_desc();
9947 }
9948
9949 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
9950 algorithm get_cell_kind() const { return base::get_cell_kind(); }
9951
9952 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
9953 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
9954
9955 /// @copydoc dnnl::primitive_desc_base::get_direction()const
9956 rnn_direction get_direction() const { return base::get_direction(); }
9957 };
9958
9959 /// Default constructor. Produces an empty object.
9960 gru_backward() = default;
9961
9962 /// Constructs a GRU backward propagation primitive.
9963 /// @param pd Primitive descriptor for a GRU backward propagation
9964 /// primitive.
9965 gru_backward(const primitive_desc &pd) : primitive(pd) {}
9966
9967 /// Constructs a GRU backward propagation primitive from a cache blob.
9968 /// @param pd Primitive descriptor for a GRU backward propagation
9969 /// primitive.
9970 /// @param cache_blob Cache blob.
9971 gru_backward(
9972 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
9973 : primitive(pd, cache_blob) {}
9974};
9975
9976/// LBR GRU forward propagation primitive.
9977struct lbr_gru_forward : public primitive {
9978 /// Primitive descriptor for an LBR GRU forward propagation primitive.
9979 struct primitive_desc : public rnn_primitive_desc_base {
9980 /// Default constructor. Produces an empty object.
9981 primitive_desc() = default;
9982
9983 /// Constructs a primitive descriptor for LBR GRU forward propagation
9984 /// primitive.
9985 ///
9986 /// The following arguments may point to a zero memory descriptor:
9987 /// - @p src_iter_desc,
9988 /// - @p bias_desc,
9989 /// - @p dst_iter_desc.
9990 ///
9991 /// This would then indicate that the LBR GRU forward propagation
9992 /// primitive should not use them and should default to zero values
9993 /// instead.
9994 ///
9995 /// @note
9996 /// All memory descriptors except @p src_iter_desc may be
9997 /// initialized with an #dnnl::memory::format_tag::any value of @p
9998 /// format_tag.
9999 ///
10000 /// @param aengine Engine to use.
10001 /// @param aprop_kind Propagation kind. Possible values are
10002 /// #dnnl::prop_kind::forward_training, and
10003 /// #dnnl::prop_kind::forward_inference.
10004 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
10005 /// more info.
10006 /// @param src_layer_desc Memory descriptor for the input vector.
10007 /// @param src_iter_desc Memory descriptor for the input recurrent
10008 /// hidden state vector.
10009 /// @param weights_layer_desc Memory descriptor for the weights
10010 /// applied to the layer input.
10011 /// @param weights_iter_desc Memory descriptor for the weights applied
10012 /// to the recurrent input.
10013 /// @param bias_desc Bias memory descriptor.
10014 /// @param dst_layer_desc Memory descriptor for the output vector.
10015 /// @param dst_iter_desc Memory descriptor for the output recurrent
10016 /// hidden state vector.
10017 /// @param attr Primitive attributes to use. Attributes are optional
10018 /// and default to empty attributes.
10019 /// @param allow_empty A flag signifying whether construction is
10020 /// allowed to fail without throwing an exception. In this case an
10021 /// empty object will be produced. This flag is optional and
10022 /// defaults to false.
10023 primitive_desc(const engine &aengine, prop_kind aprop_kind,
10024 rnn_direction direction, const memory::desc &src_layer_desc,
10025 const memory::desc &src_iter_desc,
10026 const memory::desc &weights_layer_desc,
10027 const memory::desc &weights_iter_desc,
10028 const memory::desc &bias_desc,
10029 const memory::desc &dst_layer_desc,
10030 const memory::desc &dst_iter_desc,
10031 const primitive_attr &attr = default_attr(),
10032 bool allow_empty = false)
10033 : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
10034 algorithm::undef, direction, src_layer_desc, src_iter_desc,
10035 nullptr, nullptr, weights_layer_desc, weights_iter_desc,
10036 nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
10037 nullptr, rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
10038
10039 /// Constructs a primitive descriptor for a LBR GRU forward propagation
10040 /// primitive from a C API primitive descriptor that must have a
10041 /// matching kind.
10042 ///
10043 /// @param pd C API primitive descriptor for a LBR GRU forward
10044 /// propagation primitive.
10045 primitive_desc(dnnl_primitive_desc_t pd)
10046 : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
10047 dnnl::prop_kind::forward_inference,
10048 dnnl::algorithm::lbr_gru) {}
10049
10050 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
10051 memory::desc src_layer_desc() const {
10052 return rnn_base::src_layer_desc();
10053 }
10054
10055 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
10056 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
10057
10058 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
10059 memory::desc weights_layer_desc() const {
10060 return rnn_base::weights_layer_desc();
10061 }
10062
10063 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
10064 memory::desc weights_iter_desc() const {
10065 return rnn_base::weights_iter_desc();
10066 }
10067
10068 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
10069 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
10070
10071 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
10072 memory::desc dst_layer_desc() const {
10073 return rnn_base::dst_layer_desc();
10074 }
10075
10076 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
10077 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
10078
10079 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
10080 memory::desc workspace_desc() const {
10081 return rnn_base::workspace_desc();
10082 }
10083
10084 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
10085 algorithm get_cell_kind() const { return base::get_cell_kind(); }
10086
10087 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
10088 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
10089
10090 /// @copydoc dnnl::primitive_desc_base::get_direction()const
10091 rnn_direction get_direction() const { return base::get_direction(); }
10092 };
10093
10094 /// Default constructor. Produces an empty object.
10095 lbr_gru_forward() = default;
10096
10097 /// Constructs an LBR GRU forward propagation primitive.
10098 /// @param pd Primitive descriptor for an LBR GRU forward propagation
10099 /// primitive.
10100 lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
10101
10102 /// Constructs an LBR GRU forward propagation primitive from a cache blob.
10103 /// @param pd Primitive descriptor for an LBR GRU forward propagation
10104 /// primitive.
10105 /// @param cache_blob Cache blob.
10106 lbr_gru_forward(
10107 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
10108 : primitive(pd, cache_blob) {}
10109};
10110
10111/// LBR GRU backward propagation primitive.
10112struct lbr_gru_backward : public primitive {
10113 /// Primitive descriptor for an LBR GRU backward propagation primitive.
10114 struct primitive_desc : public rnn_primitive_desc_base {
10115 /// Default constructor. Produces an empty object.
10116 primitive_desc() = default;
10117
10118 /// Constructs a primitive descriptor for LBR GRU backward propagation
10119 /// primitive.
10120 ///
10121 /// The following arguments may point to a zero memory descriptor:
10122 /// - @p src_iter_desc together with @p diff_src_iter_desc,
10123 /// - @p bias_desc together with @p diff_bias_desc,
10124 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
10125 ///
10126 /// This would then indicate that the LBR GRU backward propagation
10127 /// primitive should not use them and should default to zero values
10128 /// instead.
10129 ///
10130 /// @note
10131 /// All memory descriptors may be initialized with
10132 /// #dnnl::memory::format_tag::any value of @p format_tag.
10133 ///
10134 /// @param aengine Engine to use.
10135 /// @param aprop_kind Propagation kind. Must be
10136 /// #dnnl::prop_kind::backward.
10137 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
10138 /// more info.
10139 /// @param src_layer_desc Memory descriptor for the input vector.
10140 /// @param src_iter_desc Memory descriptor for the input recurrent
10141 /// hidden state vector.
10142 /// @param weights_layer_desc Memory descriptor for the weights
10143 /// applied to the layer input.
10144 /// @param weights_iter_desc Memory descriptor for the weights applied
10145 /// to the recurrent input.
10146 /// @param bias_desc Bias memory descriptor.
10147 /// @param dst_layer_desc Memory descriptor for the output vector.
10148 /// @param dst_iter_desc Memory descriptor for the output recurrent
10149 /// hidden state vector.
10150 /// @param diff_src_layer_desc Memory descriptor for the diff of input
10151 /// vector.
10152 /// @param diff_src_iter_desc Memory descriptor for the diff of input
10153 /// recurrent hidden state vector.
10154 /// @param diff_weights_layer_desc Memory descriptor for the diff of
10155 /// weights applied to the layer input.
10156 /// @param diff_weights_iter_desc Memory descriptor for the diff of
10157 /// weights applied to the recurrent input.
10158 /// @param diff_bias_desc Diff bias memory descriptor.
10159 /// @param diff_dst_layer_desc Memory descriptor for the diff of
10160 /// output vector.
10161 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
10162 /// recurrent hidden state vector.
10163 /// @param hint_fwd_pd Primitive descriptor for an LBR GRU
10164 /// forward propagation primitive. It is used as a hint for
10165 /// deciding which memory format to use.
10166 /// @param attr Primitive attributes to use. Attributes are optional
10167 /// and default to empty attributes.
10168 /// @param allow_empty A flag signifying whether construction is
10169 /// allowed to fail without throwing an exception. In this case an
10170 /// empty object will be produced. This flag is optional and
10171 /// defaults to false.
10172 primitive_desc(const engine &aengine, prop_kind aprop_kind,
10173 rnn_direction direction, const memory::desc &src_layer_desc,
10174 const memory::desc &src_iter_desc,
10175 const memory::desc &weights_layer_desc,
10176 const memory::desc &weights_iter_desc,
10177 const memory::desc &bias_desc,
10178 const memory::desc &dst_layer_desc,
10179 const memory::desc &dst_iter_desc,
10180 const memory::desc &diff_src_layer_desc,
10181 const memory::desc &diff_src_iter_desc,
10182 const memory::desc &diff_weights_layer_desc,
10183 const memory::desc &diff_weights_iter_desc,
10184 const memory::desc &diff_bias_desc,
10185 const memory::desc &diff_dst_layer_desc,
10186 const memory::desc &diff_dst_iter_desc,
10187 const gru_forward::primitive_desc &hint_fwd_pd,
10188 const primitive_attr &attr = default_attr(),
10189 bool allow_empty = false)
10190 : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
10191 algorithm::undef, direction, src_layer_desc, src_iter_desc,
10192 nullptr, nullptr, weights_layer_desc, weights_iter_desc,
10193 nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
10194 nullptr, diff_src_layer_desc, diff_src_iter_desc, nullptr,
10195 nullptr, diff_weights_layer_desc, diff_weights_iter_desc,
10196 nullptr, nullptr, diff_bias_desc, diff_dst_layer_desc,
10197 diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
10198 hint_fwd_pd, attr, allow_empty) {}
10199
10200 /// Constructs a primitive descriptor for a LBR GRU backward propagation
10201 /// primitive from a C API primitive descriptor that must have a
10202 /// matching kind.
10203 ///
10204 /// @param pd C API primitive descriptor for a LBR GRU backward
10205 /// propagation primitive.
10206 primitive_desc(dnnl_primitive_desc_t pd)
10207 : rnn_primitive_desc_base(
10208 pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {}
10209
10210 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
10211 memory::desc src_layer_desc() const {
10212 return rnn_base::src_layer_desc();
10213 }
10214
10215 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
10216 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
10217
10218 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
10219 memory::desc weights_layer_desc() const {
10220 return rnn_base::weights_layer_desc();
10221 }
10222
10223 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
10224 memory::desc weights_iter_desc() const {
10225 return rnn_base::weights_iter_desc();
10226 }
10227
10228 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
10229 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
10230
10231 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
10232 memory::desc dst_layer_desc() const {
10233 return rnn_base::dst_layer_desc();
10234 }
10235
10236 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
10237 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
10238
10239 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
10240 memory::desc workspace_desc() const {
10241 return rnn_base::workspace_desc();
10242 }
10243
10244 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
10245 memory::desc diff_src_layer_desc() const {
10246 return rnn_base::diff_src_layer_desc();
10247 }
10248
10249 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
10250 memory::desc diff_src_iter_desc() const {
10251 return rnn_base::diff_src_iter_desc();
10252 }
10253
10254 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
10255 memory::desc diff_weights_layer_desc() const {
10256 return rnn_base::diff_weights_layer_desc();
10257 }
10258
10259 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
10260 memory::desc diff_weights_iter_desc() const {
10261 return rnn_base::diff_weights_iter_desc();
10262 }
10263
10264 /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
10265 memory::desc diff_bias_desc() const {
10266 return rnn_base::diff_bias_desc();
10267 }
10268
10269 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
10270 memory::desc diff_dst_layer_desc() const {
10271 return rnn_base::diff_dst_layer_desc();
10272 }
10273
10274 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
10275 memory::desc diff_dst_iter_desc() const {
10276 return rnn_base::diff_dst_iter_desc();
10277 }
10278
10279 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
10280 algorithm get_cell_kind() const { return base::get_cell_kind(); }
10281
10282 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
10283 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
10284
10285 /// @copydoc dnnl::primitive_desc_base::get_direction()const
10286 rnn_direction get_direction() const { return base::get_direction(); }
10287 };
10288
10289 /// Default constructor. Produces an empty object.
10290 lbr_gru_backward() = default;
10291
10292 /// Constructs an LBR GRU backward propagation primitive.
10293 /// @param pd Primitive descriptor for an LBR GRU backward propagation
10294 /// primitive.
10295 lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
10296
10297 /// Constructs an LBR GRU backward propagation primitive from a cache blob.
10298 /// @param pd Primitive descriptor for an LBR GRU backward propagation
10299 /// primitive.
10300 /// @param cache_blob Cache blob.
10301 lbr_gru_backward(
10302 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
10303 : primitive(pd, cache_blob) {}
10304};
10305
10306/// AUGRU forward propagation primitive.
10307struct augru_forward : public primitive {
10308 /// Primitive descriptor for an AUGRU forward propagation primitive.
10309 struct primitive_desc : public rnn_primitive_desc_base {
10310 /// Default constructor. Produces an empty object.
10311 primitive_desc() = default;
10312
10313 /// Constructs a primitive descriptor for an AUGRU forward propagation
10314 /// primitive.
10315 ///
10316 /// The following arguments may point to a zero memory descriptor:
10317 /// - @p src_iter_desc,
10318 /// - @p bias_desc,
10319 /// - @p dst_iter_desc.
10320 ///
10321 /// This would then indicate that the AUGRU forward propagation
10322 /// primitive should not use them and should default to zero values
10323 /// instead.
10324 ///
10325 /// @note
10326 /// All memory descriptors except @p src_iter_desc may be
10327 /// initialized with an #dnnl::memory::format_tag::any value of @p
10328 /// format_tag.
10329 ///
10330 /// @param aengine Engine to use.
10331 /// @param aprop_kind Propagation kind. Possible values are
10332 /// #dnnl::prop_kind::forward_training, and
10333 /// #dnnl::prop_kind::forward_inference.
10334 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
10335 /// more info.
10336 /// @param src_layer_desc Memory descriptor for the input vector.
10337 /// @param src_iter_desc Memory descriptor for the input recurrent
10338 /// hidden state vector.
10339 /// @param attention_desc Memory descriptor for the attention vector.
10340 /// @param weights_layer_desc Memory descriptor for the weights
10341 /// applied to the layer input.
10342 /// @param weights_iter_desc Memory descriptor for the weights applied
10343 /// to the recurrent input.
10344 /// @param bias_desc Bias memory descriptor.
10345 /// @param dst_layer_desc Memory descriptor for the output vector.
10346 /// @param dst_iter_desc Memory descriptor for the output recurrent
10347 /// hidden state vector.
10348 /// @param attr Primitive attributes to use. Attributes are optional
10349 /// and default to empty attributes.
10350 /// @param allow_empty A flag signifying whether construction is
10351 /// allowed to fail without throwing an exception. In this case an
10352 /// empty object will be produced. This flag is optional and
10353 /// defaults to false.
10354 primitive_desc(const engine &aengine, prop_kind aprop_kind,
10355 rnn_direction direction, const memory::desc &src_layer_desc,
10356 const memory::desc &src_iter_desc,
10357 const memory::desc &attention_desc,
10358 const memory::desc &weights_layer_desc,
10359 const memory::desc &weights_iter_desc,
10360 const memory::desc &bias_desc,
10361 const memory::desc &dst_layer_desc,
10362 const memory::desc &dst_iter_desc,
10363 const primitive_attr &attr = default_attr(),
10364 bool allow_empty = false)
10365 : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
10366 aprop_kind, algorithm::undef, direction, src_layer_desc,
10367 src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
10368 weights_iter_desc, nullptr, nullptr, bias_desc,
10369 dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
10370 0.0f, 0.0f, attr, allow_empty) {}
10371
10372 /// Constructs a primitive descriptor for an AUGRU forward propagation
10373 /// primitive from a C API primitive descriptor that must have a
10374 /// matching kind.
10375 ///
10376 /// @param pd C API primitive descriptor for an AUGRU forward
10377 /// propagation primitive.
10378 primitive_desc(dnnl_primitive_desc_t pd)
10379 : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
10380 dnnl::prop_kind::forward_inference,
10381 dnnl::algorithm::vanilla_augru) {}
10382
10383 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
10384 memory::desc src_layer_desc() const {
10385 return rnn_base::src_layer_desc();
10386 }
10387
10388 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
10389 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
10390
10391 /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
10392 memory::desc attention_desc() const {
10393 return rnn_base::augru_attention_desc();
10394 }
10395
10396 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
10397 memory::desc weights_layer_desc() const {
10398 return rnn_base::weights_layer_desc();
10399 }
10400
10401 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
10402 memory::desc weights_iter_desc() const {
10403 return rnn_base::weights_iter_desc();
10404 }
10405
10406 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
10407 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
10408
10409 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
10410 memory::desc dst_layer_desc() const {
10411 return rnn_base::dst_layer_desc();
10412 }
10413
10414 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
10415 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
10416
10417 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
10418 memory::desc workspace_desc() const {
10419 return rnn_base::workspace_desc();
10420 }
10421
10422 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
10423 algorithm get_cell_kind() const { return base::get_cell_kind(); }
10424
10425 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
10426 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
10427
10428 /// @copydoc dnnl::primitive_desc_base::get_direction()const
10429 rnn_direction get_direction() const { return base::get_direction(); }
10430 };
10431
10432 /// Default constructor. Produces an empty object.
10433 augru_forward() = default;
10434
10435 /// Constructs an AUGRU forward propagation primitive.
10436 /// @param pd Primitive descriptor for an AUGRU forward propagation
10437 /// primitive.
10438 augru_forward(const primitive_desc &pd) : primitive(pd) {}
10439
10440 /// Constructs an AUGRU forward propagation primitive from a cache blob.
10441 /// @param pd Primitive descriptor for an AUGRU forward propagation
10442 /// primitive.
10443 /// @param cache_blob Cache blob.
10444 augru_forward(
10445 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
10446 : primitive(pd, cache_blob) {}
10447};
10448
10449/// AUGRU backward propagation primitive.
10450struct augru_backward : public primitive {
10451 /// Descriptor for an AUGRU backward propagation primitive.
10452 /// Primitive descriptor for an AUGRU backward propagation primitive.
10453 struct primitive_desc : public rnn_primitive_desc_base {
10454 /// Default constructor. Produces an empty object.
10455 primitive_desc() = default;
10456
10457 /// Constructs a primitive descriptor for an AUGRU backward propagation
10458 /// primitive.
10459 ///
10460 /// The following arguments may point to a zero memory descriptor:
10461 /// - @p src_iter_desc together with @p diff_src_iter_desc,
10462 /// - @p bias_desc together with @p diff_bias_desc,
10463 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
10464 ///
10465 /// This would then indicate that the AUGRU backward propagation
10466 /// primitive should not use them and should default to zero values
10467 /// instead.
10468 ///
10469 /// @note
10470 /// All memory descriptors may be initialized with
10471 /// #dnnl::memory::format_tag::any value of @p format_tag.
10472 ///
10473 /// @param aengine Engine to use.
10474 /// @param aprop_kind Propagation kind. Must be
10475 /// #dnnl::prop_kind::backward.
10476 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
10477 /// more info.
10478 /// @param src_layer_desc Memory descriptor for the input vector.
10479 /// @param src_iter_desc Memory descriptor for the input recurrent
10480 /// hidden state vector.
10481 /// @param attention_desc Memory descriptor for the attention vector.
10482 /// @param weights_layer_desc Memory descriptor for the weights
10483 /// applied to the layer input.
10484 /// @param weights_iter_desc Memory descriptor for the weights applied
10485 /// to the recurrent input.
10486 /// @param bias_desc Bias memory descriptor.
10487 /// @param dst_layer_desc Memory descriptor for the output vector.
10488 /// @param dst_iter_desc Memory descriptor for the output recurrent
10489 /// hidden state vector.
10490 /// @param diff_src_layer_desc Memory descriptor for the diff of input
10491 /// vector.
10492 /// @param diff_src_iter_desc Memory descriptor for the diff of input
10493 /// recurrent hidden state vector.
10494 /// @param diff_attention_desc Memory descriptor for the diff of
10495 /// attention vector.
10496 /// @param diff_weights_layer_desc Memory descriptor for the diff of
10497 /// weights applied to the layer input.
10498 /// @param diff_weights_iter_desc Memory descriptor for the diff of
10499 /// weights applied to the recurrent input.
10500 /// @param diff_bias_desc Diff bias memory descriptor.
10501 /// @param diff_dst_layer_desc Memory descriptor for the diff of
10502 /// output vector.
10503 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
10504 /// recurrent hidden state vector.
10505 /// @param hint_fwd_pd Primitive descriptor for an AUGRU
10506 /// forward propagation primitive. It is used as a hint for
10507 /// deciding which memory format to use.
10508 /// @param attr Primitive attributes to use. Attributes are optional
10509 /// and default to empty attributes.
10510 /// @param allow_empty A flag signifying whether construction is
10511 /// allowed to fail without throwing an exception. In this case an
10512 /// empty object will be produced. This flag is optional and
10513 /// defaults to false.
10514 primitive_desc(const engine &aengine, prop_kind aprop_kind,
10515 rnn_direction direction, const memory::desc &src_layer_desc,
10516 const memory::desc &src_iter_desc,
10517 const memory::desc &attention_desc,
10518 const memory::desc &weights_layer_desc,
10519 const memory::desc &weights_iter_desc,
10520 const memory::desc &bias_desc,
10521 const memory::desc &dst_layer_desc,
10522 const memory::desc &dst_iter_desc,
10523 const memory::desc &diff_src_layer_desc,
10524 const memory::desc &diff_src_iter_desc,
10525 const memory::desc &diff_attention_desc,
10526 const memory::desc &diff_weights_layer_desc,
10527 const memory::desc &diff_weights_iter_desc,
10528 const memory::desc &diff_bias_desc,
10529 const memory::desc &diff_dst_layer_desc,
10530 const memory::desc &diff_dst_iter_desc,
10531 const gru_forward::primitive_desc &hint_fwd_pd,
10532 const primitive_attr &attr = default_attr(),
10533 bool allow_empty = false)
10534 : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
10535 aprop_kind, algorithm::undef, direction, src_layer_desc,
10536 src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
10537 weights_iter_desc, nullptr, nullptr, bias_desc,
10538 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
10539 diff_src_iter_desc, nullptr, &diff_attention_desc,
10540 diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
10541 nullptr, diff_bias_desc, diff_dst_layer_desc,
10542 diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
10543 hint_fwd_pd, attr, allow_empty) {}
10544
10545 /// Constructs a primitive descriptor for an AUGRU backward propagation
10546 /// primitive from a C API primitive descriptor that must have a
10547 /// matching kind.
10548 ///
10549 /// @param pd C API primitive descriptor for an AUGRU backward
10550 /// propagation primitive.
10551 primitive_desc(dnnl_primitive_desc_t pd)
10552 : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
10553 dnnl::algorithm::vanilla_augru) {}
10554
10555 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
10556 memory::desc src_layer_desc() const {
10557 return rnn_base::src_layer_desc();
10558 }
10559
10560 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
10561 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
10562
10563 /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
10564 memory::desc attention_desc() const {
10565 return rnn_base::augru_attention_desc();
10566 }
10567
10568 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
10569 memory::desc weights_layer_desc() const {
10570 return rnn_base::weights_layer_desc();
10571 }
10572
10573 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
10574 memory::desc weights_iter_desc() const {
10575 return rnn_base::weights_iter_desc();
10576 }
10577
10578 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
10579 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
10580
10581 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
10582 memory::desc dst_layer_desc() const {
10583 return rnn_base::dst_layer_desc();
10584 }
10585
10586 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
10587 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
10588
10589 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
10590 memory::desc workspace_desc() const {
10591 return rnn_base::workspace_desc();
10592 }
10593
10594 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
10595 memory::desc diff_src_layer_desc() const {
10596 return rnn_base::diff_src_layer_desc();
10597 }
10598
10599 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
10600 memory::desc diff_src_iter_desc() const {
10601 return rnn_base::diff_src_iter_desc();
10602 }
10603
10604 /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
10605 memory::desc diff_attention_desc() const {
10606 return rnn_base::diff_augru_attention_desc();
10607 }
10608
10609 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
10610 memory::desc diff_weights_layer_desc() const {
10611 return rnn_base::diff_weights_layer_desc();
10612 }
10613
10614 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
10615 memory::desc diff_weights_iter_desc() const {
10616 return rnn_base::diff_weights_iter_desc();
10617 }
10618
10619 /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
10620 memory::desc diff_bias_desc() const {
10621 return rnn_base::diff_bias_desc();
10622 }
10623
10624 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
10625 memory::desc diff_dst_layer_desc() const {
10626 return rnn_base::diff_dst_layer_desc();
10627 }
10628
10629 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
10630 memory::desc diff_dst_iter_desc() const {
10631 return rnn_base::diff_dst_iter_desc();
10632 }
10633
10634 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
10635 algorithm get_cell_kind() const { return base::get_cell_kind(); }
10636
10637 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
10638 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
10639
10640 /// @copydoc dnnl::primitive_desc_base::get_direction()const
10641 rnn_direction get_direction() const { return base::get_direction(); }
10642 };
10643
10644 /// Default constructor. Produces an empty object.
10645 augru_backward() = default;
10646
10647 /// Constructs an AUGRU backward propagation primitive.
10648 /// @param pd Primitive descriptor for an AUGRU backward propagation
10649 /// primitive.
10650 augru_backward(const primitive_desc &pd) : primitive(pd) {}
10651
10652 /// Constructs an AUGRU backward propagation primitive from a cache blob.
10653 /// @param pd Primitive descriptor for an AUGRU backward propagation
10654 /// primitive.
10655 /// @param cache_blob Cache blob.
10656 augru_backward(
10657 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
10658 : primitive(pd, cache_blob) {}
10659};
10660
10661/// LBR AUGRU forward propagation primitive.
10662struct lbr_augru_forward : public primitive {
10663 /// Descriptor for an LBR AUGRU forward propagation primitive.
10664
10665 /// Primitive descriptor for an LBR AUGRU forward propagation primitive.
10666 struct primitive_desc : public rnn_primitive_desc_base {
10667 /// Default constructor. Produces an empty object.
10668 primitive_desc() = default;
10669
10670 /// Constructs a primitive descriptor for LBR AUGRU forward propagation
10671 /// primitive.
10672 ///
10673 /// The following arguments may point to a zero memory descriptor:
10674 /// - @p src_iter_desc,
10675 /// - @p bias_desc,
10676 /// - @p dst_iter_desc.
10677 ///
10678 /// This would then indicate that the LBR AUGRU forward propagation
10679 /// primitive should not use them and should default to zero values
10680 /// instead.
10681 ///
10682 /// @note
10683 /// All memory descriptors except @p src_iter_desc may be
10684 /// initialized with an #dnnl::memory::format_tag::any value of @p
10685 /// format_tag.
10686 ///
10687 /// @param aengine Engine to use.
10688 /// @param aprop_kind Propagation kind. Possible values are
10689 /// #dnnl::prop_kind::forward_training, and
10690 /// #dnnl::prop_kind::forward_inference.
10691 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
10692 /// more info.
10693 /// @param src_layer_desc Memory descriptor for the input vector.
10694 /// @param src_iter_desc Memory descriptor for the input recurrent
10695 /// hidden state vector.
10696 /// @param attention_desc Memory descriptor for the attention vector.
10697 /// @param weights_layer_desc Memory descriptor for the weights
10698 /// applied to the layer input.
10699 /// @param weights_iter_desc Memory descriptor for the weights applied
10700 /// to the recurrent input.
10701 /// @param bias_desc Bias memory descriptor.
10702 /// @param dst_layer_desc Memory descriptor for the output vector.
10703 /// @param dst_iter_desc Memory descriptor for the output recurrent
10704 /// hidden state vector.
10705 /// @param attr Primitive attributes to use. Attributes are optional
10706 /// and default to empty attributes.
10707 /// @param allow_empty A flag signifying whether construction is
10708 /// allowed to fail without throwing an exception. In this case an
10709 /// empty object will be produced. This flag is optional and
10710 /// defaults to false.
10711 primitive_desc(const engine &aengine, prop_kind aprop_kind,
10712 rnn_direction direction, const memory::desc &src_layer_desc,
10713 const memory::desc &src_iter_desc,
10714 const memory::desc &attention_desc,
10715 const memory::desc &weights_layer_desc,
10716 const memory::desc &weights_iter_desc,
10717 const memory::desc &bias_desc,
10718 const memory::desc &dst_layer_desc,
10719 const memory::desc &dst_iter_desc,
10720 const primitive_attr &attr = default_attr(),
10721 bool allow_empty = false)
10722 : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
10723 algorithm::undef, direction, src_layer_desc, src_iter_desc,
10724 nullptr, &attention_desc, weights_layer_desc,
10725 weights_iter_desc, nullptr, nullptr, bias_desc,
10726 dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
10727 0.0f, 0.0f, attr, allow_empty) {}
10728
10729 /// Constructs a primitive descriptor for an LBR AUGRU forward propagation
10730 /// primitive from a C API primitive descriptor that must have a
10731 /// matching kind.
10732 ///
10733 /// @param pd C API primitive descriptor for an LBR AUGRU forward
10734 /// propagation primitive.
10735 primitive_desc(dnnl_primitive_desc_t pd)
10736 : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
10737 dnnl::prop_kind::forward_inference,
10738 dnnl::algorithm::lbr_augru) {}
10739
10740 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
10741 memory::desc src_layer_desc() const {
10742 return rnn_base::src_layer_desc();
10743 }
10744
10745 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
10746 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
10747
10748 /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
10749 memory::desc attention_desc() const {
10750 return rnn_base::augru_attention_desc();
10751 }
10752
10753 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
10754 memory::desc weights_layer_desc() const {
10755 return rnn_base::weights_layer_desc();
10756 }
10757
10758 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
10759 memory::desc weights_iter_desc() const {
10760 return rnn_base::weights_iter_desc();
10761 }
10762
10763 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
10764 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
10765
10766 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
10767 memory::desc dst_layer_desc() const {
10768 return rnn_base::dst_layer_desc();
10769 }
10770
10771 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
10772 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
10773
10774 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
10775 memory::desc workspace_desc() const {
10776 return rnn_base::workspace_desc();
10777 }
10778
10779 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
10780 algorithm get_cell_kind() const { return base::get_cell_kind(); }
10781
10782 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
10783 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
10784
10785 /// @copydoc dnnl::primitive_desc_base::get_direction()const
10786 rnn_direction get_direction() const { return base::get_direction(); }
10787 };
10788
10789 /// Default constructor. Produces an empty object.
10790 lbr_augru_forward() = default;
10791
10792 /// Constructs an LBR AUGRU forward propagation primitive.
10793 /// @param pd Primitive descriptor for an LBR AUGRU forward propagation
10794 /// primitive.
10795 lbr_augru_forward(const primitive_desc &pd) : primitive(pd) {}
10796
10797 /// Constructs an LBR AUGRU forward propagation primitive from a cache blob.
10798 /// @param pd Primitive descriptor for an LBR AUGRU forward propagation
10799 /// primitive.
10800 /// @param cache_blob Cache blob.
10801 lbr_augru_forward(
10802 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
10803 : primitive(pd, cache_blob) {}
10804};
10805
10806/// LBR AUGRU backward propagation primitive.
10807struct lbr_augru_backward : public primitive {
10808 /// Primitive descriptor for an LBR AUGRU backward propagation primitive.
10809 struct primitive_desc : public rnn_primitive_desc_base {
10810 /// Default constructor. Produces an empty object.
10811 primitive_desc() = default;
10812
10813 /// Constructs a primitive descriptor for LBR AUGRU backward propagation
10814 /// primitive.
10815 ///
10816 /// The following arguments may point to a zero memory descriptor:
10817 /// - @p src_iter_desc together with @p diff_src_iter_desc,
10818 /// - @p bias_desc together with @p diff_bias_desc,
10819 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
10820 ///
10821 /// This would then indicate that the LBR AUGRU backward propagation
10822 /// primitive should not use them and should default to zero values
10823 /// instead.
10824 ///
10825 /// @note
10826 /// All memory descriptors may be initialized with
10827 /// #dnnl::memory::format_tag::any value of @p format_tag.
10828 ///
10829 /// @param aengine Engine to use.
10830 /// @param aprop_kind Propagation kind. Must be
10831 /// #dnnl::prop_kind::backward.
10832 /// @param direction RNN direction. See @ref dnnl::rnn_direction for
10833 /// more info.
10834 /// @param src_layer_desc Memory descriptor for the input vector.
10835 /// @param src_iter_desc Memory descriptor for the input recurrent
10836 /// hidden state vector.
10837 /// @param attention_desc Memory descriptor for the attention vector.
10838 /// @param weights_layer_desc Memory descriptor for the weights
10839 /// applied to the layer input.
10840 /// @param weights_iter_desc Memory descriptor for the weights applied
10841 /// to the recurrent input.
10842 /// @param bias_desc Bias memory descriptor.
10843 /// @param dst_layer_desc Memory descriptor for the output vector.
10844 /// @param dst_iter_desc Memory descriptor for the output recurrent
10845 /// hidden state vector.
10846 /// @param diff_src_layer_desc Memory descriptor for the diff of input
10847 /// vector.
10848 /// @param diff_src_iter_desc Memory descriptor for the diff of input
10849 /// recurrent hidden state vector.
10850 /// @param diff_attention_desc Memory descriptor for the diff of
10851 /// attention vector.
10852 /// @param diff_weights_layer_desc Memory descriptor for the diff of
10853 /// weights applied to the layer input.
10854 /// @param diff_weights_iter_desc Memory descriptor for the diff of
10855 /// weights applied to the recurrent input.
10856 /// @param diff_bias_desc Diff bias memory descriptor.
10857 /// @param diff_dst_layer_desc Memory descriptor for the diff of
10858 /// output vector.
10859 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
10860 /// recurrent hidden state vector.
10861 /// @param hint_fwd_pd Primitive descriptor for an LBR AUGRU
10862 /// forward propagation primitive. It is used as a hint for
10863 /// deciding which memory format to use.
10864 /// @param attr Primitive attributes to use. Attributes are optional
10865 /// and default to empty attributes.
10866 /// @param allow_empty A flag signifying whether construction is
10867 /// allowed to fail without throwing an exception. In this case an
10868 /// empty object will be produced. This flag is optional and
10869 /// defaults to false.
10870 primitive_desc(const engine &aengine, prop_kind aprop_kind,
10871 rnn_direction direction, const memory::desc &src_layer_desc,
10872 const memory::desc &src_iter_desc,
10873 const memory::desc &attention_desc,
10874 const memory::desc &weights_layer_desc,
10875 const memory::desc &weights_iter_desc,
10876 const memory::desc &bias_desc,
10877 const memory::desc &dst_layer_desc,
10878 const memory::desc &dst_iter_desc,
10879 const memory::desc &diff_src_layer_desc,
10880 const memory::desc &diff_src_iter_desc,
10881 const memory::desc &diff_attention_desc,
10882 const memory::desc &diff_weights_layer_desc,
10883 const memory::desc &diff_weights_iter_desc,
10884 const memory::desc &diff_bias_desc,
10885 const memory::desc &diff_dst_layer_desc,
10886 const memory::desc &diff_dst_iter_desc,
10887 const gru_forward::primitive_desc &hint_fwd_pd,
10888 const primitive_attr &attr = default_attr(),
10889 bool allow_empty = false)
10890 : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
10891 algorithm::undef, direction, src_layer_desc, src_iter_desc,
10892 nullptr, &attention_desc, weights_layer_desc,
10893 weights_iter_desc, nullptr, nullptr, bias_desc,
10894 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
10895 diff_src_iter_desc, nullptr, &diff_attention_desc,
10896 diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
10897 nullptr, diff_bias_desc, diff_dst_layer_desc,
10898 diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
10899 hint_fwd_pd, attr, allow_empty) {}
10900
10901 /// Constructs a primitive descriptor for an LBR AUGRU backward
10902 /// propagation primitive from a C API primitive descriptor that must
10903 /// have a matching kind.
10904 ///
10905 /// @param pd C API primitive descriptor for an LBR AUGRU backward
10906 /// propagation primitive.
10907 primitive_desc(dnnl_primitive_desc_t pd)
10908 : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
10909 dnnl::algorithm::lbr_augru) {}
10910
10911 /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
10912 memory::desc src_layer_desc() const {
10913 return rnn_base::src_layer_desc();
10914 }
10915
10916 /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
10917 memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
10918
10919 /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
10920 memory::desc attention_desc() const {
10921 return rnn_base::augru_attention_desc();
10922 }
10923
10924 /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
10925 memory::desc weights_layer_desc() const {
10926 return rnn_base::weights_layer_desc();
10927 }
10928
10929 /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
10930 memory::desc weights_iter_desc() const {
10931 return rnn_base::weights_iter_desc();
10932 }
10933
10934 /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
10935 memory::desc bias_desc() const { return rnn_base::bias_desc(); }
10936
10937 /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
10938 memory::desc dst_layer_desc() const {
10939 return rnn_base::dst_layer_desc();
10940 }
10941
10942 /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
10943 memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
10944
10945 /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
10946 memory::desc workspace_desc() const {
10947 return rnn_base::workspace_desc();
10948 }
10949
10950 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
10951 memory::desc diff_src_layer_desc() const {
10952 return rnn_base::diff_src_layer_desc();
10953 }
10954
10955 /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
10956 memory::desc diff_src_iter_desc() const {
10957 return rnn_base::diff_src_iter_desc();
10958 }
10959
10960 /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
10961 memory::desc diff_attention_desc() const {
10962 return rnn_base::diff_augru_attention_desc();
10963 }
10964
10965 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
10966 memory::desc diff_weights_layer_desc() const {
10967 return rnn_base::diff_weights_layer_desc();
10968 }
10969
10970 /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
10971 memory::desc diff_weights_iter_desc() const {
10972 return rnn_base::diff_weights_iter_desc();
10973 }
10974
10975 /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
10976 memory::desc diff_bias_desc() const {
10977 return rnn_base::diff_bias_desc();
10978 }
10979
10980 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
10981 memory::desc diff_dst_layer_desc() const {
10982 return rnn_base::diff_dst_layer_desc();
10983 }
10984
10985 /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
10986 memory::desc diff_dst_iter_desc() const {
10987 return rnn_base::diff_dst_iter_desc();
10988 }
10989
10990 /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
10991 algorithm get_cell_kind() const { return base::get_cell_kind(); }
10992
10993 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
10994 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
10995
10996 /// @copydoc dnnl::primitive_desc_base::get_direction()const
10997 rnn_direction get_direction() const { return base::get_direction(); }
10998 };
10999
11000 /// Default constructor. Produces an empty object.
11001 lbr_augru_backward() = default;
11002
11003 /// Constructs an LBR AUGRU backward propagation primitive.
11004 /// @param pd Primitive descriptor for an LBR AUGRU backward propagation
11005 /// primitive.
11006 lbr_augru_backward(const primitive_desc &pd) : primitive(pd) {}
11007
11008 /// Constructs an LBR AUGRU backward propagation primitive from a cache blob.
11009 /// @param pd Primitive descriptor for an LBR AUGRU backward propagation
11010 /// primitive.
11011 /// @param cache_blob Cache blob.
11012 lbr_augru_backward(
11013 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11014 : primitive(pd, cache_blob) {}
11015};
11016
11017/// @} dnnl_api_rnn
11018
11019/// @addtogroup dnnl_api_shuffle Shuffle
11020///
11021/// A primitive to shuffle tensor data along an axis.
11022///
11023/// @sa @ref dev_guide_shuffle in developer guide
11024///
11025/// @{
11026
11027/// Shuffle forward propagation primitive.
11028struct shuffle_forward : public primitive {
11029 /// Primitive descriptor for a shuffle forward propagation primitive.
11030 struct primitive_desc : public dnnl::primitive_desc {
11031 /// Default constructor. Produces an empty object.
11032 primitive_desc() = default;
11033
11034 /// Constructs a primitive descriptor for a shuffle forward propagation
11035 /// primitive.
11036 ///
11037 /// @param aengine Engine to use.
11038 /// @param aprop_kind Propagation kind. Possible values are
11039 /// #dnnl::prop_kind::forward_training, and
11040 /// #dnnl::prop_kind::forward_inference.
11041 /// @param src_desc Source memory descriptor.
11042 /// @param dst_desc Destination memory descriptor.
11043 /// @param axis The axis along which the data is shuffled.
11044 /// @param group_size Shuffle group size.
11045 /// @param attr Primitive attributes to use. Attributes are optional
11046 /// and default to empty attributes.
11047 /// @param allow_empty A flag signifying whether construction is
11048 /// allowed to fail without throwing an exception. In this case an
11049 /// empty object will be produced. This flag is optional and
11050 /// defaults to false.
11051 primitive_desc(const engine &aengine, prop_kind aprop_kind,
11052 const memory::desc &src_desc, const memory::desc &dst_desc,
11053 int axis, int group_size,
11054 const primitive_attr &attr = default_attr(),
11055 bool allow_empty = false) {
11056
11057 dnnl_primitive_desc_t pd = nullptr;
11058 dnnl_status_t status = dnnl_shuffle_forward_primitive_desc_create(
11059 &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
11060 src_desc.get(), dst_desc.get(), axis, group_size,
11061 attr.get());
11062
11063 if (!allow_empty)
11064 error::wrap_c_api(status,
11065 "could not create a primitive descriptor for a shuffle "
11066 "forward propagation primitive");
11067 reset(pd);
11068 }
11069
11070 /// Constructs a primitive descriptor for a shuffle forward propagation
11071 /// primitive from a C API primitive descriptor that must have a
11072 /// matching kind.
11073 ///
11074 /// @param pd C API primitive descriptor for a shuffle forward
11075 /// propagation primitive.
11076 primitive_desc(dnnl_primitive_desc_t pd)
11077 : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
11078 dnnl::prop_kind::forward_training,
11079 dnnl::prop_kind::forward_inference) {}
11080
11081 /// @copydoc dnnl::primitive_desc_base::src_desc()const
11082 memory::desc src_desc() const { return base::src_desc(0); }
11083
11084 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
11085 memory::desc dst_desc() const { return base::dst_desc(0); }
11086
11087 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
11088 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
11089
11090 /// @copydoc dnnl::primitive_desc_base::get_axis()const
11091 int get_axis() const { return base::get_axis(); }
11092
11093 /// @copydoc dnnl::primitive_desc_base::get_group_size()const
11094 memory::dim get_group_size() const { return base::get_group_size(); }
11095 };
11096
11097 /// Default constructor. Produces an empty object.
11098 shuffle_forward() = default;
11099
11100 /// Constructs a shuffle forward propagation primitive.
11101 /// @param pd Primitive descriptor for a shuffle forward propagation
11102 /// primitive.
11103 shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
11104
11105 /// Constructs a shuffle forward propagation primitive from a cache blob.
11106 /// @param pd Primitive descriptor for a shuffle forward propagation
11107 /// primitive.
11108 /// @param cache_blob Cache blob.
11109 shuffle_forward(
11110 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11111 : primitive(pd, cache_blob) {}
11112};
11113
11114/// Shuffle backward propagation primitive.
11115struct shuffle_backward : public primitive {
11116 /// Primitive descriptor for a shuffle backward propagation primitive.
11117 struct primitive_desc : public dnnl::primitive_desc {
11118 /// Default constructor. Produces an empty object.
11119 primitive_desc() = default;
11120
11121 /// Constructs a primitive descriptor for a shuffle backward propagation
11122 /// primitive.
11123 ///
11124 /// @param aengine Engine to use.
11125 /// @param diff_src_desc Diff source memory descriptor.
11126 /// @param diff_dst_desc Diff destination memory descriptor.
11127 /// @param axis The axis along which the data is shuffled.
11128 /// @param group_size Shuffle group size.
11129 /// @param hint_fwd_pd Primitive descriptor for a shuffle forward
11130 /// propagation primitive. It is used as a hint for deciding which
11131 /// memory format to use.
11132 /// @param attr Primitive attributes to use. Attributes are optional
11133 /// and default to empty attributes.
11134 /// @param allow_empty A flag signifying whether construction is
11135 /// allowed to fail without throwing an exception. In this case an
11136 /// empty object will be produced. This flag is optional and
11137 /// defaults to false.
11138 primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
11139 const memory::desc &diff_dst_desc, int axis, int group_size,
11140 const shuffle_forward::primitive_desc &hint_fwd_pd,
11141 const primitive_attr &attr = default_attr(),
11142 bool allow_empty = false) {
11143
11144 dnnl_primitive_desc_t pd = nullptr;
11145 dnnl_status_t status = dnnl_shuffle_backward_primitive_desc_create(
11146 &pd, aengine.get(), diff_src_desc.get(),
11147 diff_dst_desc.get(), axis, group_size, hint_fwd_pd.get(),
11148 attr.get());
11149
11150 if (!allow_empty)
11151 error::wrap_c_api(status,
11152 "could not create a primitive descriptor for a shuffle "
11153 "backward propagation primitive");
11154 reset(pd);
11155 }
11156
11157 /// Constructs a primitive descriptor for a shuffle backward
11158 /// propagation primitive from a C API primitive descriptor that must
11159 /// have a matching kind.
11160 ///
11161 /// @param pd C API primitive descriptor for a shuffle backward
11162 /// propagation primitive.
11163 primitive_desc(dnnl_primitive_desc_t pd)
11164 : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
11165 dnnl::prop_kind::backward_data) {}
11166
11167 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
11168 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
11169
11170 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
11171 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
11172
11173 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
11174 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
11175
11176 /// @copydoc dnnl::primitive_desc_base::get_axis()const
11177 int get_axis() const { return base::get_axis(); }
11178
11179 /// @copydoc dnnl::primitive_desc_base::get_group_size()const
11180 memory::dim get_group_size() const { return base::get_group_size(); }
11181 };
11182
11183 /// Default constructor. Produces an empty object.
11184 shuffle_backward() = default;
11185
11186 /// Constructs a shuffle backward propagation primitive.
11187 /// @param pd Primitive descriptor for a shuffle backward propagation
11188 /// primitive.
11189 shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
11190
11191 /// Constructs a shuffle backward propagation primitive from a cache blob.
11192 /// @param pd Primitive descriptor for a shuffle backward propagation
11193 /// primitive.
11194 /// @param cache_blob Cache blob.
11195 shuffle_backward(
11196 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11197 : primitive(pd, cache_blob) {}
11198};
11199
11200/// @} dnnl_api_shuffle
11201
11202/// @addtogroup dnnl_api_binary Binary
11203///
11204/// A primitive to perform tensor operations over two tensors.
11205///
11206/// @sa @ref dev_guide_binary in developer guide
11207///
11208/// @{
11209
11210/// Elementwise binary operator primitive.
11211struct binary : public primitive {
11212 /// Primitive descriptor for an elementwise binary operator primitive.
11213 struct primitive_desc : public dnnl::primitive_desc {
11214 /// Default constructor. Produces an empty object.
11215 primitive_desc() = default;
11216
11217 /// Constructs a primitive descriptor for an elementwise binary operator
11218 /// primitive.
11219 ///
11220 /// @param aengine Engine to use.
11221 /// @param aalgorithm Elementwise binary algorithm.
11222 /// @param src0 Memory descriptor for source tensor #0.
11223 /// @param src1 Memory descriptor for source tensor #1.
11224 /// @param dst Memory descriptor for destination tensor.
11225 /// @param attr Primitive attributes to use. Attributes are optional
11226 /// and default to empty attributes.
11227 /// @param allow_empty A flag signifying whether construction is
11228 /// allowed to fail without throwing an exception. In this case an
11229 /// empty object will be produced. This flag is optional and
11230 /// defaults to false.
11231 primitive_desc(const engine &aengine, algorithm aalgorithm,
11232 const memory::desc &src0, const memory::desc &src1,
11233 const memory::desc &dst,
11234 const primitive_attr &attr = default_attr(),
11235 bool allow_empty = false) {
11236
11237 dnnl_primitive_desc_t pd = nullptr;
11238 dnnl_status_t status = dnnl_binary_primitive_desc_create(&pd,
11239 aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
11240 src1.get(), dst.get(), attr.get());
11241
11242 if (!allow_empty)
11243 error::wrap_c_api(status,
11244 "could not create a primitive descriptor for a binary "
11245 "operation primitive");
11246 reset(pd);
11247 }
11248
11249 /// Constructs a primitive descriptor for a binary primitive from a C
11250 /// API primitive descriptor that must have a matching kind.
11251 ///
11252 /// @param pd C API primitive descriptor for a binary primitive.
11253 primitive_desc(dnnl_primitive_desc_t pd)
11254 : dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
11255
11256 /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
11257 memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
11258
11259 /// Returns the memory descriptor for source #0.
11260 memory::desc src0_desc() const { return base::src_desc(0); }
11261
11262 /// Returns the memory descriptor for source #1.
11263 memory::desc src1_desc() const { return base::src_desc(1); }
11264
11265 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
11266 memory::desc dst_desc() const { return base::dst_desc(0); }
11267
11268 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
11269 algorithm get_algorithm() const { return base::get_algorithm(); }
11270 };
11271
11272 /// Default constructor. Produces an empty object.
11273 binary() = default;
11274
11275 /// Constructs an elementwise binary operation primitive.
11276 /// @param pd Primitive descriptor for an elementwise binary operation
11277 /// primitive.
11278 binary(const primitive_desc &pd) : primitive(pd) {}
11279
11280 /// Constructs an elementwise binary operation primitive from a cache blob.
11281 /// @param pd Primitive descriptor for an elementwise binary operation
11282 /// primitive.
11283 /// @param cache_blob Cache blob.
11284 binary(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11285 : primitive(pd, cache_blob) {}
11286};
11287
11288/// @} dnnl_api_binary
11289
11290/// @addtogroup dnnl_api_matmul Matrix Multiplication
11291///
11292/// A primitive to perform matrix-matrix multiplication. The batched mode
11293/// is supported with 3D tensors.
11294///
11295/// @sa @ref dev_guide_matmul in developer guide
11296///
11297///
11298/// @{
11299
11300/// Matrix multiplication (matmul) primitive.
11301struct matmul : public primitive {
11302 /// Primitive descriptor for a matmul primitive.
11303 struct primitive_desc : public dnnl::primitive_desc {
11304 /// Default constructor. Produces an empty object.
11305 primitive_desc() = default;
11306
11307 /// Constructs a primitive descriptor for a matmul primitive
11308 /// without bias.
11309 ///
11310 /// @param aengine Engine to use.
11311 /// @param src_desc Memory descriptor for source (matrix A).
11312 /// @param weights_desc Memory descriptor for weights (matrix B).
11313 /// @param dst_desc Memory descriptor for destination (matrix C).
11314 /// @param attr Primitive attributes to use. Attributes are optional
11315 /// and default to empty attributes.
11316 /// @param allow_empty A flag signifying whether construction is
11317 /// allowed to fail without throwing an exception. In this case an
11318 /// empty object will be produced. This flag is optional and
11319 /// defaults to false.
11320 primitive_desc(const engine &aengine, const memory::desc &src_desc,
11321 const memory::desc &weights_desc, const memory::desc &dst_desc,
11322 const primitive_attr &attr = default_attr(),
11323 bool allow_empty = false)
11324 : primitive_desc(aengine, src_desc, weights_desc, nullptr, dst_desc,
11325 attr, allow_empty) {}
11326
11327 /// Constructs a primitive descriptor for a matmul primitive with bias.
11328 ///
11329 /// @param aengine Engine to use.
11330 /// @param src_desc Memory descriptor for source (matrix A).
11331 /// @param weights_desc Memory descriptor for weights (matrix B).
11332 /// @param dst_desc Memory descriptor for destination (matrix C).
11333 /// @param bias_desc Memory descriptor for bias.
11334 /// @param attr Primitive attributes to use. Attributes are optional
11335 /// and default to empty attributes.
11336 /// @param allow_empty A flag signifying whether construction is
11337 /// allowed to fail without throwing an exception. In this case an
11338 /// empty object will be produced. This flag is optional and
11339 /// defaults to false.
11340 primitive_desc(const engine &aengine, const memory::desc &src_desc,
11341 const memory::desc &weights_desc, const memory::desc &bias_desc,
11342 const memory::desc &dst_desc,
11343 const primitive_attr &attr = default_attr(),
11344 bool allow_empty = false)
11345 : primitive_desc(aengine, src_desc, weights_desc, &bias_desc,
11346 dst_desc, attr, allow_empty) {}
11347
11348 /// Constructs a primitive descriptor for a matmul primitive from a C
11349 /// API primitive descriptor that must have a matching kind.
11350 ///
11351 /// @param pd C API primitive descriptor for a matmul primitive.
11352 primitive_desc(dnnl_primitive_desc_t pd)
11353 : dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {}
11354
11355 /// @copydoc dnnl::primitive_desc_base::src_desc()const
11356 memory::desc src_desc() const { return query_md(query::src_md, 0); }
11357
11358 /// @copydoc dnnl::primitive_desc_base::weights_desc()const
11359 memory::desc weights_desc() const {
11360 return query_md(query::weights_md, 0);
11361 }
11362
11363 /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
11364 memory::desc bias_desc() const {
11365 return query_md(query::weights_md, 1);
11366 }
11367
11368 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
11369 memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
11370
11371 private:
11372 primitive_desc(const engine &aengine, const memory::desc &src_desc,
11373 const memory::desc &weights_desc, const memory::desc *bias_desc,
11374 const memory::desc &dst_desc, const primitive_attr &attr,
11375 bool allow_empty) {
11376
11377 dnnl_primitive_desc_t pd = nullptr;
11378 dnnl_status_t status = dnnl_matmul_primitive_desc_create(&pd,
11379 aengine.get(), src_desc.get(), weights_desc.get(),
11380 optional_arg(bias_desc), dst_desc.get(), attr.get());
11381
11382 if (!allow_empty)
11383 error::wrap_c_api(status,
11384 "could not create a primitive descriptor for a matmul "
11385 "primitive");
11386 reset(pd);
11387 }
11388 };
11389
11390 /// Default constructor. Produces an empty object.
11391 matmul() = default;
11392
11393 /// Constructs a matmul primitive.
11394 /// @param pd Primitive descriptor for a matmul primitive.
11395 matmul(const primitive_desc &pd) : primitive(pd) {}
11396
11397 /// Constructs a matmul primitive from a cache blob.
11398 /// @param pd Primitive descriptor for a matmul primitive.
11399 /// @param cache_blob Cache blob.
11400 matmul(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11401 : primitive(pd, cache_blob) {}
11402};
11403
11404/// @} dnnl_api_matmul
11405
11406/// @addtogroup dnnl_api_resampling Resampling
11407///
11408/// A primitive to compute resampling operation on 1D, 2D or 3D data tensor
11409/// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation
11410/// method.
11411///
11412/// @sa @ref dev_guide_resampling in developer guide
11413///
11414/// @{
11415
11416/// Resampling forward propagation.
11417struct resampling_forward : public primitive {
11418 /// Primitive descriptor for a resampling forward propagation primitive.
11419 struct primitive_desc : public dnnl::primitive_desc {
11420 /// Default constructor. Produces an empty object.
11421 primitive_desc() = default;
11422
11423 /// Constructs a primitive descriptor for a resampling forward
11424 /// propagation primitive using source and destination memory
11425 /// descriptors.
11426 ///
11427 /// @note
11428 /// Destination memory descriptor may be initialized with
11429 /// #dnnl::memory::format_tag::any value of @p format_tag.
11430 ///
11431 /// @param aengine Engine to use.
11432 /// @param aprop_kind Propagation kind. Possible values are
11433 /// #dnnl::prop_kind::forward_training, and
11434 /// #dnnl::prop_kind::forward_inference.
11435 /// @param aalgorithm resampling algorithm kind: either
11436 /// #dnnl::algorithm::resampling_nearest, or
11437 /// #dnnl::algorithm::resampling_linear
11438 /// @param src_desc Source memory descriptor.
11439 /// @param dst_desc Destination memory descriptor.
11440 /// @param attr Primitive attributes to use. Attributes are optional
11441 /// and default to empty attributes.
11442 /// @param allow_empty A flag signifying whether construction is
11443 /// allowed to fail without throwing an exception. In this case an
11444 /// empty object will be produced. This flag is optional and
11445 /// defaults to false.
11446 primitive_desc(const engine &aengine, prop_kind aprop_kind,
11447 algorithm aalgorithm, const memory::desc &src_desc,
11448 const memory::desc &dst_desc,
11449 const primitive_attr &attr = default_attr(),
11450 bool allow_empty = false)
11451 : primitive_desc(aengine, aprop_kind, aalgorithm, nullptr, src_desc,
11452 &dst_desc, attr, allow_empty) {}
11453
11454 /// Constructs a primitive descriptor for a resampling forward
11455 /// propagation primitive using source memory descriptor and
11456 /// factors.
11457 ///
11458 /// @param aengine Engine to use.
11459 /// @param aprop_kind Propagation kind. Possible values are
11460 /// #dnnl::prop_kind::forward_training, and
11461 /// #dnnl::prop_kind::forward_inference.
11462 /// @param aalgorithm resampling algorithm kind: either
11463 /// #dnnl::algorithm::resampling_nearest, or
11464 /// #dnnl::algorithm::resampling_linear
11465 /// @param factors Vector of scaling factors for spatial dimension.
11466 /// @param src_desc Source memory descriptor.
11467 /// @param attr Primitive attributes to use. Attributes are optional
11468 /// and default to empty attributes.
11469 /// @param allow_empty A flag signifying whether construction is
11470 /// allowed to fail without throwing an exception. In this case an
11471 /// empty object will be produced. This flag is optional and
11472 /// defaults to false.
11473 primitive_desc(const engine &aengine, prop_kind aprop_kind,
11474 algorithm aalgorithm, const std::vector<float> &factors,
11475 const memory::desc &src_desc,
11476 const primitive_attr &attr = default_attr(),
11477 bool allow_empty = false)
11478 : primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
11479 src_desc, nullptr, attr, allow_empty) {}
11480
11481 /// Constructs a primitive descriptor for a resampling forward
11482 /// propagation primitive.
11483 ///
11484 /// @note
11485 /// The destination memory descriptor may be initialized with
11486 /// #dnnl::memory::format_tag::any value of @p format_tag.
11487 ///
11488 /// @param aengine Engine to use.
11489 /// @param aprop_kind Propagation kind. Possible values are
11490 /// #dnnl::prop_kind::forward_training, and
11491 /// #dnnl::prop_kind::forward_inference.
11492 /// @param aalgorithm resampling algorithm kind: either
11493 /// #dnnl::algorithm::resampling_nearest, or
11494 /// #dnnl::algorithm::resampling_linear
11495 /// @param factors Vector of scaling factors for spatial dimension.
11496 /// @param src_desc Source memory descriptor.
11497 /// @param dst_desc Destination memory descriptor.
11498 /// @param attr Primitive attributes to use. Attributes are optional
11499 /// and default to empty attributes.
11500 /// @param allow_empty A flag signifying whether construction is
11501 /// allowed to fail without throwing an exception. In this case an
11502 /// empty object will be produced. This flag is optional and
11503 /// defaults to false.
11504 primitive_desc(const engine &aengine, prop_kind aprop_kind,
11505 algorithm aalgorithm, const std::vector<float> &factors,
11506 const memory::desc &src_desc, const memory::desc &dst_desc,
11507 const primitive_attr &attr = default_attr(),
11508 bool allow_empty = false)
11509 : primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
11510 src_desc, &dst_desc, attr, allow_empty) {}
11511
11512 /// Constructs a primitive descriptor for a resampling forward
11513 /// propagation primitive from a C API primitive descriptor that must
11514 /// have a matching kind.
11515 ///
11516 /// @param pd C API primitive descriptor for a resampling forward
11517 /// propagation primitive.
11518 primitive_desc(dnnl_primitive_desc_t pd)
11519 : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
11520 dnnl::prop_kind::forward_training,
11521 dnnl::prop_kind::forward_inference) {}
11522
11523 /// @copydoc dnnl::primitive_desc_base::src_desc()const
11524 memory::desc src_desc() const { return base::src_desc(0); }
11525
11526 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
11527 memory::desc dst_desc() const { return base::dst_desc(0); }
11528
11529 private:
11530 primitive_desc(const engine &aengine, prop_kind aprop_kind,
11531 algorithm aalgorithm, const std::vector<float> *factors,
11532 const memory::desc &src_desc, const memory::desc *dst_desc,
11533 const primitive_attr &attr, bool allow_empty) {
11534
11535 if (factors)
11536 memory::validate_dims(*factors, src_desc.get_ndims() - 2);
11537
11538 dnnl_primitive_desc_t pd = nullptr;
11539 dnnl_status_t status
11540 = dnnl_resampling_forward_primitive_desc_create(&pd,
11541 aengine.get(), dnnl::convert_to_c(aprop_kind),
11542 convert_to_c(aalgorithm), optional_arg(factors),
11543 src_desc.get(), optional_arg(dst_desc), attr.get());
11544
11545 if (!allow_empty)
11546 error::wrap_c_api(status,
11547 "could not create a primitive descriptor for a "
11548 "resampling forward propagation primitive");
11549 reset(pd);
11550 }
11551 };
11552
11553 /// Default constructor. Produces an empty object.
11554 resampling_forward() = default;
11555
11556 /// Constructs a resampling forward propagation primitive.
11557 /// @param pd Primitive descriptor for a resampling forward propagation
11558 /// primitive.
11559 resampling_forward(const primitive_desc &pd) : primitive(pd) {}
11560
11561 /// Constructs a resampling forward propagation primitive from a cache
11562 /// blob.
11563 /// @param pd Primitive descriptor for a resampling forward propagation
11564 /// primitive.
11565 /// @param cache_blob Cache blob.
11566 resampling_forward(
11567 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11568 : primitive(pd, cache_blob) {}
11569};
11570
11571/// Resampling backward propagation primitive.
11572struct resampling_backward : public primitive {
11573 /// Primitive descriptor for resampling backward propagation primitive.
11574 struct primitive_desc : public dnnl::primitive_desc {
11575 /// Default constructor. Produces an empty object.
11576 primitive_desc() = default;
11577
11578 /// Constructs a primitive descriptor for a resampling backward
11579 /// propagation primitive using source and destination memory
11580 /// descriptors.
11581 ///
11582 /// @param aengine Engine to use.
11583 /// @param aalgorithm resampling algorithm kind: either
11584 /// #dnnl::algorithm::resampling_nearest, or
11585 /// #dnnl::algorithm::resampling_linear
11586 /// @param diff_src_desc Diff source memory descriptor.
11587 /// @param diff_dst_desc Diff destination memory descriptor.
11588 /// @param hint_fwd_pd Primitive descriptor for a resampling
11589 /// forward propagation primitive. It is used as a hint for
11590 /// deciding which memory format to use.
11591 /// @param attr Primitive attributes to use. Attributes are optional
11592 /// and default to empty attributes.
11593 /// @param allow_empty A flag signifying whether construction is
11594 /// allowed to fail without throwing an exception. In this case an
11595 /// empty object will be produced. This flag is optional and
11596 /// defaults to false.
11597 primitive_desc(const engine &aengine, algorithm aalgorithm,
11598 const memory::desc &diff_src_desc,
11599 const memory::desc &diff_dst_desc,
11600 const resampling_forward::primitive_desc &hint_fwd_pd,
11601 const primitive_attr &attr = default_attr(),
11602 bool allow_empty = false)
11603 : primitive_desc(aengine, aalgorithm, nullptr, diff_src_desc,
11604 diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
11605
11606 /// Constructs a primitive descriptor for resampling backward
11607 /// propagation primitive.
11608 ///
11609 /// @param aengine Engine to use.
11610 /// @param aalgorithm resampling algorithm kind: either
11611 /// #dnnl::algorithm::resampling_nearest, or
11612 /// #dnnl::algorithm::resampling_linear
11613 /// @param factors Vector of scaling factors for spatial dimension.
11614 /// @param diff_src_desc Diff source memory descriptor.
11615 /// @param diff_dst_desc Diff destination memory descriptor.
11616 /// @param hint_fwd_pd Primitive descriptor for a resampling
11617 /// forward propagation primitive. It is used as a hint for
11618 /// deciding which memory format to use.
11619 /// @param attr Primitive attributes to use. Attributes are optional
11620 /// and default to empty attributes.
11621 /// @param allow_empty A flag signifying whether construction is
11622 /// allowed to fail without throwing an exception. In this case an
11623 /// empty object will be produced. This flag is optional and
11624 /// defaults to false.
11625 primitive_desc(const engine &aengine, algorithm aalgorithm,
11626 const std::vector<float> &factors,
11627 const memory::desc &diff_src_desc,
11628 const memory::desc &diff_dst_desc,
11629 const resampling_forward::primitive_desc &hint_fwd_pd,
11630 const primitive_attr &attr = default_attr(),
11631 bool allow_empty = false)
11632 : primitive_desc(aengine, aalgorithm, &factors, diff_src_desc,
11633 diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
11634
11635 /// Constructs a primitive descriptor for a resampling backward
11636 /// propagation primitive from a C API primitive descriptor that must
11637 /// have a matching kind.
11638 ///
11639 /// @param pd C API primitive descriptor for a resampling backward
11640 /// propagation primitive.
11641 primitive_desc(dnnl_primitive_desc_t pd)
11642 : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
11643 dnnl::prop_kind::backward_data) {}
11644
11645 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
11646 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
11647
11648 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
11649 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
11650
11651 private:
11652 primitive_desc(const engine &aengine, algorithm aalgorithm,
11653 const std::vector<float> *factors,
11654 const memory::desc &diff_src_desc,
11655 const memory::desc &diff_dst_desc,
11656 const resampling_forward::primitive_desc &hint_fwd_pd,
11657 const primitive_attr &attr, bool allow_empty) {
11658
11659 if (factors)
11660 memory::validate_dims(*factors, diff_src_desc.get_ndims() - 2);
11661
11662 dnnl_primitive_desc_t pd = nullptr;
11663 dnnl_status_t status
11664 = dnnl_resampling_backward_primitive_desc_create(&pd,
11665 aengine.get(), convert_to_c(aalgorithm),
11666 optional_arg(factors), diff_src_desc.get(),
11667 diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
11668
11669 if (!allow_empty)
11670 error::wrap_c_api(status,
11671 "could not create a primitive descriptor for a "
11672 "resampling backward propagation primitive");
11673 reset(pd);
11674 }
11675 };
11676
11677 /// Default constructor. Produces an empty object.
11678 resampling_backward() = default;
11679
11680 /// Constructs a resampling backward propagation primitive.
11681 /// @param pd Primitive descriptor for a resampling backward propagation
11682 /// primitive.
11683 resampling_backward(const primitive_desc &pd) : primitive(pd) {}
11684
11685 /// Constructs a resampling backward propagation primitive from a cache
11686 /// blob.
11687 /// @param pd Primitive descriptor for a resampling backward propagation
11688 /// primitive.
11689 /// @param cache_blob Cache blob.
11690 resampling_backward(
11691 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11692 : primitive(pd, cache_blob) {}
11693};
11694
11695/// @} dnnl_api_resampling
11696
11697/// @addtogroup dnnl_api_pooling Pooling
11698///
11699/// A primitive to perform max or average pooling with dilation.
11700///
11701/// @sa @ref dev_guide_pooling in developer guide
11702///
11703/// @{
11704
11705/// Pooling forward propagation primitive.
11706struct pooling_forward : public primitive {
11707 /// Primitive descriptor for a pooling forward propagation primitive.
11708 struct primitive_desc : public dnnl::primitive_desc {
11709 /// Default constructor. Produces an empty object.
11710 primitive_desc() = default;
11711
11712 /// Constructs a primitive descriptor for pooling forward propagation
11713 /// primitive.
11714 ///
11715 /// Arrays @p strides, @p kernel, @p dilation, @p padding_l
11716 /// and @p padding_r contain values for spatial dimensions only and
11717 /// hence must have the same number of elements as there are spatial
11718 /// dimensions. The order of values is the same as in the tensor:
11719 /// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
11720 ///
11721 /// @param aengine Engine to use.
11722 /// @param aprop_kind Propagation kind. Possible values are
11723 /// #dnnl::prop_kind::forward_training, and
11724 /// #dnnl::prop_kind::forward_inference.
11725 /// @param aalgorithm Pooling algorithm kind: either
11726 /// #dnnl::algorithm::pooling_max,
11727 /// #dnnl::algorithm::pooling_avg_include_padding,
11728 /// or #dnnl::algorithm::pooling_avg_exclude_padding.
11729 /// @param src_desc Source memory descriptor.
11730 /// @param dst_desc Destination memory descriptor.
11731 /// @param strides Vector of strides for spatial dimension.
11732 /// @param kernel Vector of kernel spatial dimensions.
11733 /// @param dilation Array of dilations for spatial dimension.
11734 /// @param padding_l Vector of padding values for low indices for each
11735 /// spatial dimension `([[front,] top,] left)`.
11736 /// @param padding_r Vector of padding values for high indices for
11737 /// each spatial dimension `([[back,] bottom,] right)`.
11738 /// @param attr Primitive attributes to use. Attributes are optional
11739 /// and default to empty attributes.
11740 /// @param allow_empty A flag signifying whether construction is
11741 /// allowed to fail without throwing an exception. In this case an
11742 /// empty object will be produced. This flag is optional and
11743 /// defaults to false.
11744 primitive_desc(const engine &aengine, prop_kind aprop_kind,
11745 algorithm aalgorithm, const memory::desc &src_desc,
11746 const memory::desc &dst_desc, const memory::dims &strides,
11747 const memory::dims &kernel, const memory::dims &dilation,
11748 const memory::dims &padding_l, const memory::dims &padding_r,
11749 const primitive_attr &attr = default_attr(),
11750 bool allow_empty = false) {
11751
11752 memory::validate_dims(strides, src_desc.get_ndims() - 2);
11753 memory::validate_dims(kernel, src_desc.get_ndims() - 2);
11754 memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
11755 memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
11756 memory::validate_dims(dilation, src_desc.get_ndims() - 2);
11757
11758 dnnl_primitive_desc_t pd = nullptr;
11759 dnnl_status_t status = dnnl_pooling_forward_primitive_desc_create(
11760 &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
11761 convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
11762 &strides[0], &kernel[0], &dilation[0], &padding_l[0],
11763 &padding_r[0], attr.get());
11764
11765 if (!allow_empty)
11766 error::wrap_c_api(status,
11767 "could not create a descriptor for a pooling forward "
11768 "propagation primitive");
11769 reset(pd);
11770 }
11771
11772 /// Constructs a primitive descriptor for a pooling forward propagation
11773 /// primitive from a C API primitive descriptor that must have a
11774 /// matching kind.
11775 ///
11776 /// @param pd C API primitive descriptor for a pooling forward
11777 /// propagation primitive.
11778 primitive_desc(dnnl_primitive_desc_t pd)
11779 : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
11780 dnnl::prop_kind::forward_training,
11781 dnnl::prop_kind::forward_inference) {}
11782
11783 /// @copydoc dnnl::primitive_desc_base::src_desc()const
11784 memory::desc src_desc() const { return base::src_desc(0); }
11785
11786 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
11787 memory::desc dst_desc() const { return base::dst_desc(0); }
11788
11789 /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
11790 memory::desc workspace_desc() const { return base::workspace_desc(); }
11791
11792 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
11793 algorithm get_algorithm() const { return base::get_algorithm(); }
11794
11795 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
11796 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
11797
11798 /// @copydoc dnnl::primitive_desc_base::get_strides()const
11799 memory::dims get_strides() const { return base::get_strides(); }
11800
11801 /// @copydoc dnnl::primitive_desc_base::get_kernel()const
11802 memory::dims get_kernel() const { return base::get_kernel(); }
11803
11804 /// @copydoc dnnl::primitive_desc_base::get_dilations()const
11805 memory::dims get_dilations() const { return base::get_dilations(); }
11806
11807 /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
11808 memory::dims get_padding_l() const { return base::get_padding_l(); }
11809
11810 /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
11811 memory::dims get_padding_r() const { return base::get_padding_r(); }
11812 };
11813
11814 /// Default constructor. Produces an empty object.
11815 pooling_forward() = default;
11816
11817 /// Constructs a pooling forward propagation primitive.
11818 ///
11819 /// @param pd Primitive descriptor for a pooling forward propagation
11820 /// primitive.
11821 pooling_forward(const primitive_desc &pd) : primitive(pd) {}
11822
11823 /// Constructs a pooling forward propagation primitive from a cache blob.
11824 ///
11825 /// @param pd Primitive descriptor for a pooling forward propagation
11826 /// primitive.
11827 /// @param cache_blob Cache blob.
11828 pooling_forward(
11829 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11830 : primitive(pd, cache_blob) {}
11831};
11832
11833/// Pooling backward propagation primitive.
11834struct pooling_backward : public primitive {
11835 /// Primitive descriptor for a pooling backward propagation primitive.
11836 struct primitive_desc : public dnnl::primitive_desc {
11837 /// Default constructor. Produces an empty object.
11838 primitive_desc() = default;
11839
11840 /// Constructs a primitive descriptor for a pooling backward propagation
11841 /// primitive.
11842 ///
11843 /// Arrays @p strides, @p kernel, @p dilation, @p padding_l
11844 /// and @p padding_r contain values for spatial dimensions only and
11845 /// hence must have the same number of elements as there are spatial
11846 /// dimensions. The order of values is the same as in the tensor:
11847 /// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
11848 ///
11849 /// @param aengine Engine to use.
11850 /// @param aalgorithm Pooling algorithm kind: either
11851 /// #dnnl::algorithm::pooling_max,
11852 /// #dnnl::algorithm::pooling_avg_include_padding,
11853 /// or #dnnl::algorithm::pooling_avg_exclude_padding.
11854 /// @param diff_src_desc Diff source memory descriptor.
11855 /// @param diff_dst_desc Diff destination memory descriptor.
11856 /// @param strides Vector of strides for spatial dimension.
11857 /// @param kernel Vector of kernel spatial dimensions.
11858 /// @param dilation Array of dilations for spatial dimension.
11859 /// @param padding_l Vector of padding values for low indices for each
11860 /// spatial dimension `([[front,] top,] left)`.
11861 /// @param padding_r Vector of padding values for high indices for
11862 /// each spatial dimension `([[back,] bottom,] right)`.
11863 /// @param hint_fwd_pd Primitive descriptor for a pooling
11864 /// forward propagation primitive. It is used as a hint for
11865 /// deciding which memory format to use.
11866 /// @param attr Primitive attributes to use. Attributes are optional
11867 /// and default to empty attributes.
11868 /// @param allow_empty A flag signifying whether construction is
11869 /// allowed to fail without throwing an exception. In this case an
11870 /// empty object will be produced. This flag is optional and
11871 /// defaults to false.
11872 primitive_desc(const engine &aengine, algorithm aalgorithm,
11873 const memory::desc &diff_src_desc,
11874 const memory::desc &diff_dst_desc, const memory::dims &strides,
11875 const memory::dims &kernel, const memory::dims &dilation,
11876 const memory::dims &padding_l, const memory::dims &padding_r,
11877 const pooling_forward::primitive_desc &hint_fwd_pd,
11878 const primitive_attr &attr = default_attr(),
11879 bool allow_empty = false) {
11880
11881 memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
11882 memory::validate_dims(kernel, diff_src_desc.get_ndims() - 2);
11883 memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
11884 memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
11885 memory::validate_dims(dilation, diff_src_desc.get_ndims() - 2);
11886
11887 dnnl_primitive_desc_t pd = nullptr;
11888 dnnl_status_t status = dnnl_pooling_backward_primitive_desc_create(
11889 &pd, aengine.get(), convert_to_c(aalgorithm),
11890 diff_src_desc.get(), diff_dst_desc.get(), &strides[0],
11891 &kernel[0], &dilation[0], &padding_l[0], &padding_r[0],
11892 hint_fwd_pd.get(), attr.get());
11893 if (!allow_empty)
11894 error::wrap_c_api(status,
11895 "could not create a descriptor for a pooling backward "
11896 "propagation primitive");
11897 reset(pd);
11898 }
11899
11900 /// Constructs a primitive descriptor for a pooling backward propagation
11901 /// primitive from a C API primitive descriptor that must have a
11902 /// matching kind.
11903 ///
11904 /// @param pd C API primitive descriptor for a pooling backward
11905 /// propagation primitive.
11906 primitive_desc(dnnl_primitive_desc_t pd)
11907 : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
11908 dnnl::prop_kind::backward_data) {}
11909
11910 /// @copydoc dnnl::primitive_desc_base::src_desc()const
11911 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
11912
11913 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
11914 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
11915
11916 /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
11917 memory::desc workspace_desc() const { return base::workspace_desc(); }
11918
11919 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
11920 algorithm get_algorithm() const { return base::get_algorithm(); }
11921
11922 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
11923 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
11924
11925 /// @copydoc dnnl::primitive_desc_base::get_strides()const
11926 memory::dims get_strides() const { return base::get_strides(); }
11927
11928 /// @copydoc dnnl::primitive_desc_base::get_kernel()const
11929 memory::dims get_kernel() const { return base::get_kernel(); }
11930
11931 /// @copydoc dnnl::primitive_desc_base::get_dilations()const
11932 memory::dims get_dilations() const { return base::get_dilations(); }
11933
11934 /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
11935 memory::dims get_padding_l() const { return base::get_padding_l(); }
11936
11937 /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
11938 memory::dims get_padding_r() const { return base::get_padding_r(); }
11939 };
11940
11941 /// Default constructor. Produces an empty object.
11942 pooling_backward() = default;
11943
11944 /// Constructs a pooling backward propagation primitive.
11945 ///
11946 /// @param pd Primitive descriptor for a pooling backward propagation
11947 /// primitive.
11948 pooling_backward(const primitive_desc &pd) : primitive(pd) {}
11949
11950 /// Constructs a pooling backward propagation primitive from a cache blob.
11951 ///
11952 /// @param pd Primitive descriptor for a pooling backward propagation
11953 /// primitive.
11954 /// @param cache_blob Cache blob.
11955 pooling_backward(
11956 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
11957 : primitive(pd, cache_blob) {}
11958};
11959
11960/// @} dnnl_api_pooling
11961
11962/// @addtogroup dnnl_api_prelu PReLU
11963///
11964/// PReLU primitive
11965/// A primitive to perform PReLU (leaky ReLU with trainable alpha parameter)
11966///
11967/// @sa @ref dev_guide_prelu in developer guide
11968///
11969/// @{
11970
11971/// PReLU forward propagation primitive.
11972struct prelu_forward : public primitive {
11973 /// Primitive descriptor for a PReLU forward propagation primitive.
11974 struct primitive_desc : public dnnl::primitive_desc {
11975 /// Default constructor. Produces an empty object.
11976 primitive_desc() = default;
11977
11978 /// Constructs a primitive descriptor for a PReLU forward propagation
11979 /// primitive.
11980 ///
11981 /// @param aengine Engine to use.
11982 /// @param aprop_kind Propagation kind. Possible values are
11983 /// #dnnl::prop_kind::forward_training, and
11984 /// #dnnl::prop_kind::forward_inference.
11985 /// @param src_desc Source memory descriptor.
11986 /// @param weight_desc Alpha parameters memory descriptor.
11987 /// @param dst_desc Destination memory descriptor.
11988 /// @param attr Primitive attributes to use. Attributes are optional
11989 /// and default to empty attributes.
11990 /// @param allow_empty A flag signifying whether construction is
11991 /// allowed to fail without throwing an exception. In this case an
11992 /// empty object will be produced. This flag is optional and
11993 /// defaults to false.
11994 primitive_desc(const engine &aengine, prop_kind aprop_kind,
11995 const memory::desc &src_desc, const memory::desc &weight_desc,
11996 const memory::desc &dst_desc,
11997 const primitive_attr &attr = default_attr(),
11998 bool allow_empty = false) {
11999
12000 dnnl_primitive_desc_t pd = nullptr;
12001 dnnl_status_t status = dnnl_prelu_forward_primitive_desc_create(&pd,
12002 aengine.get(), dnnl::convert_to_c(aprop_kind),
12003 src_desc.get(), weight_desc.get(), dst_desc.get(),
12004 attr.get());
12005
12006 if (!allow_empty)
12007 error::wrap_c_api(status,
12008 "could not create a primitive descriptor for a prelu "
12009 "forward propagation primitive");
12010 reset(pd);
12011 }
12012
12013 /// Constructs a primitive descriptor for a prelu forward
12014 /// propagation primitive from a C API primitive descriptor that must
12015 /// have a matching kind.
12016 ///
12017 /// @param pd C API primitive descriptor for a prelu forward
12018 /// propagation primitive.
12019 primitive_desc(dnnl_primitive_desc_t pd)
12020 : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
12021 dnnl::prop_kind::forward_training,
12022 dnnl::prop_kind::forward_inference) {}
12023
12024 /// @copydoc dnnl::primitive_desc_base::src_desc()const
12025 memory::desc src_desc() const { return base::src_desc(0); }
12026
12027 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
12028 memory::desc dst_desc() const { return base::dst_desc(0); }
12029
12030 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
12031 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
12032 };
12033
12034 /// Default constructor. Produces an empty object.
12035 prelu_forward() = default;
12036
12037 /// Constructs a prelu forward propagation primitive.
12038 /// @param pd Primitive descriptor for a prelu forward propagation
12039 /// primitive.
12040 prelu_forward(const primitive_desc &pd) : primitive(pd) {}
12041
12042 /// Constructs a prelu forward propagation primitive from a cache blob.
12043 /// @param pd Primitive descriptor for a prelu forward propagation
12044 /// primitive.
12045 /// @param cache_blob Cache blob.
12046 prelu_forward(
12047 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
12048 : primitive(pd, cache_blob) {}
12049};
12050
12051/// PReLU backward propagation primitive.
12052struct prelu_backward : public primitive {
12053 /// Primitive descriptor for prelu backward propagation.
12054 struct primitive_desc : public dnnl::primitive_desc {
12055 /// Default constructor. Produces an empty object.
12056 primitive_desc() = default;
12057
12058 /// Constructs a descriptor for a PReLU backward propagation
12059 /// primitive.
12060 ///
12061 /// @param aengine Engine to use.
12062 /// @param src_desc Source memory descriptor.
12063 /// @param weight_desc Alpha parameters memory descriptor.
12064 /// @param diff_src_desc Diff source memory descriptor.
12065 /// @param diff_weights_desc Diff alpha parameters memory descriptor.
12066 /// @param diff_dst_desc Diff destination memory descriptor.
12067 /// @param hint_fwd_pd Primitive descriptor for a PReLU
12068 /// forward propagation primitive. It is used as a hint for
12069 /// deciding which memory format to use.
12070 /// @param attr Primitive attributes to use. Attributes are optional
12071 /// and default to empty attributes.
12072 /// @param allow_empty A flag signifying whether construction is
12073 /// allowed to fail without throwing an exception. In this case an
12074 /// empty object will be produced. This flag is optional and
12075 /// defaults to false.
12076 primitive_desc(const engine &aengine, const memory::desc &src_desc,
12077 const memory::desc &weight_desc,
12078 const memory::desc &diff_src_desc,
12079 const memory::desc &diff_weights_desc,
12080 const memory::desc &diff_dst_desc,
12081 const prelu_forward::primitive_desc &hint_fwd_pd,
12082 const primitive_attr &attr = default_attr(),
12083 bool allow_empty = false) {
12084
12085 dnnl_primitive_desc_t pd = nullptr;
12086 dnnl_status_t status = dnnl_prelu_backward_primitive_desc_create(
12087 &pd, aengine.get(), src_desc.get(), weight_desc.get(),
12088 diff_src_desc.get(), diff_weights_desc.get(),
12089 diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
12090
12091 if (!allow_empty)
12092 error::wrap_c_api(status,
12093 "could not create a primitive descriptor for a prelu "
12094 "backward propagation primitive");
12095 reset(pd);
12096 }
12097
12098 /// Constructs a primitive descriptor for a prelu backward
12099 /// propagation primitive from a C API primitive descriptor that must
12100 /// have a matching kind.
12101 ///
12102 /// @param pd C API primitive descriptor for a prelu backward
12103 /// propagation primitive.
12104 primitive_desc(dnnl_primitive_desc_t pd)
12105 : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
12106 dnnl::prop_kind::backward) {}
12107
12108 /// @copydoc dnnl::primitive_desc_base::src_desc()const
12109 memory::desc src_desc() const { return base::src_desc(0); }
12110
12111 /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
12112 memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
12113
12114 /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
12115 memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
12116
12117 /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
12118 prop_kind get_prop_kind() const { return base::get_prop_kind(); }
12119 };
12120
12121 /// Default constructor. Produces an empty object.
12122 prelu_backward() = default;
12123
12124 /// Constructs a prelu backward propagation primitive.
12125 /// @param pd Primitive descriptor for a prelu backward propagation
12126 /// primitive.
12127 prelu_backward(const primitive_desc &pd) : primitive(pd) {}
12128
12129 /// Constructs a prelu backward propagation primitive from a cache blob.
12130 /// @param pd Primitive descriptor for a prelu backward propagation
12131 /// primitive.
12132 /// @param cache_blob Cache blob.
12133 prelu_backward(
12134 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
12135 : primitive(pd, cache_blob) {}
12136};
12137
12138/// @} dnnl_api_prelu
12139
12140/// @addtogroup dnnl_api_reduction Reduction
12141///
12142/// A primitive to compute reduction operation on data tensor
12143/// using min, max, mul, sum, mean and norm_lp operations.
12144///
12145/// @sa @ref dev_guide_reduction in developer guide
12146///
12147/// @{
12148
12149/// Reduction.
12150struct reduction : public primitive {
12151 /// Primitive descriptor for a reduction primitive.
12152 struct primitive_desc : public dnnl::primitive_desc {
12153 /// Default constructor. Produces an empty object.
12154 primitive_desc() = default;
12155
12156 /// Constructs a primitive descriptor for a reduction primitive using
12157 /// algorithm specific parameters, source and destination memory
12158 /// descriptors.
12159 ///
12160 /// @note
12161 /// Destination memory descriptor may be initialized with
12162 /// #dnnl::memory::format_tag::any value of @p format_tag.
12163 ///
12164 /// @param aengine Engine to use.
12165 /// @param aalgorithm reduction algorithm kind. Possible values:
12166 /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
12167 /// #dnnl_reduction_mul, #dnnl_reduction_mean,
12168 /// #dnnl_reduction_norm_lp_max, #dnnl_reduction_norm_lp_sum,
12169 /// #dnnl_reduction_norm_lp_power_p_max,
12170 /// #dnnl_reduction_norm_lp_power_p_sum.
12171 /// @param p algorithm specific parameter.
12172 /// @param eps algorithm specific parameter.
12173 /// @param src_desc Source memory descriptor.
12174 /// @param dst_desc Destination memory descriptor.
12175 /// @param attr Primitive attributes to use. Attributes are optional
12176 /// and default to empty attributes.
12177 /// @param allow_empty A flag signifying whether construction is
12178 /// allowed to fail without throwing an exception. In this case an
12179 /// empty object will be produced. This flag is optional and
12180 /// defaults to false.
12181 primitive_desc(const engine &aengine, algorithm aalgorithm,
12182 const memory::desc &src_desc, const memory::desc &dst_desc,
12183 float p, float eps, const primitive_attr &attr = default_attr(),
12184 bool allow_empty = false) {
12185
12186 dnnl_primitive_desc_t pd = nullptr;
12187 dnnl_status_t status = dnnl_reduction_primitive_desc_create(&pd,
12188 aengine.get(), convert_to_c(aalgorithm), src_desc.get(),
12189 dst_desc.get(), p, eps, attr.get());
12190
12191 if (!allow_empty)
12192 error::wrap_c_api(status,
12193 "could not create a primitive descriptor for a "
12194 "reduction primitive descriptor");
12195 reset(pd);
12196 }
12197
12198 /// Constructs a primitive descriptor for a reduction primitive from a C
12199 /// API primitive descriptor that must have a matching kind.
12200 ///
12201 /// @param pd C API primitive descriptor for a reduction primitive.
12202 primitive_desc(dnnl_primitive_desc_t pd)
12203 : dnnl::primitive_desc(pd, dnnl::primitive::kind::reduction) {}
12204
12205 /// @copydoc dnnl::primitive_desc_base::src_desc()const
12206 memory::desc src_desc() const { return base::src_desc(0); }
12207
12208 /// @copydoc dnnl::primitive_desc_base::dst_desc()const
12209 memory::desc dst_desc() const { return base::dst_desc(0); }
12210
12211 /// @copydoc dnnl::primitive_desc_base::get_p()const
12212 float get_p() const { return base::get_p(); }
12213
12214 /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
12215 float get_epsilon() const { return base::get_epsilon(); }
12216
12217 /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
12218 algorithm get_algorithm() const { return base::get_algorithm(); }
12219 };
12220
12221 /// Default constructor. Produces an empty object.
12222 reduction() = default;
12223
12224 /// Constructs a reduction primitive.
12225 /// @param pd Primitive descriptor for a reduction primitive.
12226 reduction(const primitive_desc &pd) : primitive(pd) {}
12227
12228 /// Constructs a reduction primitive from a cache blob.
12229 /// @param pd Primitive descriptor for a reduction primitive.
12230 /// @param cache_blob Cache blob.
12231 reduction(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
12232 : primitive(pd, cache_blob) {}
12233};
12234
12235/// @} dnnl_api_reduction
12236
12237/// @} dnnl_api_primitives
12238
12239/// @addtogroup dnnl_api_service Service
12240///
12241/// A set of functions that aid in oneDNN debugging and profiling.
12242///
12243/// @{
12244
12245/// @copydoc dnnl_version_t
12246using version_t = dnnl_version_t;
12247
12248/// Status values returned by the library functions.
12249enum class status {
12250 /// @copydoc dnnl_success
12251 success = dnnl_success,
12252 /// @copydoc dnnl_out_of_memory
12253 out_of_memory = dnnl_out_of_memory,
12254 /// @copydoc dnnl_invalid_arguments
12255 invalid_arguments = dnnl_invalid_arguments,
12256 /// @copydoc dnnl_unimplemented
12257 unimplemented = dnnl_unimplemented,
12258 /// @copydoc dnnl_last_impl_reached
12259 last_impl_reached = dnnl_last_impl_reached,
12260 /// @copydoc dnnl_runtime_error
12261 runtime_error = dnnl_runtime_error,
12262 /// @copydoc dnnl_not_required
12263 not_required = dnnl_not_required,
12264};
12265
12266/// @copydoc dnnl_set_verbose()
12267inline status set_verbose(int level) {
12268 return static_cast<status>(dnnl_set_verbose(level));
12269}
12270
12271/// @copydoc dnnl_version()
12272inline const version_t *version() {
12273 return dnnl_version();
12274}
12275
12276/// Returns the floating-point math mode that will be used by default
12277/// for all subsequently created primitives.
12278///
12279/// @returns Output FP math mode.
12280inline fpmath_mode get_default_fpmath_mode() {
12281 dnnl_fpmath_mode_t mode;
12282 error::wrap_c_api(dnnl_get_default_fpmath_mode(&mode),
12283 "could not get a default fpmath mode");
12284 return static_cast<fpmath_mode>(mode);
12285}
12286
12287/// @copydoc dnnl_set_default_fpmath_mode()
12288inline status set_default_fpmath_mode(fpmath_mode mode) {
12289 return static_cast<status>(
12290 dnnl_set_default_fpmath_mode(convert_to_c(mode)));
12291}
12292
12293/// @copydoc dnnl_set_jit_dump()
12294inline status set_jit_dump(int enable) {
12295 return static_cast<status>(dnnl_set_jit_dump(enable));
12296}
12297
12298/// @copydoc dnnl_set_jit_profiling_flags()
12299inline status set_jit_profiling_flags(unsigned flags) {
12300 return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
12301}
12302
12303/// @copydoc dnnl_set_jit_profiling_jitdumpdir()
12304inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
12305 return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
12306}
12307
12308/// @copydoc dnnl_cpu_isa_t
12309enum class cpu_isa {
12310 /// @copydoc dnnl_cpu_isa_default
12311 isa_default = dnnl_cpu_isa_default,
12312 /// @copydoc dnnl_cpu_isa_sse41
12313 sse41 = dnnl_cpu_isa_sse41,
12314 /// @copydoc dnnl_cpu_isa_avx
12315 avx = dnnl_cpu_isa_avx,
12316 /// @copydoc dnnl_cpu_isa_avx2
12317 avx2 = dnnl_cpu_isa_avx2,
12318 /// @copydoc dnnl_cpu_isa_avx2_vnni
12319 avx2_vnni = dnnl_cpu_isa_avx2_vnni,
12320 /// @copydoc dnnl_cpu_isa_avx2_vnni_2
12321 avx2_vnni_2 = dnnl_cpu_isa_avx2_vnni_2,
12322 /// @copydoc dnnl_cpu_isa_avx512_core
12323 avx512_core = dnnl_cpu_isa_avx512_core,
12324 /// @copydoc dnnl_cpu_isa_avx512_core_vnni
12325 avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni,
12326 /// @copydoc dnnl_cpu_isa_avx512_core_bf16
12327 avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16,
12328 /// @copydoc dnnl_cpu_isa_avx512_core_fp16
12329 avx512_core_fp16 = dnnl_cpu_isa_avx512_core_fp16,
12330 /// @copydoc dnnl_cpu_isa_avx512_core_amx
12331 avx512_core_amx = dnnl_cpu_isa_avx512_core_amx,
12332 /// @copydoc dnnl_cpu_isa_avx512_core_amx_fp16
12333 avx512_core_amx_fp16 = dnnl_cpu_isa_avx512_core_amx_fp16,
12334};
12335
12336/// @copydoc dnnl_set_max_cpu_isa()
12337inline status set_max_cpu_isa(cpu_isa isa) {
12338 return static_cast<status>(
12339 dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
12340}
12341
12342/// @copydoc dnnl_get_effective_cpu_isa()
12343inline cpu_isa get_effective_cpu_isa() {
12344 return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa());
12345}
12346
12347/// @copydoc dnnl_cpu_isa_hints_t
12348enum class cpu_isa_hints {
12349 /// @copydoc dnnl_cpu_isa_no_hints
12350 no_hints = dnnl_cpu_isa_no_hints,
12351 /// @copydoc dnnl_cpu_isa_prefer_ymm
12352 prefer_ymm = dnnl_cpu_isa_prefer_ymm,
12353};
12354
12355/// @copydoc dnnl_set_cpu_isa_hints()
12356inline status set_cpu_isa_hints(cpu_isa_hints isa_hints) {
12357 return static_cast<status>(dnnl_set_cpu_isa_hints(
12358 static_cast<dnnl_cpu_isa_hints_t>(isa_hints)));
12359}
12360
12361/// @copydoc dnnl_get_cpu_isa_hints()
12362inline cpu_isa_hints get_cpu_isa_hints() {
12363 return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints());
12364}
12365
12366/// @} dnnl_api_service
12367
12368/// @addtogroup dnnl_api_primitive_cache Primitive Cache
12369///
12370/// A set of functions that provide primitive cache control.
12371///
12372/// @{
12373
12374/// Returns the number of primitives that can be held in the primitive cache
12375/// at the same time.
12376inline int get_primitive_cache_capacity() {
12377 int result = 0;
12378 error::wrap_c_api(dnnl_get_primitive_cache_capacity(&result),
12379 "could not get primitive cache capacity");
12380 return result;
12381}
12382
12383/// @copydoc dnnl_set_primitive_cache_capacity(int capacity)
12384inline void set_primitive_cache_capacity(int capacity) {
12385 error::wrap_c_api(dnnl_set_primitive_cache_capacity(capacity),
12386 "could not set primitive cache capacity");
12387}
12388
12389/// @} dnnl_api_primitive_cache
12390
12391/// @addtogroup dnnl_api_blas BLAS functions
12392///
12393/// A subset of Basic Linear Algebra (BLAS) functions that perform
12394/// matrix-matrix multiplication.
12395///
12396/// @{
12397
12398/// @copydoc dnnl_sgemm()
12399inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
12400 dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
12401 const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
12402 return static_cast<status>(dnnl_sgemm(
12403 transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
12404}
12405
12406/// @copydoc dnnl_gemm_u8s8s32()
12407inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
12408 dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
12409 dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
12410 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
12411 return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
12412 K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
12413}
12414
12415/// @copydoc dnnl_gemm_s8s8s32()
12416inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
12417 dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
12418 dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
12419 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
12420 return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
12421 K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
12422}
12423
12424/// @} dnnl_api_blas
12425
12426// implementation section
12427
12428/// @cond DO_NOT_DOCUMENT_THIS
12429inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
12430 dnnl_primitive_t result;
12431 error::wrap_c_api(dnnl_primitive_create(&result, c_pd),
12432 "could not create a primitive");
12433 reset(result);
12434}
12435
12436inline primitive::primitive(const_dnnl_primitive_desc_t c_pd,
12437 const std::vector<uint8_t> &cache_blob) {
12438 dnnl_primitive_t result;
12439 size_t size = cache_blob.size();
12440 const uint8_t *cache_blob_data = cache_blob.data();
12441 error::wrap_c_api(dnnl_primitive_create_from_cache_blob(
12442 &result, c_pd, size, cache_blob_data),
12443 "could not create a primitive from a cache blob");
12444 reset(result);
12445}
12446
12447inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
12448inline primitive::primitive(
12449 const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
12450 : primitive(pd.get(), cache_blob) {}
12451
12452inline void primitive::execute(const stream &astream,
12453 const std::unordered_map<int, memory> &args) const {
12454 std::vector<dnnl_exec_arg_t> c_args;
12455 c_args.reserve(args.size());
12456 for (const auto &a : args)
12457 c_args.push_back({a.first, a.second.get(true)});
12458
12459 error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
12460 (int)c_args.size(), c_args.data()),
12461 "could not execute a primitive");
12462}
12463
12464/// @endcond
12465
12466#undef DNNL_DEFINE_BITMASK_OPS
12467
12468} // namespace dnnl
12469
12470/// oneAPI namespace
12471
12472/// The oneAPI namespace.
12473/// Contains the oneapi::dnnl namespace as an alias to the ::dnnl namespace.
12474namespace oneapi {
12475// Note: without this guard, doxygen warns of potentially recursive namespace
12476#ifndef DOXYGEN_SHOULD_SKIP_THIS
12477/// oneDNN alias namespace
12478namespace dnnl = ::dnnl;
12479#endif
12480} // namespace oneapi
12481
12482/// @} dnnl_api
12483
12484#endif /* ONEAPI_DNNL_DNNL_HPP */
12485