1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /// @file |
18 | /// C++ API |
19 | |
20 | #ifndef ONEAPI_DNNL_DNNL_HPP |
21 | #define ONEAPI_DNNL_DNNL_HPP |
22 | |
23 | #include "oneapi/dnnl/dnnl_config.h" |
24 | |
25 | /// @cond DO_NOT_DOCUMENT_THIS |
26 | #include <algorithm> |
27 | #include <cstdlib> |
28 | #include <iterator> |
29 | #include <memory> |
30 | #include <string> |
31 | #include <vector> |
32 | #include <unordered_map> |
33 | |
34 | #include "oneapi/dnnl/dnnl.h" |
35 | #include "oneapi/dnnl/dnnl_common.hpp" |
36 | |
37 | /// @endcond |
38 | |
39 | /// @addtogroup dnnl_api oneDNN API |
40 | /// @{ |
41 | |
42 | /// oneDNN namespace |
43 | namespace dnnl { |
44 | |
45 | /// @addtogroup dnnl_api_utils Utilities |
46 | /// Utility types and definitions. |
47 | /// @{ |
48 | |
49 | /// @cond DO_NOT_DOCUMENT_THIS |
50 | template <typename T> |
51 | void validate_container_size(const T &v, const char *error_message, |
52 | int min_size = 1, int max_size = -1) { |
53 | const int size = (int)v.size(); |
54 | if (size < min_size || (max_size >= 0 && size > max_size)) |
55 | DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message); |
56 | } |
57 | /// @endcond |
58 | |
59 | /// @cond DO_NOT_DOCUMENT_THIS |
60 | template <> |
61 | struct handle_traits<dnnl_memory_desc_t> { |
62 | static dnnl_status_t destructor(dnnl_memory_desc_t p) { |
63 | return dnnl_memory_desc_destroy(p); |
64 | } |
65 | }; |
66 | |
67 | template <> |
68 | struct handle_traits<dnnl_memory_t> { |
69 | static dnnl_status_t destructor(dnnl_memory_t p) { |
70 | return dnnl_memory_destroy(p); |
71 | } |
72 | }; |
73 | |
74 | template <> |
75 | struct handle_traits<dnnl_primitive_desc_t> { |
76 | static dnnl_status_t destructor(dnnl_primitive_desc_t p) { |
77 | return dnnl_primitive_desc_destroy(p); |
78 | } |
79 | }; |
80 | |
81 | template <> |
82 | struct handle_traits<dnnl_primitive_t> { |
83 | static dnnl_status_t destructor(dnnl_primitive_t p) { |
84 | return dnnl_primitive_destroy(p); |
85 | } |
86 | }; |
87 | |
88 | /// @endcond |
89 | |
90 | /// @} dnnl_api_utils |
91 | |
92 | struct stream; |
93 | struct memory; |
94 | struct primitive_desc; |
95 | |
96 | /// @addtogroup dnnl_api_primitives Primitives |
97 | /// Compute primitives |
98 | /// @sa @ref dev_guide_basic_concepts |
99 | /// @{ |
100 | |
101 | /// @addtogroup dnnl_api_primitives_common Common |
102 | /// Common operations to create, destroy and inspect primitives |
103 | /// @{ |
104 | |
105 | /// Base class for all computational primitives. |
106 | struct primitive : public handle<dnnl_primitive_t> { |
107 | /// Kinds of primitives supported by the library. |
108 | enum class kind { |
109 | /// Undefined primitive |
110 | undef = dnnl_undefined_primitive, |
111 | /// A reorder primitive. |
112 | reorder = dnnl_reorder, |
113 | /// A shuffle primitive. |
114 | shuffle = dnnl_shuffle, |
115 | /// A (out-of-place) tensor concatenation primitive. |
116 | concat = dnnl_concat, |
117 | /// A summation primitive. |
118 | sum = dnnl_sum, |
119 | /// A convolution primitive. |
120 | convolution = dnnl_convolution, |
121 | /// A deconvolution primitive. |
122 | deconvolution = dnnl_deconvolution, |
123 | /// An element-wise primitive. |
124 | eltwise = dnnl_eltwise, |
125 | /// An LRN primitive. |
126 | lrn = dnnl_lrn, |
127 | /// A batch normalization primitive. |
128 | batch_normalization = dnnl_batch_normalization, |
129 | /// An inner product primitive. |
130 | inner_product = dnnl_inner_product, |
131 | /// An RNN primitive. |
132 | rnn = dnnl_rnn, |
133 | /// A binary primitive. |
134 | binary = dnnl_binary, |
135 | /// A matmul (matrix multiplication) primitive. |
136 | matmul = dnnl_matmul, |
137 | /// A resampling primitive. |
138 | resampling = dnnl_resampling, |
139 | /// A pooling primitive. |
140 | pooling = dnnl_pooling, |
141 | /// A reduction primitive. |
142 | reduction = dnnl_reduction, |
143 | /// A PReLU primitive. |
144 | prelu = dnnl_prelu, |
145 | /// A softmax primitive. |
146 | softmax = dnnl_softmax, |
147 | /// A layer normalization primitive. |
148 | layer_normalization = dnnl_layer_normalization, |
149 | }; |
150 | |
151 | using handle::handle; |
152 | |
153 | /// Default constructor. Constructs an empty object. |
154 | primitive() = default; |
155 | |
156 | /// Constructs a primitive from a C API primitive descriptor. |
157 | /// |
158 | /// @param c_pd C API primitive descriptor. |
159 | primitive(const_dnnl_primitive_desc_t c_pd); |
160 | |
161 | /// Constructs a primitive from a C API primitive descriptor and a cache blob. |
162 | /// |
163 | /// @param c_pd C API primitive descriptor. |
164 | /// @param cache_blob Cache blob. |
165 | primitive(const_dnnl_primitive_desc_t c_pd, |
166 | const std::vector<uint8_t> &cache_blob); |
167 | |
168 | /// Constructs a primitive from a primitive descriptor. |
169 | /// |
170 | /// @param pd Primitive descriptor. |
171 | primitive(const primitive_desc &pd); |
172 | |
173 | /// Constructs a primitive from a primitive descriptor and a cache blob. |
174 | /// |
175 | /// @param pd Primitive descriptor. |
176 | /// @param cache_blob Cache blob. |
177 | primitive(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob); |
178 | |
179 | /// Returns the C API primitive descriptor of the underlying C API |
180 | /// primitive. |
181 | /// |
182 | /// @returns The underlying C API primitive descriptor. |
183 | inline const_dnnl_primitive_desc_t get_primitive_desc() const; |
184 | |
185 | /// Returns the kind of the primitive. |
186 | /// |
187 | /// @returns The primitive kind. |
188 | inline kind get_kind() const; |
189 | |
190 | /// Returns a cache blob for the primitive. |
191 | /// |
192 | /// @returns Vector containing the cache blob. |
193 | /// |
194 | /// @note The cache blob can be empty. It's the user's responsibility to |
195 | /// check whether it's empty prior to passing it to the primitive |
196 | /// constructor. |
197 | inline std::vector<uint8_t> get_cache_blob() const; |
198 | |
199 | /// Executes computations specified by the primitive in a specified stream. |
200 | /// |
201 | /// Arguments are passed via an arguments map containing <index, |
202 | /// memory object> pairs. The index must be one of the `DNNL_ARG_*` values |
203 | /// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor |
204 | /// matching the one returned by |
205 | /// primitive_desc::query_md(#query::exec_arg_md, index) unless using |
206 | /// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL). |
207 | /// |
208 | /// @param astream Stream object. The stream must belong to the same engine |
209 | /// as the primitive. |
210 | /// @param args Arguments map. |
211 | void execute(const stream &astream, |
212 | const std::unordered_map<int, memory> &args) const; |
213 | }; |
214 | |
215 | /// Converts primitive kind enum value from C++ API to C API type. |
216 | /// |
217 | /// @param akind C++ API primitive kind enum value. |
218 | /// @returns Corresponding C API primitive kind enum value. |
219 | inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) { |
220 | return static_cast<dnnl_primitive_kind_t>(akind); |
221 | } |
222 | |
223 | const_dnnl_primitive_desc_t primitive::get_primitive_desc() const { |
224 | const_dnnl_primitive_desc_t pd; |
225 | error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd), |
226 | "could not get a primitive descriptor from a primitive" ); |
227 | return pd; |
228 | } |
229 | |
230 | dnnl::primitive::kind primitive::get_kind() const { |
231 | const_dnnl_primitive_desc_t pd = get_primitive_desc(); |
232 | // TODO (Roma): the code below is only needed because get_primitive_desc |
233 | // returns a C type. |
234 | dnnl_primitive_kind_t kind; |
235 | error::wrap_c_api(dnnl_primitive_desc_query( |
236 | pd, dnnl_query_primitive_kind, 0, (void *)&kind), |
237 | "could not get a primitive kind from a primitive descriptor" ); |
238 | return static_cast<dnnl::primitive::kind>(kind); |
239 | } |
240 | |
241 | std::vector<uint8_t> primitive::get_cache_blob() const { |
242 | size_t size; |
243 | error::wrap_c_api(dnnl_primitive_get_cache_blob(get(), &size, nullptr), |
244 | "could not get cache blob size from a primitive" ); |
245 | |
246 | std::vector<uint8_t> cache_blob(size); |
247 | error::wrap_c_api( |
248 | dnnl_primitive_get_cache_blob(get(), &size, cache_blob.data()), |
249 | "could not get a cache blob from a primitive" ); |
250 | return cache_blob; |
251 | } |
252 | |
253 | /// @} dnnl_api_primitives_common |
254 | |
255 | /// @addtogroup dnnl_api_attributes |
256 | /// |
257 | /// A container for parameters that extend primitives behavior. |
258 | /// |
259 | /// Attributes can also contain Post-ops, which are computations executed |
260 | /// after the primitive. |
261 | /// |
262 | /// @sa @ref dev_guide_attributes |
263 | /// @sa @ref dev_guide_attributes_post_ops |
264 | /// |
265 | /// @{ |
266 | |
267 | /// Scratchpad mode |
268 | enum class scratchpad_mode { |
269 | /// The library manages the scratchpad allocation according to the policy |
270 | /// specified by the `DNNL_ENABLE_CONCURRENT_EXEC` |
271 | /// [build option](@ref dev_guide_build_options) (default). |
272 | /// |
273 | /// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library |
274 | /// scratchpad is common to all primitives to reduce the memory footprint. |
275 | /// This configuration comes with limited thread-safety properties, namely |
276 | /// primitives can be created and executed in parallel but cannot migrate |
277 | /// between threads (in other words, each primitive should be executed in |
278 | /// the same thread it was created in). |
279 | /// |
280 | /// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is |
281 | /// private to each primitive. The memory footprint is larger than when |
282 | /// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be |
283 | /// created and run concurrently (the same primitive cannot be run |
284 | /// concurrently from two different threads though). |
285 | library = dnnl_scratchpad_mode_library, |
286 | /// The user manages the scratchpad allocation by querying and providing |
287 | /// the scratchpad memory to primitives. This mode is thread-safe as long |
288 | /// as the scratchpad buffers are not used concurrently by two primitive |
289 | /// executions. |
290 | user = dnnl_scratchpad_mode_user, |
291 | }; |
292 | |
293 | /// Converts a scratchpad mode enum value from C++ API to C API type. |
294 | /// |
295 | /// @param mode C++ API scratchpad mode enum value. |
296 | /// @returns Corresponding C API scratchpad mode enum value. |
297 | inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { |
298 | return static_cast<dnnl_scratchpad_mode_t>(mode); |
299 | } |
300 | |
301 | /// Propagation kind. |
302 | enum class prop_kind { |
303 | /// Undefined propagation kind. |
304 | undef = dnnl_prop_kind_undef, |
305 | /// Forward data propagation (training mode). In this mode, primitives |
306 | /// perform computations necessary for subsequent backward propagation. |
307 | forward_training = dnnl_forward_training, |
308 | /// Forward data propagation (inference mode). In this mode, primitives |
309 | /// perform only computations that are necessary for inference and omit |
310 | /// computations that are necessary only for backward propagation. |
311 | forward_inference = dnnl_forward_inference, |
312 | /// Forward data propagation, |
313 | /// alias for #dnnl::prop_kind::forward_training. |
314 | forward = dnnl_forward, |
315 | /// Backward propagation (with respect to all parameters). |
316 | backward = dnnl_backward, |
317 | /// Backward data propagation. |
318 | backward_data = dnnl_backward_data, |
319 | /// Backward weights propagation. |
320 | backward_weights = dnnl_backward_weights, |
321 | /// Backward bias propagation. |
322 | backward_bias = dnnl_backward_bias |
323 | }; |
324 | |
325 | /// Converts propagation kind enum value from C++ API to C API type. |
326 | /// |
327 | /// @param akind C++ API propagation kind enum value. |
328 | /// @returns Corresponding C API propagation kind enum value. |
329 | inline dnnl_prop_kind_t convert_to_c(prop_kind akind) { |
330 | return static_cast<dnnl_prop_kind_t>(akind); |
331 | } |
332 | |
333 | /// Kinds of algorithms. |
334 | enum class algorithm { |
335 | /// Undefined algorithm |
336 | undef = dnnl_alg_kind_undef, |
337 | /// Convolution algorithm that is chosen to be either direct or Winograd |
338 | /// automatically |
339 | convolution_auto = dnnl_convolution_auto, |
340 | /// Direct convolution |
341 | convolution_direct = dnnl_convolution_direct, |
342 | /// Winograd convolution |
343 | convolution_winograd = dnnl_convolution_winograd, |
344 | /// Direct deconvolution |
345 | deconvolution_direct = dnnl_deconvolution_direct, |
346 | /// Winograd deconvolution |
347 | deconvolution_winograd = dnnl_deconvolution_winograd, |
348 | /// Elementwise: rectified linear unit (ReLU) |
349 | eltwise_relu = dnnl_eltwise_relu, |
350 | /// Elementwise: hyperbolic tangent non-linearity (tanh) |
351 | eltwise_tanh = dnnl_eltwise_tanh, |
352 | /// Elementwise: exponential linear unit (ELU) |
353 | eltwise_elu = dnnl_eltwise_elu, |
354 | /// Elementwise: square |
355 | eltwise_square = dnnl_eltwise_square, |
356 | /// Elementwise: abs |
357 | eltwise_abs = dnnl_eltwise_abs, |
358 | /// Elementwise: square root |
359 | eltwise_sqrt = dnnl_eltwise_sqrt, |
360 | /// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$) |
361 | eltwise_swish = dnnl_eltwise_swish, |
362 | /// Elementwise: linear |
363 | eltwise_linear = dnnl_eltwise_linear, |
364 | /// Elementwise: soft_relu |
365 | eltwise_soft_relu = dnnl_eltwise_soft_relu, |
366 | /// Elementwise: mish |
367 | eltwise_mish = dnnl_eltwise_mish, |
368 | /// Elementwise: logistic |
369 | eltwise_logistic = dnnl_eltwise_logistic, |
370 | /// Elementwise: exponent |
371 | eltwise_exp = dnnl_eltwise_exp, |
372 | /// Elementwise: tanh-based gelu |
373 | eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh, |
374 | /// Elementwise: erf-based gelu |
375 | eltwise_gelu_erf = dnnl_eltwise_gelu_erf, |
376 | /// Elementwise: natural logarithm |
377 | eltwise_log = dnnl_eltwise_log, |
378 | /// Elementwise: clip |
379 | eltwise_clip = dnnl_eltwise_clip, |
380 | /// Eltwise: clip version 2 |
381 | eltwise_clip_v2 = dnnl_eltwise_clip_v2, |
382 | /// Elementwise: pow |
383 | eltwise_pow = dnnl_eltwise_pow, |
384 | /// Elementwise: round |
385 | eltwise_round = dnnl_eltwise_round, |
386 | /// Elementwise: hardswish |
387 | eltwise_hardswish = dnnl_eltwise_hardswish, |
388 | /// Elementwise: hardsigmoid |
389 | eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid, |
390 | /// Elementwise: rectified linar unit (ReLU) (dst for backward) |
391 | eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd, |
392 | /// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward) |
393 | eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd, |
394 | /// Elementwise: exponential linear unit (ELU) (dst for backward) |
395 | eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd, |
396 | /// Elementwise: square root (dst for backward) |
397 | eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd, |
398 | /// Elementwise: logistic (dst for backward) |
399 | eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd, |
400 | /// Elementwise: exponent (dst for backward) |
401 | eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd, |
402 | /// Elementwise: clip version 2 (dst for backward) |
403 | eltwise_clip_v2_use_dst_for_bwd = dnnl_eltwise_clip_v2_use_dst_for_bwd, |
404 | /// Local response normalization (LRN) across multiple channels |
405 | lrn_across_channels = dnnl_lrn_across_channels, |
406 | /// LRN within a single channel |
407 | lrn_within_channel = dnnl_lrn_within_channel, |
408 | /// Max pooling |
409 | pooling_max = dnnl_pooling_max, |
410 | /// Average pooling include padding |
411 | pooling_avg_include_padding = dnnl_pooling_avg_include_padding, |
412 | /// Average pooling exclude padding |
413 | pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding, |
414 | /// RNN cell |
415 | vanilla_rnn = dnnl_vanilla_rnn, |
416 | /// LSTM cell |
417 | vanilla_lstm = dnnl_vanilla_lstm, |
418 | /// GRU cell |
419 | vanilla_gru = dnnl_vanilla_gru, |
420 | /// GRU cell with linear before reset. Differs from the vanilla GRU |
421 | /// in how the new memory gate is calculated: |
422 | /// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$ |
423 | /// LRB GRU expects 4 bias tensors on input: |
424 | /// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$ |
425 | lbr_gru = dnnl_lbr_gru, |
426 | /// AUGRU cell |
427 | vanilla_augru = dnnl_vanilla_augru, |
428 | /// AUGRU cell with linear before reset |
429 | lbr_augru = dnnl_lbr_augru, |
430 | /// Binary add |
431 | binary_add = dnnl_binary_add, |
432 | /// Binary mul |
433 | binary_mul = dnnl_binary_mul, |
434 | /// Binary max |
435 | binary_max = dnnl_binary_max, |
436 | /// Binary min |
437 | binary_min = dnnl_binary_min, |
438 | /// Binary div |
439 | binary_div = dnnl_binary_div, |
440 | /// Binary sub |
441 | binary_sub = dnnl_binary_sub, |
442 | /// Binary greater than or equal |
443 | binary_ge = dnnl_binary_ge, |
444 | /// Binary greater than |
445 | binary_gt = dnnl_binary_gt, |
446 | /// Binary less than or equal |
447 | binary_le = dnnl_binary_le, |
448 | /// Binary less than |
449 | binary_lt = dnnl_binary_lt, |
450 | /// Binary equal |
451 | binary_eq = dnnl_binary_eq, |
452 | /// Binary not equal |
453 | binary_ne = dnnl_binary_ne, |
454 | /// Nearest Neighbor resampling method |
455 | resampling_nearest = dnnl_resampling_nearest, |
456 | /// Linear (Bilinear, Trilinear) resampling method |
457 | resampling_linear = dnnl_resampling_linear, |
458 | /// Reduction using max operation |
459 | reduction_max = dnnl_reduction_max, |
460 | /// Reduction using min operation |
461 | reduction_min = dnnl_reduction_min, |
462 | /// Reduction using sum operation |
463 | reduction_sum = dnnl_reduction_sum, |
464 | /// Reduction using mul operation |
465 | reduction_mul = dnnl_reduction_mul, |
466 | /// Reduction using mean operation |
467 | reduction_mean = dnnl_reduction_mean, |
468 | /// Reduction using norm_lp_max operation |
469 | reduction_norm_lp_max = dnnl_reduction_norm_lp_max, |
470 | /// Reduction using norm_lp_sum operation |
471 | reduction_norm_lp_sum = dnnl_reduction_norm_lp_sum, |
472 | /// Reduction using norm_lp_power_p_max operation |
473 | reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max, |
474 | /// Reduction using norm_lp_power_p_sum operation |
475 | reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum, |
476 | /// Softmax, numerically stable |
477 | softmax_accurate = dnnl_softmax_accurate, |
478 | /// LogSoftmax, numerically stable |
479 | softmax_log = dnnl_softmax_log, |
480 | }; |
481 | |
482 | /// Converts algorithm kind enum value from C++ API to C API type. |
483 | /// @param aalgorithm C++ API algorithm kind enum value. |
484 | /// @returns Corresponding C API algorithm kind enum value. |
485 | inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) { |
486 | return static_cast<dnnl_alg_kind_t>(aalgorithm); |
487 | } |
488 | |
489 | /// @} dnnl_api_attributes |
490 | |
491 | /// @addtogroup dnnl_api_primitives_common |
492 | /// @{ |
493 | |
494 | /// Flags for normalization primitives. |
495 | enum class normalization_flags : unsigned { |
496 | /// Use no normalization flags. If specified, the library computes mean and |
497 | /// variance on forward propagation for training and inference, outputs them |
498 | /// on forward propagation for training, and computes the respective |
499 | /// derivatives on backward propagation. |
500 | none = dnnl_normalization_flags_none, |
501 | |
502 | /// Use global statistics. If specified, the library uses mean and |
503 | /// variance provided by the user as an input on forward propagation and |
504 | /// does not compute their derivatives on backward propagation. Otherwise, |
505 | /// the library computes mean and variance on forward propagation for |
506 | /// training and inference, outputs them on forward propagation for |
507 | /// training, and computes the respective derivatives on backward |
508 | /// propagation. |
509 | use_global_stats = dnnl_use_global_stats, |
510 | |
511 | /// Use scale parameter. If specified, the user is expected to pass scale as |
512 | /// input on forward propagation. On backward propagation of type |
513 | /// #dnnl::prop_kind::backward, the library computes its derivative. |
514 | use_scale = dnnl_use_scale, |
515 | |
516 | /// Use shift parameter. If specified, the user is expected to pass shift as |
517 | /// input on forward propagation. On backward propagation of type |
518 | /// #dnnl::prop_kind::backward, the library computes its derivative. |
519 | use_shift = dnnl_use_shift, |
520 | |
521 | /// Fuse normalization with ReLU. On training, normalization will require |
522 | /// the workspace to implement backward propagation. On inference, the |
523 | /// workspace is not required and behavior is the same as when normalization |
524 | /// is fused with ReLU using the post-ops API. |
525 | fuse_norm_relu = dnnl_fuse_norm_relu, |
526 | |
527 | /// Fuse normalization with elementwise binary Add and then fuse with ReLU. |
528 | /// On training, normalization will require the workspace to implement |
529 | /// backward propagation. On inference, the workspace is not required. |
530 | fuse_norm_add_relu = dnnl_fuse_norm_add_relu, |
531 | }; |
532 | |
533 | /// Converts normalization flags enum value from C++ API to C API type. |
534 | /// @param flags C++ API normalization flags enum value. |
535 | /// @returns Corresponding C API normalization flags enum value. |
536 | inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) { |
537 | return static_cast<dnnl_normalization_flags_t>(flags); |
538 | } |
539 | |
540 | /// @} dnnl_api_primitives_common |
541 | |
542 | /// @addtogroup dnnl_api_rnn |
543 | /// @{ |
544 | |
545 | /// RNN cell flags. |
546 | enum class rnn_flags : unsigned { |
547 | /// Undefined RNN flags |
548 | undef = dnnl_rnn_flags_undef |
549 | }; |
550 | |
551 | /// Converts RNN cell flags enum value from C++ API to C API type. |
552 | /// @param flags C++ API RNN cell flags enum value. |
553 | /// @returns Corresponding C API RNN cell flags enum value. |
554 | inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) { |
555 | return static_cast<dnnl_rnn_flags_t>(flags); |
556 | } |
557 | |
558 | DNNL_DEFINE_BITMASK_OPS(normalization_flags) |
559 | DNNL_DEFINE_BITMASK_OPS(rnn_flags) |
560 | |
561 | /// A direction of RNN primitive execution |
562 | enum class rnn_direction { |
563 | /// Undefined RNN direction. |
564 | undef = dnnl_rnn_direction_undef, |
565 | /// Unidirectional execution of RNN primitive from left to right. |
566 | unidirectional_left2right = dnnl_unidirectional_left2right, |
567 | /// Unidirectional execution of RNN primitive from right to left. |
568 | unidirectional_right2left = dnnl_unidirectional_right2left, |
569 | /// Bidirectional execution of RNN primitive with concatenation of the |
570 | /// results. |
571 | bidirectional_concat = dnnl_bidirectional_concat, |
572 | /// Bidirectional execution of RNN primitive with summation of the |
573 | /// results. |
574 | bidirectional_sum = dnnl_bidirectional_sum, |
575 | }; |
576 | |
577 | /// Converts RNN direction enum value from C++ API to C API type. |
578 | /// @param dir C++ API RNN direction enum value. |
579 | /// @returns Corresponding C API RNN direction enum value. |
580 | inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) { |
581 | return static_cast<dnnl_rnn_direction_t>(dir); |
582 | } |
583 | |
584 | /// @} dnnl_api_rnn |
585 | |
586 | /// @addtogroup dnnl_api_primitives_common |
587 | /// @{ |
588 | |
589 | /// Primitive descriptor query specification. |
590 | /// |
591 | /// In general, queries are not used with the C++ API because most queries are |
592 | /// implemented as class members. |
593 | /// |
594 | /// See @ref dnnl_query_t for more information. |
595 | enum class query { |
596 | /// no query |
597 | undef = dnnl_query_undef, |
598 | |
599 | /// execution engine |
600 | engine = dnnl_query_engine, |
601 | /// primitive kind |
602 | primitive_kind = dnnl_query_primitive_kind, |
603 | |
604 | /// number of inputs expected |
605 | num_of_inputs_s32 = dnnl_query_num_of_inputs_s32, |
606 | /// number of outputs expected |
607 | num_of_outputs_s32 = dnnl_query_num_of_outputs_s32, |
608 | |
609 | /// runtime estimation (seconds), unimplemented |
610 | time_estimate_f64 = dnnl_query_time_estimate_f64, |
611 | /// memory required for scratchpad (bytes) |
612 | /// |
613 | /// @sa @ref dev_guide_attributes_scratchpad |
614 | memory_consumption_s64 = dnnl_query_memory_consumption_s64, |
615 | |
616 | /// scratchpad engine |
617 | /// |
618 | /// engine to be used for creating scratchpad memory |
619 | scratchpad_engine = dnnl_query_scratchpad_engine, |
620 | |
621 | /// reorder source engine |
622 | reorder_src_engine = dnnl_query_reorder_src_engine, |
623 | /// reorder destination engine |
624 | reorder_dst_engine = dnnl_query_reorder_dst_engine, |
625 | |
626 | /// implementation name |
627 | impl_info_str = dnnl_query_impl_info_str, |
628 | |
629 | /// propagation kind |
630 | prop_kind = dnnl_query_prop_kind, |
631 | |
632 | /// size of cache blob ID in bytes |
633 | cache_blob_id_size_s64 = dnnl_query_cache_blob_id_size_s64, |
634 | |
635 | /// cache blob ID (pointer to array) |
636 | cache_blob_id = dnnl_query_cache_blob_id, |
637 | |
638 | /// strides |
639 | strides = dnnl_query_strides, |
640 | /// dilations |
641 | dilations = dnnl_query_dilations, |
642 | /// left padding |
643 | padding_l = dnnl_query_padding_l, |
644 | /// right padding |
645 | padding_r = dnnl_query_padding_r, |
646 | /// epsilon |
647 | epsilon_f32 = dnnl_query_epsilon_f32, |
648 | /// flags |
649 | flags = dnnl_query_flags, |
650 | /// algorithm kind |
651 | alg_kind = dnnl_query_alg_kind, |
652 | /// alpha |
653 | alpha_f32 = dnnl_query_alpha_f32, |
654 | /// beta |
655 | beta_f32 = dnnl_query_beta_f32, |
656 | /// axis |
657 | axis_s32 = dnnl_query_axis_s32, |
658 | /// LRN parameter local size |
659 | local_size_s64 = dnnl_query_local_size_s64, |
660 | /// LRN parameter K |
661 | k_f32 = dnnl_query_k_f32, |
662 | /// Reduction parameter P |
663 | p_f32 = dnnl_query_p_f32, |
664 | /// Resampling parameter factors |
665 | factors = dnnl_query_factors, |
666 | /// RNN parameter cell kind |
667 | cell_kind = dnnl_query_cell_kind, |
668 | /// RNN parameter direction |
669 | direction = dnnl_query_direction, |
670 | /// RNN parameter activation kind |
671 | activation_kind = dnnl_query_activation_kind, |
672 | /// Pooling parameter kernel |
673 | kernel = dnnl_query_kernel, |
674 | /// Shuffle parameter group size |
675 | group_size_s64 = dnnl_query_group_size_s64, |
676 | |
677 | /// source memory desc |
678 | src_md = dnnl_query_src_md, |
679 | /// source gradient (diff) memory desc |
680 | diff_src_md = dnnl_query_diff_src_md, |
681 | /// weights memory descriptor desc |
682 | weights_md = dnnl_query_weights_md, |
683 | /// weights gradient (diff) memory desc |
684 | diff_weights_md = dnnl_query_diff_weights_md, |
685 | /// destination memory desc |
686 | dst_md = dnnl_query_dst_md, |
687 | /// destination gradient (diff) memory desc |
688 | diff_dst_md = dnnl_query_diff_dst_md, |
689 | /// workspace memory desc |
690 | workspace_md = dnnl_query_workspace_md, |
691 | /// scratchpad memory desc |
692 | scratchpad_md = dnnl_query_scratchpad_md, |
693 | /// memory desc of an execute argument |
694 | exec_arg_md = dnnl_query_exec_arg_md, |
695 | |
696 | /// number of dimensions |
697 | ndims_s32 = dnnl_query_ndims_s32, |
698 | /// vector of dimensions |
699 | dims = dnnl_query_dims, |
700 | /// data type |
701 | data_type = dnnl_query_data_type, |
702 | /// submemory offset |
703 | submemory_offset_s64 = dnnl_query_submemory_offset_s64, |
704 | /// vector of padded dimensions |
705 | padded_dims = dnnl_query_padded_dims, |
706 | /// vector of padded offsets |
707 | padded_offsets = dnnl_query_padded_offsets, |
708 | /// format kind |
709 | format_kind = dnnl_query_format_kind, |
710 | /// number of innermost blocks |
711 | inner_nblks_s32 = dnnl_query_inner_nblks_s32, |
712 | /// vector of sizes of the innermost blocks |
713 | inner_blks = dnnl_query_inner_blks, |
714 | /// vector of logical indices of the blocks |
715 | inner_idxs = dnnl_query_inner_idxs, |
716 | }; |
717 | |
718 | /// Converts query enum value from C++ API to C API type. |
719 | /// @param aquery C++ API query enum value. |
720 | /// @returns Corresponding C API query enum value. |
721 | inline dnnl_query_t convert_to_c(query aquery) { |
722 | return static_cast<dnnl_query_t>(aquery); |
723 | } |
724 | |
725 | /// @} dnnl_api_primitives_common |
726 | |
727 | /// @} dnnl_api_primitives |
728 | |
729 | /// @addtogroup dnnl_api_memory Memory |
730 | /// |
731 | /// A container that describes and stores data. Memory objects can contain |
732 | /// data of various types and formats. There are two levels of abstraction: |
733 | /// |
734 | /// 1. **Memory descriptor** -- engine-agnostic logical description of data |
735 | /// (number of dimensions, dimension sizes, and data type), and, |
736 | /// optionally, the information about the physical format of data in |
737 | /// memory. If this information is not known yet, a memory descriptor can |
738 | /// be created with #dnnl::memory::format_tag::any. This allows |
739 | /// compute-intensive primitives to choose the best format for |
740 | /// computation. The user is responsible for reordering the data into the |
741 | /// chosen format when formats do not match. |
742 | /// |
743 | /// A memory descriptor can be initialized either by specifying dimensions |
744 | /// and a memory format tag or strides for each of them, or by |
745 | /// manipulating the dnnl_memory_desc_t structure directly. |
746 | /// |
747 | /// @warning |
748 | /// The latter approach requires understanding how the physical data |
749 | /// representation is mapped to the structure and is discouraged. This |
750 | /// topic is discussed in @ref dev_guide_understanding_memory_formats. |
751 | /// |
752 | /// The user can query the amount of memory required by a memory |
753 | /// descriptor using the #dnnl::memory::desc::get_size() function. The |
754 | /// size of data in general cannot be computed as the product of |
755 | /// dimensions multiplied by the size of the data type. So users are |
756 | /// required to use this function for better code portability. |
757 | /// |
758 | /// Two memory descriptors can be compared using the equality and |
759 | /// inequality operators. The comparison is especially useful when |
760 | /// checking whether it is necessary to reorder data from the user's data |
761 | /// format to a primitive's format. |
762 | /// |
763 | /// 2. **Memory object** -- an engine-specific object that handles the memory |
764 | /// buffer and its description (a memory descriptor). For the CPU engine or |
765 | /// with USM, the memory buffer handle is simply a pointer to @c void. The |
766 | /// memory buffer can be queried using #dnnl::memory::get_data_handle() and |
767 | /// set using #dnnl::memory::set_data_handle(). The underlying SYCL buffer, |
768 | /// when used, can be queried using #dnnl::sycl_interop::get_buffer and set |
769 | /// using #dnnl::sycl_interop::set_buffer. A memory object can also be |
770 | /// queried for the underlying memory descriptor and for its engine using |
771 | /// #dnnl::memory::get_desc() and dnnl::memory::get_engine(). |
772 | /// |
773 | /// Along with ordinary memory descriptors with all dimensions being positive, |
774 | /// the library supports *zero-volume* memory descriptors with one or more |
775 | /// dimensions set to zero. This is used to support the NumPy\* convention. |
776 | /// If a zero-volume memory is passed to a primitive, the primitive typically |
777 | /// does not perform any computations with this memory. For example: |
778 | /// |
779 | /// - A concatenation primitive would ignore all memory object with zeroes in |
780 | /// the concat dimension / axis. |
781 | /// |
782 | /// - A forward convolution with a source memory object with zero in the |
783 | /// minibatch dimension would always produce a destination memory object |
784 | /// with a zero in the minibatch dimension and perform no computations. |
785 | /// |
786 | /// - However, a forward convolution with a zero in one of the weights |
787 | /// dimensions is ill-defined and is considered to be an error by the |
788 | /// library because there is no clear definition of what the output values |
789 | /// should be. |
790 | /// |
791 | /// Memory buffer of a zero-volume memory is never accessed. |
792 | /// |
793 | /// @{ |
794 | |
795 | /// Memory object. |
796 | /// |
797 | /// A memory object encapsulates a handle to a memory buffer allocated on a |
798 | /// specific engine, tensor dimensions, data type, and memory format, which is |
799 | /// the way tensor indices map to offsets in linear memory space. Memory |
800 | /// objects are passed to primitives during execution. |
801 | struct memory : public handle<dnnl_memory_t> { |
802 | using handle::handle; |
803 | |
804 | /// Integer type for representing dimension sizes and indices. |
805 | typedef dnnl_dim_t dim; |
806 | /// Vector of dimensions. Implementations are free to force a limit on the |
807 | /// vector's length. |
808 | typedef std::vector<dim> dims; |
809 | |
810 | /// Helper function that validates that an `std::vector` of dimensions can |
811 | /// be safely converted to the C API array ::dnnl_dims_t. Throws if |
812 | /// validation fails. |
813 | /// |
814 | /// @param v Vector of dimensions. |
815 | /// @param min_size Minimum expected size of the vector. |
816 | template <typename T> |
817 | static void validate_dims(const std::vector<T> &v, int min_size = 0) { |
818 | validate_container_size( |
819 | v, "dimensions are invalid" , min_size, DNNL_MAX_NDIMS); |
820 | } |
821 | |
822 | /// Data type specification. |
823 | enum class data_type { |
824 | /// Undefined data type (used for empty memory descriptors). |
825 | undef = dnnl_data_type_undef, |
826 | /// [16-bit/half-precision floating point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format). |
827 | f16 = dnnl_f16, |
828 | /// non-standard |
829 | /// [16-bit floating point with 7-bit mantissa](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format). |
830 | bf16 = dnnl_bf16, |
831 | /// [32-bit/single-precision floating point](https://en.wikipedia.org/wiki/Single-precision_floating-point_format). |
832 | f32 = dnnl_f32, |
833 | //// [64-bit/double-precision floating point](https://en.wikipedia.org/wiki/Double-precision_floating-point_format). |
834 | f64 = dnnl_f64, |
835 | /// 32-bit signed integer. |
836 | s32 = dnnl_s32, |
837 | /// 8-bit signed integer. |
838 | s8 = dnnl_s8, |
839 | /// 8-bit unsigned integer. |
840 | u8 = dnnl_u8, |
841 | }; |
842 | |
843 | /// Returns size of data type in bytes. |
844 | /// @returns The number of bytes occupied by data type. |
845 | static size_t data_type_size(data_type adata_type) { |
846 | return dnnl_data_type_size(convert_to_c(adata_type)); |
847 | } |
848 | |
849 | /// Memory format kind |
850 | enum class format_kind { |
851 | /// Undefined memory format kind, used for empty memory descriptors. |
852 | undef = dnnl_format_kind_undef, |
853 | /// A special format kind that indicates that the actual format will be |
854 | /// selected by a primitive automatically. |
855 | any = dnnl_format_kind_any, |
856 | /// A tensor in a generic format described by the stride and blocking |
857 | /// values in each dimension. |
858 | blocked = dnnl_blocked, |
859 | /// A special format kind that indicates that tensor format is opaque. |
860 | opaque = dnnl_format_kind_opaque, |
861 | }; |
862 | |
863 | /// Memory format tag specification. |
864 | /// |
865 | /// Memory format tags can be further divided into two categories: |
866 | /// |
867 | /// - Domain-agnostic names, i.e. names that do not depend on the tensor |
868 | /// usage in the specific primitive. These names use letters from `a` |
869 | /// to `f` to denote logical dimensions and form the order in which the |
870 | /// dimensions are laid in memory. For example, |
871 | /// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the |
872 | /// second logical dimension (denoted as `b`) is the innermost, i.e. |
873 | /// has stride = 1, and the first logical dimension (`a`) is laid out in |
874 | /// memory with stride equal to the size of the second dimension. On the |
875 | /// other hand, #dnnl::memory::format_tag::ba is the transposed version |
876 | /// of the same tensor: the outermost dimension (`a`) becomes the |
877 | /// innermost one. |
878 | /// |
879 | /// - Domain-specific names, i.e. names that make sense only in the |
880 | /// context of a certain domain, such as CNN. These names are |
881 | /// aliases to the corresponding domain-agnostic tags and used mostly |
882 | /// for convenience. For example, #dnnl::memory::format_tag::nc |
883 | /// is used to denote 2D CNN activations tensor memory format, where |
884 | /// the channels dimension is the innermost one and the batch dimension |
885 | /// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is |
886 | /// an alias for #dnnl::memory::format_tag::ab, because for |
887 | /// CNN primitives the logical dimensions of activations tensors come |
888 | /// in order: batch, channels, spatial. In other words, batch |
889 | /// corresponds to the first logical dimension (`a`), and channels |
890 | /// correspond to the second one (`b`). |
891 | /// |
892 | /// The following domain-specific notation applies to memory format tags: |
893 | /// - @c 'n' denotes the mini-batch dimension |
894 | /// - @c 'c' denotes a channels dimension |
895 | /// - When there are multiple channel dimensions (for example, |
896 | /// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions |
897 | /// of input and output channels |
898 | /// - @c 'g' denotes a groups dimension for convolution weights |
899 | /// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width |
900 | /// respectively |
901 | /// |
902 | /// See @ref dnnl_format_tag_t for a detailed description. |
903 | enum class format_tag { |
904 | /// Undefined memory format tag |
905 | undef = dnnl_format_tag_undef, |
906 | /// Placeholder memory format tag. Used to instruct the primitive to |
907 | /// select a format automatically. |
908 | any = dnnl_format_tag_any, |
909 | |
910 | /// plain 1D tensor |
911 | a = dnnl_a, |
912 | |
913 | /// plain 2D tensor |
914 | ab = dnnl_ab, |
915 | /// permuted 2D tensor |
916 | ba = dnnl_ba, |
917 | |
918 | /// plain 3D tensor |
919 | abc = dnnl_abc, |
920 | /// permuted 3D tensor |
921 | acb = dnnl_acb, |
922 | /// permuted 3D tensor |
923 | bac = dnnl_bac, |
924 | /// permuted 3D tensor |
925 | bca = dnnl_bca, |
926 | /// permuted 3D tensor |
927 | cba = dnnl_cba, |
928 | |
929 | /// plain 4D tensor |
930 | abcd = dnnl_abcd, |
931 | /// permuted 4D tensor |
932 | abdc = dnnl_abdc, |
933 | /// permuted 4D tensor |
934 | acbd = dnnl_acbd, |
935 | /// permuted 4D tensor |
936 | acdb = dnnl_acdb, |
937 | /// permuted 4D tensor |
938 | adbc = dnnl_adbc, |
939 | /// permuted 4D tensor |
940 | bacd = dnnl_bacd, |
941 | /// permuted 4D tensor |
942 | bcda = dnnl_bcda, |
943 | /// permuted 4D tensor |
944 | cdba = dnnl_cdba, |
945 | /// permuted 4D tensor |
946 | dcab = dnnl_dcab, |
947 | |
948 | /// plain 5D tensor |
949 | abcde = dnnl_abcde, |
950 | /// permuted 5D tensor |
951 | abdec = dnnl_abdec, |
952 | /// permuted 5D tensor |
953 | acbde = dnnl_acbde, |
954 | /// permuted 5D tensor |
955 | acdeb = dnnl_acdeb, |
956 | /// permuted 5D tensor |
957 | bacde = dnnl_bacde, |
958 | /// permuted 5D tensor |
959 | bcdea = dnnl_bcdea, |
960 | /// permuted 5D tensor |
961 | cdeba = dnnl_cdeba, |
962 | /// permuted 5D tensor |
963 | decab = dnnl_decab, |
964 | /// permuted 5D tensor |
965 | abced = dnnl_abced, |
966 | |
967 | /// plain 6D tensor |
968 | abcdef = dnnl_abcdef, |
969 | /// permuted 6D tensor |
970 | abdfce = dnnl_abdfce, |
971 | /// permuted 6D tensor |
972 | acbdef = dnnl_acbdef, |
973 | /// permuted 6D tensor |
974 | abdefc = dnnl_abdefc, |
975 | /// permuted 6D tensor |
976 | defcab = dnnl_defcab, |
977 | /// permuted 6D tensor |
978 | abcdfe = dnnl_abcdfe, |
979 | |
980 | /// plain 7D tensor |
981 | abcdefg = dnnl_abcdefg, |
982 | /// permuted 7D tensor |
983 | abcdegf = dnnl_abcdegf, |
984 | |
985 | /// plain 8D tensor |
986 | abcdefgh = dnnl_abcdefgh, |
987 | /// permuted 8D tensor |
988 | abcdefhg = dnnl_abcdefhg, |
989 | |
990 | /// plain 9D tensor |
991 | abcdefghi = dnnl_abcdefghi, |
992 | /// permuted 9D tensor |
993 | abcdefgih = dnnl_abcdefgih, |
994 | |
995 | /// plain 10D tensor |
996 | abcdefghij = dnnl_abcdefghij, |
997 | /// permuted 10D tensor |
998 | abcdefghji = dnnl_abcdefghji, |
999 | |
1000 | /// plain 11D tensor |
1001 | abcdefghijk = dnnl_abcdefghijk, |
1002 | /// permuted 11D tensor |
1003 | abcdefghikj = dnnl_abcdefghikj, |
1004 | |
1005 | /// plain 12D tensor |
1006 | abcdefghijkl = dnnl_abcdefghijkl, |
1007 | /// permuted 12D tensor |
1008 | abcdefghijlk = dnnl_abcdefghijlk, |
1009 | |
1010 | /// 1D tensor; an alias for #dnnl::memory::format_tag::a |
1011 | x = a, |
1012 | /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab |
1013 | nc = ab, |
1014 | /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba |
1015 | cn = ba, |
1016 | /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab |
1017 | tn = ab, |
1018 | /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba |
1019 | nt = ba, |
1020 | /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc |
1021 | ncw = abc, |
1022 | /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb |
1023 | nwc = acb, |
1024 | /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd |
1025 | nchw = abcd, |
1026 | /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb |
1027 | nhwc = acdb, |
1028 | /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda |
1029 | chwn = bcda, |
1030 | /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde |
1031 | ncdhw = abcde, |
1032 | /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb |
1033 | ndhwc = acdeb, |
1034 | |
1035 | /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab |
1036 | oi = ab, |
1037 | /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba |
1038 | io = ba, |
1039 | /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc |
1040 | oiw = abc, |
1041 | /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb |
1042 | owi = acb, |
1043 | /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba |
1044 | wio = cba, |
1045 | /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca |
1046 | iwo = bca, |
1047 | /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd |
1048 | oihw = abcd, |
1049 | /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba |
1050 | hwio = cdba, |
1051 | /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb |
1052 | ohwi = acdb, |
1053 | /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda |
1054 | ihwo = bcda, |
1055 | /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd |
1056 | iohw = bacd, |
1057 | /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde |
1058 | oidhw = abcde, |
1059 | /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba |
1060 | dhwio = cdeba, |
1061 | /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb |
1062 | odhwi = acdeb, |
1063 | /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacde |
1064 | iodhw = bacde, |
1065 | /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea |
1066 | idhwo = bcdea, |
1067 | |
1068 | /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd |
1069 | goiw = abcd, |
1070 | /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdc |
1071 | gowi = abdc, |
1072 | /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab |
1073 | wigo = dcab, |
1074 | /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdec |
1075 | gohwi = abdec, |
1076 | /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde |
1077 | goihw = abcde, |
1078 | /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab |
1079 | hwigo = decab, |
1080 | /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde |
1081 | giohw = acbde, |
1082 | /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef |
1083 | goidhw = abcdef, |
1084 | /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef |
1085 | giodhw = acbdef, |
1086 | /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdefc |
1087 | godhwi = abdefc, |
1088 | /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab |
1089 | dhwigo = defcab, |
1090 | |
1091 | /// 3D RNN data tensor in the format (seq_length, batch, input |
1092 | /// channels); an alias for #dnnl::memory::format_tag::abc. |
1093 | tnc = abc, |
1094 | /// 3D RNN data tensor in the format (batch, seq_length, input |
1095 | /// channels); an alias for #dnnl::memory::format_tag::bac. |
1096 | ntc = bac, |
1097 | /// 4D RNN states tensor in the format (num_layers, num_directions, |
1098 | /// batch, state channels); an alias for #dnnl::memory::format_tag::abcd. |
1099 | ldnc = abcd, |
1100 | /// 5D RNN weights tensor in the format (num_layers, num_directions, |
1101 | /// input_channels, num_gates, output_channels); |
1102 | /// an alias for #dnnl::memory::format_tag::abcde. |
1103 | /// |
1104 | /// - For LSTM cells, the gates order is input, forget, candidate |
1105 | /// and output gate. |
1106 | /// - For GRU cells, the gates order is update, reset and output gate. |
1107 | ldigo = abcde, |
1108 | /// 5D RNN weights tensor in the format (num_layers, num_directions, |
1109 | /// num_gates, output_channels, input_channels); |
1110 | /// an alias for #dnnl::memory::format_tag::abdec. |
1111 | /// |
1112 | /// - For LSTM cells, the gates order is input, forget, candidate |
1113 | /// and output gate. |
1114 | /// - For GRU cells, the gates order is update, reset and output gate. |
1115 | ldgoi = abdec, |
1116 | /// 4D LSTM projection tensor in the format (num_layers, num_directions, |
1117 | /// num_channels_in_hidden_state, num_channels_in_recurrent_projection); |
1118 | /// an alias for #dnnl::memory::format_tag::abcd. |
1119 | ldio = abcd, |
1120 | /// 4D LSTM projection tensor in the format (num_layers, num_directions, |
1121 | /// num_channels_in_recurrent_projection, num_channels_in_hidden_state); |
1122 | /// an alias for #dnnl::memory::format_tag::abdc. |
1123 | ldoi = abdc, |
1124 | /// 4D RNN bias tensor in the format (num_layers, num_directions, |
1125 | /// num_gates, output_channels); |
1126 | /// an alias for #dnnl::memory::format_tag::abcd. |
1127 | /// |
1128 | /// - For LSTM cells, the gates order is input, forget, candidate |
1129 | /// and output gate. |
1130 | /// - For GRU cells, the gates order is update, reset and output gate. |
1131 | ldgo = abcd, |
1132 | |
1133 | // Opaque blocked formats |
1134 | |
1135 | AB16b16a = dnnl_AB16b16a, |
1136 | AB16b32a = dnnl_AB16b32a, |
1137 | AB16b64a = dnnl_AB16b64a, |
1138 | AB8b16a2b = dnnl_AB8b16a2b, |
1139 | AB8b32a2b = dnnl_AB8b32a2b, |
1140 | AB8b64a2b = dnnl_AB8b64a2b, |
1141 | AB4b16a4b = dnnl_AB4b16a4b, |
1142 | AB4b32a4b = dnnl_AB4b32a4b, |
1143 | AB4b64a4b = dnnl_AB4b64a4b, |
1144 | AB16b16a4b = dnnl_AB16b16a4b, |
1145 | AB16b32a4b = dnnl_AB16b32a4b, |
1146 | AB16b48a4b = dnnl_AB16b48a4b, |
1147 | AB16b64a4b = dnnl_AB16b64a4b, |
1148 | AB16b16a2b = dnnl_AB16b16a2b, |
1149 | AB16b32a2b = dnnl_AB16b32a2b, |
1150 | AB16b48a2b = dnnl_AB16b48a2b, |
1151 | AB16b64a2b = dnnl_AB16b64a2b, |
1152 | Abc16a = dnnl_Abc16a, |
1153 | ABc16a16b = dnnl_ABc16a16b, |
1154 | ABc4a4b = dnnl_ABc4a4b, |
1155 | aBc16b = dnnl_aBc16b, |
1156 | aBc32b = dnnl_aBc32b, |
1157 | ABc16b16a = dnnl_ABc16b16a, |
1158 | ABc16b32a = dnnl_ABc16b32a, |
1159 | ABc16b64a = dnnl_ABc16b64a, |
1160 | Abc4a = dnnl_Abc4a, |
1161 | aBc4b = dnnl_aBc4b, |
1162 | ABc4b16a4b = dnnl_ABc4b16a4b, |
1163 | ABc4b32a4b = dnnl_ABc4b32a4b, |
1164 | ABc4b64a4b = dnnl_ABc4b64a4b, |
1165 | ABc2b8a4b = dnnl_ABc2b8a4b, |
1166 | ABc16a16b2a = dnnl_ABc16a16b2a, |
1167 | ABc16b16a4b = dnnl_ABc16b16a4b, |
1168 | ABc16b32a4b = dnnl_ABc16b32a4b, |
1169 | ABc16b48a4b = dnnl_ABc16b48a4b, |
1170 | ABc16b64a4b = dnnl_ABc16b64a4b, |
1171 | ABc16b16a2b = dnnl_ABc16b16a2b, |
1172 | ABc16b32a2b = dnnl_ABc16b32a2b, |
1173 | ABc16b48a2b = dnnl_ABc16b48a2b, |
1174 | ABc16b64a2b = dnnl_ABc16b64a2b, |
1175 | ABc4b4a = dnnl_ABc4b4a, |
1176 | ABc8a16b2a = dnnl_ABc8a16b2a, |
1177 | ABc8a8b = dnnl_ABc8a8b, |
1178 | ABc8a4b = dnnl_ABc8a4b, |
1179 | aBc8b = dnnl_aBc8b, |
1180 | ABc8b16a2b = dnnl_ABc8b16a2b, |
1181 | ABc8b32a2b = dnnl_ABc8b32a2b, |
1182 | ABc8b64a2b = dnnl_ABc8b64a2b, |
1183 | ABc8b8a = dnnl_ABc8b8a, |
1184 | Abcd8a = dnnl_Abcd8a, |
1185 | Abcd16a = dnnl_Abcd16a, |
1186 | Abcd32a = dnnl_Abcd32a, |
1187 | ABcd16a16b = dnnl_ABcd16a16b, |
1188 | aBcd16b = dnnl_aBcd16b, |
1189 | aBcd32b = dnnl_aBcd32b, |
1190 | ABcd16b16a = dnnl_ABcd16b16a, |
1191 | ABcd16b32a = dnnl_ABcd16b32a, |
1192 | ABcd16b64a = dnnl_ABcd16b64a, |
1193 | aBCd16b16c = dnnl_aBCd16b16c, |
1194 | aBCd16c16b = dnnl_aBCd16c16b, |
1195 | Abcd4a = dnnl_Abcd4a, |
1196 | aBcd4b = dnnl_aBcd4b, |
1197 | ABcd4b16a4b = dnnl_ABcd4b16a4b, |
1198 | ABcd4b32a4b = dnnl_ABcd4b32a4b, |
1199 | ABcd4b64a4b = dnnl_ABcd4b64a4b, |
1200 | ABcd2b8a4b = dnnl_ABcd2b8a4b, |
1201 | ABcd4b4a = dnnl_ABcd4b4a, |
1202 | ABcd4a4b = dnnl_ABcd4a4b, |
1203 | aBCd4c16b4c = dnnl_aBCd4c16b4c, |
1204 | aBCd2c8b4c = dnnl_aBCd2c8b4c, |
1205 | ABcd16a16b2a = dnnl_ABcd16a16b2a, |
1206 | ABcd16b16a4b = dnnl_ABcd16b16a4b, |
1207 | ABcd16b32a4b = dnnl_ABcd16b32a4b, |
1208 | ABcd16b48a4b = dnnl_ABcd16b48a4b, |
1209 | ABcd16b64a4b = dnnl_ABcd16b64a4b, |
1210 | ABcd16b16a2b = dnnl_ABcd16b16a2b, |
1211 | ABcd16b32a2b = dnnl_ABcd16b32a2b, |
1212 | ABcd16b48a2b = dnnl_ABcd16b48a2b, |
1213 | ABcd16b64a2b = dnnl_ABcd16b64a2b, |
1214 | aBCd16b16c2b = dnnl_aBCd16b16c2b, |
1215 | aBCd16c16b4c = dnnl_aBCd16c16b4c, |
1216 | aBCd16c16b2c = dnnl_aBCd16c16b2c, |
1217 | aBCd4c4b = dnnl_aBCd4c4b, |
1218 | aBCd4b4c = dnnl_aBCd4b4c, |
1219 | ABcd8a16b2a = dnnl_ABcd8a16b2a, |
1220 | ABcd8a8b = dnnl_ABcd8a8b, |
1221 | ABcd8a4b = dnnl_ABcd8a4b, |
1222 | ABcd8a2b = dnnl_ABcd8a2b, |
1223 | /// 4D tensor blocked by 2nd dimension with block size 8 |
1224 | aBcd8b = dnnl_aBcd8b, |
1225 | ABcd8b16a2b = dnnl_ABcd8b16a2b, |
1226 | ABcd8b32a2b = dnnl_ABcd8b32a2b, |
1227 | ABcd8b64a2b = dnnl_ABcd8b64a2b, |
1228 | aBCd8b16c2b = dnnl_aBCd8b16c2b, |
1229 | /// 4D tensor blocked by 1st and 2nd dimension with block size 8 |
1230 | ABcd8b8a = dnnl_ABcd8b8a, |
1231 | aBCd8b8c = dnnl_aBCd8b8c, |
1232 | aBCd8b4c = dnnl_aBCd8b4c, |
1233 | aBCd8c16b2c = dnnl_aBCd8c16b2c, |
1234 | aBCd8c8b = dnnl_aBCd8c8b, |
1235 | Abcde16a = dnnl_Abcde16a, |
1236 | Abcde32a = dnnl_Abcde32a, |
1237 | ABcde16a16b = dnnl_ABcde16a16b, |
1238 | aBcde16b = dnnl_aBcde16b, |
1239 | aBcde32b = dnnl_aBcde32b, |
1240 | ABcde16b16a = dnnl_ABcde16b16a, |
1241 | ABcde16b32a = dnnl_ABcde16b32a, |
1242 | ABcde16b64a = dnnl_ABcde16b64a, |
1243 | aBCde16b16c = dnnl_aBCde16b16c, |
1244 | aBCde16c16b = dnnl_aBCde16c16b, |
1245 | aBCde2c8b4c = dnnl_aBCde2c8b4c, |
1246 | Abcde4a = dnnl_Abcde4a, |
1247 | aBcde4b = dnnl_aBcde4b, |
1248 | ABcde4b4a = dnnl_ABcde4b4a, |
1249 | ABcde4a4b = dnnl_ABcde4a4b, |
1250 | aBCde4b4c = dnnl_aBCde4b4c, |
1251 | aBCde4c16b4c = dnnl_aBCde4c16b4c, |
1252 | aBCde16b16c2b = dnnl_aBCde16b16c2b, |
1253 | aBCde16c16b4c = dnnl_aBCde16c16b4c, |
1254 | aBCde16c16b2c = dnnl_aBCde16c16b2c, |
1255 | aBCdef16c16b2c = dnnl_aBCdef16c16b2c, |
1256 | aBCde4c4b = dnnl_aBCde4c4b, |
1257 | Abcde8a = dnnl_Abcde8a, |
1258 | ABcde8a8b = dnnl_ABcde8a8b, |
1259 | ABcde8a4b = dnnl_ABcde8a4b, |
1260 | aBcde8b = dnnl_aBcde8b, |
1261 | ABcde8b16a2b = dnnl_ABcde8b16a2b, |
1262 | ABcde8b32a2b = dnnl_ABcde8b32a2b, |
1263 | ABcde8b64a2b = dnnl_ABcde8b64a2b, |
1264 | ABcde4b16a4b = dnnl_ABcde4b16a4b, |
1265 | ABcde4b32a4b = dnnl_ABcde4b32a4b, |
1266 | ABcde4b64a4b = dnnl_ABcde4b64a4b, |
1267 | ABcde16b16a4b = dnnl_ABcde16b16a4b, |
1268 | ABcde16b32a4b = dnnl_ABcde16b32a4b, |
1269 | ABcde16b48a4b = dnnl_ABcde16b48a4b, |
1270 | ABcde16b64a4b = dnnl_ABcde16b64a4b, |
1271 | ABcde16b16a2b = dnnl_ABcde16b16a2b, |
1272 | ABcde16b32a2b = dnnl_ABcde16b32a2b, |
1273 | ABcde16b48a2b = dnnl_ABcde16b48a2b, |
1274 | ABcde16b64a2b = dnnl_ABcde16b64a2b, |
1275 | ABcde2b8a4b = dnnl_ABcde2b8a4b, |
1276 | aBCde8b16c2b = dnnl_aBCde8b16c2b, |
1277 | ABcde8b8a = dnnl_ABcde8b8a, |
1278 | aBCde8b8c = dnnl_aBCde8b8c, |
1279 | aBCde8b4c = dnnl_aBCde8b4c, |
1280 | ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b, |
1281 | ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b, |
1282 | aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c, |
1283 | aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c, |
1284 | aBCde8c16b2c = dnnl_aBCde8c16b2c, |
1285 | aBCde8c8b = dnnl_aBCde8c8b, |
1286 | aBcdef16b = dnnl_aBcdef16b, |
1287 | aBCdef16b16c = dnnl_aBCdef16b16c, |
1288 | aBCdef16c16b = dnnl_aBCdef16c16b, |
1289 | aBcdef4b = dnnl_aBcdef4b, |
1290 | aBCdef2c8b4c = dnnl_aBCdef2c8b4c, |
1291 | aBCdef4c4b = dnnl_aBCdef4c4b, |
1292 | aBCdef4b4c = dnnl_aBCdef4b4c, |
1293 | aBCdef8b8c = dnnl_aBCdef8b8c, |
1294 | aBCdef8b4c = dnnl_aBCdef8b4c, |
1295 | aBCdef8c16b2c = dnnl_aBCdef8c16b2c, |
1296 | aBCdef4c16b4c = dnnl_aBCdef4c16b4c, |
1297 | aBCdef8c8b = dnnl_aBCdef8c8b, |
1298 | aBdc16b = dnnl_aBdc16b, |
1299 | aBdc4b = dnnl_aBdc4b, |
1300 | aBdc8b = dnnl_aBdc8b, |
1301 | aBdec16b = dnnl_aBdec16b, |
1302 | aBdec4b = dnnl_aBdec4b, |
1303 | aBdec8b = dnnl_aBdec8b, |
1304 | aBdefc16b = dnnl_aBdefc16b, |
1305 | aCBdef16c16b = dnnl_aCBdef16c16b, |
1306 | aCBdef16b16c = dnnl_aCBdef16b16c, |
1307 | aBdefc4b = dnnl_aBdefc4b, |
1308 | aBdefc8b = dnnl_aBdefc8b, |
1309 | Acb16a = dnnl_Acb16a, |
1310 | Acb4a = dnnl_Acb4a, |
1311 | Acb8a = dnnl_Acb8a, |
1312 | aCBd16b16c = dnnl_aCBd16b16c, |
1313 | aCBd16c16b = dnnl_aCBd16c16b, |
1314 | aCBde16b16c = dnnl_aCBde16b16c, |
1315 | aCBde16c16b = dnnl_aCBde16c16b, |
1316 | Acdb16a = dnnl_Acdb16a, |
1317 | Acdb4a = dnnl_Acdb4a, |
1318 | Acdb8a = dnnl_Acdb8a, |
1319 | Acdeb16a = dnnl_Acdeb16a, |
1320 | Acdeb4a = dnnl_Acdeb4a, |
1321 | Acdeb8a = dnnl_Acdeb8a, |
1322 | BAc16a16b = dnnl_BAc16a16b, |
1323 | BAc16b16a = dnnl_BAc16b16a, |
1324 | BAcd16a16b = dnnl_BAcd16a16b, |
1325 | BAcd16b16a = dnnl_BAcd16b16a, |
1326 | ABcd32a32b = dnnl_ABcd32a32b, |
1327 | BAcde16b16a = dnnl_BAcde16b16a, |
1328 | BAcde16a16b = dnnl_BAcde16a16b, |
1329 | aBdec32b = dnnl_aBdec32b, |
1330 | Abcdef16a = dnnl_Abcdef16a, |
1331 | Abcdef32a = dnnl_Abcdef32a, |
1332 | Acdb32a = dnnl_Acdb32a, |
1333 | aBCd2b4c2b = dnnl_aBCd2b4c2b, |
1334 | aBCde2b4c2b = dnnl_aBCde2b4c2b, |
1335 | aBCdef2b4c2b = dnnl_aBCdef2b4c2b, |
1336 | aBCd2c4b2c = dnnl_aBCd2c4b2c, |
1337 | aBCde2c4b2c = dnnl_aBCde2c4b2c, |
1338 | aBCdef2c4b2c = dnnl_aBCdef2c4b2c, |
1339 | aBCd4b8c2b = dnnl_aBCd4b8c2b, |
1340 | aBCde4b8c2b = dnnl_aBCde4b8c2b, |
1341 | aBCdef4b8c2b = dnnl_aBCdef4b8c2b, |
1342 | aBCd4c8b2c = dnnl_aBCd4c8b2c, |
1343 | aBCde4c8b2c = dnnl_aBCde4c8b2c, |
1344 | aBCdef4c8b2c = dnnl_aBCdef4c8b2c, |
1345 | AB32a32b8a4b = dnnl_AB32a32b8a4b, |
1346 | AB32a32b8a2b = dnnl_AB32a32b8a2b, |
1347 | AB8a4b = dnnl_AB8a4b, |
1348 | AB8a2b = dnnl_AB8a2b, |
1349 | abDc32d = dnnl_abDc32d, |
1350 | abDC32d4c = dnnl_abDC32d4c, |
1351 | abCd32c = dnnl_abCd32c, |
1352 | abdEc32e = dnnl_abdEc32e, |
1353 | abdEC32e2c = dnnl_abdEC32e2c, |
1354 | abdEC32e4c = dnnl_abdEC32e4c, |
1355 | abdCe32c = dnnl_abdCe32c, |
1356 | abdCE32c2e = dnnl_abdCE32c2e, |
1357 | aBCdef16c16b4c = dnnl_aBCdef16c16b4c, |
1358 | aBdC16b4c = dnnl_aBdC16b4c, |
1359 | aBdeC16b4c = dnnl_aBdeC16b4c, |
1360 | AcB16a4b = dnnl_AcB16a4b, |
1361 | AcdB16a2b = dnnl_AcdB16a2b, |
1362 | aBdefC16b4c = dnnl_aBdefC16b4c, |
1363 | AcdeB16a4b = dnnl_AcdeB16a4b, |
1364 | |
1365 | Acb32a = dnnl_Acb32a, |
1366 | AcB32a2b = dnnl_AcB32a2b, |
1367 | AcB32a4b = dnnl_AcB32a4b, |
1368 | Acb48a = dnnl_Acb48a, |
1369 | AcB48a2b = dnnl_AcB48a2b, |
1370 | AcB48a4b = dnnl_AcB48a4b, |
1371 | Acb64a = dnnl_Acb64a, |
1372 | AcB64a2b = dnnl_AcB64a2b, |
1373 | AcB64a4b = dnnl_AcB64a4b, |
1374 | cBa2b = dnnl_cBa2b, |
1375 | cBa4b = dnnl_cBa4b, |
1376 | aBdc32b = dnnl_aBdc32b, |
1377 | aBdC32b2c = dnnl_aBdC32b2c, |
1378 | aBdC32b4c = dnnl_aBdC32b4c, |
1379 | aBdc48b = dnnl_aBdc48b, |
1380 | aBdC48b2c = dnnl_aBdC48b2c, |
1381 | aBdC48b4c = dnnl_aBdC48b4c, |
1382 | aBdc64b = dnnl_aBdc64b, |
1383 | aBdC64b2c = dnnl_aBdC64b2c, |
1384 | aBdC64b4c = dnnl_aBdC64b4c, |
1385 | adcb = dnnl_adcb, |
1386 | adCb2c = dnnl_adCb2c, |
1387 | adCb4c = dnnl_adCb4c, |
1388 | AcdB32a2b = dnnl_AcdB32a2b, |
1389 | AcdB32a4b = dnnl_AcdB32a4b, |
1390 | Acdb48a = dnnl_Acdb48a, |
1391 | AcdB48a2b = dnnl_AcdB48a2b, |
1392 | AcdB48a4b = dnnl_AcdB48a4b, |
1393 | Acdb64a = dnnl_Acdb64a, |
1394 | AcdB64a2b = dnnl_AcdB64a2b, |
1395 | AcdB64a4b = dnnl_AcdB64a4b, |
1396 | cdBa2b = dnnl_cdBa2b, |
1397 | cdBa4b = dnnl_cdBa4b, |
1398 | aBdeC32b2c = dnnl_aBdeC32b2c, |
1399 | aBdeC32b4c = dnnl_aBdeC32b4c, |
1400 | aBdec48b = dnnl_aBdec48b, |
1401 | aBdeC48b2c = dnnl_aBdeC48b2c, |
1402 | aBdeC48b4c = dnnl_aBdeC48b4c, |
1403 | aBdec64b = dnnl_aBdec64b, |
1404 | aBdeC64b2c = dnnl_aBdeC64b2c, |
1405 | aBdeC64b4c = dnnl_aBdeC64b4c, |
1406 | adecb = dnnl_adecb, |
1407 | adeCb2c = dnnl_adeCb2c, |
1408 | adeCb4c = dnnl_adeCb4c, |
1409 | Acdeb32a = dnnl_Acdeb32a, |
1410 | AcdeB32a2b = dnnl_AcdeB32a2b, |
1411 | AcdeB32a4b = dnnl_AcdeB32a4b, |
1412 | Acdeb48a = dnnl_Acdeb48a, |
1413 | AcdeB48a2b = dnnl_AcdeB48a2b, |
1414 | AcdeB48a4b = dnnl_AcdeB48a4b, |
1415 | Acdeb64a = dnnl_Acdeb64a, |
1416 | AcdeB64a2b = dnnl_AcdeB64a2b, |
1417 | AcdeB64a4b = dnnl_AcdeB64a4b, |
1418 | cdeBa2b = dnnl_cdeBa2b, |
1419 | cdeBa4b = dnnl_cdeBa4b, |
1420 | aBdefc32b = dnnl_aBdefc32b, |
1421 | aBdefC32b2c = dnnl_aBdefC32b2c, |
1422 | aBdefC32b4c = dnnl_aBdefC32b4c, |
1423 | aBdefc48b = dnnl_aBdefc48b, |
1424 | aBdefC48b2c = dnnl_aBdefC48b2c, |
1425 | aBdefC48b4c = dnnl_aBdefC48b4c, |
1426 | aBdefc64b = dnnl_aBdefc64b, |
1427 | aBdefC64b2c = dnnl_aBdefC64b2c, |
1428 | aBdefC64b4c = dnnl_aBdefC64b4c, |
1429 | adefcb = dnnl_adefcb, |
1430 | adefCb2c = dnnl_adefCb2c, |
1431 | adefCb4c = dnnl_adefCb4c, |
1432 | ABc32a32b = dnnl_ABc32a32b, |
1433 | BAc8a16b2a = dnnl_BAc8a16b2a, |
1434 | BAcd8a16b2a = dnnl_BAcd8a16b2a, |
1435 | ABcde8a16b2a = dnnl_ABcde8a16b2a, |
1436 | aCBd8b16c2b = dnnl_aCBd8b16c2b, |
1437 | BAcde8a16b2a = dnnl_BAcde8a16b2a, |
1438 | aCBde8b16c2b = dnnl_aCBde8b16c2b, |
1439 | ABcde32a32b = dnnl_ABcde32a32b, |
1440 | ABc4a8b8a4b = dnnl_ABc4a8b8a4b, |
1441 | ABcde4a8b8a4b = dnnl_ABcde4a8b8a4b, |
1442 | BAc4b8a8b4a = dnnl_BAc4b8a8b4a, |
1443 | BAcd4b8a8b4a = dnnl_BAcd4b8a8b4a, |
1444 | BAcde4b8a8b4a = dnnl_BAcde4b8a8b4a, |
1445 | aBCd4b8c8b4c = dnnl_aBCd4b8c8b4c, |
1446 | aBCdef4b8c8b4c = dnnl_aBCdef4b8c8b4c, |
1447 | aBCdef8b16c2b = dnnl_aBCdef8b16c2b, |
1448 | aCBdef8b16c2b = dnnl_aCBdef8b16c2b, |
1449 | aBdC16b2c = dnnl_aBdC16b2c, |
1450 | aBdeC16b2c = dnnl_aBdeC16b2c, |
1451 | aBdefC16b2c = dnnl_aBdefC16b2c, |
1452 | aBedc16b = dnnl_aBedc16b, |
1453 | AcB16a2b = dnnl_AcB16a2b, |
1454 | AcdB16a4b = dnnl_AcdB16a4b, |
1455 | AcdeB16a2b = dnnl_AcdeB16a2b, |
1456 | Adcb16a = dnnl_Adcb16a, |
1457 | aCBd4c8b8c4b = dnnl_aCBd4c8b8c4b, |
1458 | aCBde4c8b8c4b = dnnl_aCBde4c8b8c4b, |
1459 | aCBdef4c8b8c4b = dnnl_aCBdef4c8b8c4b, |
1460 | ABc32a16b = dnnl_ABc32a16b, |
1461 | ABcd16a32b = dnnl_ABcd16a32b, |
1462 | ABcd32a16b = dnnl_ABcd32a16b, |
1463 | ABcde32a16b = dnnl_ABcde32a16b, |
1464 | AB48a16b = dnnl_AB48a16b, |
1465 | AB48a32b = dnnl_AB48a32b, |
1466 | ABc40a16b = dnnl_ABc40a16b, |
1467 | ABc40a32b = dnnl_ABc40a32b, |
1468 | aBC48b16c = dnnl_aBC48b16c, |
1469 | aBC48b32c = dnnl_aBC48b32c, |
1470 | ABcd40a16b = dnnl_ABcd40a16b, |
1471 | ABcd40a32b = dnnl_ABcd40a32b, |
1472 | BA16a16b = dnnl_BA16a16b, |
1473 | BA16a32b = dnnl_BA16a32b, |
1474 | BA16a48b = dnnl_BA16a48b, |
1475 | BA16a64b = dnnl_BA16a64b, |
1476 | BA16a16b2a = dnnl_BA16a16b2a, |
1477 | BA16a32b2a = dnnl_BA16a32b2a, |
1478 | BA16a48b2a = dnnl_BA16a48b2a, |
1479 | BA16a64b2a = dnnl_BA16a64b2a, |
1480 | BA16a16b4a = dnnl_BA16a16b4a, |
1481 | BA16a32b4a = dnnl_BA16a32b4a, |
1482 | BA16a48b4a = dnnl_BA16a48b4a, |
1483 | BA16a64b4a = dnnl_BA16a64b4a, |
1484 | decbA16a = dnnl_decbA16a, |
1485 | decbA8a = dnnl_decbA8a, |
1486 | aCB16b16c = dnnl_aCB16b16c, |
1487 | aCB16b32c = dnnl_aCB16b32c, |
1488 | aCB16b48c = dnnl_aCB16b48c, |
1489 | aCB16b64c = dnnl_aCB16b64c, |
1490 | aCB16b16c2b = dnnl_aCB16b16c2b, |
1491 | aCB16b32c2b = dnnl_aCB16b32c2b, |
1492 | aCB16b48c2b = dnnl_aCB16b48c2b, |
1493 | aCB16b64c2b = dnnl_aCB16b64c2b, |
1494 | aCB16b16c4b = dnnl_aCB16b16c4b, |
1495 | aCB16b32c4b = dnnl_aCB16b32c4b, |
1496 | aCB16b48c4b = dnnl_aCB16b48c4b, |
1497 | aCB16b64c4b = dnnl_aCB16b64c4b, |
1498 | |
1499 | format_tag_last = dnnl_format_tag_last, |
1500 | |
1501 | nCdhw16c = dnnl_nCdhw16c, |
1502 | nCdhw4c = dnnl_nCdhw4c, |
1503 | nCdhw8c = dnnl_nCdhw8c, |
1504 | nChw16c = dnnl_nChw16c, |
1505 | nChw4c = dnnl_nChw4c, |
1506 | nChw8c = dnnl_nChw8c, |
1507 | nCw16c = dnnl_nCw16c, |
1508 | nCw4c = dnnl_nCw4c, |
1509 | nCw8c = dnnl_nCw8c, |
1510 | NCw16n16c = dnnl_NCw16n16c, |
1511 | NChw16n16c = dnnl_NChw16n16c, |
1512 | NCdhw16n16c = dnnl_NCdhw16n16c, |
1513 | NCdhw32n32c = dnnl_NCdhw32n32c, |
1514 | NChw32n32c = dnnl_NChw32n32c, |
1515 | IOhw16i16o = dnnl_IOhw16i16o, |
1516 | OI16i16o = dnnl_OI16i16o, |
1517 | OI16i32o = dnnl_OI16i32o, |
1518 | OI16i64o = dnnl_OI16i64o, |
1519 | OI8i16o2i = dnnl_OI8i16o2i, |
1520 | OI8i32o2i = dnnl_OI8i32o2i, |
1521 | OI8i64o2i = dnnl_OI8i64o2i, |
1522 | OI4i16o4i = dnnl_OI4i16o4i, |
1523 | OI4i32o4i = dnnl_OI4i32o4i, |
1524 | OI4i64o4i = dnnl_OI4i64o4i, |
1525 | Ohwi32o = dnnl_Ohwi32o, |
1526 | IOdhw16i16o = dnnl_IOdhw16i16o, |
1527 | gIOhw16i16o = dnnl_gIOhw16i16o, |
1528 | gOhwi32o = dnnl_gOhwi32o, |
1529 | Goidhw16g = dnnl_Goidhw16g, |
1530 | IOw16o16i = dnnl_IOw16o16i, |
1531 | OIw16i16o = dnnl_OIw16i16o, |
1532 | OIw16i32o = dnnl_OIw16i32o, |
1533 | OIw16i64o = dnnl_OIw16i64o, |
1534 | IOw16i16o = dnnl_IOw16i16o, |
1535 | gIOw16i16o = dnnl_gIOw16i16o, |
1536 | OIw16o16i = dnnl_OIw16o16i, |
1537 | Oiw16o = dnnl_Oiw16o, |
1538 | OIw4i16o4i = dnnl_OIw4i16o4i, |
1539 | OIw4i32o4i = dnnl_OIw4i32o4i, |
1540 | OIw4i64o4i = dnnl_OIw4i64o4i, |
1541 | OIw2i8o4i = dnnl_OIw2i8o4i, |
1542 | OIw4i4o = dnnl_OIw4i4o, |
1543 | OIw4o4i = dnnl_OIw4o4i, |
1544 | Oiw4o = dnnl_Oiw4o, |
1545 | OIw8i16o2i = dnnl_OIw8i16o2i, |
1546 | OIw8i32o2i = dnnl_OIw8i32o2i, |
1547 | OIw8i64o2i = dnnl_OIw8i64o2i, |
1548 | OIw8i8o = dnnl_OIw8i8o, |
1549 | OIw8o16i2o = dnnl_OIw8o16i2o, |
1550 | OIw8o8i = dnnl_OIw8o8i, |
1551 | OIw8o4i = dnnl_OIw8o4i, |
1552 | OIw16i16o4i = dnnl_OIw16i16o4i, |
1553 | OIw16i32o4i = dnnl_OIw16i32o4i, |
1554 | OIw16i48o4i = dnnl_OIw16i48o4i, |
1555 | OIw16i64o4i = dnnl_OIw16i64o4i, |
1556 | OIw16i16o2i = dnnl_OIw16i16o2i, |
1557 | OIw16i32o2i = dnnl_OIw16i32o2i, |
1558 | OIw16i48o2i = dnnl_OIw16i48o2i, |
1559 | OIw16i64o2i = dnnl_OIw16i64o2i, |
1560 | OIw16o16i2o = dnnl_OIw16o16i2o, |
1561 | Owi16o = dnnl_Owi16o, |
1562 | OwI16o2i = dnnl_OwI16o2i, |
1563 | Iwo16i = dnnl_Iwo16i, |
1564 | IwO16i2o = dnnl_IwO16i2o, |
1565 | IwO16i4o = dnnl_IwO16i4o, |
1566 | Owi4o = dnnl_Owi4o, |
1567 | Owi8o = dnnl_Owi8o, |
1568 | IOhw16o16i = dnnl_IOhw16o16i, |
1569 | Ohwi16o = dnnl_Ohwi16o, |
1570 | OhwI16o2i = dnnl_OhwI16o2i, |
1571 | Ihwo16i = dnnl_Ihwo16i, |
1572 | IhwO16i2o = dnnl_IhwO16i2o, |
1573 | IhwO16i4o = dnnl_IhwO16i4o, |
1574 | Ohwi4o = dnnl_Ohwi4o, |
1575 | Ohwi8o = dnnl_Ohwi8o, |
1576 | OIhw16i16o = dnnl_OIhw16i16o, |
1577 | OIhw16i32o = dnnl_OIhw16i32o, |
1578 | OIhw16i64o = dnnl_OIhw16i64o, |
1579 | OIhw16o16i = dnnl_OIhw16o16i, |
1580 | Oihw16o = dnnl_Oihw16o, |
1581 | OIhw4i16o4i = dnnl_OIhw4i16o4i, |
1582 | OIhw4i32o4i = dnnl_OIhw4i32o4i, |
1583 | OIhw4i64o4i = dnnl_OIhw4i64o4i, |
1584 | OIhw4i4o = dnnl_OIhw4i4o, |
1585 | OIhw4o4i = dnnl_OIhw4o4i, |
1586 | Oihw4o = dnnl_Oihw4o, |
1587 | OIhw8i16o2i = dnnl_OIhw8i16o2i, |
1588 | OIhw8i32o2i = dnnl_OIhw8i32o2i, |
1589 | OIhw8i64o2i = dnnl_OIhw8i64o2i, |
1590 | OIhw8i8o = dnnl_OIhw8i8o, |
1591 | OIhw8o16i2o = dnnl_OIhw8o16i2o, |
1592 | OIhw8o8i = dnnl_OIhw8o8i, |
1593 | OIhw8o4i = dnnl_OIhw8o4i, |
1594 | OIhw2i8o4i = dnnl_OIhw2i8o4i, |
1595 | IOdhw16o16i = dnnl_IOdhw16o16i, |
1596 | Odhwi16o = dnnl_Odhwi16o, |
1597 | OdhwI16o2i = dnnl_OdhwI16o2i, |
1598 | Idhwo16i = dnnl_Idhwo16i, |
1599 | IdhwO16i2o = dnnl_IdhwO16i2o, |
1600 | IdhwO16i4o = dnnl_IdhwO16i4o, |
1601 | Odhwi4o = dnnl_Odhwi4o, |
1602 | Odhwi8o = dnnl_Odhwi8o, |
1603 | OIdhw16i16o = dnnl_OIdhw16i16o, |
1604 | OIdhw16i32o = dnnl_OIdhw16i32o, |
1605 | OIdhw16i64o = dnnl_OIdhw16i64o, |
1606 | OIdhw16o16i = dnnl_OIdhw16o16i, |
1607 | OIdhw16o16i2o = dnnl_OIdhw16o16i2o, |
1608 | Oidhw16o = dnnl_Oidhw16o, |
1609 | OIdhw4i4o = dnnl_OIdhw4i4o, |
1610 | OIdhw4o4i = dnnl_OIdhw4o4i, |
1611 | Oidhw4o = dnnl_Oidhw4o, |
1612 | OIdhw8i16o2i = dnnl_OIdhw8i16o2i, |
1613 | OIdhw8i32o2i = dnnl_OIdhw8i32o2i, |
1614 | OIdhw8i64o2i = dnnl_OIdhw8i64o2i, |
1615 | OIdhw4i16o4i = dnnl_OIdhw4i16o4i, |
1616 | OIdhw16i16o4i = dnnl_OIdhw16i16o4i, |
1617 | OIdhw16i32o4i = dnnl_OIdhw16i32o4i, |
1618 | OIdhw16i48o4i = dnnl_OIdhw16i48o4i, |
1619 | OIdhw16i64o4i = dnnl_OIdhw16i64o4i, |
1620 | OIdhw16i16o2i = dnnl_OIdhw16i16o2i, |
1621 | OIdhw16i32o2i = dnnl_OIdhw16i32o2i, |
1622 | OIdhw16i48o2i = dnnl_OIdhw16i48o2i, |
1623 | OIdhw16i64o2i = dnnl_OIdhw16i64o2i, |
1624 | OIdhw4i32o4i = dnnl_OIdhw4i32o4i, |
1625 | OIdhw4i64o4i = dnnl_OIdhw4i64o4i, |
1626 | OIdhw2i8o4i = dnnl_OIdhw2i8o4i, |
1627 | OIdhw8i8o = dnnl_OIdhw8i8o, |
1628 | OIdhw8o8i = dnnl_OIdhw8o8i, |
1629 | OIdhw8o4i = dnnl_OIdhw8o4i, |
1630 | gIOw16o16i = dnnl_gIOw16o16i, |
1631 | gOIw16i16o = dnnl_gOIw16i16o, |
1632 | gOIw16o16i = dnnl_gOIw16o16i, |
1633 | gOiw16o = dnnl_gOiw16o, |
1634 | gOIw4i16o4i = dnnl_gOIw4i16o4i, |
1635 | gOIw2i8o4i = dnnl_gOIw2i8o4i, |
1636 | gOIw4i4o = dnnl_gOIw4i4o, |
1637 | gOIw4o4i = dnnl_gOIw4o4i, |
1638 | gOiw4o = dnnl_gOiw4o, |
1639 | gOIw8i16o2i = dnnl_gOIw8i16o2i, |
1640 | gOIw8i8o = dnnl_gOIw8i8o, |
1641 | gOIw8o16i2o = dnnl_gOIw8o16i2o, |
1642 | gOIw8o8i = dnnl_gOIw8o8i, |
1643 | gOIw8o4i = dnnl_gOIw8o4i, |
1644 | gOIw16i16o4i = dnnl_gOIw16i16o4i, |
1645 | gOIw16i16o2i = dnnl_gOIw16i16o2i, |
1646 | gOIw16o16i2o = dnnl_gOIw16o16i2o, |
1647 | gOwi16o = dnnl_gOwi16o, |
1648 | gOwI16o2i = dnnl_gOwI16o2i, |
1649 | gIwo16i = dnnl_gIwo16i, |
1650 | gIwO16i2o = dnnl_gIwO16i2o, |
1651 | gIwO16i4o = dnnl_gIwO16i4o, |
1652 | gOwi4o = dnnl_gOwi4o, |
1653 | gOwi8o = dnnl_gOwi8o, |
1654 | Goiw8g = dnnl_Goiw8g, |
1655 | Goiw16g = dnnl_Goiw16g, |
1656 | gIOhw16o16i = dnnl_gIOhw16o16i, |
1657 | gOhwi16o = dnnl_gOhwi16o, |
1658 | gOhwI16o2i = dnnl_gOhwI16o2i, |
1659 | gIhwo16i = dnnl_gIhwo16i, |
1660 | gIhwO16i2o = dnnl_gIhwO16i2o, |
1661 | gIhwO16i4o = dnnl_gIhwO16i4o, |
1662 | gOhwi4o = dnnl_gOhwi4o, |
1663 | gOhwi8o = dnnl_gOhwi8o, |
1664 | Goihw16g = dnnl_Goihw16g, |
1665 | gOIhw16i16o = dnnl_gOIhw16i16o, |
1666 | gOIhw16o16i = dnnl_gOIhw16o16i, |
1667 | gOihw16o = dnnl_gOihw16o, |
1668 | gOIhw4i16o4i = dnnl_gOIhw4i16o4i, |
1669 | gOIhw2i8o4i = dnnl_gOIhw2i8o4i, |
1670 | gOIhw4i4o = dnnl_gOIhw4i4o, |
1671 | gOIhw4o4i = dnnl_gOIhw4o4i, |
1672 | gOihw4o = dnnl_gOihw4o, |
1673 | Goihw8g = dnnl_Goihw8g, |
1674 | gOIhw8i16o2i = dnnl_gOIhw8i16o2i, |
1675 | gOIhw8i8o = dnnl_gOIhw8i8o, |
1676 | gOIhw8o16i2o = dnnl_gOIhw8o16i2o, |
1677 | OIw4o8i8o4i = dnnl_OIw4o8i8o4i, |
1678 | OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i, |
1679 | OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i, |
1680 | OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i, |
1681 | gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i, |
1682 | gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i, |
1683 | gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i, |
1684 | gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i, |
1685 | OIhw16i16o4i = dnnl_OIhw16i16o4i, |
1686 | OIhw16i32o4i = dnnl_OIhw16i32o4i, |
1687 | OIhw16i48o4i = dnnl_OIhw16i48o4i, |
1688 | OIhw16i64o4i = dnnl_OIhw16i64o4i, |
1689 | OIhw16i16o2i = dnnl_OIhw16i16o2i, |
1690 | OIhw16i32o2i = dnnl_OIhw16i32o2i, |
1691 | OIhw16i48o2i = dnnl_OIhw16i48o2i, |
1692 | OIhw16i64o2i = dnnl_OIhw16i64o2i, |
1693 | OIhw16o16i2o = dnnl_OIhw16o16i2o, |
1694 | gOIhw16i16o4i = dnnl_gOIhw16i16o4i, |
1695 | gOIhw16i16o2i = dnnl_gOIhw16i16o2i, |
1696 | gOIhw16o16i2o = dnnl_gOIhw16o16i2o, |
1697 | gOIhw8o8i = dnnl_gOIhw8o8i, |
1698 | gOIhw8o4i = dnnl_gOIhw8o4i, |
1699 | gIOdhw16i16o = dnnl_gIOdhw16i16o, |
1700 | gIOdhw16o16i = dnnl_gIOdhw16o16i, |
1701 | gOdhwi16o = dnnl_gOdhwi16o, |
1702 | gOdhwI16o2i = dnnl_gOdhwI16o2i, |
1703 | gIdhwo16i = dnnl_gIdhwo16i, |
1704 | gIdhwO16i2o = dnnl_gIdhwO16i2o, |
1705 | gIdhwO16i4o = dnnl_gIdhwO16i4o, |
1706 | gOdhwi4o = dnnl_gOdhwi4o, |
1707 | gOdhwi8o = dnnl_gOdhwi8o, |
1708 | gOIdhw16i16o = dnnl_gOIdhw16i16o, |
1709 | gOIdhw16o16i = dnnl_gOIdhw16o16i, |
1710 | gOIdhw16o16i2o = dnnl_gOIdhw16o16i2o, |
1711 | gOidhw16o = dnnl_gOidhw16o, |
1712 | gOIdhw4i4o = dnnl_gOIdhw4i4o, |
1713 | gOIdhw4o4i = dnnl_gOIdhw4o4i, |
1714 | gOidhw4o = dnnl_gOidhw4o, |
1715 | gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i, |
1716 | gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i, |
1717 | gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i, |
1718 | gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i, |
1719 | gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i, |
1720 | gOIdhw8i8o = dnnl_gOIdhw8i8o, |
1721 | gOIdhw8o8i = dnnl_gOIdhw8o8i, |
1722 | gOIdhw8o4i = dnnl_gOIdhw8o4i, |
1723 | gOIw2i4o2i = dnnl_gOIw2i4o2i, |
1724 | gOIhw2i4o2i = dnnl_gOIhw2i4o2i, |
1725 | gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i, |
1726 | gOIw2o4i2o = dnnl_gOIw2o4i2o, |
1727 | gOIhw2o4i2o = dnnl_gOIhw2o4i2o, |
1728 | gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o, |
1729 | gOIw4i8o2i = dnnl_gOIw4i8o2i, |
1730 | gOIhw4i8o2i = dnnl_gOIhw4i8o2i, |
1731 | gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i, |
1732 | gOIw4o8i2o = dnnl_gOIw4o8i2o, |
1733 | gOIhw4o8i2o = dnnl_gOIhw4o8i2o, |
1734 | gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o, |
1735 | ldOi32o = abDc32d, |
1736 | ldOI32o4i = abDC32d4c, |
1737 | ldgOi32o = abdEc32e, |
1738 | ldgOI32o2i = abdEC32e2c, |
1739 | ldgOI32o4i = abdEC32e4c, |
1740 | OwI16o4i = dnnl_OwI16o4i, |
1741 | OhwI16o4i = dnnl_OhwI16o4i, |
1742 | gOwI16o4i = dnnl_gOwI16o4i, |
1743 | gOhwI16o4i = dnnl_gOhwI16o4i, |
1744 | OdhwI16o4i = dnnl_OdhwI16o4i, |
1745 | gOdhwI16o4i = dnnl_gOdhwI16o4i, |
1746 | |
1747 | Owi32o = dnnl_Owi32o, |
1748 | OwI32o2i = dnnl_OwI32o2i, |
1749 | OwI32o4i = dnnl_OwI32o4i, |
1750 | Owi48o = dnnl_Owi48o, |
1751 | OwI48o2i = dnnl_OwI48o2i, |
1752 | OwI48o4i = dnnl_OwI48o4i, |
1753 | Owi64o = dnnl_Owi64o, |
1754 | OwI64o2i = dnnl_OwI64o2i, |
1755 | OwI64o4i = dnnl_OwI64o4i, |
1756 | Iwo32i = dnnl_Iwo32i, |
1757 | IwO32i2o = dnnl_IwO32i2o, |
1758 | IwO32i4o = dnnl_IwO32i4o, |
1759 | Iwo48i = dnnl_Iwo48i, |
1760 | IwO48i2o = dnnl_IwO48i2o, |
1761 | IwO48i4o = dnnl_IwO48i4o, |
1762 | Iwo64i = dnnl_Iwo64i, |
1763 | IwO64i2o = dnnl_IwO64i2o, |
1764 | IwO64i4o = dnnl_IwO64i4o, |
1765 | wIo2i = dnnl_wIo2i, |
1766 | wIo4i = dnnl_wIo4i, |
1767 | gOwi32o = dnnl_gOwi32o, |
1768 | gOwI32o2i = dnnl_gOwI32o2i, |
1769 | gOwI32o4i = dnnl_gOwI32o4i, |
1770 | gOwi48o = dnnl_gOwi48o, |
1771 | gOwI48o2i = dnnl_gOwI48o2i, |
1772 | gOwI48o4i = dnnl_gOwI48o4i, |
1773 | gOwi64o = dnnl_gOwi64o, |
1774 | gOwI64o2i = dnnl_gOwI64o2i, |
1775 | gOwI64o4i = dnnl_gOwI64o4i, |
1776 | gIwo32i = dnnl_gIwo32i, |
1777 | gIwO32i2o = dnnl_gIwO32i2o, |
1778 | gIwO32i4o = dnnl_gIwO32i4o, |
1779 | gIwo48i = dnnl_gIwo48i, |
1780 | gIwO48i2o = dnnl_gIwO48i2o, |
1781 | gIwO48i4o = dnnl_gIwO48i4o, |
1782 | gIwo64i = dnnl_gIwo64i, |
1783 | gIwO64i2o = dnnl_gIwO64i2o, |
1784 | gIwO64i4o = dnnl_gIwO64i4o, |
1785 | gwio = dnnl_gwio, |
1786 | gwIo2i = dnnl_gwIo2i, |
1787 | gwIo4i = dnnl_gwIo4i, |
1788 | OhwI32o = dnnl_OhwI32o, |
1789 | OhwI32o2i = dnnl_OhwI32o2i, |
1790 | OhwI32o4i = dnnl_OhwI32o4i, |
1791 | Ohwi48o = dnnl_Ohwi48o, |
1792 | OhwI48o2i = dnnl_OhwI48o2i, |
1793 | OhwI48o4i = dnnl_OhwI48o4i, |
1794 | Ohwi64o = dnnl_Ohwi64o, |
1795 | OhwI64o2i = dnnl_OhwI64o2i, |
1796 | OhwI64o4i = dnnl_OhwI64o4i, |
1797 | Ihwo32i = dnnl_Ihwo32i, |
1798 | IhwO32i2o = dnnl_IhwO32i2o, |
1799 | IhwO32i4o = dnnl_IhwO32i4o, |
1800 | Ihwo48i = dnnl_Ihwo48i, |
1801 | IhwO48i2o = dnnl_IhwO48i2o, |
1802 | IhwO48i4o = dnnl_IhwO48i4o, |
1803 | Ihwo64i = dnnl_Ihwo64i, |
1804 | IhwO64i2o = dnnl_IhwO64i2o, |
1805 | IhwO64i4o = dnnl_IhwO64i4o, |
1806 | hwIo2i = dnnl_hwIo2i, |
1807 | hwIo4i = dnnl_hwIo4i, |
1808 | gOhwI32o = dnnl_gOhwI32o, |
1809 | gOhwI32o2i = dnnl_gOhwI32o2i, |
1810 | gOhwI32o4i = dnnl_gOhwI32o4i, |
1811 | gOhwi48o = dnnl_gOhwi48o, |
1812 | gOhwI48o2i = dnnl_gOhwI48o2i, |
1813 | gOhwI48o4i = dnnl_gOhwI48o4i, |
1814 | gOhwi64o = dnnl_gOhwi64o, |
1815 | gOhwI64o2i = dnnl_gOhwI64o2i, |
1816 | gOhwI64o4i = dnnl_gOhwI64o4i, |
1817 | gIhwo32i = dnnl_gIhwo32i, |
1818 | gIhwO32i2o = dnnl_gIhwO32i2o, |
1819 | gIhwO32i4o = dnnl_gIhwO32i4o, |
1820 | gIhwo48i = dnnl_gIhwo48i, |
1821 | gIhwO48i2o = dnnl_gIhwO48i2o, |
1822 | gIhwO48i4o = dnnl_gIhwO48i4o, |
1823 | gIhwo64i = dnnl_gIhwo64i, |
1824 | gIhwO64i2o = dnnl_gIhwO64i2o, |
1825 | gIhwO64i4o = dnnl_gIhwO64i4o, |
1826 | ghwio = dnnl_ghwio, |
1827 | ghwIo2i = dnnl_ghwIo2i, |
1828 | ghwIo4i = dnnl_ghwIo4i, |
1829 | Odhwi32o = dnnl_Odhwi32o, |
1830 | OdhwI32o2i = dnnl_OdhwI32o2i, |
1831 | OdhwI32o4i = dnnl_OdhwI32o4i, |
1832 | Odhwi48o = dnnl_Odhwi48o, |
1833 | OdhwI48o2i = dnnl_OdhwI48o2i, |
1834 | OdhwI48o4i = dnnl_OdhwI48o4i, |
1835 | Odhwi64o = dnnl_Odhwi64o, |
1836 | OdhwI64o2i = dnnl_OdhwI64o2i, |
1837 | OdhwI64o4i = dnnl_OdhwI64o4i, |
1838 | Idhwo32i = dnnl_Idhwo32i, |
1839 | IdhwO32i2o = dnnl_IdhwO32i2o, |
1840 | IdhwO32i4o = dnnl_IdhwO32i4o, |
1841 | Idhwo48i = dnnl_Idhwo48i, |
1842 | IdhwO48i2o = dnnl_IdhwO48i2o, |
1843 | IdhwO48i4o = dnnl_IdhwO48i4o, |
1844 | Idhwo64i = dnnl_Idhwo64i, |
1845 | IdhwO64i2o = dnnl_IdhwO64i2o, |
1846 | IdhwO64i4o = dnnl_IdhwO64i4o, |
1847 | dhwIo2i = dnnl_dhwIo2i, |
1848 | dhwIo4i = dnnl_dhwIo4i, |
1849 | gOdhwi32o = dnnl_gOdhwi32o, |
1850 | gOdhwI32o2i = dnnl_gOdhwI32o2i, |
1851 | gOdhwI32o4i = dnnl_gOdhwI32o4i, |
1852 | gOdhwi48o = dnnl_gOdhwi48o, |
1853 | gOdhwI48o2i = dnnl_gOdhwI48o2i, |
1854 | gOdhwI48o4i = dnnl_gOdhwI48o4i, |
1855 | gOdhwi64o = dnnl_gOdhwi64o, |
1856 | gOdhwI64o2i = dnnl_gOdhwI64o2i, |
1857 | gOdhwI64o4i = dnnl_gOdhwI64o4i, |
1858 | gIdhwo32i = dnnl_gIdhwo32i, |
1859 | gIdhwO32i2o = dnnl_gIdhwO32i2o, |
1860 | gIdhwO32i4o = dnnl_gIdhwO32i4o, |
1861 | gIdhwo48i = dnnl_gIdhwo48i, |
1862 | gIdhwO48i2o = dnnl_gIdhwO48i2o, |
1863 | gIdhwO48i4o = dnnl_gIdhwO48i4o, |
1864 | gIdhwo64i = dnnl_gIdhwo64i, |
1865 | gIdhwO64i2o = dnnl_gIdhwO64i2o, |
1866 | gIdhwO64i4o = dnnl_gIdhwO64i4o, |
1867 | gdhwio = dnnl_gdhwio, |
1868 | gdhwIo2i = dnnl_gdhwIo2i, |
1869 | gdhwIo4i = dnnl_gdhwIo4i, |
1870 | ldIo32i = dnnl_ldIo32i, |
1871 | ldgIo32i = dnnl_ldgIo32i, |
1872 | ldgIO32i2o = dnnl_ldgIO32i2o, |
1873 | nCdhw32c = dnnl_nCdhw32c, |
1874 | nChw32c = dnnl_nChw32c, |
1875 | nCw32c = dnnl_nCw32c, |
1876 | NCw32n16c = dnnl_NCw32n16c, |
1877 | NChw32n16c = dnnl_NChw32n16c, |
1878 | NCdhw32n16c = dnnl_NCdhw32n16c, |
1879 | NCw32n32c = dnnl_NCw32n32c, |
1880 | OI16i16o4i = dnnl_OI16i16o4i, |
1881 | IOw8o16i2o = dnnl_IOw8o16i2o, |
1882 | IOhw8o16i2o = dnnl_IOhw8o16i2o, |
1883 | Owhi16o = dnnl_Owhi16o, |
1884 | OIdhw8o16i2o = dnnl_OIdhw8o16i2o, |
1885 | IOdhw8o16i2o = dnnl_IOdhw8o16i2o, |
1886 | Goiw4g = dnnl_Goiw4g, |
1887 | gIOw8o16i2o = dnnl_gIOw8o16i2o, |
1888 | Goiw32g = dnnl_Goiw32g, |
1889 | Goihw4g = dnnl_Goihw4g, |
1890 | gIOhw8o16i2o = dnnl_gIOhw8o16i2o, |
1891 | Goihw32g = dnnl_Goihw32g, |
1892 | gOwhi16o = dnnl_gOwhi16o, |
1893 | IOw4i8o8i4o = dnnl_IOw4i8o8i4o, |
1894 | IOhw4i8o8i4o = dnnl_IOhw4i8o8i4o, |
1895 | IOdhw4i8o8i4o = dnnl_IOdhw4i8o8i4o, |
1896 | gIOw4i8o8i4o = dnnl_gIOw4i8o8i4o, |
1897 | gIOhw4i8o8i4o = dnnl_gIOhw4i8o8i4o, |
1898 | gIOdhw4i8o8i4o = dnnl_gIOdhw4i8o8i4o, |
1899 | gOIdhw8o16i2o = dnnl_gOIdhw8o16i2o, |
1900 | gIOdhw8o16i2o = dnnl_gIOdhw8o16i2o, |
1901 | Goidhw32g = dnnl_Goidhw32g, |
1902 | OI16i32o4i = dnnl_OI16i32o4i, |
1903 | OI16i48o4i = dnnl_OI16i48o4i, |
1904 | OI16i64o4i = dnnl_OI16i64o4i, |
1905 | OI16i16o2i = dnnl_OI16i16o2i, |
1906 | OI16i32o2i = dnnl_OI16i32o2i, |
1907 | OI16i48o2i = dnnl_OI16i48o2i, |
1908 | OI16i64o2i = dnnl_OI16i64o2i, |
1909 | aBdeC16c16b4c = dnnl_aBdeC16c16b4c, |
1910 | AcB16b16a2b = dnnl_AcB16b16a2b, |
1911 | aBdC16c16b2c = dnnl_aBdC16c16b2c, |
1912 | AcB16b16a4b = dnnl_AcB16b16a4b, |
1913 | aBdC16c16b4c = dnnl_aBdC16c16b4c, |
1914 | AcdB16b16a2b = dnnl_AcdB16b16a2b, |
1915 | aBdefC16c16b4c = dnnl_aBdefC16c16b4c, |
1916 | AcdeB16b16a4b = dnnl_AcdeB16b16a4b, |
1917 | AcB16b32a2b = dnnl_AcB16b32a2b, |
1918 | AcB16b32a4b = dnnl_AcB16b32a4b, |
1919 | AcB16b48a2b = dnnl_AcB16b48a2b, |
1920 | AcB16b48a4b = dnnl_AcB16b48a4b, |
1921 | AcB16b64a2b = dnnl_AcB16b64a2b, |
1922 | AcB16b64a4b = dnnl_AcB16b64a4b, |
1923 | aBdC16c32b2c = dnnl_aBdC16c32b2c, |
1924 | aBdC16c32b4c = dnnl_aBdC16c32b4c, |
1925 | aBdC16c48b2c = dnnl_aBdC16c48b2c, |
1926 | aBdC16c48b4c = dnnl_aBdC16c48b4c, |
1927 | aBdC16c64b2c = dnnl_aBdC16c64b2c, |
1928 | aBdC16c64b4c = dnnl_aBdC16c64b4c, |
1929 | AcdB16b32a2b = dnnl_AcdB16b32a2b, |
1930 | AcdB16b32a4b = dnnl_AcdB16b32a4b, |
1931 | AcdB16b48a2b = dnnl_AcdB16b48a2b, |
1932 | AcdB16b48a4b = dnnl_AcdB16b48a4b, |
1933 | AcdB16b64a2b = dnnl_AcdB16b64a2b, |
1934 | AcdB16b64a4b = dnnl_AcdB16b64a4b, |
1935 | aBdeC16c32b2c = dnnl_aBdeC16c32b2c, |
1936 | aBdeC16c32b4c = dnnl_aBdeC16c32b4c, |
1937 | aBdeC16c48b2c = dnnl_aBdeC16c48b2c, |
1938 | aBdeC16c48b4c = dnnl_aBdeC16c48b4c, |
1939 | aBdeC16c64b2c = dnnl_aBdeC16c64b2c, |
1940 | aBdeC16c64b4c = dnnl_aBdeC16c64b4c, |
1941 | AcdeB16b32a2b = dnnl_AcdeB16b32a2b, |
1942 | AcdeB16b32a4b = dnnl_AcdeB16b32a4b, |
1943 | AcdeB16b48a2b = dnnl_AcdeB16b48a2b, |
1944 | AcdeB16b48a4b = dnnl_AcdeB16b48a4b, |
1945 | AcdeB16b64a2b = dnnl_AcdeB16b64a2b, |
1946 | AcdeB16b64a4b = dnnl_AcdeB16b64a4b, |
1947 | aBdefC16c32b2c = dnnl_aBdefC16c32b2c, |
1948 | aBdefC16c32b4c = dnnl_aBdefC16c32b4c, |
1949 | aBdefC16c48b2c = dnnl_aBdefC16c48b2c, |
1950 | aBdefC16c48b4c = dnnl_aBdefC16c48b4c, |
1951 | aBdefC16c64b2c = dnnl_aBdefC16c64b2c, |
1952 | aBdefC16c64b4c = dnnl_aBdefC16c64b4c, |
1953 | OwI16i16o2i = dnnl_OwI16i16o2i, |
1954 | gOwI16i16o2i = dnnl_gOwI16i16o2i, |
1955 | OhwI16i16o2i = dnnl_OhwI16i16o2i, |
1956 | gOhwI16i16o2i = dnnl_gOhwI16i16o2i, |
1957 | OdhwI16i16o2i = dnnl_OdhwI16i16o2i, |
1958 | gOdhwI16i16o2i = dnnl_gOdhwI16i16o2i, |
1959 | OwI16i16o4i = dnnl_OwI16i16o4i, |
1960 | gOwI16i16o4i = dnnl_gOwI16i16o4i, |
1961 | OhwI16i16o4i = dnnl_OhwI16i16o4i, |
1962 | gOhwI16i16o4i = dnnl_gOhwI16i16o4i, |
1963 | OdhwI16i16o4i = dnnl_OdhwI16i16o4i, |
1964 | gOdhwI16i16o4i = dnnl_gOdhwI16i16o4i, |
1965 | OwI16i32o2i = dnnl_OwI16i32o2i, |
1966 | OwI16i32o4i = dnnl_OwI16i32o4i, |
1967 | OwI16i48o2i = dnnl_OwI16i48o2i, |
1968 | OwI16i48o4i = dnnl_OwI16i48o4i, |
1969 | OwI16i64o2i = dnnl_OwI16i64o2i, |
1970 | OwI16i64o4i = dnnl_OwI16i64o4i, |
1971 | gOwI16i32o2i = dnnl_gOwI16i32o2i, |
1972 | gOwI16i32o4i = dnnl_gOwI16i32o4i, |
1973 | gOwI16i48o2i = dnnl_gOwI16i48o2i, |
1974 | gOwI16i48o4i = dnnl_gOwI16i48o4i, |
1975 | gOwI16i64o2i = dnnl_gOwI16i64o2i, |
1976 | gOwI16i64o4i = dnnl_gOwI16i64o4i, |
1977 | OhwI16i32o2i = dnnl_OhwI16i32o2i, |
1978 | OhwI16i32o4i = dnnl_OhwI16i32o4i, |
1979 | OhwI16i48o2i = dnnl_OhwI16i48o2i, |
1980 | OhwI16i48o4i = dnnl_OhwI16i48o4i, |
1981 | OhwI16i64o2i = dnnl_OhwI16i64o2i, |
1982 | OhwI16i64o4i = dnnl_OhwI16i64o4i, |
1983 | gOhwI16i32o2i = dnnl_gOhwI16i32o2i, |
1984 | gOhwI16i32o4i = dnnl_gOhwI16i32o4i, |
1985 | gOhwI16i48o2i = dnnl_gOhwI16i48o2i, |
1986 | gOhwI16i48o4i = dnnl_gOhwI16i48o4i, |
1987 | gOhwI16i64o2i = dnnl_gOhwI16i64o2i, |
1988 | gOhwI16i64o4i = dnnl_gOhwI16i64o4i, |
1989 | OdhwI16i32o2i = dnnl_OdhwI16i32o2i, |
1990 | OdhwI16i32o4i = dnnl_OdhwI16i32o4i, |
1991 | OdhwI16i48o2i = dnnl_OdhwI16i48o2i, |
1992 | OdhwI16i48o4i = dnnl_OdhwI16i48o4i, |
1993 | OdhwI16i64o2i = dnnl_OdhwI16i64o2i, |
1994 | OdhwI16i64o4i = dnnl_OdhwI16i64o4i, |
1995 | IdhwO16o32i2o = dnnl_IdhwO16o32i2o, |
1996 | IdhwO16o32i4o = dnnl_IdhwO16o32i4o, |
1997 | IdhwO16o48i2o = dnnl_IdhwO16o48i2o, |
1998 | IdhwO16o48i4o = dnnl_IdhwO16o48i4o, |
1999 | IdhwO16o64i2o = dnnl_IdhwO16o64i2o, |
2000 | IdhwO16o64i4o = dnnl_IdhwO16o64i4o, |
2001 | gOdhwI16i32o2i = dnnl_gOdhwI16i32o2i, |
2002 | gOdhwI16i32o4i = dnnl_gOdhwI16i32o4i, |
2003 | gOdhwI16i48o2i = dnnl_gOdhwI16i48o2i, |
2004 | gOdhwI16i48o4i = dnnl_gOdhwI16i48o4i, |
2005 | gOdhwI16i64o2i = dnnl_gOdhwI16i64o2i, |
2006 | gOdhwI16i64o4i = dnnl_gOdhwI16i64o4i, |
2007 | gIdhwO16o32i2o = dnnl_gIdhwO16o32i2o, |
2008 | gIdhwO16o32i4o = dnnl_gIdhwO16o32i4o, |
2009 | gIdhwO16o48i2o = dnnl_gIdhwO16o48i2o, |
2010 | gIdhwO16o48i4o = dnnl_gIdhwO16o48i4o, |
2011 | gIdhwO16o64i2o = dnnl_gIdhwO16o64i2o, |
2012 | gIdhwO16o64i4o = dnnl_gIdhwO16o64i4o, |
2013 | IwO16o16i2o = dnnl_IwO16o16i2o, |
2014 | IwO16o16i4o = dnnl_IwO16o16i4o, |
2015 | IhwO16o16i2o = dnnl_IhwO16o16i2o, |
2016 | IhwO16o16i4o = dnnl_IhwO16o16i4o, |
2017 | IdhwO16o16i2o = dnnl_IdhwO16o16i2o, |
2018 | IdhwO16o16i4o = dnnl_IdhwO16o16i4o, |
2019 | gIwO16o16i2o = dnnl_gIwO16o16i2o, |
2020 | gIwO16o16i4o = dnnl_gIwO16o16i4o, |
2021 | gIhwO16o16i2o = dnnl_gIhwO16o16i2o, |
2022 | gIhwO16o16i4o = dnnl_gIhwO16o16i4o, |
2023 | gIdhwO16o16i2o = dnnl_gIdhwO16o16i2o, |
2024 | gIdhwO16o16i4o = dnnl_gIdhwO16o16i4o, |
2025 | IwO16o32i2o = dnnl_IwO16o32i2o, |
2026 | IwO16o32i4o = dnnl_IwO16o32i4o, |
2027 | IwO16o48i2o = dnnl_IwO16o48i2o, |
2028 | IwO16o48i4o = dnnl_IwO16o48i4o, |
2029 | IwO16o64i2o = dnnl_IwO16o64i2o, |
2030 | IwO16o64i4o = dnnl_IwO16o64i4o, |
2031 | gIwO16o32i2o = dnnl_gIwO16o32i2o, |
2032 | gIwO16o32i4o = dnnl_gIwO16o32i4o, |
2033 | gIwO16o48i2o = dnnl_gIwO16o48i2o, |
2034 | gIwO16o48i4o = dnnl_gIwO16o48i4o, |
2035 | gIwO16o64i2o = dnnl_gIwO16o64i2o, |
2036 | gIwO16o64i4o = dnnl_gIwO16o64i4o, |
2037 | IhwO16o32i2o = dnnl_IhwO16o32i2o, |
2038 | IhwO16o32i4o = dnnl_IhwO16o32i4o, |
2039 | IhwO16o48i2o = dnnl_IhwO16o48i2o, |
2040 | IhwO16o48i4o = dnnl_IhwO16o48i4o, |
2041 | IhwO16o64i2o = dnnl_IhwO16o64i2o, |
2042 | IhwO16o64i4o = dnnl_IhwO16o64i4o, |
2043 | gIhwO16o32i2o = dnnl_gIhwO16o32i2o, |
2044 | gIhwO16o32i4o = dnnl_gIhwO16o32i4o, |
2045 | gIhwO16o48i2o = dnnl_gIhwO16o48i2o, |
2046 | gIhwO16o48i4o = dnnl_gIhwO16o48i4o, |
2047 | gIhwO16o64i2o = dnnl_gIhwO16o64i2o, |
2048 | gIhwO16o64i4o = dnnl_gIhwO16o64i4o, |
2049 | aBdeC16c16b2c = dnnl_aBdeC16c16b2c, |
2050 | aBdefC16c16b2c = dnnl_aBdefC16c16b2c, |
2051 | AcdB16b16a4b = dnnl_AcdB16b16a4b, |
2052 | AcdeB16b16a2b = dnnl_AcdeB16b16a2b, |
2053 | hwioG16g = dnnl_hwioG16g, |
2054 | hwioG8g = dnnl_hwioG8g, |
2055 | ABc4a2b = dnnl_ABc4a2b, |
2056 | ABc8a2b = dnnl_ABc8a2b, |
2057 | ABcd4a2b = dnnl_ABcd4a2b, |
2058 | ABcde4a2b = dnnl_ABcde4a2b, |
2059 | ABcde8a2b = dnnl_ABcde8a2b, |
2060 | ABcd4a8b8a2b = dnnl_ABcd4a8b8a2b, |
2061 | NCdhw40n32c = dnnl_NCdhw40n32c, |
2062 | NChw40n32c = dnnl_NChw40n32c, |
2063 | NCw40n32c = dnnl_NCw40n32c, |
2064 | OIdhw4o8i8o2i = dnnl_OIdhw4o8i8o2i, |
2065 | OIhw4o8i8o2i = dnnl_OIhw4o8i8o2i, |
2066 | OIw4o8i8o2i = dnnl_OIw4o8i8o2i, |
2067 | gOIdhw4o8i8o2i = dnnl_gOIdhw4o8i8o2i, |
2068 | gOIhw4o8i8o2i = dnnl_gOIhw4o8i8o2i, |
2069 | gOIw4o8i8o2i = dnnl_gOIw4o8i8o2i, |
2070 | IOdhw4i8o8i2o = dnnl_IOdhw4i8o8i2o, |
2071 | IOhw4i8o8i2o = dnnl_IOhw4i8o8i2o, |
2072 | IOw4i8o8i2o = dnnl_IOw4i8o8i2o, |
2073 | gIOdhw4i8o8i2o = dnnl_gIOdhw4i8o8i2o, |
2074 | gIOhw4i8o8i2o = dnnl_gIOhw4i8o8i2o, |
2075 | gIOw4i8o8i2o = dnnl_gIOw4i8o8i2o, |
2076 | aBCd8b2c = dnnl_aBCd8b2c, |
2077 | ABcde40a16b = dnnl_ABcde40a16b, |
2078 | ABcde40a32b = dnnl_ABcde40a32b, |
2079 | aBCde8b2c = dnnl_aBCde8b2c, |
2080 | ABcde4a8b8a2b = dnnl_ABcde4a8b8a2b, |
2081 | ABc4a8b8a2b = dnnl_ABc4a8b8a2b, |
2082 | aBCdef4b8c8b2c = dnnl_aBCdef4b8c8b2c, |
2083 | aBCde4b8c8b2c = dnnl_aBCde4b8c8b2c, |
2084 | aBCd4b8c8b2c = dnnl_aBCd4b8c8b2c, |
2085 | BAcde4b8a8b2a = dnnl_BAcde4b8a8b2a, |
2086 | BAcd4b8a8b2a = dnnl_BAcd4b8a8b2a, |
2087 | BAc4b8a8b2a = dnnl_BAc4b8a8b2a, |
2088 | aCBdef4c8b8c2b = dnnl_aCBdef4c8b8c2b, |
2089 | aCBde4c8b8c2b = dnnl_aCBde4c8b8c2b, |
2090 | aCBd4c8b8c2b = dnnl_aCBd4c8b8c2b, |
2091 | aBCdef8b2c = dnnl_aBCdef8b2c, |
2092 | AB32a16b = dnnl_AB32a16b, |
2093 | AB32a32b = dnnl_AB32a32b, |
2094 | BA4b8a8b2a = dnnl_BA4b8a8b2a, |
2095 | BA4b8a8b4a = dnnl_BA4b8a8b4a, |
2096 | aBC32b16c = dnnl_aBC32b16c, |
2097 | aBC32b32c = dnnl_aBC32b32c, |
2098 | aCB4c8b8c2b = dnnl_aCB4c8b8c2b, |
2099 | aCB4c8b8c4b = dnnl_aCB4c8b8c4b, |
2100 | ABc2b8a16b4a = dnnl_ABc2b8a16b4a, |
2101 | ABcd2b8a16b4a = dnnl_ABcd2b8a16b4a, |
2102 | ABcde2b8a16b4a = dnnl_ABcde2b8a16b4a, |
2103 | ABc2a8b16a4b = dnnl_ABc2a8b16a4b, |
2104 | ABc2a8b16a2b = dnnl_ABc2a8b16a2b, |
2105 | ABc2b32a8b = dnnl_ABc2b32a8b, |
2106 | ABcd2a8b16a4b = dnnl_ABcd2a8b16a4b, |
2107 | ABcd2a8b16a2b = dnnl_ABcd2a8b16a2b, |
2108 | aCBd2c8b16c2b = dnnl_aCBd2c8b16c2b, |
2109 | ABcd2b32a8b = dnnl_ABcd2b32a8b, |
2110 | aBCd2c8b16c2b = dnnl_aBCd2c8b16c2b, |
2111 | ABcde2a8b16a4b = dnnl_ABcde2a8b16a4b, |
2112 | ABcde2a8b16a2b = dnnl_ABcde2a8b16a2b, |
2113 | aCBde2c8b16c2b = dnnl_aCBde2c8b16c2b, |
2114 | ABcde2b32a8b = dnnl_ABcde2b32a8b, |
2115 | aBC2b8c16b2c = dnnl_aBC2b8c16b2c, |
2116 | aBCd2b8c16b2c = dnnl_aBCd2b8c16b2c, |
2117 | aBCde2b8c16b2c = dnnl_aBCde2b8c16b2c, |
2118 | aBCdef2b8c16b2c = dnnl_aBCdef2b8c16b2c, |
2119 | BAcde2b8a16b4a = dnnl_BAcde2b8a16b4a, |
2120 | BAcd2b8a16b4a = dnnl_BAcd2b8a16b4a, |
2121 | BAc2b8a16b4a = dnnl_BAc2b8a16b4a, |
2122 | BAcde2b8a16b2a = dnnl_BAcde2b8a16b2a, |
2123 | BAcd2b8a16b2a = dnnl_BAcd2b8a16b2a, |
2124 | BAc2b8a16b2a = dnnl_BAc2b8a16b2a, |
2125 | aBCde2c8b16c2b = dnnl_aBCde2c8b16c2b, |
2126 | aBCdef2c8b16c2b = dnnl_aBCdef2c8b16c2b, |
2127 | aCBdef2c8b16c2b = dnnl_aCBdef2c8b16c2b, |
2128 | aBCd2b8c16b4c = dnnl_aBCd2b8c16b4c, |
2129 | aBCde2b8c16b4c = dnnl_aBCde2b8c16b4c, |
2130 | NCdhw40n16c = dnnl_NCdhw40n16c, |
2131 | NCw40n16c = dnnl_NCw40n16c, |
2132 | NChw40n16c = dnnl_NChw40n16c, |
2133 | NCw2c32n8c = dnnl_NCw2c32n8c, |
2134 | NChw2c32n8c = dnnl_NChw2c32n8c, |
2135 | NCdhw2c32n8c = dnnl_NCdhw2c32n8c, |
2136 | OIw2i8o16i4o = dnnl_OIw2i8o16i4o, |
2137 | OIhw2i8o16i4o = dnnl_OIhw2i8o16i4o, |
2138 | OIdhw2i8o16i4o = dnnl_OIdhw2i8o16i4o, |
2139 | OIw2o8i16o4i = dnnl_OIw2o8i16o4i, |
2140 | OIw2o8i16o2i = dnnl_OIw2o8i16o2i, |
2141 | IOw2i8o16i4o = dnnl_IOw2i8o16i4o, |
2142 | IOw2i8o16i2o = dnnl_IOw2i8o16i2o, |
2143 | OIhw2o8i16o4i = dnnl_OIhw2o8i16o4i, |
2144 | OIhw2o8i16o2i = dnnl_OIhw2o8i16o2i, |
2145 | IOhw2i8o16i4o = dnnl_IOhw2i8o16i4o, |
2146 | IOhw2i8o16i2o = dnnl_IOhw2i8o16i2o, |
2147 | OIdhw2o8i16o4i = dnnl_OIdhw2o8i16o4i, |
2148 | OIdhw2o8i16o2i = dnnl_OIdhw2o8i16o2i, |
2149 | IOdhw2i8o16i4o = dnnl_IOdhw2i8o16i4o, |
2150 | IOdhw2i8o16i2o = dnnl_IOdhw2i8o16i2o, |
2151 | gOIw2o8i16o2i = dnnl_gOIw2o8i16o2i, |
2152 | gIOw2i8o16i2o = dnnl_gIOw2i8o16i2o, |
2153 | gIOhw2i8o16i2o = dnnl_gIOhw2i8o16i2o, |
2154 | gIOdhw2i8o16i2o = dnnl_gIOdhw2i8o16i2o, |
2155 | gOIhw2o8i16o2i = dnnl_gOIhw2o8i16o2i, |
2156 | gOIdhw2o8i16o2i = dnnl_gOIdhw2o8i16o2i, |
2157 | gOIw2o8i16o4i = dnnl_gOIw2o8i16o4i, |
2158 | gOIhw2o8i16o4i = dnnl_gOIhw2o8i16o4i, |
2159 | BA4b8a16b2a = dnnl_BA4b8a16b2a, |
2160 | BA4b8a16b4a = dnnl_BA4b8a16b4a, |
2161 | aCB4c8b16c2b = dnnl_aCB4c8b16c2b, |
2162 | aCB4c8b16c4b = dnnl_aCB4c8b16c4b, |
2163 | aCB16c2b = dnnl_aCB16c2b, |
2164 | aCB16c4b = dnnl_aCB16c4b, |
2165 | BA16b2a = dnnl_BA16b2a, |
2166 | BA16b4a = dnnl_BA16b4a, |
2167 | aBC16b16c = dnnl_aBC16b16c, |
2168 | aBC16b32c = dnnl_aBC16b32c, |
2169 | AB16a16b = dnnl_AB16a16b, |
2170 | AB16a32b = dnnl_AB16a32b, |
2171 | ABcde16a16b2a = dnnl_ABcde16a16b2a, |
2172 | aBCdef16b16c2b = dnnl_aBCdef16b16c2b, |
2173 | Acedb16a = dnnl_Acedb16a, |
2174 | aBdfec16b = dnnl_aBdfec16b, |
2175 | Odwhi16o = dnnl_Odwhi16o, |
2176 | gOdwhi16o = dnnl_gOdwhi16o, |
2177 | abdEC64e2c = dnnl_abdEC64e2c, |
2178 | abdEC64e4c = dnnl_abdEC64e4c, |
2179 | ldgOI64o2i = abdEC64e2c, |
2180 | ldgOI64o4i = abdEC64e4c, |
2181 | abCd4c = dnnl_abCd4c, |
2182 | abCde4c = dnnl_abCde4c, |
2183 | abCdef4c = dnnl_abCdef4c, |
2184 | abCde32c = dnnl_abCde32c, |
2185 | abCdef32c = dnnl_abCdef32c, |
2186 | aCdefB16b32c2b = dnnl_aCdefB16b32c2b, |
2187 | aCdefB16b32c4b = dnnl_aCdefB16b32c4b, |
2188 | aCdefB16b48c2b = dnnl_aCdefB16b48c2b, |
2189 | aCdefB16b48c4b = dnnl_aCdefB16b48c4b, |
2190 | aCdefB16b64c2b = dnnl_aCdefB16b64c2b, |
2191 | aCdefB16b64c4b = dnnl_aCdefB16b64c4b, |
2192 | BcdeA16a32b2a = dnnl_BcdeA16a32b2a, |
2193 | BcdeA16a32b4a = dnnl_BcdeA16a32b4a, |
2194 | BcdeA16a48b2a = dnnl_BcdeA16a48b2a, |
2195 | BcdeA16a48b4a = dnnl_BcdeA16a48b4a, |
2196 | BcdeA16a64b2a = dnnl_BcdeA16a64b2a, |
2197 | BcdeA16a64b4a = dnnl_BcdeA16a64b4a, |
2198 | aCdefb32c = dnnl_aCdefb32c, |
2199 | aCdefB32c2b = dnnl_aCdefB32c2b, |
2200 | aCdefB32c4b = dnnl_aCdefB32c4b, |
2201 | aCdefb48c = dnnl_aCdefb48c, |
2202 | aCdefB48c2b = dnnl_aCdefB48c2b, |
2203 | aCdefB48c4b = dnnl_aCdefB48c4b, |
2204 | aCdefb64c = dnnl_aCdefb64c, |
2205 | aCdefB64c2b = dnnl_aCdefB64c2b, |
2206 | aCdefB64c4b = dnnl_aCdefB64c4b, |
2207 | Bcdea32b = dnnl_Bcdea32b, |
2208 | BcdeA32b2a = dnnl_BcdeA32b2a, |
2209 | BcdeA32b4a = dnnl_BcdeA32b4a, |
2210 | Bcdea48b = dnnl_Bcdea48b, |
2211 | BcdeA48b2a = dnnl_BcdeA48b2a, |
2212 | BcdeA48b4a = dnnl_BcdeA48b4a, |
2213 | Bcdea64b = dnnl_Bcdea64b, |
2214 | BcdeA64b2a = dnnl_BcdeA64b2a, |
2215 | BcdeA64b4a = dnnl_BcdeA64b4a, |
2216 | Bca32b = dnnl_Bca32b, |
2217 | BcA32b2a = dnnl_BcA32b2a, |
2218 | BcA32b4a = dnnl_BcA32b4a, |
2219 | Bca48b = dnnl_Bca48b, |
2220 | BcA48b2a = dnnl_BcA48b2a, |
2221 | BcA48b4a = dnnl_BcA48b4a, |
2222 | Bca64b = dnnl_Bca64b, |
2223 | BcA64b2a = dnnl_BcA64b2a, |
2224 | BcA64b4a = dnnl_BcA64b4a, |
2225 | aCdb32c = dnnl_aCdb32c, |
2226 | aCdB32c2b = dnnl_aCdB32c2b, |
2227 | aCdB32c4b = dnnl_aCdB32c4b, |
2228 | aCdb48c = dnnl_aCdb48c, |
2229 | aCdB48c2b = dnnl_aCdB48c2b, |
2230 | aCdB48c4b = dnnl_aCdB48c4b, |
2231 | aCdb64c = dnnl_aCdb64c, |
2232 | aCdB64c2b = dnnl_aCdB64c2b, |
2233 | aCdB64c4b = dnnl_aCdB64c4b, |
2234 | BcA16a16b2a = dnnl_BcA16a16b2a, |
2235 | BcA16a16b4a = dnnl_BcA16a16b4a, |
2236 | BcdA16a16b2a = dnnl_BcdA16a16b2a, |
2237 | BcdA16a16b4a = dnnl_BcdA16a16b4a, |
2238 | BcdeA16a16b2a = dnnl_BcdeA16a16b2a, |
2239 | BcdeA16a16b4a = dnnl_BcdeA16a16b4a, |
2240 | aCdB16b16c2b = dnnl_aCdB16b16c2b, |
2241 | aCdB16b16c4b = dnnl_aCdB16b16c4b, |
2242 | aCdeB16b16c2b = dnnl_aCdeB16b16c2b, |
2243 | aCdeB16b16c4b = dnnl_aCdeB16b16c4b, |
2244 | aCdefB16b16c2b = dnnl_aCdefB16b16c2b, |
2245 | aCdefB16b16c4b = dnnl_aCdefB16b16c4b, |
2246 | BcA16a32b2a = dnnl_BcA16a32b2a, |
2247 | BcA16a32b4a = dnnl_BcA16a32b4a, |
2248 | BcA16a48b2a = dnnl_BcA16a48b2a, |
2249 | BcA16a48b4a = dnnl_BcA16a48b4a, |
2250 | BcA16a64b2a = dnnl_BcA16a64b2a, |
2251 | BcA16a64b4a = dnnl_BcA16a64b4a, |
2252 | aCdB16b32c2b = dnnl_aCdB16b32c2b, |
2253 | aCdB16b32c4b = dnnl_aCdB16b32c4b, |
2254 | aCdB16b48c2b = dnnl_aCdB16b48c2b, |
2255 | aCdB16b48c4b = dnnl_aCdB16b48c4b, |
2256 | aCdB16b64c2b = dnnl_aCdB16b64c2b, |
2257 | aCdB16b64c4b = dnnl_aCdB16b64c4b, |
2258 | BcdA16a32b2a = dnnl_BcdA16a32b2a, |
2259 | BcdA16a32b4a = dnnl_BcdA16a32b4a, |
2260 | BcdA16a48b2a = dnnl_BcdA16a48b2a, |
2261 | BcdA16a48b4a = dnnl_BcdA16a48b4a, |
2262 | BcdA16a64b2a = dnnl_BcdA16a64b2a, |
2263 | BcdA16a64b4a = dnnl_BcdA16a64b4a, |
2264 | aCdeB16b32c2b = dnnl_aCdeB16b32c2b, |
2265 | aCdeB16b32c4b = dnnl_aCdeB16b32c4b, |
2266 | aCdeB16b48c2b = dnnl_aCdeB16b48c2b, |
2267 | aCdeB16b48c4b = dnnl_aCdeB16b48c4b, |
2268 | aCdeB16b64c2b = dnnl_aCdeB16b64c2b, |
2269 | aCdeB16b64c4b = dnnl_aCdeB16b64c4b, |
2270 | Bca16b = dnnl_Bca16b, |
2271 | BcA16b2a = dnnl_BcA16b2a, |
2272 | BcA16b4a = dnnl_BcA16b4a, |
2273 | Bcda16b = dnnl_Bcda16b, |
2274 | BcdA16b2a = dnnl_BcdA16b2a, |
2275 | BcdA16b4a = dnnl_BcdA16b4a, |
2276 | Bcdea16b = dnnl_Bcdea16b, |
2277 | BcdeA16b2a = dnnl_BcdeA16b2a, |
2278 | BcdeA16b4a = dnnl_BcdeA16b4a, |
2279 | aCdb16c = dnnl_aCdb16c, |
2280 | aCdB16c2b = dnnl_aCdB16c2b, |
2281 | aCdB16c4b = dnnl_aCdB16c4b, |
2282 | aCdeb16c = dnnl_aCdeb16c, |
2283 | aCdeB16c2b = dnnl_aCdeB16c2b, |
2284 | aCdeB16c4b = dnnl_aCdeB16c4b, |
2285 | aCdefb16c = dnnl_aCdefb16c, |
2286 | aCdefB16c2b = dnnl_aCdefB16c2b, |
2287 | aCdefB16c4b = dnnl_aCdefB16c4b, |
2288 | Bcda32b = dnnl_Bcda32b, |
2289 | BcdA32b2a = dnnl_BcdA32b2a, |
2290 | BcdA32b4a = dnnl_BcdA32b4a, |
2291 | Bcda48b = dnnl_Bcda48b, |
2292 | BcdA48b2a = dnnl_BcdA48b2a, |
2293 | BcdA48b4a = dnnl_BcdA48b4a, |
2294 | Bcda64b = dnnl_Bcda64b, |
2295 | BcdA64b2a = dnnl_BcdA64b2a, |
2296 | BcdA64b4a = dnnl_BcdA64b4a, |
2297 | aCdeb32c = dnnl_aCdeb32c, |
2298 | aCdeB32c2b = dnnl_aCdeB32c2b, |
2299 | aCdeB32c4b = dnnl_aCdeB32c4b, |
2300 | aCdeb48c = dnnl_aCdeb48c, |
2301 | aCdeB48c2b = dnnl_aCdeB48c2b, |
2302 | aCdeB48c4b = dnnl_aCdeB48c4b, |
2303 | aCdeb64c = dnnl_aCdeb64c, |
2304 | aCdeB64c2b = dnnl_aCdeB64c2b, |
2305 | aCdeB64c4b = dnnl_aCdeB64c4b, |
2306 | NChw16n32c = dnnl_NChw16n32c, |
2307 | goIw4i = dnnl_goIw4i, |
2308 | goIw32i = dnnl_goIw32i, |
2309 | goIhw4i = dnnl_goIhw4i, |
2310 | goIhw32i = dnnl_goIhw32i, |
2311 | goIdhw4i = dnnl_goIdhw4i, |
2312 | goIdhw32i = dnnl_goIdhw32i, |
2313 | cab = dnnl_cab, |
2314 | cdab = dnnl_cdab, |
2315 | cdeab = dnnl_cdeab, |
2316 | woi = dnnl_woi, |
2317 | hwoi = dnnl_hwoi, |
2318 | dhwoi = dnnl_dhwoi, |
2319 | }; |
2320 | |
2321 | /// A memory descriptor. |
2322 | struct desc : public handle<dnnl_memory_desc_t> { |
2323 | using handle<dnnl_memory_desc_t>::handle; |
2324 | |
2325 | friend struct memory; |
2326 | |
2327 | /// Constructs a zero (empty) memory descriptor. Such a memory |
2328 | /// descriptor can be used to indicate absence of an argument. |
2329 | desc() { |
2330 | dnnl_memory_desc_t zero_md = nullptr; |
2331 | error::wrap_c_api( |
2332 | dnnl_memory_desc_create_with_tag(&zero_md, 0, nullptr, |
2333 | dnnl_data_type_undef, dnnl_format_tag_undef), |
2334 | "could not create a zero memory descriptor" ); |
2335 | reset(zero_md); |
2336 | } |
2337 | |
2338 | /// Constructs a memory descriptor. |
2339 | /// |
2340 | /// @note |
2341 | /// The logical order of dimensions corresponds to the `abc...` |
2342 | /// format tag, and the physical meaning of the dimensions depends |
2343 | /// both on the primitive that would operate on this memory and |
2344 | /// the operation context. |
2345 | /// |
2346 | /// @param adims Tensor dimensions. |
2347 | /// @param adata_type Data precision/type. |
2348 | /// @param aformat_tag Memory format tag. |
2349 | /// @param allow_empty A flag signifying whether construction is |
2350 | /// allowed to fail without throwing an exception. In this case a |
2351 | /// zero memory descriptor will be constructed. This flag is |
2352 | /// optional and defaults to false. |
2353 | desc(const dims &adims, data_type adata_type, format_tag aformat_tag, |
2354 | bool allow_empty = false) { |
2355 | validate_dims(adims); |
2356 | dnnl_memory_desc_t md = nullptr; |
2357 | dnnl_status_t status = dnnl_memory_desc_create_with_tag(&md, |
2358 | (int)adims.size(), adims.data(), convert_to_c(adata_type), |
2359 | convert_to_c(aformat_tag)); |
2360 | if (!allow_empty) |
2361 | error::wrap_c_api(status, |
2362 | "could not construct a memory descriptor using a " |
2363 | "format tag" ); |
2364 | reset(md); |
2365 | } |
2366 | |
2367 | /// Constructs a memory descriptor by strides. |
2368 | /// |
2369 | /// @note |
2370 | /// The logical order of dimensions corresponds to the `abc...` |
2371 | /// format tag, and the physical meaning of the dimensions depends |
2372 | /// both on the primitive that would operate on this memory and |
2373 | /// the operation context. |
2374 | /// |
2375 | /// @param adims Tensor dimensions. |
2376 | /// @param adata_type Data precision/type. |
2377 | /// @param strides Strides for each dimension. |
2378 | /// @param allow_empty A flag signifying whether construction is |
2379 | /// allowed to fail without throwing an exception. In this case a |
2380 | /// zero memory descriptor will be constructed. This flag is |
2381 | /// optional and defaults to false. |
2382 | desc(const dims &adims, data_type adata_type, const dims &strides, |
2383 | bool allow_empty = false) { |
2384 | validate_dims(adims); |
2385 | if (!strides.empty()) validate_dims(strides, (int)adims.size()); |
2386 | dnnl_memory_desc_t md = nullptr; |
2387 | dnnl_status_t status = dnnl_memory_desc_create_with_strides(&md, |
2388 | (int)adims.size(), adims.data(), convert_to_c(adata_type), |
2389 | strides.empty() ? nullptr : &strides[0]); |
2390 | if (!allow_empty) |
2391 | error::wrap_c_api(status, |
2392 | "could not construct a memory descriptor using " |
2393 | "strides" ); |
2394 | reset(md); |
2395 | } |
2396 | |
2397 | /// Construct a memory descriptor from a C API ::dnnl_memory_desc_t |
2398 | /// handle. The resulting handle is not weak and the C handle will be |
2399 | /// destroyed during the destruction of the C++ object. |
2400 | /// |
2401 | /// @param md The C API memory descriptor. |
2402 | desc(dnnl_memory_desc_t md) : handle<dnnl_memory_desc_t>(md) {} |
2403 | |
2404 | /// Constructs a memory descriptor for a region inside an area |
2405 | /// described by this memory descriptor. |
2406 | // |
2407 | /// @param adims Sizes of the region. |
2408 | /// @param offsets Offsets to the region from the encompassing |
2409 | /// memory object in each dimension. |
2410 | /// @param allow_empty A flag signifying whether construction is |
2411 | /// allowed to fail without throwing an exception. In this case a |
2412 | /// zero memory descriptor will be returned. This flag is optional |
2413 | /// and defaults to false. |
2414 | /// @returns A memory descriptor for the region. |
2415 | desc submemory_desc(const dims &adims, const dims &offsets, |
2416 | bool allow_empty = false) const { |
2417 | validate_dims(adims, get_ndims()); |
2418 | validate_dims(offsets, get_ndims()); |
2419 | dnnl_memory_desc_t sub_md = nullptr; |
2420 | dnnl_status_t status = dnnl_memory_desc_create_submemory( |
2421 | &sub_md, get(), adims.data(), offsets.data()); |
2422 | if (!allow_empty) |
2423 | error::wrap_c_api(status, "could not construct a sub-memory" ); |
2424 | return desc(sub_md); |
2425 | } |
2426 | |
2427 | /// Constructs a memory descriptor by reshaping an existing one. The |
2428 | /// new memory descriptor inherits the data type. This operation is |
2429 | /// valid only for memory descriptors that have format_kind set to |
2430 | /// #dnnl::memory::format_kind::blocked or |
2431 | /// #dnnl::memory::format_kind::any. |
2432 | /// |
2433 | /// The operation ensures that the transformation of the physical memory |
2434 | /// format corresponds to the transformation of the logical dimensions. |
2435 | /// If such transformation is impossible, the function either throws an |
2436 | /// exception (default) or returns a zero memory descriptor depending on |
2437 | /// the `allow_empty` flag. |
2438 | /// |
2439 | /// The reshape operation can be described as a combination of the |
2440 | /// following basic operations: |
2441 | /// 1. Add a dimension of size `1`. This is always possible. |
2442 | /// 2. Remove a dimension of size `1`. This is possible only if the |
2443 | /// dimension has no padding (i.e. |
2444 | /// `padded_dims[dim] == dims[dim] && dims[dim] == 1`). |
2445 | /// 3. Split a dimension into multiple ones. This is possible only if |
2446 | /// the product of all tensor dimensions stays constant and the |
2447 | /// dimension being split does not have padding (i.e. |
2448 | /// `padded_dims[dim] = dims[dim]`). |
2449 | /// 4. Join multiple consecutive dimensions into a single one. As in |
2450 | /// the cases above, this requires that the dimensions do not have |
2451 | /// padding and that the memory format is such that in physical |
2452 | /// memory these dimensions are dense and have the same order as |
2453 | /// their logical counterparts. This also assumes that these |
2454 | /// dimensions are not blocked. |
2455 | /// - Here, 'dense' means: |
2456 | /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`; |
2457 | /// - And 'same order' means: |
2458 | /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`. |
2459 | /// |
2460 | /// @warning |
2461 | /// Some combinations of physical memory layout and/or offsets or |
2462 | /// dimensions may result in a failure to make a reshape. |
2463 | /// |
2464 | /// @param adims New dimensions. The product of dimensions must |
2465 | /// remain constant. |
2466 | /// @param allow_empty A flag signifying whether construction is |
2467 | /// allowed to fail without throwing an exception. In this case a |
2468 | /// zero memory descriptor will be returned. This flag is optional |
2469 | /// and defaults to false. |
2470 | /// @returns A new memory descriptor with new dimensions. |
2471 | desc reshape(const dims &adims, bool allow_empty = false) const { |
2472 | if (get_ndims()) validate_dims(adims, 1); |
2473 | dnnl_memory_desc_t out_md = nullptr; |
2474 | dnnl_status_t status = dnnl_memory_desc_reshape( |
2475 | &out_md, get(), (int)adims.size(), adims.data()); |
2476 | if (!allow_empty) |
2477 | error::wrap_c_api( |
2478 | status, "could not reshape a memory descriptor" ); |
2479 | return desc(out_md); |
2480 | } |
2481 | |
2482 | /// Constructs a memory descriptor by permuting axes in an existing |
2483 | /// one. |
2484 | /// |
2485 | /// The physical memory layout representation is adjusted accordingly |
2486 | /// to maintain the consistency between the logical and physical parts |
2487 | /// of the memory descriptor. The new memory descriptor inherits the |
2488 | /// data type. |
2489 | /// |
2490 | /// The new memory descriptor inherits the data type. This operation is |
2491 | /// valid only for memory descriptors that have format_kind set to |
2492 | /// #dnnl::memory::format_kind::blocked or |
2493 | /// #dnnl::memory::format_kind::any. |
2494 | /// |
2495 | /// The logical axes will be permuted in the following manner: |
2496 | /// @code |
2497 | /// for (i = 0; i < get_ndims(); i++) |
2498 | /// new_desc.dims()[permutation[i]] = dims()[i]; |
2499 | /// @endcode |
2500 | /// |
2501 | /// Example: |
2502 | /// @code |
2503 | /// std::vector<int> permutation = {1, 0}; // swap the first and |
2504 | /// // the second axes |
2505 | /// dnnl::memory::desc in_md( |
2506 | /// {2, 3}, data_type, memory::format_tag::ab); |
2507 | /// dnnl::memory::desc expect_out_md( |
2508 | /// {3, 2}, data_type, memory::format_tag::ba); |
2509 | /// |
2510 | /// assert(in_md.permute_axes(permutation) == expect_out_md); |
2511 | /// @endcode |
2512 | /// |
2513 | /// @param permutation Axes permutation. |
2514 | /// @param allow_empty A flag signifying whether construction is |
2515 | /// allowed to fail without throwing an exception. In this case a |
2516 | /// zero memory descriptor will be returned. This flag is optional |
2517 | /// and defaults to false. |
2518 | /// @returns A new memory descriptor with new dimensions. |
2519 | desc permute_axes(const std::vector<int> &permutation, |
2520 | bool allow_empty = false) const { |
2521 | validate_dims(permutation, get_ndims()); |
2522 | dnnl_memory_desc_t out_md = nullptr; |
2523 | dnnl_status_t status = dnnl_memory_desc_permute_axes( |
2524 | &out_md, get(), permutation.data()); |
2525 | if (!allow_empty) |
2526 | error::wrap_c_api(status, |
2527 | "could not permute axes of a memory descriptor" ); |
2528 | return desc(out_md); |
2529 | } |
2530 | |
2531 | /// Returns a number of dimensions of the memory descriptor. |
2532 | /// |
2533 | /// @returns A number of dimensions. |
2534 | int get_ndims() const { return query_s32(query::ndims_s32); } |
2535 | |
2536 | /// Returns padded dimensions of the memory descriptor. |
2537 | /// |
2538 | /// @returns A copy of the padded dimensions vector. |
2539 | memory::dims get_padded_dims() const { |
2540 | return query_dims(query::padded_dims); |
2541 | } |
2542 | |
2543 | /// Returns padded offsets of the memory descriptor. |
2544 | /// |
2545 | /// @returns A copy of the padded offsets vector. |
2546 | memory::dims get_padded_offsets() const { |
2547 | return query_dims(query::padded_offsets); |
2548 | } |
2549 | |
2550 | /// Returns a submemory offset of the memory descriptor. |
2551 | /// |
2552 | /// @returns A submemory offset. |
2553 | memory::dim get_submemory_offset() const { |
2554 | dnnl_dim_t submemory_offset; |
2555 | dnnl_status_t status = dnnl_memory_desc_query( |
2556 | get(), dnnl_query_submemory_offset_s64, &submemory_offset); |
2557 | return status == dnnl_success ? submemory_offset : 0; |
2558 | } |
2559 | |
2560 | /// Returns strides of the memory descriptor. |
2561 | /// |
2562 | /// @note |
2563 | /// This API is only applicable to memory descriptors with format |
2564 | /// kind #dnnl_blocked. |
2565 | /// |
2566 | /// @returns A copy of the strides vector. |
2567 | /// @returns An empty #dnnl::memory::dims if the memory descriptor |
2568 | /// does not have strides. |
2569 | memory::dims get_strides() const { return query_dims(query::strides); } |
2570 | |
2571 | /// Returns a number of inner blocks of the memory descriptor. |
2572 | /// |
2573 | /// @note |
2574 | /// This API is only applicable to memory descriptors with format |
2575 | /// kind #dnnl_blocked. |
2576 | /// |
2577 | /// @returns A number of inner blocks. |
2578 | int get_inner_nblks() const { |
2579 | return query_s32(query::inner_nblks_s32); |
2580 | } |
2581 | |
2582 | /// Returns inner blocks of the memory descriptor. |
2583 | /// |
2584 | /// @note |
2585 | /// This API is only applicable to memory descriptors with format |
2586 | /// kind #dnnl_blocked. |
2587 | /// |
2588 | /// @returns A copy of the inner blocks vector. |
2589 | /// @returns An empty #dnnl::memory::dims if the memory descriptor |
2590 | /// does not have inner blocks. |
2591 | memory::dims get_inner_blks() const { |
2592 | return query_dims(query::inner_blks); |
2593 | } |
2594 | |
2595 | /// Returns inner indices of the memory descriptor. |
2596 | /// |
2597 | /// @note |
2598 | /// This API is only applicable to memory descriptors with format |
2599 | /// kind #dnnl_blocked. |
2600 | /// |
2601 | /// @returns A copy of the inner indices vector. |
2602 | /// @returns An empty #dnnl::memory::dims if the memory descriptor |
2603 | /// does not have inner indices. |
2604 | memory::dims get_inner_idxs() const { |
2605 | return query_dims(query::inner_idxs); |
2606 | } |
2607 | |
2608 | /// Returns the format kind of the memory descriptor. |
2609 | /// |
2610 | /// @returns the format kind. |
2611 | memory::format_kind get_format_kind() const { |
2612 | dnnl_format_kind_t format_kind; |
2613 | dnnl_status_t status = dnnl_memory_desc_query( |
2614 | get(), dnnl_query_format_kind, &format_kind); |
2615 | return status == dnnl_success |
2616 | ? static_cast<dnnl::memory::format_kind>(format_kind) |
2617 | : dnnl::memory::format_kind::undef; |
2618 | } |
2619 | |
2620 | /// Returns the data type of the memory descriptor. |
2621 | /// |
2622 | /// @returns The data type. |
2623 | memory::data_type get_data_type() const { |
2624 | dnnl_data_type_t data_type; |
2625 | dnnl_status_t status = dnnl_memory_desc_query( |
2626 | get(), dnnl_query_data_type, &data_type); |
2627 | return status == dnnl_success |
2628 | ? static_cast<dnnl::memory::data_type>(data_type) |
2629 | : dnnl::memory::data_type::undef; |
2630 | } |
2631 | |
2632 | /// Returns dimensions of the memory descriptor. |
2633 | /// |
2634 | /// Potentially expensive due to the data copy involved. |
2635 | /// @returns A copy of the dimensions vector. |
2636 | memory::dims get_dims() const { return query_dims(query::dims); } |
2637 | |
2638 | /// Returns size of the memory descriptor in bytes. |
2639 | /// @returns The number of bytes required to allocate a memory buffer |
2640 | /// for the memory object described by this memory descriptor |
2641 | /// including the padding area. |
2642 | size_t get_size() const { return dnnl_memory_desc_get_size(get()); } |
2643 | |
2644 | /// Checks whether the memory descriptor is zero (empty). |
2645 | /// @returns @c true if the memory descriptor describes an empty |
2646 | /// memory and @c false otherwise. |
2647 | bool is_zero() const { return get_ndims() == 0; } |
2648 | |
2649 | /// An equality operator. |
2650 | /// @param other Another memory descriptor. |
2651 | /// @returns Whether this and the other memory descriptors have |
2652 | /// the same format tag, dimensions, strides, blocking, etc. |
2653 | bool operator==(const desc &other) const { |
2654 | return dnnl_memory_desc_equal(get(), other.get()) != 0; |
2655 | } |
2656 | |
2657 | /// An inequality operator. |
2658 | /// @param other Another memory descriptor. |
2659 | /// @returns Whether this and the other memory descriptors describe |
2660 | /// different memory. |
2661 | bool operator!=(const desc &other) const { return !operator==(other); } |
2662 | |
2663 | private: |
2664 | int query_s32(query what) const { |
2665 | int res; |
2666 | dnnl_status_t status = dnnl_memory_desc_query( |
2667 | get(), dnnl::convert_to_c(what), &res); |
2668 | return status == dnnl_success ? res : 0; |
2669 | } |
2670 | |
2671 | memory::dims query_dims(query what) const { |
2672 | dnnl_dims_t *c_dims; |
2673 | dnnl_status_t status = dnnl_memory_desc_query( |
2674 | get(), dnnl::convert_to_c(what), &c_dims); |
2675 | |
2676 | const int ndims |
2677 | = (what == query::inner_idxs || what == query::inner_blks) |
2678 | ? get_inner_nblks() |
2679 | : get_ndims(); |
2680 | |
2681 | return status == dnnl_success |
2682 | ? memory::dims(*c_dims, *c_dims + ndims) |
2683 | : memory::dims {}; |
2684 | } |
2685 | }; |
2686 | |
2687 | /// Default constructor. |
2688 | /// |
2689 | /// Constructs an empty memory object, which can be used to indicate |
2690 | /// absence of a parameter. |
2691 | memory() = default; |
2692 | |
2693 | /// Constructs a memory object. |
2694 | /// |
2695 | /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory |
2696 | /// object will have the underlying buffer set. In this case, the buffer |
2697 | /// will be initialized as if #dnnl::memory::set_data_handle() had been |
2698 | /// called. |
2699 | /// |
2700 | /// @sa memory::set_data_handle() |
2701 | /// |
2702 | /// @param md Memory descriptor. |
2703 | /// @param aengine Engine to store the data on. |
2704 | /// @param handle Handle of the memory buffer to use. |
2705 | /// - A pointer to the user-allocated buffer. In this case the library |
2706 | /// doesn't own the buffer. |
2707 | /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to |
2708 | /// allocate the buffer for the memory object. In this case the |
2709 | /// library owns the buffer. |
2710 | /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying |
2711 | /// buffer. |
2712 | memory(const desc &md, const engine &aengine, void *handle) { |
2713 | dnnl_memory_t result; |
2714 | error::wrap_c_api( |
2715 | dnnl_memory_create(&result, md.get(), aengine.get(), handle), |
2716 | "could not create a memory object" ); |
2717 | reset(result); |
2718 | } |
2719 | |
2720 | /// Constructs a memory object. |
2721 | /// |
2722 | /// The underlying buffer for the memory will be allocated by the library. |
2723 | /// |
2724 | /// @param md Memory descriptor. |
2725 | /// @param aengine Engine to store the data on. |
2726 | memory(const desc &md, const engine &aengine) |
2727 | : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {} |
2728 | |
2729 | /// Returns the associated memory descriptor. |
2730 | desc get_desc() const { |
2731 | const_dnnl_memory_desc_t cdesc; |
2732 | error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc), |
2733 | "could not get a memory descriptor from a memory object" ); |
2734 | dnnl_memory_desc_t cloned_md = nullptr; |
2735 | error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), |
2736 | "could not clone a memory descriptor" ); |
2737 | return desc(cloned_md); |
2738 | } |
2739 | |
2740 | /// Returns the associated engine. |
2741 | engine get_engine() const { |
2742 | dnnl_engine_t c_engine; |
2743 | error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine), |
2744 | "could not get an engine from a memory object" ); |
2745 | return engine(c_engine, true); |
2746 | } |
2747 | |
2748 | /// Returns the underlying memory buffer. |
2749 | /// |
2750 | /// On the CPU engine, or when using USM, this is a pointer to the |
2751 | /// allocated memory. |
2752 | void *get_data_handle() const { |
2753 | void *handle; |
2754 | error::wrap_c_api(dnnl_memory_get_data_handle(get(), &handle), |
2755 | "could not get a native handle from a memory object" ); |
2756 | return handle; |
2757 | } |
2758 | |
2759 | /// Sets the underlying memory buffer. |
2760 | /// |
2761 | /// @param handle Memory buffer to use. On the CPU engine or when USM is |
2762 | /// used, the memory buffer is a pointer to the actual data. For OpenCL |
2763 | /// it is a cl_mem. It must have at least |
2764 | /// #dnnl::memory::desc::get_size() bytes allocated. |
2765 | void set_data_handle(void *handle) const { |
2766 | error::wrap_c_api(dnnl_memory_set_data_handle(get(), handle), |
2767 | "could not set native handle of a memory object" ); |
2768 | } |
2769 | |
2770 | /// Maps a memory object and returns a host-side pointer to a memory |
2771 | /// buffer with a copy of its contents. |
2772 | /// |
2773 | /// Mapping enables read/write directly from/to the memory contents for |
2774 | /// engines that do not support direct memory access. |
2775 | /// |
2776 | /// Mapping is an exclusive operation - a memory object cannot be used in |
2777 | /// other operations until it is unmapped via #dnnl::memory::unmap_data() |
2778 | /// call. |
2779 | /// |
2780 | /// @note |
2781 | /// Any primitives working with the memory should be completed before |
2782 | /// the memory is mapped. Use #dnnl::stream::wait() to synchronize the |
2783 | /// corresponding execution stream. |
2784 | /// |
2785 | /// @note |
2786 | /// The map_data and unmap_data functions are provided mainly for |
2787 | /// debug and testing purposes and their performance may be suboptimal. |
2788 | /// |
2789 | /// @tparam T Data type to return a pointer to. |
2790 | /// @returns Pointer to the mapped memory. |
2791 | template <typename T = void> |
2792 | T *map_data() const { |
2793 | void *mapped_ptr; |
2794 | error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr), |
2795 | "could not map memory object data" ); |
2796 | return static_cast<T *>(mapped_ptr); |
2797 | } |
2798 | |
2799 | /// Unmaps a memory object and writes back any changes made to the |
2800 | /// previously mapped memory buffer. |
2801 | /// |
2802 | /// @note |
2803 | /// The map_data and unmap_data functions are provided mainly for |
2804 | /// debug and testing purposes and their performance may be |
2805 | /// suboptimal. |
2806 | /// |
2807 | /// @param mapped_ptr A pointer previously returned by |
2808 | /// #dnnl::memory::map_data(). |
2809 | void unmap_data(void *mapped_ptr) const { |
2810 | error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr), |
2811 | "could not unmap memory object data" ); |
2812 | } |
2813 | |
2814 | static dnnl_data_type_t convert_to_c(data_type adata_type) { |
2815 | return static_cast<dnnl_data_type_t>(adata_type); |
2816 | } |
2817 | static dnnl_format_tag_t convert_to_c(format_tag format) { |
2818 | return static_cast<dnnl_format_tag_t>(format); |
2819 | } |
2820 | }; |
2821 | |
2822 | inline bool operator==(dnnl_data_type_t a, memory::data_type b) { |
2823 | return a == memory::convert_to_c(b); |
2824 | } |
2825 | inline bool operator!=(dnnl_data_type_t a, memory::data_type b) { |
2826 | return !(a == b); |
2827 | } |
2828 | inline bool operator==(memory::data_type a, dnnl_data_type_t b) { |
2829 | return b == a; |
2830 | } |
2831 | inline bool operator!=(memory::data_type a, dnnl_data_type_t b) { |
2832 | return !(a == b); |
2833 | } |
2834 | |
2835 | inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) { |
2836 | return a == memory::convert_to_c(b); |
2837 | } |
2838 | inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) { |
2839 | return !(a == b); |
2840 | } |
2841 | inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) { |
2842 | return b == a; |
2843 | } |
2844 | inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) { |
2845 | return !(a == b); |
2846 | } |
2847 | |
2848 | /// @} dnnl_api_memory |
2849 | |
2850 | /// @addtogroup dnnl_api_primitives |
2851 | /// @{ |
2852 | /// @addtogroup dnnl_api_attributes Attributes |
2853 | /// |
2854 | /// A container for parameters that extend primitives behavior. |
2855 | /// |
2856 | /// @{ |
2857 | |
2858 | /// @cond DO_NOT_DOCUMENT_THIS |
2859 | template <> |
2860 | struct handle_traits<dnnl_post_ops_t> { |
2861 | static dnnl_status_t destructor(dnnl_post_ops_t p) { |
2862 | return dnnl_post_ops_destroy(p); |
2863 | } |
2864 | }; |
2865 | /// @endcond |
2866 | |
2867 | /// Post-ops. |
2868 | /// |
2869 | /// Post-ops are computations executed after the main primitive computations |
2870 | /// and are attached to the primitive via primitive attributes. |
2871 | /// |
2872 | /// @sa @ref dev_guide_attributes_post_ops |
2873 | /// |
2874 | struct post_ops : public handle<dnnl_post_ops_t> { |
2875 | using handle<dnnl_post_ops_t>::handle; |
2876 | |
2877 | /// Constructs an empty sequence of post-ops. |
2878 | post_ops() { |
2879 | dnnl_post_ops_t result; |
2880 | error::wrap_c_api( |
2881 | dnnl_post_ops_create(&result), "could not create post-ops" ); |
2882 | reset(result); |
2883 | } |
2884 | |
2885 | /// Creates post-ops primitive attribute from a C API ::dnnl_post_ops_t |
2886 | /// handle. The resulting handle is not weak and the C handle will be |
2887 | /// destroyed during the destruction of the C++ object. |
2888 | /// |
2889 | /// @param post_ops The C API post-ops primitive attribute. |
2890 | post_ops(dnnl_post_ops_t post_ops) : handle<dnnl_post_ops_t>(post_ops) {} |
2891 | |
2892 | /// Returns the number of post-ops entries. |
2893 | int len() const { return dnnl_post_ops_len(get()); } |
2894 | |
2895 | /// Returns the primitive kind of post-op at entry with a certain index. |
2896 | /// @param index Index of the post-op to return the kind for. |
2897 | /// @returns Primitive kind of the post-op at the specified index. |
2898 | primitive::kind kind(int index) const { |
2899 | error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments, |
2900 | "post-ops index is out of range" ); |
2901 | return static_cast<primitive::kind>( |
2902 | dnnl_post_ops_get_kind(get(), index)); |
2903 | } |
2904 | |
2905 | /// Appends an accumulation (sum) post-op. Prior to accumulating the |
2906 | /// result, the previous value will be will be reduced by zero point |
2907 | /// @p zero_point and multiplied by a scaling factor @p scale. |
2908 | /// |
2909 | /// The kind of this post-op is #dnnl::primitive::kind::sum. |
2910 | /// |
2911 | /// This feature may improve performance for cases like dequantize the |
2912 | /// asymmetrically quantized sum's src1 tensor to f32 domain before |
2913 | /// performing the sum operation by subtracting @p zero_point before the |
2914 | /// scaling. |
2915 | /// |
2916 | /// In the simplest case when the accumulation is the only post-op, |
2917 | /// the computations will be `dst[:] := scale * (dst[:] - zero_point) + |
2918 | /// op(...)` instead of `dst[:] := op(...)`. |
2919 | /// |
2920 | /// If @p data_type is specified, the original dst tensor will be |
2921 | /// reinterpreted as a tensor with the provided data type. Because it is a |
2922 | /// reinterpretation, data_type and dst data type should have the same size. |
2923 | /// As a result, computations will be `dst[:] <- scale * |
2924 | /// (as_data_type(dst[:]) - zero_point) + op(...)` instead of |
2925 | /// `dst[:] <- op(...)`. |
2926 | /// |
2927 | /// @note |
2928 | /// This post-op executes in-place and does not change the |
2929 | /// destination layout. |
2930 | /// |
2931 | /// @param scale Scaling factor. |
2932 | /// @param zero_point Zero point. |
2933 | /// @param data_type Data type. |
2934 | void append_sum(float scale = 1.f, int32_t zero_point = 0, |
2935 | memory::data_type data_type = memory::data_type::undef) { |
2936 | error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale, zero_point, |
2937 | memory::convert_to_c(data_type)), |
2938 | "could not append a sum post-op" ); |
2939 | } |
2940 | |
2941 | /// Returns the parameters of an accumulation (sum) post-op. |
2942 | /// |
2943 | /// @param index Index of the sum post-op. |
2944 | /// @param scale Scaling factor of the sum post-op. |
2945 | void get_params_sum(int index, float &scale) const { |
2946 | error::wrap_c_api(dnnl_post_ops_get_params_sum( |
2947 | get(), index, &scale, nullptr, nullptr), |
2948 | "could not get parameters of a sum post-op" ); |
2949 | } |
2950 | |
2951 | /// Returns the parameters of an accumulation (sum) post-op. |
2952 | /// |
2953 | /// @param index Index of the sum post-op. |
2954 | /// @param scale Scaling factor of the sum post-op. |
2955 | /// @param data_type Data type of the sum post-op. |
2956 | void get_params_sum( |
2957 | int index, float &scale, memory::data_type &data_type) const { |
2958 | dnnl_data_type_t c_data_type; |
2959 | error::wrap_c_api(dnnl_post_ops_get_params_sum( |
2960 | get(), index, &scale, nullptr, &c_data_type), |
2961 | "could not get parameters of a sum post-op" ); |
2962 | data_type = static_cast<memory::data_type>(c_data_type); |
2963 | } |
2964 | |
2965 | /// Returns the parameters of an accumulation (sum) post-op. |
2966 | /// |
2967 | /// @param index Index of the sum post-op. |
2968 | /// @param scale Scaling factor of the sum post-op. |
2969 | /// @param zero_point Single scalar int32_t value of zeropoint. |
2970 | /// @param data_type Data type of the sum post-op. |
2971 | void get_params_sum(int index, float &scale, int32_t &zero_point, |
2972 | memory::data_type &data_type) const { |
2973 | dnnl_data_type_t c_data_type; |
2974 | error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale, |
2975 | &zero_point, &c_data_type), |
2976 | "could not get parameters of a sum post-op" ); |
2977 | data_type = static_cast<memory::data_type>(c_data_type); |
2978 | } |
2979 | |
2980 | /// Appends an elementwise post-op. |
2981 | /// |
2982 | /// The kind of this post-op is #dnnl::primitive::kind::eltwise. |
2983 | /// |
2984 | /// In the simplest case when the elementwise is the only post-op, the |
2985 | /// computations would be `dst[:] := eltwise_op (op(...))` instead |
2986 | /// of `dst[:] <- op(...)`, where eltwise_op is configured with the given |
2987 | /// parameters. |
2988 | /// |
2989 | /// @param aalgorithm Elementwise algorithm. |
2990 | /// @param alpha Alpha parameter for the elementwise algorithm. |
2991 | /// @param beta Beta parameter for the elementwise algorithm. |
2992 | void append_eltwise(algorithm aalgorithm, float alpha, float beta) { |
2993 | error::wrap_c_api(dnnl_post_ops_append_eltwise( |
2994 | get(), convert_to_c(aalgorithm), alpha, beta), |
2995 | "could not append an elementwise post-op" ); |
2996 | } |
2997 | |
2998 | /// Returns parameters of an elementwise post-op. |
2999 | /// |
3000 | /// @param index Index of the post-op. |
3001 | /// @param aalgorithm Output elementwise algorithm kind. |
3002 | /// @param alpha Output alpha parameter for the elementwise algorithm. |
3003 | /// @param beta Output beta parameter for the elementwise algorithm. |
3004 | void get_params_eltwise( |
3005 | int index, algorithm &aalgorithm, float &alpha, float &beta) const { |
3006 | dnnl_alg_kind_t c_alg; |
3007 | error::wrap_c_api(dnnl_post_ops_get_params_eltwise( |
3008 | get(), index, &c_alg, &alpha, &beta), |
3009 | "could not get parameters of an elementwise post-op" ); |
3010 | aalgorithm = static_cast<dnnl::algorithm>(c_alg); |
3011 | } |
3012 | |
3013 | /// Appends a depthwise post-op convolution. |
3014 | /// |
3015 | /// This post-op can only be fused with a 2D 1x1 convolution (convolution |
3016 | /// with weights spatial dimension equal to 1 i.e., kh=kw=1). |
3017 | /// |
3018 | /// The kind of this post-op is #dnnl_convolution. |
3019 | /// |
3020 | /// The number of outputs for primitive remain same as before. The output |
3021 | /// spatial size can be derived as below: |
3022 | /// |
3023 | /// output_height = ceil(output_height_1x1_convolution, stride) |
3024 | /// output_width = ceil(output_width_1x1_convolution, stride) |
3025 | /// |
3026 | /// See @ref dev_guide_attributes_post_ops_depthwise and |
3027 | /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info. |
3028 | /// |
3029 | /// @param weights_data_type Weights data type of depthwise post-op |
3030 | /// @param bias_data_type Bias data type of depthwise post-op |
3031 | /// @param dst_data_type Output data type of depthwise post-op |
3032 | /// @param kernel_size Size of kernel of depthwise post-op |
3033 | /// @param stride_size Size of stride of depthwise post-op |
3034 | /// @param padding_l_size Size of left and top paddings of depthwise post-op |
3035 | void append_dw(memory::data_type weights_data_type, |
3036 | memory::data_type bias_data_type, memory::data_type dst_data_type, |
3037 | memory::dim kernel_size, memory::dim stride_size, |
3038 | memory::dim padding_l_size) { |
3039 | |
3040 | error::wrap_c_api(dnnl_post_ops_append_dw(get(), |
3041 | memory::convert_to_c(weights_data_type), |
3042 | memory::convert_to_c(bias_data_type), |
3043 | memory::convert_to_c(dst_data_type), |
3044 | kernel_size, stride_size, padding_l_size), |
3045 | "could not append depthwise post-op" ); |
3046 | } |
3047 | |
3048 | /// Returns the parameters of an depthwise post-op. |
3049 | /// |
3050 | /// @param index Index of the elementwise post-op. |
3051 | /// @param weights_data_type Weights data type of depthwise post-op |
3052 | /// @param bias_data_type Bias data type of depthwise post-op |
3053 | /// @param dst_data_type Output data type of depthwise post-op |
3054 | /// @param kernel_size Size of kernel of depthwise post-op |
3055 | /// @param stride_size Size of stride of depthwise post-op |
3056 | /// @param padding_l_size Size of left and top paddings of depthwise post-op |
3057 | void get_params_dw(int index, memory::data_type &weights_data_type, |
3058 | memory::data_type &bias_data_type, memory::data_type &dst_data_type, |
3059 | memory::dim &kernel_size, memory::dim &stride_size, |
3060 | memory::dim &padding_l_size) const { |
3061 | |
3062 | dnnl_data_type_t c_weights_data_type; |
3063 | dnnl_data_type_t c_bias_data_type; |
3064 | dnnl_data_type_t c_dst_data_type; |
3065 | dnnl_dim_t c_kernel_size; |
3066 | dnnl_dim_t c_stride_size; |
3067 | dnnl_dim_t c_padding_l_size; |
3068 | error::wrap_c_api( |
3069 | dnnl_post_ops_get_params_dw(get(), index, &c_weights_data_type, |
3070 | &c_bias_data_type, &c_dst_data_type, &c_kernel_size, |
3071 | &c_stride_size, &c_padding_l_size), |
3072 | "could not get parameters of depthwise post-op" ); |
3073 | |
3074 | weights_data_type = static_cast<memory::data_type>(c_weights_data_type); |
3075 | bias_data_type = static_cast<memory::data_type>(c_bias_data_type); |
3076 | dst_data_type = static_cast<memory::data_type>(c_dst_data_type); |
3077 | kernel_size = c_kernel_size; |
3078 | stride_size = c_stride_size; |
3079 | padding_l_size = c_padding_l_size; |
3080 | } |
3081 | |
3082 | /// Appends a binary post-op. |
3083 | /// |
3084 | /// The kind of this post operation is #dnnl_binary. |
3085 | /// |
3086 | /// In the simplest case when the binary is the only post operation, the |
3087 | /// computations would be: |
3088 | /// |
3089 | /// dst[:] <- binary_op (dst[:], another_input[:]) |
3090 | /// |
3091 | /// where binary_op is configured with the given parameters. binary_op |
3092 | /// supports broadcast semantics for a second operand. |
3093 | /// |
3094 | /// @param aalgorithm Binary algorithm for the post-op. |
3095 | /// @param src1_desc Memory descriptor of a second operand. |
3096 | void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) { |
3097 | error::wrap_c_api(dnnl_post_ops_append_binary(get(), |
3098 | convert_to_c(aalgorithm), src1_desc.get()), |
3099 | "could not append a binary post-op" ); |
3100 | } |
3101 | |
3102 | /// Returns the parameters of a binary post-op. |
3103 | /// |
3104 | /// @param index Index of the binary post-op. |
3105 | /// @param aalgorithm Output binary algorithm kind. |
3106 | /// @param src1_desc Output memory descriptor of a second operand. |
3107 | void get_params_binary( |
3108 | int index, algorithm &aalgorithm, memory::desc &src1_desc) const { |
3109 | dnnl_alg_kind_t c_alg; |
3110 | const_dnnl_memory_desc_t cdesc; |
3111 | error::wrap_c_api( |
3112 | dnnl_post_ops_get_params_binary(get(), index, &c_alg, &cdesc), |
3113 | "could not get parameters of a binary post-op" ); |
3114 | aalgorithm = static_cast<dnnl::algorithm>(c_alg); |
3115 | dnnl_memory_desc_t cloned_md = nullptr; |
3116 | error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), |
3117 | "could not clone a memory descriptor" ); |
3118 | src1_desc = memory::desc(cloned_md); |
3119 | } |
3120 | |
3121 | /// Appends a prelu forward post-op. |
3122 | /// |
3123 | /// The kind of this post-op is #dnnl::primitive::kind::prelu. |
3124 | /// |
3125 | /// The post-op can be defined as: |
3126 | /// |
3127 | /// dst[:] <- prelu(dst[:], weights[:]) |
3128 | /// prelu: |
3129 | /// dst[:] <- dst[:] if dst[:] > 0 |
3130 | /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0 |
3131 | /// |
3132 | /// |
3133 | /// Example usage: |
3134 | /// @code |
3135 | /// int mb = 32, oc = 32, |
3136 | /// oh = 14, ow = 14; // convolution output params |
3137 | /// // unique weights per output channel |
3138 | /// vector<float> weights = { ... }; |
3139 | /// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... |
3140 | /// |
3141 | /// // construct a convolution descriptor |
3142 | /// dnnl::convolution::desc conv_d; |
3143 | /// |
3144 | /// dnnl::primitive_attr attr; |
3145 | /// attr.append_prelu(1 << oc_dim); |
3146 | /// |
3147 | /// dnnl::primitive_desc conv_pd(conv_d, attr, engine); |
3148 | /// memory prelu_weights({{1}, dt::f32, {1}}, eng, weights.data()); |
3149 | /// |
3150 | /// std::unordered_map<int, memory> conv_args; |
3151 | /// |
3152 | /// conv_args.insert( |
3153 | /// {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_WEIGHTS, prelu_weights}) |
3154 | /// @endcode |
3155 | /// |
3156 | /// @note |
3157 | /// The order of dimensions does not depend on how elements are laid |
3158 | /// out in memory. For example: |
3159 | /// - for a 2D CNN activations tensor the order is always (n, c) |
3160 | /// - for a 4D CNN activations tensor the order is always (n, c, h, w) |
3161 | /// - for a 5D CNN weights tensor the order is always |
3162 | /// (g, oc, ic, kh, kw) |
3163 | /// |
3164 | /// Prelu weights tensor is passed in runtime execution phase. Prelu |
3165 | /// weights tensor data type is implicitly assumed as f32 using plain |
3166 | /// layout (a, ab, acb, acdb, acdeb). |
3167 | /// |
3168 | /// @param mask Defines the correspondence between the output tensor |
3169 | /// dimensions and the prelu weights tensor. The set i-th bit indicates |
3170 | /// that a dedicated weights value is used for each index along that |
3171 | /// dimension. Set the mask to 0 to use a common weights value |
3172 | /// for the whole output tensor. |
3173 | void append_prelu(int mask) { |
3174 | error::wrap_c_api(dnnl_post_ops_append_prelu(get(), mask), |
3175 | "could not append a prelu post-op" ); |
3176 | } |
3177 | |
3178 | /// Returns the parameters of a prelu post-op. |
3179 | /// |
3180 | /// @param index Index of the prelu post-op. |
3181 | /// @param mask Weights mask of prelu post-op. |
3182 | void get_params_prelu(int index, int &mask) const { |
3183 | error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask), |
3184 | "could not get parameters of a binary post-op" ); |
3185 | } |
3186 | }; |
3187 | |
3188 | /// @cond DO_NOT_DOCUMENT_THIS |
3189 | template <> |
3190 | struct handle_traits<dnnl_primitive_attr_t> { |
3191 | static dnnl_status_t destructor(dnnl_primitive_attr_t p) { |
3192 | return dnnl_primitive_attr_destroy(p); |
3193 | } |
3194 | }; |
3195 | /// @endcond |
3196 | |
3197 | /// Primitive attributes. |
3198 | /// |
3199 | /// @sa @ref dev_guide_attributes |
3200 | struct primitive_attr : public handle<dnnl_primitive_attr_t> { |
3201 | using handle<dnnl_primitive_attr_t>::handle; |
3202 | |
3203 | /// Constructs default (empty) primitive attributes. |
3204 | primitive_attr() { |
3205 | dnnl_primitive_attr_t result; |
3206 | error::wrap_c_api(dnnl_primitive_attr_create(&result), |
3207 | "could not create primitive attribute" ); |
3208 | reset(result); |
3209 | } |
3210 | |
3211 | /// Creates primitive attributes from a C API ::dnnl_primitive_attr_t |
3212 | /// handle. The resulting handle is not weak and the C handle will be |
3213 | /// destroyed during the destruction of the C++ object. |
3214 | /// |
3215 | /// @param attr The C API primitive attributes. |
3216 | primitive_attr(dnnl_primitive_attr_t attr) |
3217 | : handle<dnnl_primitive_attr_t>(attr) {} |
3218 | |
3219 | /// Returns the fpmath mode |
3220 | fpmath_mode get_fpmath_mode() const { |
3221 | dnnl_fpmath_mode_t result; |
3222 | error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode(get(), &result), |
3223 | "could not get fpmath mode primitive attribute" ); |
3224 | return fpmath_mode(result); |
3225 | } |
3226 | |
3227 | /// Sets fpmath mode. |
3228 | /// |
3229 | /// @param mode Specified fpmath mode. |
3230 | void set_fpmath_mode(fpmath_mode mode) { |
3231 | error::wrap_c_api(dnnl_primitive_attr_set_fpmath_mode( |
3232 | get(), dnnl::convert_to_c(mode)), |
3233 | "could not set fpmath mode primitive attribute" ); |
3234 | } |
3235 | |
3236 | /// Returns the scratchpad mode. |
3237 | scratchpad_mode get_scratchpad_mode() const { |
3238 | dnnl_scratchpad_mode_t result; |
3239 | error::wrap_c_api( |
3240 | dnnl_primitive_attr_get_scratchpad_mode(get(), &result), |
3241 | "could not get scratchpad mode primitive attribute" ); |
3242 | return scratchpad_mode(result); |
3243 | } |
3244 | |
3245 | /// Sets scratchpad mode. |
3246 | /// |
3247 | /// @param mode Specified scratchpad mode. |
3248 | void set_scratchpad_mode(scratchpad_mode mode) { |
3249 | error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode( |
3250 | get(), dnnl::convert_to_c(mode)), |
3251 | "could not set scratchpad mode primitive attribute" ); |
3252 | } |
3253 | |
3254 | /// Sets scaling factors for primitive operations for a given memory |
3255 | /// argument. The scaling factors must be passed at execution time |
3256 | /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg. |
3257 | /// |
3258 | /// @sa dnnl_primitive_attr_set_scales_mask |
3259 | /// |
3260 | /// @param arg Parameter argument index as passed to the |
3261 | /// primitive::execute() call. |
3262 | /// @param mask Scaling factors correspondence mask that defines the |
3263 | /// correspondence between the tensor dimensions and the @p scales |
3264 | /// vector. The set i-th bit indicates that a dedicated scaling factor |
3265 | /// is used for each index along that dimension. Set the mask to 0 to |
3266 | /// use a common scaling factor for the whole output tensor. |
3267 | void set_scales_mask(int arg, int mask) { |
3268 | error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask), |
3269 | "could not set scales primitive attribute" ); |
3270 | } |
3271 | |
3272 | /// Sets zero points for primitive operations for a given memory argument. |
3273 | /// The zero points must be passed at execution time as an argument with |
3274 | /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg. |
3275 | /// |
3276 | /// @sa dnnl_primitive_attr_set_zero_points_mask |
3277 | /// |
3278 | /// @param arg Parameter argument index as passed to the |
3279 | /// primitive::execute() call. |
3280 | /// @param mask Zero point correspondence mask that defines the |
3281 | /// correspondence between the tensor dimensions and the @p |
3282 | /// zero_points vector. The set i-th bit indicates that a dedicated |
3283 | /// zero point is used for each index along that dimension. Set the |
3284 | /// mask to 0 to use a common zero point for the whole output tensor. |
3285 | void set_zero_points_mask(int arg, int mask) { |
3286 | error::wrap_c_api( |
3287 | dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask), |
3288 | "could not set zero points primitive attribute" ); |
3289 | } |
3290 | |
3291 | /// Returns post-ops previously set via set_post_ops(). |
3292 | /// |
3293 | /// @returns Post-ops. |
3294 | const post_ops get_post_ops() const { |
3295 | const_dnnl_post_ops_t const_c_post_ops; |
3296 | error::wrap_c_api( |
3297 | dnnl_primitive_attr_get_post_ops(get(), &const_c_post_ops), |
3298 | "could not get post-ops primitive attribute" ); |
3299 | dnnl_post_ops_t c_post_ops; |
3300 | error::wrap_c_api(dnnl_post_ops_clone(&c_post_ops, const_c_post_ops), |
3301 | "could not clone post-ops primitive attribute" ); |
3302 | return post_ops(c_post_ops); |
3303 | } |
3304 | |
3305 | /// Sets post-ops. |
3306 | /// |
3307 | /// @note |
3308 | /// There is no way to check whether the post-ops would be supported |
3309 | /// by the target primitive. Any error will be reported |
3310 | /// by the respective primitive descriptor constructor. |
3311 | /// |
3312 | /// @param ops Post-ops object to copy post-ops from. |
3313 | void set_post_ops(const post_ops ops) { |
3314 | error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()), |
3315 | "could not set post-ops primitive attribute" ); |
3316 | } |
3317 | |
3318 | /// Sets quantization scale and shift parameters for RNN data tensors. |
3319 | /// |
3320 | /// For performance reasons, the low-precision configuration of the RNN |
3321 | /// primitives expect input activations to have the unsigned 8-bit integer |
3322 | /// data type. The scale and shift parameters are used to quantize |
3323 | /// floating-point data to unsigned integer and must be passed to the RNN |
3324 | /// primitive using attributes. |
3325 | /// |
3326 | /// The quantization formula is `scale * data + shift`. |
3327 | /// |
3328 | /// Example usage: |
3329 | /// @code |
3330 | /// // RNN parameters |
3331 | /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; |
3332 | /// // Activations quantization parameters |
3333 | /// float scale = 63.f, shift = 64.f; |
3334 | /// |
3335 | /// primitive_attr attr; |
3336 | /// |
3337 | /// // Set scale and shift for int8 quantization of activation |
3338 | /// attr.set_rnn_data_qparams(scale, shift); |
3339 | /// |
3340 | /// // Create an RNN primitive descriptor. |
3341 | /// vanilla_rnn_forward::primitive_desc rnn_d( |
3342 | /// engine, /* arguments */, attr); |
3343 | /// @endcode |
3344 | /// |
3345 | /// @note |
3346 | /// Quantization scale and shift are common for src_layer, src_iter, |
3347 | /// dst_iter, and dst_layer. |
3348 | /// |
3349 | /// @param scale The value to scale the data by. |
3350 | /// @param shift The value to shift the data by. |
3351 | void set_rnn_data_qparams(float scale, float shift) { |
3352 | error::wrap_c_api( |
3353 | dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift), |
3354 | "could not set RNN data quantization parameters primitive " |
3355 | "attribute" ); |
3356 | } |
3357 | |
3358 | /// Returns the quantization scale and shift parameters for RNN data |
3359 | /// tensors. |
3360 | /// |
3361 | /// @note |
3362 | /// Quantization scale and shift are common for src_layer, src_iter, |
3363 | /// dst_iter, and dst_layer. |
3364 | /// |
3365 | /// @param scale The value to scale the data by. |
3366 | /// @param shift The value to shift the data by. |
3367 | void get_rnn_data_qparams(float &scale, float &shift) { |
3368 | float c_scale, c_shift; |
3369 | error::wrap_c_api(dnnl_primitive_attr_get_rnn_data_qparams( |
3370 | get(), &c_scale, &c_shift), |
3371 | "could not set RNN data quantization parameters primitive " |
3372 | "attribute" ); |
3373 | scale = c_scale; |
3374 | shift = c_shift; |
3375 | } |
3376 | |
3377 | /// Sets quantization scaling factors for RNN weights tensors. The |
3378 | /// low-precision configuration of the RNN primitives expect input weights |
3379 | /// to use the signed 8-bit integer data type. The scaling factors are |
3380 | /// used to quantize floating-point data to signed integer and must be |
3381 | /// passed to RNN primitives using attributes. |
3382 | /// |
3383 | /// @note |
3384 | /// The dimension order is always native and does not depend on the |
3385 | /// actual layout used. For example, five-dimensional weights always |
3386 | /// have (l, d, i, g, o) logical dimension ordering. |
3387 | /// |
3388 | /// @note |
3389 | /// Quantization scales are common for weights_layer and |
3390 | /// weights_iteration |
3391 | /// |
3392 | /// @param mask Scaling factors correspondence mask that defines the |
3393 | /// correspondence between the output tensor dimensions and the @p |
3394 | /// scales vector. The set i-th bit indicates that a dedicated scaling |
3395 | /// factor should be used each index along that dimension. Set the |
3396 | /// mask to 0 to use a common scaling factor for the whole output |
3397 | /// tensor. |
3398 | /// @param scales Constant vector of output scaling factors. The following |
3399 | /// equality must hold: |
3400 | /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$ |
3401 | /// Violations can only be detected when the attributes are used to |
3402 | /// create a primitive descriptor. |
3403 | void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) { |
3404 | error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(), |
3405 | (int)scales.size(), mask, scales.data()), |
3406 | "could not set RNN weights quantization parameters primitive " |
3407 | "attribute" ); |
3408 | } |
3409 | |
3410 | /// Returns the quantization scaling factors for RNN projection weights |
3411 | /// tensors. |
3412 | /// |
3413 | /// @note |
3414 | /// The dimension order is always native and does not depend on the |
3415 | /// actual layout used. For example, five-dimensional weights always |
3416 | /// have (l, d, i, g, o) logical dimension ordering. |
3417 | /// |
3418 | /// @param mask Scaling factors correspondence mask that defines the |
3419 | /// correspondence between the output tensor dimensions and the @p |
3420 | /// scales vector. The set i-th bit indicates that a dedicated scaling |
3421 | /// factor should be used each index along that dimension. Set the |
3422 | /// mask to 0 to use a common scaling factor for the whole output |
3423 | /// tensor. |
3424 | /// @param scales Constant vector of output scaling factors. The following |
3425 | /// equality must hold: |
3426 | /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$ |
3427 | /// Violations can only be detected when the attributes are used to |
3428 | /// create a primitive descriptor. |
3429 | void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) { |
3430 | dnnl_dim_t count; |
3431 | int c_mask; |
3432 | const float *c_scales; |
3433 | error::wrap_c_api(dnnl_primitive_attr_get_rnn_weights_qparams( |
3434 | get(), &count, &c_mask, &c_scales), |
3435 | "could not get primitive RNN weights quantization " |
3436 | "parameters attributes" ); |
3437 | scales.resize(count); |
3438 | |
3439 | mask = c_mask; |
3440 | for (dnnl_dim_t c = 0; c < count; c++) |
3441 | scales[c] = c_scales[c]; |
3442 | } |
3443 | |
3444 | /// Sets quantization scaling factors for RNN projection weights tensors. |
3445 | // The low-precision configuration of the RNN primitives expect input |
3446 | // weights to use the signed 8-bit integer data type. The scaling factors |
3447 | // are used to quantize floating-point data to signed integer and must be |
3448 | /// passed to RNN primitives using attributes. |
3449 | /// |
3450 | /// @note |
3451 | /// The dimension order is always native and does not depend on the |
3452 | /// actual layout used. For example, five-dimensional weights always |
3453 | /// have (l, d, i, g, o) logical dimension ordering. |
3454 | /// |
3455 | /// @note |
3456 | /// Quantization scales are common for weights_layer and |
3457 | /// weights_iteration |
3458 | /// |
3459 | /// @param mask Scaling factors correspondence mask that defines the |
3460 | /// correspondence between the output tensor dimensions and the @p |
3461 | /// scales vector. The set i-th bit indicates that a dedicated scaling |
3462 | /// factor should be used each index along that dimension. Set the |
3463 | /// mask to 0 to use a common scaling factor for the whole output |
3464 | /// tensor. |
3465 | /// @param scales Constant vector of output scaling factors. The following |
3466 | /// equality must hold: |
3467 | /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$ |
3468 | /// Violations can only be detected when the attributes are used to |
3469 | /// create a primitive descriptor. |
3470 | void set_rnn_weights_projection_qparams( |
3471 | int mask, const std::vector<float> &scales) { |
3472 | error::wrap_c_api( |
3473 | dnnl_primitive_attr_set_rnn_weights_projection_qparams( |
3474 | get(), (int)scales.size(), mask, scales.data()), |
3475 | "could not set primitive RNN weights projection quantization " |
3476 | "parameters attributes" ); |
3477 | } |
3478 | |
3479 | /// Returns the quantization scaling factors for RNN projection weights |
3480 | /// tensors. |
3481 | /// |
3482 | /// @note |
3483 | /// The dimension order is always native and does not depend on the |
3484 | /// actual layout used. For example, five-dimensional weights always |
3485 | /// have (l, d, i, g, o) logical dimension ordering. |
3486 | /// |
3487 | /// @param mask Scaling factors correspondence mask that defines the |
3488 | /// correspondence between the output tensor dimensions and the @p |
3489 | /// scales vector. The set i-th bit indicates that a dedicated scaling |
3490 | /// factor should be used each index along that dimension. Set the |
3491 | /// mask to 0 to use a common scaling factor for the whole output |
3492 | /// tensor. |
3493 | /// @param scales Constant vector of output scaling factors. The following |
3494 | /// equality must hold: |
3495 | /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$ |
3496 | /// Violations can only be detected when the attributes are used to |
3497 | /// create a primitive descriptor. |
3498 | void get_rnn_weights_projection_qparams( |
3499 | int &mask, std::vector<float> &scales) { |
3500 | dnnl_dim_t count; |
3501 | int c_mask; |
3502 | const float *c_scales; |
3503 | error::wrap_c_api( |
3504 | dnnl_primitive_attr_get_rnn_weights_projection_qparams( |
3505 | get(), &count, &c_mask, &c_scales), |
3506 | "could not get primitive RNN weights projection quantization " |
3507 | "parameters attributes" ); |
3508 | scales.resize(count); |
3509 | |
3510 | mask = c_mask; |
3511 | for (dnnl_dim_t c = 0; c < count; c++) |
3512 | scales[c] = c_scales[c]; |
3513 | } |
3514 | }; |
3515 | |
3516 | /// @} dnnl_api_attributes |
3517 | |
3518 | /// @addtogroup dnnl_api_primitives_common |
3519 | /// @{ |
3520 | |
3521 | /// Base class for all primitive descriptors. |
3522 | struct primitive_desc_base : public handle<dnnl_primitive_desc_t> { |
3523 | using handle<dnnl_primitive_desc_t>::handle; |
3524 | |
3525 | /// Default constructor. Produces an empty object. |
3526 | primitive_desc_base() = default; |
3527 | |
3528 | /// Returns the engine of the primitive descriptor. |
3529 | /// @returns The engine of the primitive descriptor. |
3530 | engine get_engine() const { return query_engine(query::engine); } |
3531 | |
3532 | /// Returns implementation name. |
3533 | /// @returns The implementation name. |
3534 | const char *impl_info_str() const { |
3535 | const char *res; |
3536 | error::wrap_c_api(dnnl_primitive_desc_query( |
3537 | get(), dnnl_query_impl_info_str, 0, &res), |
3538 | "could not retrieve implementation info string from a " |
3539 | "primitive descriptor" ); |
3540 | return res; |
3541 | } |
3542 | |
3543 | /// Returns a memory::dim value (same as int64_t). |
3544 | /// @param what The value to query. |
3545 | /// @returns The result of the query. |
3546 | memory::dim query_s64(query what) const { |
3547 | memory::dim res; |
3548 | dnnl_status_t status = dnnl_primitive_desc_query( |
3549 | get(), dnnl::convert_to_c(what), 0, &res); |
3550 | return status == dnnl_success ? res : 0; |
3551 | } |
3552 | |
3553 | /// Returns strides. |
3554 | /// @returns Strides. |
3555 | /// @returns An empty #dnnl::memory::dims if the primitive does not have |
3556 | /// a strides parameter. |
3557 | memory::dims get_strides() const { return query_dims(query::strides); } |
3558 | |
3559 | /// Returns dilations. |
3560 | /// @returns Dilations. |
3561 | /// @returns An empty #dnnl::memory::dims if the primitive does not have |
3562 | /// a dilations parameter. |
3563 | memory::dims get_dilations() const { return query_dims(query::dilations); } |
3564 | |
3565 | /// Returns a left padding. |
3566 | /// @returns A left padding. |
3567 | /// @returns An empty #dnnl::memory::dims if the primitive does not have |
3568 | /// a left padding parameter. |
3569 | memory::dims get_padding_l() const { return query_dims(query::padding_l); } |
3570 | |
3571 | /// Returns a right padding. |
3572 | /// @returns A right padding. |
3573 | /// @returns An empty #dnnl::memory::dims if the primitive does not have |
3574 | /// a right padding parameter. |
3575 | memory::dims get_padding_r() const { return query_dims(query::padding_r); } |
3576 | |
3577 | /// Returns an epsilon. |
3578 | /// @returns An epsilon. |
3579 | /// @returns Zero if the primitive does not have an epsilon parameter. |
3580 | float get_epsilon() const { return query_f32(query::epsilon_f32); } |
3581 | |
3582 | /// Returns flags. |
3583 | /// @tparam T Flags enumeration type. |
3584 | /// @returns Flags. |
3585 | /// @returns Zero if the primitive does not have a flags parameter. |
3586 | template <typename T = unsigned> |
3587 | T get_flags() const { |
3588 | unsigned res; |
3589 | dnnl_status_t status |
3590 | = dnnl_primitive_desc_query(get(), dnnl_query_flags, 0, &res); |
3591 | return static_cast<T>(status == dnnl_success ? res : 0x0U); |
3592 | } |
3593 | |
3594 | /// Returns an algorithm kind. |
3595 | /// @returns An algorithm kind. |
3596 | /// @returns #dnnl::algorithm::undef if the primitive does not have an |
3597 | /// algorithm parameter. |
3598 | dnnl::algorithm get_algorithm() const { return query_alg(query::alg_kind); } |
3599 | |
3600 | /// Returns an alpha. |
3601 | /// @returns An alpha. |
3602 | /// @returns Zero if the primitive does not have an alpha parameter. |
3603 | float get_alpha() const { return query_f32(query::alpha_f32); } |
3604 | |
3605 | /// Returns a beta. |
3606 | /// @returns A beta. |
3607 | /// @returns Zero if the primitive does not have a beta parameter. |
3608 | float get_beta() const { return query_f32(query::beta_f32); } |
3609 | |
3610 | /// Returns an axis. |
3611 | /// @returns An axis. |
3612 | /// @returns A negative number if the primitive does not have an axis |
3613 | /// parameter. |
3614 | int get_axis() const { |
3615 | int res; |
3616 | dnnl_status_t status = dnnl_primitive_desc_query( |
3617 | get(), dnnl_query_axis_s32, 0, &res); |
3618 | return status == dnnl_success ? res : -1; |
3619 | } |
3620 | |
3621 | /// Returns an LRN local size parameter. |
3622 | /// @returns An LRN local size parameter. |
3623 | /// @returns Zero if the primitive does not have an LRN local size |
3624 | /// parameter. |
3625 | memory::dim get_local_size() const { |
3626 | return query_s64(query::local_size_s64); |
3627 | } |
3628 | |
3629 | /// Returns an LRN K parameter. |
3630 | /// @returns An LRN K parameter. |
3631 | /// @returns Zero if the primitive does not have an LRN K parameter. |
3632 | float get_k() const { return query_f32(query::k_f32); } |
3633 | |
3634 | /// Returns a reduction P parameter. |
3635 | /// @returns A reduction P parameter. |
3636 | /// @returns Zero if the primitive does not have a reduction P parameter. |
3637 | float get_p() const { return query_f32(query::p_f32); } |
3638 | |
3639 | /// Returns a resampling factors parameters. |
3640 | /// @returns A vector of factors. |
3641 | /// @returns An empty vector if the primitive does not have a resampling |
3642 | /// factors parameter. |
3643 | std::vector<float> get_factors() const { |
3644 | float *factors; |
3645 | dnnl_status_t status = dnnl_primitive_desc_query( |
3646 | get(), dnnl_query_factors, 0, &factors); |
3647 | |
3648 | const bool is_backward = get_prop_kind() != prop_kind::forward_training |
3649 | && get_prop_kind() != prop_kind::forward_inference; |
3650 | const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(), |
3651 | is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0); |
3652 | |
3653 | int ndims; |
3654 | error::wrap_c_api( |
3655 | dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims), |
3656 | "could not query ndims from a memory descriptor" ); |
3657 | |
3658 | return status == dnnl_success |
3659 | ? std::vector<float>(factors, factors + (ndims - 2)) |
3660 | : std::vector<float> {}; |
3661 | } |
3662 | |
3663 | /// Returns an RNN cell kind parameter. |
3664 | /// @returns An RNN cell kind parameter. |
3665 | /// @returns #dnnl::algorithm::undef if the primitive does not have an |
3666 | /// RNN cell kind parameter. |
3667 | dnnl::algorithm get_cell_kind() const { |
3668 | return query_alg(query::cell_kind); |
3669 | } |
3670 | |
3671 | /// Returns an RNN direction parameter. |
3672 | /// @returns An RNN direction parameter. |
3673 | /// @returns #dnnl::rnn_direction::undef if the primitive does not have |
3674 | /// an RNN direction parameter. |
3675 | dnnl::rnn_direction get_direction() const { |
3676 | dnnl_rnn_direction_t direction; |
3677 | dnnl_status_t status = dnnl_primitive_desc_query( |
3678 | get(), dnnl_query_direction, 0, &direction); |
3679 | return status == dnnl_success |
3680 | ? static_cast<dnnl::rnn_direction>(direction) |
3681 | : dnnl::rnn_direction::undef; |
3682 | } |
3683 | |
3684 | /// Returns an RNN activation kind parameter. |
3685 | /// @returns An RNN activation kind parameter. |
3686 | /// @returns #dnnl::algorithm::undef if the primitive does not have an |
3687 | /// RNN activation kind parameter. |
3688 | dnnl::algorithm get_activation_kind() const { |
3689 | return query_alg(query::activation_kind); |
3690 | } |
3691 | |
3692 | /// Returns a pooling kernel parameter. |
3693 | /// @returns A pooling kernel parameter. |
3694 | /// @returns An empty #dnnl::memory::dims if the primitive does not have |
3695 | /// a pooling kernel parameter. |
3696 | memory::dims get_kernel() const { return query_dims(query::kernel); } |
3697 | |
3698 | /// Returns a shuffle group size parameter. |
3699 | /// @returns A shuffle group size parameter. |
3700 | /// @returns Zero if the primitive does not have a shuffle group size |
3701 | /// parameter. |
3702 | memory::dim get_group_size() const { |
3703 | return query_s64(query::group_size_s64); |
3704 | } |
3705 | |
3706 | /// Returns a propagation kind. |
3707 | /// @returns A propagation kind. |
3708 | /// @returns #dnnl::prop_kind::undef if the primitive does not have |
3709 | /// a propagation parameter. |
3710 | dnnl::prop_kind get_prop_kind() const { |
3711 | dnnl_prop_kind_t prop_kind; |
3712 | dnnl_status_t status = dnnl_primitive_desc_query( |
3713 | get(), dnnl_query_prop_kind, 0, &prop_kind); |
3714 | return status == dnnl_success ? static_cast<dnnl::prop_kind>(prop_kind) |
3715 | : dnnl::prop_kind::undef; |
3716 | } |
3717 | |
3718 | /// Returns a memory descriptor. |
3719 | /// |
3720 | /// @note |
3721 | /// There are also convenience methods |
3722 | /// #dnnl::primitive_desc_base::src_desc(), |
3723 | /// #dnnl::primitive_desc_base::dst_desc(), and others. |
3724 | /// |
3725 | /// @param what The kind of parameter to query; can be |
3726 | /// #dnnl::query::src_md, #dnnl::query::dst_md, etc. |
3727 | /// @param idx Index of the parameter. For example, convolution bias can |
3728 | /// be queried with what = #dnnl::query::weights_md and idx = 1. |
3729 | /// @returns The requested memory descriptor. |
3730 | /// @returns A zero memory descriptor if the primitive does not have a |
3731 | /// parameter of the specified kind or index. |
3732 | memory::desc query_md(query what, int idx = 0) const { |
3733 | std::vector<query> valid_q {query::src_md, query::diff_src_md, |
3734 | query::weights_md, query::diff_weights_md, query::dst_md, |
3735 | query::diff_dst_md, query::workspace_md, query::scratchpad_md, |
3736 | query::exec_arg_md}; |
3737 | if (!std::any_of(valid_q.cbegin(), valid_q.cend(), |
3738 | [=](query q) { return what == q; })) |
3739 | DNNL_THROW_ERROR(dnnl_invalid_arguments, |
3740 | "memory descriptor query is invalid" ); |
3741 | |
3742 | const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md( |
3743 | get(), dnnl::convert_to_c(what), idx); |
3744 | if (!cdesc) return memory::desc(); |
3745 | |
3746 | dnnl_memory_desc_t cloned_md = nullptr; |
3747 | error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), |
3748 | "could not clone a memory descriptor" ); |
3749 | |
3750 | return memory::desc(cloned_md); |
3751 | } |
3752 | |
3753 | /// Returns a source memory descriptor. |
3754 | /// @param idx Source index. |
3755 | /// @returns Source memory descriptor. |
3756 | /// @returns A zero memory descriptor if the primitive does not have a |
3757 | /// source parameter with index @p idx. |
3758 | memory::desc src_desc(int idx) const { |
3759 | return query_md(query::src_md, idx); |
3760 | } |
3761 | |
3762 | /// Returns a destination memory descriptor. |
3763 | /// @param idx Destination index. |
3764 | /// @returns Destination memory descriptor. |
3765 | /// @returns A zero memory descriptor if the primitive does not have a |
3766 | /// destination parameter with index @p idx. |
3767 | memory::desc dst_desc(int idx) const { |
3768 | return query_md(query::dst_md, idx); |
3769 | } |
3770 | |
3771 | /// Returns a weights memory descriptor. |
3772 | /// @param idx Weights index. |
3773 | /// @returns Weights memory descriptor. |
3774 | /// @returns A zero memory descriptor if the primitive does not have a |
3775 | /// weights parameter with index @p idx. |
3776 | memory::desc weights_desc(int idx) const { |
3777 | return query_md(query::weights_md, idx); |
3778 | } |
3779 | |
3780 | /// Returns a diff source memory descriptor. |
3781 | /// @param idx Diff source index. |
3782 | /// @returns Diff source memory descriptor. |
3783 | /// @returns A zero memory descriptor if the primitive does not have a |
3784 | /// diff source parameter with index @p idx. |
3785 | memory::desc diff_src_desc(int idx) const { |
3786 | return query_md(query::diff_src_md, idx); |
3787 | } |
3788 | |
3789 | /// Returns a diff destination memory descriptor. |
3790 | /// @param idx Diff destination index. |
3791 | /// @returns Diff destination memory descriptor. |
3792 | /// @returns A zero memory descriptor if the primitive does not have a |
3793 | /// diff destination parameter with index @p idx. |
3794 | memory::desc diff_dst_desc(int idx) const { |
3795 | return query_md(query::diff_dst_md, idx); |
3796 | } |
3797 | |
3798 | /// Returns a diff weights memory descriptor. |
3799 | /// @param idx Diff weights index. |
3800 | /// @returns Diff weights memory descriptor. |
3801 | /// @returns A zero memory descriptor if the primitive does not have a |
3802 | /// diff weights parameter with index @p idx. |
3803 | memory::desc diff_weights_desc(int idx) const { |
3804 | return query_md(query::diff_weights_md, idx); |
3805 | } |
3806 | |
3807 | // Separate versions without the index argument for documentation |
3808 | // purposes. |
3809 | |
3810 | /// Returns a source memory descriptor. |
3811 | /// @returns Source memory descriptor. |
3812 | /// @returns A zero memory descriptor if the primitive does not have a |
3813 | /// source parameter. |
3814 | memory::desc src_desc() const { return src_desc(0); } |
3815 | |
3816 | /// Returns a destination memory descriptor. |
3817 | /// @returns Destination memory descriptor. |
3818 | /// @returns A zero memory descriptor if the primitive does not have a |
3819 | /// destination parameter. |
3820 | memory::desc dst_desc() const { return dst_desc(0); } |
3821 | |
3822 | /// Returns a weights memory descriptor. |
3823 | /// @returns Weights memory descriptor. |
3824 | /// @returns A zero memory descriptor if the primitive does not have a |
3825 | /// weights parameter. |
3826 | memory::desc weights_desc() const { return weights_desc(0); } |
3827 | |
3828 | /// Returns a diff source memory descriptor. |
3829 | /// @returns Diff source memory descriptor. |
3830 | /// @returns A zero memory descriptor if the primitive does not have a |
3831 | /// diff source memory with. |
3832 | memory::desc diff_src_desc() const { return diff_src_desc(0); } |
3833 | |
3834 | /// Returns a diff destination memory descriptor. |
3835 | /// @returns Diff destination memory descriptor. |
3836 | /// @returns A zero memory descriptor if the primitive does not have a |
3837 | /// diff destination parameter. |
3838 | memory::desc diff_dst_desc() const { return diff_dst_desc(0); } |
3839 | |
3840 | /// Returns a diff weights memory descriptor. |
3841 | /// @returns Diff weights memory descriptor. |
3842 | /// @returns A zero memory descriptor if the primitive does not have a |
3843 | /// diff weights parameter. |
3844 | memory::desc diff_weights_desc() const { return diff_weights_desc(0); } |
3845 | |
3846 | /// Returns the workspace memory descriptor. |
3847 | /// @returns Workspace memory descriptor. |
3848 | /// @returns A zero memory descriptor if the primitive does not require |
3849 | /// workspace parameter. |
3850 | memory::desc workspace_desc() const { |
3851 | return query_md(query::workspace_md, 0); |
3852 | } |
3853 | |
3854 | /// Returns the scratchpad memory descriptor. |
3855 | /// @returns scratchpad memory descriptor. |
3856 | /// @returns A zero memory descriptor if the primitive does not require |
3857 | /// scratchpad parameter. |
3858 | /// @sa @ref dev_guide_attributes_scratchpad |
3859 | memory::desc scratchpad_desc() const { |
3860 | return query_md(query::scratchpad_md, 0); |
3861 | } |
3862 | |
3863 | /// Returns the engine on which the scratchpad memory is located. |
3864 | /// @returns The engine on which the scratchpad memory is located. |
3865 | engine scratchpad_engine() const { |
3866 | dnnl_engine_t c_engine; |
3867 | error::wrap_c_api(dnnl_primitive_desc_query(get(), |
3868 | dnnl::convert_to_c(query::scratchpad_engine), |
3869 | 0, &c_engine), |
3870 | "could not retrieve scratchpad engine from a primitive " |
3871 | "descriptor" ); |
3872 | return engine(c_engine, true); |
3873 | } |
3874 | |
3875 | /// Returns the primitive attributes. |
3876 | /// @returns The primitive attributes. |
3877 | primitive_attr get_primitive_attr() const { |
3878 | const_dnnl_primitive_attr_t const_c_attr; |
3879 | error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr), |
3880 | "could not get attributes from a primitive descriptor" ); |
3881 | dnnl_primitive_attr_t c_attr; |
3882 | error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr), |
3883 | "could not clone primitive attributes" ); |
3884 | return primitive_attr(c_attr); |
3885 | } |
3886 | |
3887 | /// Returns the kind of the primitive descriptor. |
3888 | /// @returns The kind of the primitive descriptor. |
3889 | dnnl::primitive::kind get_kind() const { |
3890 | dnnl_primitive_kind_t kind; |
3891 | error::wrap_c_api(dnnl_primitive_desc_query(get(), |
3892 | dnnl_query_primitive_kind, 0, (void *)&kind), |
3893 | "could not get primitive kind from a primitive descriptor" ); |
3894 | return static_cast<dnnl::primitive::kind>(kind); |
3895 | } |
3896 | |
3897 | /// Returns the cache blob ID of the primitive descriptor. |
3898 | /// @returns The cache blob ID of the primitive descriptor. |
3899 | std::vector<uint8_t> get_cache_blob_id() const { |
3900 | dnnl_dim_t count; |
3901 | const uint8_t *c_id; |
3902 | error::wrap_c_api( |
3903 | dnnl_primitive_desc_query(get(), |
3904 | dnnl::convert_to_c(query::cache_blob_id_size_s64), 0, |
3905 | (void *)&count), |
3906 | "could not get size of cache blob ID from a primitive " |
3907 | "descriptor" ); |
3908 | error::wrap_c_api(dnnl_primitive_desc_query(get(), |
3909 | dnnl::convert_to_c(query::cache_blob_id), 0, |
3910 | (void **)&c_id), |
3911 | "could not get cache blob ID from a primitive descriptor" ); |
3912 | std::vector<uint8_t> id(c_id, c_id + count); |
3913 | return id; |
3914 | } |
3915 | |
3916 | protected: |
3917 | /// Returns a float value. |
3918 | /// @param what The value to query. |
3919 | /// @returns The result of the query. |
3920 | /// @returns Zero if the primitive doesn't support the query. |
3921 | float query_f32(query what) const { |
3922 | float res; |
3923 | dnnl_status_t status = dnnl_primitive_desc_query( |
3924 | get(), dnnl::convert_to_c(what), 0, &res); |
3925 | return status == dnnl_success ? res : 0.0f; |
3926 | } |
3927 | |
3928 | /// Returns an #dnnl::algorithm value. |
3929 | /// @param what The value to query. |
3930 | /// @returns The result of the query. |
3931 | /// @returns #dnnl::algorithm::undef if the primitive doesn't support |
3932 | /// the query. |
3933 | algorithm query_alg(query what) const { |
3934 | dnnl_alg_kind_t res; |
3935 | dnnl_status_t status = dnnl_primitive_desc_query( |
3936 | get(), dnnl::convert_to_c(what), 0, &res); |
3937 | return status == dnnl_success ? static_cast<dnnl::algorithm>(res) |
3938 | : algorithm::undef; |
3939 | } |
3940 | |
3941 | /// Returns a memory::dims value. |
3942 | /// @param what The value to query. |
3943 | /// @returns The result of the query. |
3944 | /// @returns An empty #dnnl::memory::dims if the primitive doesn't support |
3945 | /// the query. |
3946 | memory::dims query_dims(query what) const { |
3947 | const bool is_backward = get_prop_kind() != prop_kind::forward_training |
3948 | && get_prop_kind() != prop_kind::forward_inference; |
3949 | const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(), |
3950 | is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0); |
3951 | |
3952 | int nspatial_dims = 0; |
3953 | if (md) { |
3954 | int ndims; |
3955 | error::wrap_c_api( |
3956 | dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims), |
3957 | "could not query ndims from a memory descriptor" ); |
3958 | nspatial_dims = ndims - 2; |
3959 | } |
3960 | |
3961 | dnnl_dims_t *c_dims; |
3962 | dnnl_status_t status = dnnl_primitive_desc_query( |
3963 | get(), dnnl::convert_to_c(what), 0, &c_dims); |
3964 | return status == dnnl_success |
3965 | ? memory::dims(*c_dims, *c_dims + nspatial_dims) |
3966 | : memory::dims {}; |
3967 | } |
3968 | |
3969 | /// Returns an #dnnl::engine value. |
3970 | /// @param what The value to query. |
3971 | /// @returns The result of the query. |
3972 | /// @returns A weak handle to the engine that the primitive descriptor was |
3973 | /// created with. |
3974 | engine query_engine(query what) const { |
3975 | dnnl_engine_t c_engine; |
3976 | error::wrap_c_api(dnnl_primitive_desc_query(get(), |
3977 | dnnl::convert_to_c(what), 0, &c_engine), |
3978 | "could not get an engine from a primitive_desc" ); |
3979 | return engine(c_engine, true); |
3980 | } |
3981 | |
3982 | /// Resets the value of the handle to a clone of a C API primitive |
3983 | /// descriptor. |
3984 | /// @param pd A C API primitive descriptor to clone. |
3985 | void reset_with_clone(const_dnnl_primitive_desc_t pd) { |
3986 | dnnl_primitive_desc_t new_pd; |
3987 | error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd), |
3988 | "could not clone a primitive descriptor" ); |
3989 | reset(new_pd); |
3990 | } |
3991 | |
3992 | /// Constructs a primitive descriptor base object from a clone of a C API |
3993 | /// primitive descriptor after verifying that it is what the caller |
3994 | /// expects. |
3995 | /// |
3996 | /// @note |
3997 | /// The @p prim_kind should map to a primitive that does not have |
3998 | /// different values of propagation kind (e.g. #dnnl::binary). |
3999 | /// @note |
4000 | /// Primitive descriptor base constructed this way does not support |
4001 | /// next_impl() (will throw). |
4002 | /// |
4003 | /// @param pd C API primitive descriptor to clone. |
4004 | /// @param prim_kind Expected primitive kind. |
4005 | primitive_desc_base( |
4006 | dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind) |
4007 | : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {} |
4008 | |
4009 | /// Constructs a primitive descriptor base object from a clone of a C API |
4010 | /// primitive descriptor after verifying that it is what the caller |
4011 | /// expects. |
4012 | /// |
4013 | /// @note |
4014 | /// Primitive descriptor base constructed this way does not support |
4015 | /// next_impl() (will throw). |
4016 | /// |
4017 | /// @param pd C API primitive descriptor to clone. |
4018 | /// @param prim_kind Expected primitive kind. |
4019 | /// @param aprop_kind Expected propagation kind. |
4020 | primitive_desc_base(dnnl_primitive_desc_t pd, |
4021 | dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind) |
4022 | : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {} |
4023 | |
4024 | /// Constructs a primitive descriptor base object from a clone of a C API |
4025 | /// primitive descriptor after verifying that it is what the caller |
4026 | /// expects. |
4027 | /// |
4028 | /// @note |
4029 | /// Primitive descriptor base constructed this way does not support |
4030 | /// next_impl() (will throw). |
4031 | /// |
4032 | /// @param pd C API primitive descriptor to clone. |
4033 | /// @param prim_kind Expected primitive kind. |
4034 | /// @param prop_kind1 Expected propagation kind (option 1). |
4035 | /// @param prop_kind2 Expected propagation kind (option 2). This value is |
4036 | /// checked if the check with @p prop_kind1 fails. |
4037 | primitive_desc_base(dnnl_primitive_desc_t pd, |
4038 | dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1, |
4039 | dnnl::prop_kind prop_kind2) { |
4040 | // It is OK to pass an empty primitive descriptor |
4041 | if (pd == nullptr) return; |
4042 | |
4043 | dnnl_status_t rc; |
4044 | |
4045 | dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind); |
4046 | dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1); |
4047 | dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2); |
4048 | |
4049 | // Check that primitive kind matches |
4050 | dnnl_primitive_kind_t pd_kind; |
4051 | rc = dnnl_primitive_desc_query( |
4052 | pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind); |
4053 | error::wrap_c_api( |
4054 | rc, "could not get primitive kind from a primitive descriptor" ); |
4055 | if (pd_kind != c_prim_kind) |
4056 | DNNL_THROW_ERROR(dnnl_invalid_arguments, |
4057 | "primitive descriptor operation kind mismatch" ); |
4058 | |
4059 | // Check that propagation kind matches |
4060 | dnnl_prop_kind_t pd_prop_kind; |
4061 | rc = dnnl_primitive_desc_query( |
4062 | pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind); |
4063 | |
4064 | // Something went wrong |
4065 | if (rc != dnnl_success && rc != dnnl_unimplemented) |
4066 | DNNL_THROW_ERROR(dnnl_invalid_arguments, |
4067 | "could not get propagation kind from the primitive " |
4068 | "descriptor" ); |
4069 | |
4070 | // Everything is fine |
4071 | if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef) |
4072 | || (rc == dnnl_success |
4073 | && (pd_prop_kind == c_prop_kind1 |
4074 | || pd_prop_kind == c_prop_kind2))) { |
4075 | reset_with_clone(pd); |
4076 | return; |
4077 | } |
4078 | |
4079 | // We could get the propagation kind but there is a mismatch |
4080 | DNNL_THROW_ERROR(dnnl_invalid_arguments, |
4081 | "primitive descriptor propagation kind mismatch" ); |
4082 | } |
4083 | |
4084 | /// Returns a constant reference to a static instance of default constructed |
4085 | /// primitive attributes |
4086 | static const primitive_attr &default_attr() { |
4087 | static const primitive_attr attr; |
4088 | return attr; |
4089 | } |
4090 | |
4091 | const_dnnl_memory_desc_t optional_arg(const memory::desc *md) { |
4092 | return md ? md->get() : nullptr; |
4093 | } |
4094 | |
4095 | const dnnl_dim_t *optional_arg(const memory::dims *dims) { |
4096 | return dims ? dims->data() : nullptr; |
4097 | } |
4098 | |
4099 | const float *optional_arg(const std::vector<float> *arg) { |
4100 | return arg ? arg->data() : nullptr; |
4101 | } |
4102 | |
4103 | using base = primitive_desc_base; |
4104 | }; |
4105 | |
4106 | /// @} dnnl_api_primitives_common |
4107 | |
4108 | /// @addtogroup dnnl_api_reorder Reorder |
4109 | /// |
4110 | /// A primitive to copy data between two memory objects. This primitive is |
4111 | /// typically used to change the way the data is laid out in memory. |
4112 | /// |
4113 | /// @sa @ref dev_guide_reorder in developer guide |
4114 | /// |
4115 | /// @{ |
4116 | |
4117 | /// Reorder primitive. |
4118 | struct reorder : public primitive { |
4119 | /// Primitive descriptor for a reorder primitive. |
4120 | struct primitive_desc : public primitive_desc_base { |
4121 | using primitive_desc_base::primitive_desc_base; |
4122 | |
4123 | /// Default constructor. Produces an empty object. |
4124 | primitive_desc() = default; |
4125 | |
4126 | /// Constructs a primitive descriptor for reorder primitive. |
4127 | /// |
4128 | /// @note |
4129 | /// If @p allow_empty is true, the constructor does not throw if a |
4130 | /// primitive descriptor cannot be created. |
4131 | /// |
4132 | /// @param src_engine Engine on which the source memory object will be |
4133 | /// located. |
4134 | /// @param src_md Source memory descriptor. |
4135 | /// @param dst_engine Engine on which the destination memory object |
4136 | /// will be located. |
4137 | /// @param dst_md Destination memory descriptor. |
4138 | /// @param attr Primitive attributes to use. Attributes are optional |
4139 | /// and default to empty attributes. |
4140 | /// @param allow_empty A flag signifying whether construction is allowed |
4141 | /// to fail without throwing an exception. In this case an empty |
4142 | /// object will be produced. This flag is optional and defaults to |
4143 | /// false. |
4144 | primitive_desc(const engine &src_engine, const memory::desc &src_md, |
4145 | const engine &dst_engine, const memory::desc &dst_md, |
4146 | const primitive_attr &attr = default_attr(), |
4147 | bool allow_empty = false) { |
4148 | dnnl_primitive_desc_t result; |
4149 | dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result, |
4150 | src_md.get(), src_engine.get(), dst_md.get(), |
4151 | dst_engine.get(), attr.get()); |
4152 | if (!allow_empty) |
4153 | error::wrap_c_api(status, |
4154 | "could not create a primitive descriptor for a reorder " |
4155 | "primitive" ); |
4156 | reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); |
4157 | } |
4158 | |
4159 | /// Constructs a primitive descriptor for reorder primitive. |
4160 | /// |
4161 | /// @param src Source memory object. It is used to obtain the source |
4162 | /// memory descriptor and engine. |
4163 | /// @param dst Destination memory object. It is used to obtain the |
4164 | /// destination memory descriptor and engine. |
4165 | /// @param attr Primitive attributes to use. Attributes are optional |
4166 | /// and default to empty attributes. |
4167 | /// @param allow_empty A flag signifying whether construction is allowed |
4168 | /// to fail without throwing an exception. In this case an empty |
4169 | /// object will be produced. This flag is optional and defaults to |
4170 | /// false. |
4171 | primitive_desc(const memory &src, const memory &dst, |
4172 | const primitive_attr &attr = default_attr(), |
4173 | bool allow_empty = false) { |
4174 | dnnl_primitive_desc_t result; |
4175 | auto src_md = src.get_desc(); |
4176 | auto dst_md = dst.get_desc(); |
4177 | dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result, |
4178 | src_md.get(), src.get_engine().get(), dst_md.get(), |
4179 | dst.get_engine().get(), attr.get()); |
4180 | if (!allow_empty) |
4181 | error::wrap_c_api(status, |
4182 | "could not create a primitive descriptor for a reorder " |
4183 | "primitive" ); |
4184 | reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); |
4185 | } |
4186 | |
4187 | /// Constructs a primitive descriptor for reorder primitive from a C |
4188 | /// API primitive descriptor which must have a matching kind. |
4189 | /// |
4190 | /// @param pd C API primitive descriptor for reorder primitive. |
4191 | primitive_desc(dnnl_primitive_desc_t pd) |
4192 | : primitive_desc_base(pd, dnnl::primitive::kind::reorder) {} |
4193 | |
4194 | /// Returns the engine on which the source memory is allocated. |
4195 | /// @returns The engine on which the source memory is allocated. |
4196 | engine get_src_engine() const { |
4197 | return query_engine(dnnl::query::reorder_src_engine); |
4198 | } |
4199 | |
4200 | /// Returns the engine on which the destination memory is allocated. |
4201 | /// @returns The engine on which the destination memory is allocated. |
4202 | engine get_dst_engine() const { |
4203 | return query_engine(dnnl::query::reorder_dst_engine); |
4204 | } |
4205 | |
4206 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
4207 | memory::desc src_desc() const { return base::src_desc(0); } |
4208 | |
4209 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
4210 | memory::desc dst_desc() const { return base::dst_desc(0); } |
4211 | }; |
4212 | |
4213 | /// Default constructor. Produces an empty object. |
4214 | reorder() = default; |
4215 | |
4216 | /// Constructs a reorder primitive. |
4217 | /// @param pd Primitive descriptor for reorder primitive. |
4218 | reorder(const primitive_desc &pd) : primitive(pd.get()) {} |
4219 | |
4220 | /// Constructs a reorder primitive from a cache blob. |
4221 | /// @param pd Primitive descriptor for reorder primitive. |
4222 | /// @param cache_blob Cache blob. |
4223 | reorder(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
4224 | : primitive(pd.get(), cache_blob) {} |
4225 | |
4226 | /// Constructs a reorder primitive that would reorder data between memory |
4227 | /// objects having the same memory descriptors as memory objects @p src and |
4228 | /// @p dst. |
4229 | /// |
4230 | /// @param src Source memory object. |
4231 | /// @param dst Destination memory object. |
4232 | /// @param attr Primitive attributes to use (optional). |
4233 | reorder(const memory &src, const memory &dst, |
4234 | const primitive_attr &attr = primitive_attr()) |
4235 | : primitive(primitive_desc(src, dst, attr).get()) {} |
4236 | |
4237 | using primitive::execute; |
4238 | |
4239 | /// Executes the reorder primitive. |
4240 | /// |
4241 | /// @param astream Stream object. The stream must belong to the same engine |
4242 | /// as the primitive. |
4243 | /// @param src Source memory object. |
4244 | /// @param dst Destination memory object. |
4245 | void execute(const stream &astream, memory &src, memory &dst) const { |
4246 | primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}}); |
4247 | } |
4248 | }; |
4249 | |
4250 | /// @} dnnl_api_reorder |
4251 | |
4252 | /// @addtogroup dnnl_api_concat Concat |
4253 | /// |
4254 | /// A primitive to concatenate data by arbitrary dimension. |
4255 | /// |
4256 | /// @sa @ref dev_guide_concat in developer guide |
4257 | /// |
4258 | /// @{ |
4259 | |
4260 | /// @cond DO_NOT_DOCUMENT_THIS |
4261 | inline std::vector<const_dnnl_memory_desc_t> convert_to_c( |
4262 | const std::vector<memory::desc> &mds) { |
4263 | std::vector<const_dnnl_memory_desc_t> c_mds; |
4264 | c_mds.reserve(mds.size()); |
4265 | for (const auto &md : mds) |
4266 | c_mds.push_back(md.get()); |
4267 | return c_mds; |
4268 | } |
4269 | /// @endcond |
4270 | |
4271 | /// Tensor concatenation (concat) primitive. |
4272 | struct concat : public primitive { |
4273 | /// Primitive descriptor for a concat primitive. |
4274 | struct primitive_desc : public primitive_desc_base { |
4275 | using primitive_desc_base::primitive_desc_base; |
4276 | |
4277 | /// Default constructor. Produces an empty object. |
4278 | primitive_desc() = default; |
4279 | |
4280 | /// Constructs a primitive descriptor for an out-of-place concatenation |
4281 | /// primitive. |
4282 | /// |
4283 | /// @param aengine Engine to perform the operation on. |
4284 | /// @param dst Destination memory descriptor. |
4285 | /// @param concat_dimension Source tensors will be concatenated over |
4286 | /// dimension with this index. Note that order of dimensions does |
4287 | /// not depend on memory format. |
4288 | /// @param srcs Vector of source memory descriptors. |
4289 | /// @param attr Primitive attributes to use. Attributes are optional |
4290 | /// and default to empty attributes. |
4291 | primitive_desc(const engine &aengine, const memory::desc &dst, |
4292 | int concat_dimension, const std::vector<memory::desc> &srcs, |
4293 | const primitive_attr &attr = default_attr()) { |
4294 | auto c_srcs = convert_to_c(srcs); |
4295 | |
4296 | dnnl_primitive_desc_t result; |
4297 | error::wrap_c_api( |
4298 | dnnl_concat_primitive_desc_create(&result, aengine.get(), |
4299 | dst.get(), (int)c_srcs.size(), concat_dimension, |
4300 | c_srcs.data(), attr.get()), |
4301 | "could not create a primitive descriptor for a concat " |
4302 | "primitive" ); |
4303 | reset(result); |
4304 | } |
4305 | |
4306 | /// Constructs a primitive descriptor for an out-of-place concatenation |
4307 | /// primitive. |
4308 | /// |
4309 | /// This version derives the destination memory descriptor |
4310 | /// automatically. |
4311 | /// |
4312 | /// @param aengine Engine to perform the operation on. |
4313 | /// @param concat_dimension Source tensors will be concatenated over |
4314 | /// dimension with this index. Note that order of dimensions does |
4315 | /// not depend on memory format. |
4316 | /// @param srcs Vector of source memory descriptors. |
4317 | /// @param attr Primitive attributes to use. Attributes are optional |
4318 | /// and default to empty attributes. |
4319 | primitive_desc(const engine &aengine, int concat_dimension, |
4320 | const std::vector<memory::desc> &srcs, |
4321 | const primitive_attr &attr = default_attr()) { |
4322 | auto c_api_srcs = convert_to_c(srcs); |
4323 | |
4324 | dnnl_primitive_desc_t result; |
4325 | error::wrap_c_api( |
4326 | dnnl_concat_primitive_desc_create(&result, aengine.get(), |
4327 | nullptr, (int)c_api_srcs.size(), concat_dimension, |
4328 | c_api_srcs.data(), attr.get()), |
4329 | "could not create a primitive descriptor for a concat " |
4330 | "primitive" ); |
4331 | reset(result); |
4332 | } |
4333 | |
4334 | /// Constructs a primitive descriptor for concat primitive from a C |
4335 | /// API primitive descriptor which must have a matching kind. |
4336 | /// |
4337 | /// @param pd C API primitive descriptor for concat primitive. |
4338 | primitive_desc(dnnl_primitive_desc_t pd) |
4339 | : primitive_desc_base(pd, dnnl::primitive::kind::concat) {} |
4340 | |
4341 | /// @copydoc dnnl::primitive_desc_base::src_desc(int)const |
4342 | memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); } |
4343 | |
4344 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
4345 | memory::desc dst_desc() const { return base::dst_desc(0); } |
4346 | }; |
4347 | |
4348 | /// Default constructor. Produces an empty object. |
4349 | concat() = default; |
4350 | |
4351 | /// Constructs a concatenation primitive. |
4352 | /// @param pd Primitive descriptor for concatenation primitive. |
4353 | concat(const primitive_desc &pd) : primitive(pd.get()) {} |
4354 | |
4355 | /// Constructs a concatenation primitive from a cache blob. |
4356 | /// @param pd Primitive descriptor for concatenation primitive. |
4357 | /// @param cache_blob Cache blob. |
4358 | concat(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
4359 | : primitive(pd.get(), cache_blob) {} |
4360 | }; |
4361 | |
4362 | /// @} dnnl_api_concat |
4363 | |
4364 | /// @addtogroup dnnl_api_sum Sum |
4365 | /// |
4366 | /// A primitive to sum multiple tensors. |
4367 | /// |
4368 | /// @sa @ref dev_guide_sum in developer guide |
4369 | /// |
4370 | /// @{ |
4371 | |
4372 | /// Out-of-place summation (sum) primitive. |
4373 | struct sum : public primitive { |
4374 | /// Primitive descriptor for a sum primitive. |
4375 | struct primitive_desc : public primitive_desc_base { |
4376 | using primitive_desc_base::primitive_desc_base; |
4377 | |
4378 | /// Default constructor. Produces an empty object. |
4379 | primitive_desc() = default; |
4380 | |
4381 | /// Constructs a primitive descriptor for a sum primitive. |
4382 | /// |
4383 | /// @param aengine Engine to perform the operation on. |
4384 | /// @param dst Destination memory descriptor. |
4385 | /// @param scales Vector of scales to multiply data in each source |
4386 | /// memory by. |
4387 | /// @param srcs Vector of source memory descriptors. |
4388 | /// @param attr Primitive attributes to use. Attributes are optional |
4389 | /// and default to empty attributes. |
4390 | primitive_desc(const engine &aengine, const memory::desc &dst, |
4391 | const std::vector<float> &scales, |
4392 | const std::vector<memory::desc> &srcs, |
4393 | const primitive_attr &attr = default_attr()) { |
4394 | validate_container_size(scales, |
4395 | "counts of scales and sources are not equal" , |
4396 | (int)srcs.size(), (int)srcs.size()); |
4397 | |
4398 | auto c_api_srcs = convert_to_c(srcs); |
4399 | |
4400 | dnnl_primitive_desc_t result; |
4401 | error::wrap_c_api( |
4402 | dnnl_sum_primitive_desc_create(&result, aengine.get(), |
4403 | dst.get(), (int)c_api_srcs.size(), scales.data(), |
4404 | c_api_srcs.data(), attr.get()), |
4405 | "could not create a primitive descriptor for a sum " |
4406 | "primitive" ); |
4407 | reset(result); |
4408 | } |
4409 | |
4410 | /// Constructs a primitive descriptor for a sum primitive. |
4411 | /// |
4412 | /// This version derives the destination memory descriptor |
4413 | /// automatically. |
4414 | /// |
4415 | /// @param aengine Engine on which to perform the operation. |
4416 | /// @param scales Vector of scales by which to multiply data in each |
4417 | /// source memory object. |
4418 | /// @param srcs Vector of source memory descriptors. |
4419 | /// @param attr Primitive attributes to use. Attributes are optional |
4420 | /// and default to empty attributes. |
4421 | primitive_desc(const engine &aengine, const std::vector<float> &scales, |
4422 | const std::vector<memory::desc> &srcs, |
4423 | const primitive_attr &attr = default_attr()) { |
4424 | validate_container_size(scales, |
4425 | "counts of scales and sources are not equal" , |
4426 | (int)srcs.size(), (int)srcs.size()); |
4427 | |
4428 | auto c_api_srcs = convert_to_c(srcs); |
4429 | dnnl_primitive_desc_t result; |
4430 | error::wrap_c_api( |
4431 | dnnl_sum_primitive_desc_create(&result, aengine.get(), |
4432 | nullptr, (int)c_api_srcs.size(), scales.data(), |
4433 | c_api_srcs.data(), attr.get()), |
4434 | "could not create a primitive descriptor for a sum " |
4435 | "primitive" ); |
4436 | reset(result); |
4437 | } |
4438 | |
4439 | /// Constructs a primitive descriptor for sum primitive from a C API |
4440 | /// primitive descriptor which must have a matching kind. |
4441 | /// |
4442 | /// @param pd C API primitive descriptor for reorder primitive. |
4443 | primitive_desc(dnnl_primitive_desc_t pd) |
4444 | : primitive_desc_base(pd, dnnl::primitive::kind::sum) {} |
4445 | |
4446 | /// @copydoc dnnl::primitive_desc_base::src_desc(int)const |
4447 | memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); } |
4448 | |
4449 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
4450 | memory::desc dst_desc() const { return base::dst_desc(0); } |
4451 | }; |
4452 | |
4453 | /// Default constructor. Produces an empty object. |
4454 | sum() = default; |
4455 | |
4456 | /// Constructs a sum primitive. |
4457 | /// @param pd Primitive descriptor for sum primitive. |
4458 | sum(const primitive_desc &pd) : primitive(pd.get()) {} |
4459 | |
4460 | /// Constructs a sum primitive from a cache blob. |
4461 | /// @param pd Primitive descriptor for sum primitive. |
4462 | /// @param cache_blob Cache blob. |
4463 | sum(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
4464 | : primitive(pd.get(), cache_blob) {} |
4465 | }; |
4466 | |
4467 | /// @} dnnl_api_sum |
4468 | |
4469 | /// @addtogroup dnnl_api_primitives_common |
4470 | /// @{ |
4471 | |
4472 | /// A base class for descriptors of all primitives that support iteration |
4473 | /// over multiple implementations. |
4474 | struct primitive_desc : public primitive_desc_base { |
4475 | using primitive_desc_base::primitive_desc_base; |
4476 | |
4477 | primitive_desc() = default; |
4478 | |
4479 | /// Changes the primitive descriptor to point to the next available |
4480 | /// implementation. |
4481 | /// |
4482 | /// @returns @c true on success and @c false if the last available |
4483 | /// implementation has already been reached. In the latter case, the |
4484 | /// primitive descriptor itself is kept unchanged. |
4485 | bool next_impl() { |
4486 | dnnl_status_t status = dnnl_primitive_desc_next_impl(get()); |
4487 | if (status == dnnl_last_impl_reached) return false; |
4488 | error::wrap_c_api(status, "last available implementation is reached" ); |
4489 | return true; |
4490 | } |
4491 | }; |
4492 | |
4493 | /// @} dnnl_api_primitives_common |
4494 | |
4495 | /// @addtogroup dnnl_api_convolution Convolution |
4496 | /// |
4497 | /// A primitive to perform 1D, 2D or 3D convolution. Supported variants are |
4498 | /// forward propagation, backward propagation, and weights gradient with or |
4499 | /// without bias. |
4500 | /// |
4501 | /// @sa @ref dev_guide_convolution in developer guide |
4502 | /// |
4503 | /// @{ |
4504 | |
4505 | /// Convolution forward propagation primitive. |
4506 | struct convolution_forward : public primitive { |
4507 | /// Primitive descriptor for a convolution forward propagation primitive. |
4508 | struct primitive_desc : public dnnl::primitive_desc { |
4509 | /// Default constructor. Produces an empty object. |
4510 | primitive_desc() = default; |
4511 | |
4512 | /// Constructs a primitive descriptor for a convolution forward |
4513 | /// propagation primitive with bias. |
4514 | /// |
4515 | /// @note |
4516 | /// All the memory descriptors may be initialized with the |
4517 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
4518 | /// |
4519 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
4520 | /// for spatial dimensions only and hence must have the same number of |
4521 | /// elements as there are spatial dimensions. The order of values is |
4522 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
4523 | /// and 2D tensors), and width. |
4524 | /// |
4525 | /// @param aengine Engine to use. |
4526 | /// @param aprop_kind Propagation kind. Possible values are |
4527 | /// #dnnl::prop_kind::forward_training, and |
4528 | /// #dnnl::prop_kind::forward_inference. |
4529 | /// @param aalgorithm Convolution algorithm. Possible values are |
4530 | /// #dnnl::algorithm::convolution_direct, |
4531 | /// #dnnl::algorithm::convolution_winograd, and |
4532 | /// #dnnl::algorithm::convolution_auto. |
4533 | /// @param src_desc Source memory descriptor. |
4534 | /// @param weights_desc Weights memory descriptor. |
4535 | /// @param bias_desc Bias memory descriptor. Passing zero memory |
4536 | /// descriptor disables the bias term. |
4537 | /// @param dst_desc Destination memory descriptor. |
4538 | /// @param strides Strides for each spatial dimension. |
4539 | /// @param padding_l Vector of padding values for low indices for each |
4540 | /// spatial dimension `([[front,] top,] left)`. |
4541 | /// @param padding_r Vector of padding values for high indices for |
4542 | /// each spatial dimension `([[back,] bottom,] right)`. |
4543 | /// @param attr Primitive attributes to use. Attributes are optional |
4544 | /// and default to empty attributes. |
4545 | /// @param allow_empty A flag signifying whether construction is |
4546 | /// allowed to fail without throwing an exception. In this case an |
4547 | /// empty object will be produced. This flag is optional and |
4548 | /// defaults to false. |
4549 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
4550 | algorithm aalgorithm, const memory::desc &src_desc, |
4551 | const memory::desc &weights_desc, const memory::desc &bias_desc, |
4552 | const memory::desc &dst_desc, const memory::dims &strides, |
4553 | const memory::dims &padding_l, const memory::dims &padding_r, |
4554 | const primitive_attr &attr = default_attr(), |
4555 | bool allow_empty = false) |
4556 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
4557 | weights_desc, &bias_desc, dst_desc, strides, nullptr, |
4558 | padding_l, padding_r, attr, allow_empty) {} |
4559 | |
4560 | /// Constructs a primitive descriptor for a convolution forward |
4561 | /// propagation primitive without bias. |
4562 | /// |
4563 | /// @note |
4564 | /// All the memory descriptors may be initialized with the |
4565 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
4566 | /// |
4567 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
4568 | /// for spatial dimensions only and hence must have the same number of |
4569 | /// elements as there are spatial dimensions. The order of values is |
4570 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
4571 | /// and 2D tensors), and width. |
4572 | /// |
4573 | /// @param aengine Engine to use. |
4574 | /// @param aprop_kind Propagation kind. Possible values are |
4575 | /// #dnnl::prop_kind::forward_training, and |
4576 | /// #dnnl::prop_kind::forward_inference. |
4577 | /// @param aalgorithm Convolution algorithm. Possible values are |
4578 | /// #dnnl::algorithm::convolution_direct, |
4579 | /// #dnnl::algorithm::convolution_winograd, and |
4580 | /// #dnnl::algorithm::convolution_auto. |
4581 | /// @param src_desc Source memory descriptor. |
4582 | /// @param weights_desc Weights memory descriptor. |
4583 | /// @param dst_desc Destination memory descriptor. |
4584 | /// @param strides Strides for each spatial dimension. |
4585 | /// @param padding_l Vector of padding values for low indices for each |
4586 | /// spatial dimension `([[front,] top,] left)`. |
4587 | /// @param padding_r Vector of padding values for high indices for |
4588 | /// each spatial dimension `([[back,] bottom,] right)`. |
4589 | /// @param attr Primitive attributes to use. Attributes are optional |
4590 | /// and default to empty attributes. |
4591 | /// @param allow_empty A flag signifying whether construction is |
4592 | /// allowed to fail without throwing an exception. In this case an |
4593 | /// empty object will be produced. This flag is optional and |
4594 | /// defaults to false. |
4595 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
4596 | algorithm aalgorithm, const memory::desc &src_desc, |
4597 | const memory::desc &weights_desc, const memory::desc &dst_desc, |
4598 | const memory::dims &strides, const memory::dims &padding_l, |
4599 | const memory::dims &padding_r, |
4600 | const primitive_attr &attr = default_attr(), |
4601 | bool allow_empty = false) |
4602 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
4603 | weights_desc, nullptr, dst_desc, strides, nullptr, |
4604 | padding_l, padding_r, attr, allow_empty) {} |
4605 | |
4606 | /// Constructs a primitive descriptor for a convolution forward |
4607 | /// propagation primitive with bias. |
4608 | /// |
4609 | /// @note |
4610 | /// All the memory descriptors may be initialized with the |
4611 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
4612 | /// |
4613 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
4614 | /// contain values for spatial dimensions only and hence must have the |
4615 | /// same number of elements as there are spatial dimensions. The order |
4616 | /// of values is the same as in the tensor: depth (for 3D tensors), |
4617 | /// height (for 3D and 2D tensors), and width. |
4618 | /// |
4619 | /// @param aengine Engine to use. |
4620 | /// @param aprop_kind Propagation kind. Possible values are |
4621 | /// #dnnl::prop_kind::forward_training, and |
4622 | /// #dnnl::prop_kind::forward_inference. |
4623 | /// @param aalgorithm Convolution algorithm. Possible values are |
4624 | /// #dnnl::algorithm::convolution_direct, |
4625 | /// #dnnl::algorithm::convolution_winograd, and |
4626 | /// #dnnl::algorithm::convolution_auto. |
4627 | /// @param src_desc Source memory descriptor. |
4628 | /// @param weights_desc Weights memory descriptor. |
4629 | /// @param bias_desc Bias memory descriptor. Passing zero memory |
4630 | /// descriptor disables the bias term. |
4631 | /// @param dst_desc Destination memory descriptor. |
4632 | /// @param strides Strides for each spatial dimension. |
4633 | /// @param dilates Dilations for each spatial dimension. A zero value |
4634 | /// means no dilation in the corresponding dimension. |
4635 | /// @param padding_l Vector of padding values for low indices for each |
4636 | /// spatial dimension `([[front,] top,] left)`. |
4637 | /// @param padding_r Vector of padding values for high indices for |
4638 | /// each spatial dimension `([[back,] bottom,] right)`. |
4639 | /// @param attr Primitive attributes to use. Attributes are optional |
4640 | /// and default to empty attributes. |
4641 | /// @param allow_empty A flag signifying whether construction is |
4642 | /// allowed to fail without throwing an exception. In this case an |
4643 | /// empty object will be produced. This flag is optional and |
4644 | /// defaults to false. |
4645 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
4646 | algorithm aalgorithm, const memory::desc &src_desc, |
4647 | const memory::desc &weights_desc, const memory::desc &bias_desc, |
4648 | const memory::desc &dst_desc, const memory::dims &strides, |
4649 | const memory::dims &dilates, const memory::dims &padding_l, |
4650 | const memory::dims &padding_r, |
4651 | const primitive_attr &attr = default_attr(), |
4652 | bool allow_empty = false) |
4653 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
4654 | weights_desc, &bias_desc, dst_desc, strides, &dilates, |
4655 | padding_l, padding_r, attr, allow_empty) {} |
4656 | |
4657 | /// Constructs a primitive descriptor for a convolution forward |
4658 | /// propagation primitive without bias. |
4659 | /// |
4660 | /// @note |
4661 | /// All the memory descriptors may be initialized with the |
4662 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
4663 | /// |
4664 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
4665 | /// contain values for spatial dimensions only and hence must have the |
4666 | /// same number of elements as there are spatial dimensions. The order |
4667 | /// of values is the same as in the tensor: depth (for 3D tensors), |
4668 | /// height (for 3D and 2D tensors), and width. |
4669 | /// |
4670 | /// @param aengine Engine to use. |
4671 | /// @param aprop_kind Propagation kind. Possible values are |
4672 | /// #dnnl::prop_kind::forward_training, and |
4673 | /// #dnnl::prop_kind::forward_inference. |
4674 | /// @param aalgorithm Convolution algorithm. Possible values are |
4675 | /// #dnnl::algorithm::convolution_direct, |
4676 | /// #dnnl::algorithm::convolution_winograd, and |
4677 | /// #dnnl::algorithm::convolution_auto. |
4678 | /// @param src_desc Source memory descriptor. |
4679 | /// @param weights_desc Weights memory descriptor. |
4680 | /// @param dst_desc Destination memory descriptor. |
4681 | /// @param strides Strides for each spatial dimension. |
4682 | /// @param dilates Dilations for each spatial dimension. A zero value |
4683 | /// means no dilation in the corresponding dimension. |
4684 | /// @param padding_l Vector of padding values for low indices for each |
4685 | /// spatial dimension `([[front,] top,] left)`. |
4686 | /// @param padding_r Vector of padding values for high indices for |
4687 | /// each spatial dimension `([[back,] bottom,] right)`. |
4688 | /// @param attr Primitive attributes to use. Attributes are optional |
4689 | /// and default to empty attributes. |
4690 | /// @param allow_empty A flag signifying whether construction is |
4691 | /// allowed to fail without throwing an exception. In this case an |
4692 | /// empty object will be produced. This flag is optional and |
4693 | /// defaults to false. |
4694 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
4695 | algorithm aalgorithm, const memory::desc &src_desc, |
4696 | const memory::desc &weights_desc, const memory::desc &dst_desc, |
4697 | const memory::dims &strides, const memory::dims &dilates, |
4698 | const memory::dims &padding_l, const memory::dims &padding_r, |
4699 | const primitive_attr &attr = default_attr(), |
4700 | bool allow_empty = false) |
4701 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
4702 | weights_desc, nullptr, dst_desc, strides, &dilates, |
4703 | padding_l, padding_r, attr, allow_empty) {} |
4704 | |
4705 | /// Constructs a primitive descriptor for a convolution forward |
4706 | /// propagation primitive from a C API primitive descriptor that must |
4707 | /// have a matching kind. |
4708 | /// |
4709 | /// @param pd C API primitive descriptor for a convolution forward |
4710 | /// propagation primitive. |
4711 | primitive_desc(dnnl_primitive_desc_t pd) |
4712 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution, |
4713 | dnnl::prop_kind::forward_training, |
4714 | dnnl::prop_kind::forward_inference) {} |
4715 | |
4716 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
4717 | memory::desc src_desc() const { return base::src_desc(0); } |
4718 | |
4719 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
4720 | memory::desc weights_desc() const { return base::weights_desc(0); } |
4721 | |
4722 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
4723 | memory::desc dst_desc() const { return base::dst_desc(0); } |
4724 | |
4725 | /// Returns the bias memory descriptor. |
4726 | /// @returns The bias memory descriptor. |
4727 | /// @returns A zero memory descriptor of the primitive does not have a |
4728 | /// bias parameter. |
4729 | memory::desc bias_desc() const { return base::weights_desc(1); } |
4730 | |
4731 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
4732 | algorithm get_algorithm() const { return base::get_algorithm(); } |
4733 | |
4734 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
4735 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
4736 | |
4737 | /// @copydoc dnnl::primitive_desc_base::get_strides()const |
4738 | memory::dims get_strides() const { return base::get_strides(); } |
4739 | |
4740 | /// @copydoc dnnl::primitive_desc_base::get_dilations()const |
4741 | memory::dims get_dilations() const { return base::get_dilations(); } |
4742 | |
4743 | /// @copydoc dnnl::primitive_desc_base::get_padding_l()const |
4744 | memory::dims get_padding_l() const { return base::get_padding_l(); } |
4745 | |
4746 | /// @copydoc dnnl::primitive_desc_base::get_padding_r()const |
4747 | memory::dims get_padding_r() const { return base::get_padding_r(); } |
4748 | |
4749 | private: |
4750 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
4751 | algorithm aalgorithm, const memory::desc &src_desc, |
4752 | const memory::desc &weights_desc, const memory::desc *bias_desc, |
4753 | const memory::desc &dst_desc, const memory::dims &strides, |
4754 | const memory::dims *dilates, const memory::dims &padding_l, |
4755 | const memory::dims &padding_r, const primitive_attr &attr, |
4756 | bool allow_empty) { |
4757 | |
4758 | memory::validate_dims(strides, src_desc.get_ndims() - 2); |
4759 | memory::validate_dims(padding_l, src_desc.get_ndims() - 2); |
4760 | memory::validate_dims(padding_r, src_desc.get_ndims() - 2); |
4761 | |
4762 | if (dilates) |
4763 | memory::validate_dims(*dilates, src_desc.get_ndims() - 2); |
4764 | |
4765 | dnnl_primitive_desc_t pd = nullptr; |
4766 | dnnl_status_t status |
4767 | = dnnl_convolution_forward_primitive_desc_create(&pd, |
4768 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
4769 | convert_to_c(aalgorithm), src_desc.get(), |
4770 | weights_desc.get(), optional_arg(bias_desc), |
4771 | dst_desc.get(), &strides[0], optional_arg(dilates), |
4772 | &padding_l[0], &padding_r[0], attr.get()); |
4773 | if (!allow_empty) |
4774 | error::wrap_c_api(status, |
4775 | "could not create a primitive descriptor for a " |
4776 | "convolution forward propagation primitive" ); |
4777 | reset(pd); |
4778 | } |
4779 | }; |
4780 | |
4781 | /// Default constructor. Produces an empty object. |
4782 | convolution_forward() = default; |
4783 | |
4784 | /// Constructs a convolution forward propagation primitive. |
4785 | /// @param pd Primitive descriptor for a convolution forward propagation |
4786 | /// primitive. |
4787 | convolution_forward(const primitive_desc &pd) : primitive(pd) {} |
4788 | |
4789 | /// Constructs a convolution forward propagation primitive from a cache |
4790 | /// blob. |
4791 | /// @param pd Primitive descriptor for a convolution forward propagation |
4792 | /// primitive. |
4793 | /// @param cache_blob Cache blob. |
4794 | convolution_forward( |
4795 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
4796 | : primitive(pd, cache_blob) {} |
4797 | }; |
4798 | |
4799 | /// Convolution backward propagation primitive. |
4800 | struct convolution_backward_data : public primitive { |
4801 | /// Primitive descriptor for a convolution backward propagation primitive. |
4802 | struct primitive_desc : public dnnl::primitive_desc { |
4803 | /// Default constructor. Produces an empty object. |
4804 | primitive_desc() = default; |
4805 | |
4806 | /// Constructs a primitive descriptor for a convolution backward |
4807 | /// propagation primitive. |
4808 | /// |
4809 | /// @note |
4810 | /// All the memory descriptors may be initialized with the |
4811 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
4812 | /// |
4813 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
4814 | /// for spatial dimensions only and hence must have the same number of |
4815 | /// elements as there are spatial dimensions. The order of values is |
4816 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
4817 | /// and 2D tensors), and width. |
4818 | /// |
4819 | /// @param aengine Engine to use. |
4820 | /// @param aalgorithm Convolution algorithm. Possible values are |
4821 | /// #dnnl::algorithm::convolution_direct, |
4822 | /// #dnnl::algorithm::convolution_winograd, and |
4823 | /// #dnnl::algorithm::convolution_auto. |
4824 | /// @param diff_src_desc Diff source memory descriptor. |
4825 | /// @param weights_desc Weights memory descriptor. |
4826 | /// @param diff_dst_desc Diff destination memory descriptor. |
4827 | /// @param strides Strides for each spatial dimension. |
4828 | /// @param padding_l Vector of padding values for low indices for each |
4829 | /// spatial dimension `([[front,] top,] left)`. |
4830 | /// @param padding_r Vector of padding values for high indices for |
4831 | /// each spatial dimension `([[back,] bottom,] right)`. |
4832 | /// @param hint_fwd_pd Primitive descriptor for a convolution |
4833 | /// forward propagation primitive. It is used as a hint for |
4834 | /// deciding which memory format to use. |
4835 | /// @param attr Primitive attributes to use. Attributes are optional |
4836 | /// and default to empty attributes. |
4837 | /// @param allow_empty A flag signifying whether construction is |
4838 | /// allowed to fail without throwing an exception. In this case an |
4839 | /// empty object will be produced. This flag is optional and |
4840 | /// defaults to false. |
4841 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
4842 | const memory::desc &diff_src_desc, |
4843 | const memory::desc &weights_desc, |
4844 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
4845 | const memory::dims &padding_l, const memory::dims &padding_r, |
4846 | const convolution_forward::primitive_desc &hint_fwd_pd, |
4847 | const primitive_attr &attr = default_attr(), |
4848 | bool allow_empty = false) |
4849 | : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc, |
4850 | diff_dst_desc, strides, nullptr, padding_l, padding_r, |
4851 | hint_fwd_pd, attr, allow_empty) {} |
4852 | |
4853 | /// Constructs a primitive descriptor for a convolution backward |
4854 | /// propagation primitive. |
4855 | /// |
4856 | /// @note |
4857 | /// All the memory descriptors may be initialized with the |
4858 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
4859 | /// |
4860 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
4861 | /// contain values for spatial dimensions only and hence must have the |
4862 | /// same number of elements as there are spatial dimensions. The order |
4863 | /// of values is the same as in the tensor: depth (for 3D tensors), |
4864 | /// height (for 3D and 2D tensors), and width. |
4865 | /// |
4866 | /// @param aengine Engine to use. |
4867 | /// @param aalgorithm Convolution algorithm. Possible values are |
4868 | /// #dnnl::algorithm::convolution_direct, |
4869 | /// #dnnl::algorithm::convolution_winograd, and |
4870 | /// #dnnl::algorithm::convolution_auto. |
4871 | /// @param diff_src_desc Diff source memory descriptor. |
4872 | /// @param weights_desc Weights memory descriptor. |
4873 | /// @param diff_dst_desc Diff destination memory descriptor. |
4874 | /// @param strides Strides for each spatial dimension. |
4875 | /// @param dilates Dilations for each spatial dimension. A zero value |
4876 | /// means no dilation in the corresponding dimension. |
4877 | /// @param padding_l Vector of padding values for low indices for each |
4878 | /// spatial dimension `([[front,] top,] left)`. |
4879 | /// @param padding_r Vector of padding values for high indices for |
4880 | /// each spatial dimension `([[back,] bottom,] right)`. |
4881 | /// @param hint_fwd_pd Primitive descriptor for a convolution |
4882 | /// forward propagation primitive. It is used as a hint for |
4883 | /// deciding which memory format to use. |
4884 | /// @param attr Primitive attributes to use. Attributes are optional |
4885 | /// and default to empty attributes. |
4886 | /// @param allow_empty A flag signifying whether construction is |
4887 | /// allowed to fail without throwing an exception. In this case an |
4888 | /// empty object will be produced. This flag is optional and |
4889 | /// defaults to false. |
4890 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
4891 | const memory::desc &diff_src_desc, |
4892 | const memory::desc &weights_desc, |
4893 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
4894 | const memory::dims &dilates, const memory::dims &padding_l, |
4895 | const memory::dims &padding_r, |
4896 | const convolution_forward::primitive_desc &hint_fwd_pd, |
4897 | const primitive_attr &attr = default_attr(), |
4898 | bool allow_empty = false) |
4899 | : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc, |
4900 | diff_dst_desc, strides, &dilates, padding_l, padding_r, |
4901 | hint_fwd_pd, attr, allow_empty) {} |
4902 | |
4903 | /// Constructs a primitive descriptor for a convolution backward |
4904 | /// propagation primitive from a C API primitive descriptor that must |
4905 | /// have a matching kind. |
4906 | /// |
4907 | /// @param pd C API primitive descriptor for a convolution backward |
4908 | /// propagation primitive. |
4909 | primitive_desc(dnnl_primitive_desc_t pd) |
4910 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution, |
4911 | dnnl::prop_kind::backward_data) {} |
4912 | |
4913 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
4914 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
4915 | |
4916 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
4917 | memory::desc weights_desc() const { return base::weights_desc(0); } |
4918 | |
4919 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
4920 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
4921 | |
4922 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
4923 | algorithm get_algorithm() const { return base::get_algorithm(); } |
4924 | |
4925 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
4926 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
4927 | |
4928 | /// @copydoc dnnl::primitive_desc_base::get_strides()const |
4929 | memory::dims get_strides() const { return base::get_strides(); } |
4930 | |
4931 | /// @copydoc dnnl::primitive_desc_base::get_dilations()const |
4932 | memory::dims get_dilations() const { return base::get_dilations(); } |
4933 | |
4934 | /// @copydoc dnnl::primitive_desc_base::get_padding_l()const |
4935 | memory::dims get_padding_l() const { return base::get_padding_l(); } |
4936 | |
4937 | /// @copydoc dnnl::primitive_desc_base::get_padding_r()const |
4938 | memory::dims get_padding_r() const { return base::get_padding_r(); } |
4939 | |
4940 | private: |
4941 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
4942 | const memory::desc &diff_src_desc, |
4943 | const memory::desc &weights_desc, |
4944 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
4945 | const memory::dims *dilates, const memory::dims &padding_l, |
4946 | const memory::dims &padding_r, |
4947 | const convolution_forward::primitive_desc &hint_fwd_pd, |
4948 | const primitive_attr &attr, bool allow_empty) { |
4949 | |
4950 | memory::validate_dims(strides, diff_src_desc.get_ndims() - 2); |
4951 | memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2); |
4952 | memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2); |
4953 | |
4954 | if (dilates) |
4955 | memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2); |
4956 | |
4957 | dnnl_primitive_desc_t pd = nullptr; |
4958 | dnnl_status_t status |
4959 | = dnnl_convolution_backward_data_primitive_desc_create(&pd, |
4960 | aengine.get(), convert_to_c(aalgorithm), |
4961 | diff_src_desc.get(), weights_desc.get(), |
4962 | diff_dst_desc.get(), &strides[0], |
4963 | optional_arg(dilates), &padding_l[0], &padding_r[0], |
4964 | hint_fwd_pd.get(), attr.get()); |
4965 | if (!allow_empty) |
4966 | error::wrap_c_api(status, |
4967 | "could not create a primitive descriptor for a " |
4968 | "convolution backward propagation primitive" ); |
4969 | reset(pd); |
4970 | } |
4971 | }; |
4972 | |
4973 | /// Default constructor. Produces an empty object. |
4974 | convolution_backward_data() = default; |
4975 | |
4976 | /// Constructs a convolution backward propagation primitive. |
4977 | /// @param pd Primitive descriptor for a convolution backward propagation |
4978 | /// primitive. |
4979 | convolution_backward_data(const primitive_desc &pd) : primitive(pd) {} |
4980 | |
4981 | /// Constructs a convolution backward propagation primitive from a cache |
4982 | /// blob. |
4983 | /// @param pd Primitive descriptor for a convolution backward propagation |
4984 | /// primitive. |
4985 | /// @param cache_blob Cache blob. |
4986 | convolution_backward_data( |
4987 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
4988 | : primitive(pd, cache_blob) {} |
4989 | }; |
4990 | |
4991 | /// Convolution weights gradient primitive. |
4992 | struct convolution_backward_weights : public primitive { |
4993 | /// Primitive descriptor for a convolution weights gradient primitive. |
4994 | struct primitive_desc : public dnnl::primitive_desc { |
4995 | /// Default constructor. Produces an empty object. |
4996 | primitive_desc() = default; |
4997 | |
4998 | /// Constructs a primitive descriptor for a convolution weights gradient |
4999 | /// primitive with bias. |
5000 | /// |
5001 | /// @note |
5002 | /// All the memory descriptors may be initialized with the |
5003 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5004 | /// |
5005 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
5006 | /// for spatial dimensions only and hence must have the same number of |
5007 | /// elements as there are spatial dimensions. The order of values is |
5008 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
5009 | /// and 2D tensors), and width. |
5010 | /// |
5011 | /// @param aengine Engine to use. |
5012 | /// @param aalgorithm Convolution algorithm. Possible values are |
5013 | /// #dnnl::algorithm::convolution_direct, |
5014 | /// #dnnl::algorithm::convolution_winograd, and |
5015 | /// #dnnl::algorithm::convolution_auto. |
5016 | /// @param src_desc Source memory descriptor. |
5017 | /// @param diff_weights_desc Diff weights memory descriptor. |
5018 | /// @param diff_bias_desc Diff bias memory descriptor. Passing zero |
5019 | /// memory descriptor disables the bias term. |
5020 | /// @param diff_dst_desc Diff destination memory descriptor. |
5021 | /// @param strides Strides for each spatial dimension. |
5022 | /// @param padding_l Vector of padding values for low indices for each |
5023 | /// spatial dimension `([[front,] top,] left)`. |
5024 | /// @param padding_r Vector of padding values for high indices for |
5025 | /// each spatial dimension `([[back,] bottom,] right)`. |
5026 | /// @param hint_fwd_pd Primitive descriptor for a convolution |
5027 | /// forward propagation primitive. It is used as a hint for |
5028 | /// deciding which memory format to use. |
5029 | /// @param attr Primitive attributes to use. Attributes are optional |
5030 | /// and default to empty attributes. |
5031 | /// @param allow_empty A flag signifying whether construction is |
5032 | /// allowed to fail without throwing an exception. In this case an |
5033 | /// empty object will be produced. This flag is optional and |
5034 | /// defaults to false. |
5035 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5036 | const memory::desc &src_desc, |
5037 | const memory::desc &diff_weights_desc, |
5038 | const memory::desc &diff_bias_desc, |
5039 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5040 | const memory::dims &padding_l, const memory::dims &padding_r, |
5041 | const convolution_forward::primitive_desc &hint_fwd_pd, |
5042 | const primitive_attr &attr = default_attr(), |
5043 | bool allow_empty = false) |
5044 | : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, |
5045 | &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l, |
5046 | padding_r, hint_fwd_pd, attr, allow_empty) {} |
5047 | |
5048 | /// Constructs a primitive descriptor for a convolution weights gradient |
5049 | /// primitive without bias. |
5050 | /// |
5051 | /// @note |
5052 | /// All the memory descriptors may be initialized with the |
5053 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5054 | /// |
5055 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
5056 | /// for spatial dimensions only and hence must have the same number of |
5057 | /// elements as there are spatial dimensions. The order of values is |
5058 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
5059 | /// and 2D tensors), and width. |
5060 | /// |
5061 | /// @param aengine Engine to use. |
5062 | /// @param aalgorithm Convolution algorithm. Possible values are |
5063 | /// #dnnl::algorithm::convolution_direct, |
5064 | /// #dnnl::algorithm::convolution_winograd, and |
5065 | /// #dnnl::algorithm::convolution_auto. |
5066 | /// @param src_desc Source memory descriptor. |
5067 | /// @param diff_weights_desc Diff weights memory descriptor. |
5068 | /// @param diff_dst_desc Diff destination memory descriptor. |
5069 | /// @param strides Strides for each spatial dimension. |
5070 | /// @param padding_l Vector of padding values for low indices for each |
5071 | /// spatial dimension `([[front,] top,] left)`. |
5072 | /// @param padding_r Vector of padding values for high indices for |
5073 | /// each spatial dimension `([[back,] bottom,] right)`. |
5074 | /// @param hint_fwd_pd Primitive descriptor for a convolution |
5075 | /// forward propagation primitive. It is used as a hint for |
5076 | /// deciding which memory format to use. |
5077 | /// @param attr Primitive attributes to use. Attributes are optional |
5078 | /// and default to empty attributes. |
5079 | /// @param allow_empty A flag signifying whether construction is |
5080 | /// allowed to fail without throwing an exception. In this case an |
5081 | /// empty object will be produced. This flag is optional and |
5082 | /// defaults to false. |
5083 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5084 | const memory::desc &src_desc, |
5085 | const memory::desc &diff_weights_desc, |
5086 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5087 | const memory::dims &padding_l, const memory::dims &padding_r, |
5088 | const convolution_forward::primitive_desc &hint_fwd_pd, |
5089 | const primitive_attr &attr = default_attr(), |
5090 | bool allow_empty = false) |
5091 | : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, |
5092 | nullptr, diff_dst_desc, strides, nullptr, padding_l, |
5093 | padding_r, hint_fwd_pd, attr, allow_empty) {} |
5094 | |
5095 | /// Constructs a primitive descriptor for a convolution weights |
5096 | /// gradient primitive with bias. |
5097 | /// |
5098 | /// @note |
5099 | /// All the memory descriptors may be initialized with the |
5100 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5101 | /// |
5102 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
5103 | /// contain values for spatial dimensions only and hence must have the |
5104 | /// same number of elements as there are spatial dimensions. The order |
5105 | /// of values is the same as in the tensor: depth (for 3D tensors), |
5106 | /// height (for 3D and 2D tensors), and width. |
5107 | /// |
5108 | /// @param aengine Engine to use. |
5109 | /// @param aalgorithm Convolution algorithm. Possible values are |
5110 | /// #dnnl::algorithm::convolution_direct, |
5111 | /// #dnnl::algorithm::convolution_winograd, and |
5112 | /// #dnnl::algorithm::convolution_auto. |
5113 | /// @param src_desc Source memory descriptor. |
5114 | /// @param diff_weights_desc Diff weights memory descriptor. |
5115 | /// @param diff_bias_desc Diff bias memory descriptor. Passing zero |
5116 | /// memory descriptor disables the bias term. |
5117 | /// @param diff_dst_desc Diff destination memory descriptor. |
5118 | /// @param strides Strides for each spatial dimension. |
5119 | /// @param dilates Dilations for each spatial dimension. A zero value |
5120 | /// means no dilation in the corresponding dimension. |
5121 | /// @param padding_l Vector of padding values for low indices for each |
5122 | /// spatial dimension `([[front,] top,] left)`. |
5123 | /// @param padding_r Vector of padding values for high indices for |
5124 | /// each spatial dimension `([[back,] bottom,] right)`. |
5125 | /// @param hint_fwd_pd Primitive descriptor for a convolution |
5126 | /// forward propagation primitive. It is used as a hint for |
5127 | /// deciding which memory format to use. |
5128 | /// @param attr Primitive attributes to use. Attributes are optional |
5129 | /// and default to empty attributes. |
5130 | /// @param allow_empty A flag signifying whether construction is |
5131 | /// allowed to fail without throwing an exception. In this case an |
5132 | /// empty object will be produced. This flag is optional and |
5133 | /// defaults to false. |
5134 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5135 | const memory::desc &src_desc, |
5136 | const memory::desc &diff_weights_desc, |
5137 | const memory::desc &diff_bias_desc, |
5138 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5139 | const memory::dims &dilates, const memory::dims &padding_l, |
5140 | const memory::dims &padding_r, |
5141 | const convolution_forward::primitive_desc &hint_fwd_pd, |
5142 | const primitive_attr &attr = default_attr(), |
5143 | bool allow_empty = false) |
5144 | : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, |
5145 | &diff_bias_desc, diff_dst_desc, strides, &dilates, |
5146 | padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {} |
5147 | |
5148 | /// Constructs a primitive descriptor for a convolution weights |
5149 | /// gradient primitive without bias. |
5150 | /// |
5151 | /// @note |
5152 | /// All the memory descriptors may be initialized with the |
5153 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5154 | /// |
5155 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
5156 | /// contain values for spatial dimensions only and hence must have the |
5157 | /// same number of elements as there are spatial dimensions. The order |
5158 | /// of values is the same as in the tensor: depth (for 3D tensors), |
5159 | /// height (for 3D and 2D tensors), and width. |
5160 | /// |
5161 | /// @param aengine Engine to use. |
5162 | /// @param aalgorithm Convolution algorithm. Possible values are |
5163 | /// #dnnl::algorithm::convolution_direct, |
5164 | /// #dnnl::algorithm::convolution_winograd, and |
5165 | /// #dnnl::algorithm::convolution_auto. |
5166 | /// @param src_desc Source memory descriptor. |
5167 | /// @param diff_weights_desc Diff weights memory descriptor. |
5168 | /// @param diff_dst_desc Diff destination memory descriptor. |
5169 | /// @param strides Strides for each spatial dimension. |
5170 | /// @param dilates Dilations for each spatial dimension. A zero value |
5171 | /// means no dilation in the corresponding dimension. |
5172 | /// @param padding_l Vector of padding values for low indices for each |
5173 | /// spatial dimension `([[front,] top,] left)`. |
5174 | /// @param padding_r Vector of padding values for high indices for |
5175 | /// each spatial dimension `([[back,] bottom,] right)`. |
5176 | /// @param hint_fwd_pd Primitive descriptor for a convolution |
5177 | /// forward propagation primitive. It is used as a hint for |
5178 | /// deciding which memory format to use. |
5179 | /// @param attr Primitive attributes to use. Attributes are optional |
5180 | /// and default to empty attributes. |
5181 | /// @param allow_empty A flag signifying whether construction is |
5182 | /// allowed to fail without throwing an exception. In this case an |
5183 | /// empty object will be produced. This flag is optional and |
5184 | /// defaults to false. |
5185 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5186 | const memory::desc &src_desc, |
5187 | const memory::desc &diff_weights_desc, |
5188 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5189 | const memory::dims &dilates, const memory::dims &padding_l, |
5190 | const memory::dims &padding_r, |
5191 | const convolution_forward::primitive_desc &hint_fwd_pd, |
5192 | const primitive_attr &attr = default_attr(), |
5193 | bool allow_empty = false) |
5194 | : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, |
5195 | nullptr, diff_dst_desc, strides, &dilates, padding_l, |
5196 | padding_r, hint_fwd_pd, attr, allow_empty) {} |
5197 | |
5198 | /// Constructs a primitive descriptor for a convolution weights gradient |
5199 | /// primitive from a C API primitive descriptor that must have a |
5200 | /// matching kind. |
5201 | /// |
5202 | /// @param pd C API primitive descriptor for a convolution weights |
5203 | /// gradient primitive. |
5204 | primitive_desc(dnnl_primitive_desc_t pd) |
5205 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution, |
5206 | dnnl::prop_kind::backward_weights) {} |
5207 | |
5208 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
5209 | memory::desc src_desc() const { return base::src_desc(0); } |
5210 | |
5211 | /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const |
5212 | memory::desc diff_weights_desc() const { |
5213 | return base::diff_weights_desc(0); |
5214 | } |
5215 | |
5216 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
5217 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
5218 | |
5219 | /// Returns the diff bias memory descriptor. |
5220 | /// @returns The diff bias memory descriptor. |
5221 | /// @returns A zero memory descriptor of the primitive does not have a |
5222 | /// diff bias parameter. |
5223 | memory::desc diff_bias_desc() const { |
5224 | return base::diff_weights_desc(1); |
5225 | } |
5226 | |
5227 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
5228 | algorithm get_algorithm() const { return base::get_algorithm(); } |
5229 | |
5230 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
5231 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
5232 | |
5233 | /// @copydoc dnnl::primitive_desc_base::get_strides()const |
5234 | memory::dims get_strides() const { return base::get_strides(); } |
5235 | |
5236 | /// @copydoc dnnl::primitive_desc_base::get_dilations()const |
5237 | memory::dims get_dilations() const { return base::get_dilations(); } |
5238 | |
5239 | /// @copydoc dnnl::primitive_desc_base::get_padding_l()const |
5240 | memory::dims get_padding_l() const { return base::get_padding_l(); } |
5241 | |
5242 | /// @copydoc dnnl::primitive_desc_base::get_padding_r()const |
5243 | memory::dims get_padding_r() const { return base::get_padding_r(); } |
5244 | |
5245 | private: |
5246 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5247 | const memory::desc &src_desc, |
5248 | const memory::desc &diff_weights_desc, |
5249 | const memory::desc *diff_bias_desc, |
5250 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5251 | const memory::dims *dilates, const memory::dims &padding_l, |
5252 | const memory::dims &padding_r, |
5253 | const convolution_forward::primitive_desc &hint_fwd_pd, |
5254 | const primitive_attr &attr, bool allow_empty) { |
5255 | |
5256 | memory::validate_dims(strides, src_desc.get_ndims() - 2); |
5257 | memory::validate_dims(padding_l, src_desc.get_ndims() - 2); |
5258 | memory::validate_dims(padding_r, src_desc.get_ndims() - 2); |
5259 | |
5260 | if (dilates) |
5261 | memory::validate_dims(*dilates, src_desc.get_ndims() - 2); |
5262 | |
5263 | dnnl_primitive_desc_t pd = nullptr; |
5264 | dnnl_status_t status |
5265 | = dnnl_convolution_backward_weights_primitive_desc_create( |
5266 | &pd, aengine.get(), convert_to_c(aalgorithm), |
5267 | src_desc.get(), diff_weights_desc.get(), |
5268 | optional_arg(diff_bias_desc), diff_dst_desc.get(), |
5269 | &strides[0], optional_arg(dilates), &padding_l[0], |
5270 | &padding_r[0], hint_fwd_pd.get(), attr.get()); |
5271 | if (!allow_empty) |
5272 | error::wrap_c_api(status, |
5273 | "could not create a primitive descriptor for a " |
5274 | "convolution weights update primitive" ); |
5275 | reset(pd); |
5276 | } |
5277 | }; |
5278 | |
5279 | /// Default constructor. Produces an empty object. |
5280 | convolution_backward_weights() = default; |
5281 | |
5282 | /// Constructs a convolution weights gradient primitive. |
5283 | /// @param pd Primitive descriptor for a convolution weights gradient |
5284 | /// primitive. |
5285 | convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {} |
5286 | |
5287 | /// Constructs a convolution weights gradient primitive from a cache blob. |
5288 | /// @param pd Primitive descriptor for a convolution weights gradient |
5289 | /// primitive. |
5290 | /// @param cache_blob Cache blob. |
5291 | convolution_backward_weights( |
5292 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
5293 | : primitive(pd, cache_blob) {} |
5294 | }; |
5295 | |
5296 | /// @} dnnl_api_convolution |
5297 | // |
5298 | /// @addtogroup dnnl_api_deconvolution Deconvolution |
5299 | /// |
5300 | /// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are |
5301 | /// forward propagation, backward propagation, and weights gradient with or |
5302 | /// without bias. |
5303 | /// |
5304 | /// @{ |
5305 | |
5306 | /// Deconvolution forward propagation primitive. |
5307 | struct deconvolution_forward : public primitive { |
5308 | /// Primitive descriptor for a deconvolution forward propagation primitive. |
5309 | struct primitive_desc : public dnnl::primitive_desc { |
5310 | /// Default constructor. Produces an empty object. |
5311 | primitive_desc() = default; |
5312 | |
5313 | /// Constructs a primitive descriptor for a deconvolution forward |
5314 | /// propagation primitive with bias. |
5315 | /// |
5316 | /// @note |
5317 | /// All the memory descriptors may be initialized with the |
5318 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5319 | /// |
5320 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
5321 | /// for spatial dimensions only and hence must have the same number of |
5322 | /// elements as there are spatial dimensions. The order of values is |
5323 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
5324 | /// and 2D tensors), and width. |
5325 | /// |
5326 | /// @param aengine Engine to use. |
5327 | /// @param aprop_kind Propagation kind. Possible values are |
5328 | /// #dnnl::prop_kind::forward_training, and |
5329 | /// #dnnl::prop_kind::forward_inference. |
5330 | /// @param aalgorithm Deconvolution algorithm: |
5331 | /// #dnnl::algorithm::deconvolution_direct, and |
5332 | /// #dnnl::algorithm::deconvolution_winograd. |
5333 | /// @param src_desc Source memory descriptor. |
5334 | /// @param weights_desc Weights memory descriptor. |
5335 | /// @param bias_desc Bias memory descriptor. Passing zero memory |
5336 | /// descriptor disables the bias term. |
5337 | /// @param dst_desc Destination memory descriptor. |
5338 | /// @param strides Vector of strides for spatial dimension. |
5339 | /// @param padding_l Vector of padding values for low indices for each |
5340 | /// spatial dimension `([[front,] top,] left)`. |
5341 | /// @param padding_r Vector of padding values for high indices for |
5342 | /// each spatial dimension `([[back,] bottom,] right)`. |
5343 | /// @param attr Primitive attributes to use. Attributes are optional |
5344 | /// and default to empty attributes. |
5345 | /// @param allow_empty A flag signifying whether construction is |
5346 | /// allowed to fail without throwing an exception. In this case an |
5347 | /// empty object will be produced. This flag is optional and |
5348 | /// defaults to false. |
5349 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
5350 | algorithm aalgorithm, const memory::desc &src_desc, |
5351 | const memory::desc &weights_desc, const memory::desc &bias_desc, |
5352 | const memory::desc &dst_desc, const memory::dims &strides, |
5353 | const memory::dims &padding_l, const memory::dims &padding_r, |
5354 | const primitive_attr &attr = default_attr(), |
5355 | bool allow_empty = false) |
5356 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
5357 | weights_desc, &bias_desc, dst_desc, strides, nullptr, |
5358 | padding_l, padding_r, attr, allow_empty) {} |
5359 | |
5360 | /// Constructs a primitive descriptor for a deconvolution forward |
5361 | /// propagation primitive without bias. |
5362 | /// |
5363 | /// @note |
5364 | /// All the memory descriptors may be initialized with the |
5365 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5366 | /// |
5367 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
5368 | /// for spatial dimensions only and hence must have the same number of |
5369 | /// elements as there are spatial dimensions. The order of values is |
5370 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
5371 | /// and 2D tensors), and width. |
5372 | /// |
5373 | /// @param aengine Engine to use. |
5374 | /// @param aprop_kind Propagation kind. Possible values are |
5375 | /// #dnnl::prop_kind::forward_training, and |
5376 | /// #dnnl::prop_kind::forward_inference. |
5377 | /// @param aalgorithm Deconvolution algorithm: |
5378 | /// #dnnl::algorithm::deconvolution_direct, and |
5379 | /// #dnnl::algorithm::deconvolution_winograd. |
5380 | /// @param src_desc Source memory descriptor. |
5381 | /// @param weights_desc Weights memory descriptor. |
5382 | /// @param dst_desc Destination memory descriptor. |
5383 | /// @param strides Vector of strides for spatial dimension. |
5384 | /// @param padding_l Vector of padding values for low indices for each |
5385 | /// spatial dimension `([[front,] top,] left)`. |
5386 | /// @param padding_r Vector of padding values for high indices for |
5387 | /// each spatial dimension `([[back,] bottom,] right)`. |
5388 | /// @param attr Primitive attributes to use. Attributes are optional |
5389 | /// and default to empty attributes. |
5390 | /// @param allow_empty A flag signifying whether construction is |
5391 | /// allowed to fail without throwing an exception. In this case an |
5392 | /// empty object will be produced. This flag is optional and |
5393 | /// defaults to false. |
5394 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
5395 | algorithm aalgorithm, const memory::desc &src_desc, |
5396 | const memory::desc &weights_desc, const memory::desc &dst_desc, |
5397 | const memory::dims &strides, const memory::dims &padding_l, |
5398 | const memory::dims &padding_r, |
5399 | const primitive_attr &attr = default_attr(), |
5400 | bool allow_empty = false) |
5401 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
5402 | weights_desc, nullptr, dst_desc, strides, nullptr, |
5403 | padding_l, padding_r, attr, allow_empty) {} |
5404 | |
5405 | /// Constructs a primitive descriptor for a deconvolution forward |
5406 | /// propagation primitive with bias. |
5407 | /// |
5408 | /// @note |
5409 | /// All the memory descriptors may be initialized with the |
5410 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5411 | /// |
5412 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
5413 | /// contain values for spatial dimensions only and hence must have the |
5414 | /// same number of elements as there are spatial dimensions. The order |
5415 | /// of values is the same as in the tensor: depth (for 3D tensors), |
5416 | /// height (for 3D and 2D tensors), and width. |
5417 | /// |
5418 | /// @param aengine Engine to use. |
5419 | /// @param aprop_kind Propagation kind. Possible values are |
5420 | /// #dnnl::prop_kind::forward_training, and |
5421 | /// #dnnl::prop_kind::forward_inference. |
5422 | /// @param aalgorithm Deconvolution algorithm: |
5423 | /// #dnnl::algorithm::deconvolution_direct, and |
5424 | /// #dnnl::algorithm::deconvolution_winograd. |
5425 | /// @param src_desc Source memory descriptor. |
5426 | /// @param weights_desc Weights memory descriptor. |
5427 | /// @param bias_desc Bias memory descriptor. Passing zero memory |
5428 | /// descriptor disables the bias term. |
5429 | /// @param dst_desc Destination memory descriptor. |
5430 | /// @param strides Vector of strides for spatial dimension. |
5431 | /// @param dilates Dilations for each spatial dimension. A zero value |
5432 | /// means no dilation in the corresponding dimension. |
5433 | /// @param padding_l Vector of padding values for low indices for each |
5434 | /// spatial dimension `([[front,] top,] left)`. |
5435 | /// @param padding_r Vector of padding values for high indices for |
5436 | /// each spatial dimension `([[back,] bottom,] right)`. |
5437 | /// @param attr Primitive attributes to use. Attributes are optional |
5438 | /// and default to empty attributes. |
5439 | /// @param allow_empty A flag signifying whether construction is |
5440 | /// allowed to fail without throwing an exception. In this case an |
5441 | /// empty object will be produced. This flag is optional and |
5442 | /// defaults to false. |
5443 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
5444 | algorithm aalgorithm, const memory::desc &src_desc, |
5445 | const memory::desc &weights_desc, const memory::desc &bias_desc, |
5446 | const memory::desc &dst_desc, const memory::dims &strides, |
5447 | const memory::dims &dilates, const memory::dims &padding_l, |
5448 | const memory::dims &padding_r, |
5449 | const primitive_attr &attr = default_attr(), |
5450 | bool allow_empty = false) |
5451 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
5452 | weights_desc, &bias_desc, dst_desc, strides, &dilates, |
5453 | padding_l, padding_r, attr, allow_empty) {} |
5454 | |
5455 | /// Constructs a primitive descriptor for a deconvolution forward |
5456 | /// propagation primitive without bias. |
5457 | /// |
5458 | /// @note |
5459 | /// All the memory descriptors may be initialized with the |
5460 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5461 | /// |
5462 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
5463 | /// contain values for spatial dimensions only and hence must have the |
5464 | /// same number of elements as there are spatial dimensions. The order |
5465 | /// of values is the same as in the tensor: depth (for 3D tensors), |
5466 | /// height (for 3D and 2D tensors), and width. |
5467 | /// |
5468 | /// @param aengine Engine to use. |
5469 | /// @param aprop_kind Propagation kind. Possible values are |
5470 | /// #dnnl::prop_kind::forward_training, and |
5471 | /// #dnnl::prop_kind::forward_inference. |
5472 | /// @param aalgorithm Deconvolution algorithm: |
5473 | /// #dnnl::algorithm::deconvolution_direct, and |
5474 | /// #dnnl::algorithm::deconvolution_winograd. |
5475 | /// @param src_desc Source memory descriptor. |
5476 | /// @param weights_desc Weights memory descriptor. |
5477 | /// @param dst_desc Destination memory descriptor. |
5478 | /// @param strides Vector of strides for spatial dimension. |
5479 | /// @param dilates Dilations for each spatial dimension. A zero value |
5480 | /// means no dilation in the corresponding dimension. |
5481 | /// @param padding_l Vector of padding values for low indices for each |
5482 | /// spatial dimension `([[front,] top,] left)`. |
5483 | /// @param padding_r Vector of padding values for high indices for |
5484 | /// each spatial dimension `([[back,] bottom,] right)`. |
5485 | /// @param attr Primitive attributes to use. Attributes are optional |
5486 | /// and default to empty attributes. |
5487 | /// @param allow_empty A flag signifying whether construction is |
5488 | /// allowed to fail without throwing an exception. In this case an |
5489 | /// empty object will be produced. This flag is optional and |
5490 | /// defaults to false. |
5491 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
5492 | algorithm aalgorithm, const memory::desc &src_desc, |
5493 | const memory::desc &weights_desc, const memory::desc &dst_desc, |
5494 | const memory::dims &strides, const memory::dims &dilates, |
5495 | const memory::dims &padding_l, const memory::dims &padding_r, |
5496 | const primitive_attr &attr = default_attr(), |
5497 | bool allow_empty = false) |
5498 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
5499 | weights_desc, nullptr, dst_desc, strides, &dilates, |
5500 | padding_l, padding_r, attr, allow_empty) {} |
5501 | |
5502 | /// Constructs a primitive descriptor for a deconvolution forward |
5503 | /// propagation primitive from a C API primitive descriptor that must |
5504 | /// have a matching kind. |
5505 | /// |
5506 | /// @param pd C API primitive descriptor for a deconvolution forward |
5507 | /// propagation primitive. |
5508 | primitive_desc(dnnl_primitive_desc_t pd) |
5509 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution, |
5510 | dnnl::prop_kind::forward_training, |
5511 | dnnl::prop_kind::forward_inference) {} |
5512 | |
5513 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
5514 | memory::desc src_desc() const { return base::src_desc(0); } |
5515 | |
5516 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
5517 | memory::desc weights_desc() const { return base::weights_desc(0); } |
5518 | |
5519 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
5520 | memory::desc dst_desc() const { return base::dst_desc(0); } |
5521 | |
5522 | /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const |
5523 | memory::desc bias_desc() const { return base::weights_desc(1); } |
5524 | |
5525 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
5526 | algorithm get_algorithm() const { return base::get_algorithm(); } |
5527 | |
5528 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
5529 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
5530 | |
5531 | /// @copydoc dnnl::primitive_desc_base::get_strides()const |
5532 | memory::dims get_strides() const { return base::get_strides(); } |
5533 | |
5534 | /// @copydoc dnnl::primitive_desc_base::get_dilations()const |
5535 | memory::dims get_dilations() const { return base::get_dilations(); } |
5536 | |
5537 | /// @copydoc dnnl::primitive_desc_base::get_padding_l()const |
5538 | memory::dims get_padding_l() const { return base::get_padding_l(); } |
5539 | |
5540 | /// @copydoc dnnl::primitive_desc_base::get_padding_r()const |
5541 | memory::dims get_padding_r() const { return base::get_padding_r(); } |
5542 | |
5543 | private: |
5544 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
5545 | algorithm aalgorithm, const memory::desc &src_desc, |
5546 | const memory::desc &weights_desc, const memory::desc *bias_desc, |
5547 | const memory::desc &dst_desc, const memory::dims &strides, |
5548 | const memory::dims *dilates, const memory::dims &padding_l, |
5549 | const memory::dims &padding_r, const primitive_attr &attr, |
5550 | bool allow_empty) { |
5551 | |
5552 | memory::validate_dims(strides, src_desc.get_ndims() - 2); |
5553 | memory::validate_dims(padding_l, src_desc.get_ndims() - 2); |
5554 | memory::validate_dims(padding_r, src_desc.get_ndims() - 2); |
5555 | |
5556 | if (dilates) |
5557 | memory::validate_dims(*dilates, src_desc.get_ndims() - 2); |
5558 | |
5559 | dnnl_primitive_desc_t pd = nullptr; |
5560 | dnnl_status_t status |
5561 | = dnnl_deconvolution_forward_primitive_desc_create(&pd, |
5562 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
5563 | convert_to_c(aalgorithm), src_desc.get(), |
5564 | weights_desc.get(), optional_arg(bias_desc), |
5565 | dst_desc.get(), &strides[0], optional_arg(dilates), |
5566 | &padding_l[0], &padding_r[0], attr.get()); |
5567 | if (!allow_empty) |
5568 | error::wrap_c_api(status, |
5569 | "could not create a primitive descriptor for a " |
5570 | "deconvolution forward propagation primitive" ); |
5571 | reset(pd); |
5572 | } |
5573 | }; |
5574 | |
5575 | /// Default constructor. Produces an empty object. |
5576 | deconvolution_forward() = default; |
5577 | |
5578 | /// Constructs a deconvolution forward propagation primitive. |
5579 | /// @param pd Primitive descriptor for a deconvolution forward propagation |
5580 | /// primitive. |
5581 | deconvolution_forward(const primitive_desc &pd) : primitive(pd) {} |
5582 | |
5583 | /// Constructs a deconvolution forward propagation primitive from a cache |
5584 | /// blob. |
5585 | /// @param pd Primitive descriptor for a deconvolution forward propagation |
5586 | /// primitive. |
5587 | /// @param cache_blob Cache blob. |
5588 | deconvolution_forward( |
5589 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
5590 | : primitive(pd, cache_blob) {} |
5591 | }; |
5592 | |
5593 | /// Deconvolution backward propagation primitive. |
5594 | struct deconvolution_backward_data : public primitive { |
5595 | /// Primitive descriptor for a deconvolution backward propagation primitive. |
5596 | struct primitive_desc : public dnnl::primitive_desc { |
5597 | /// Default constructor. Produces an empty object. |
5598 | primitive_desc() = default; |
5599 | |
5600 | /// Constructs a primitive descriptor for a deconvolution backward |
5601 | /// propagation primitive. |
5602 | /// |
5603 | /// @note |
5604 | /// All the memory descriptors may be initialized with the |
5605 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5606 | /// |
5607 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
5608 | /// for spatial dimensions only and hence must have the same number of |
5609 | /// elements as there are spatial dimensions. The order of values is |
5610 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
5611 | /// and 2D tensors), and width. |
5612 | /// |
5613 | /// @param aengine Engine to use. |
5614 | /// @param aalgorithm Deconvolution algorithm |
5615 | /// (#dnnl::algorithm::convolution_direct, |
5616 | /// #dnnl::algorithm::convolution_winograd). |
5617 | /// @param diff_src_desc Diff source memory descriptor. |
5618 | /// @param weights_desc Weights memory descriptor. |
5619 | /// @param diff_dst_desc Diff destination memory descriptor. |
5620 | /// @param strides Strides for each spatial dimension. |
5621 | /// @param padding_l Vector of padding values for low indices for each |
5622 | /// spatial dimension `([[front,] top,] left)`. |
5623 | /// @param padding_r Vector of padding values for high indices for |
5624 | /// each spatial dimension `([[back,] bottom,] right)`. |
5625 | /// @param hint_fwd_pd Primitive descriptor for a deconvolution |
5626 | /// forward propagation primitive. It is used as a hint for |
5627 | /// deciding which memory format to use. |
5628 | /// @param attr Primitive attributes to use. Attributes are optional |
5629 | /// and default to empty attributes. |
5630 | /// @param allow_empty A flag signifying whether construction is |
5631 | /// allowed to fail without throwing an exception. In this case an |
5632 | /// empty object will be produced. This flag is optional and |
5633 | /// defaults to false. |
5634 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5635 | const memory::desc &diff_src_desc, |
5636 | const memory::desc &weights_desc, |
5637 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5638 | const memory::dims &padding_l, const memory::dims &padding_r, |
5639 | const deconvolution_forward::primitive_desc &hint_fwd_pd, |
5640 | const primitive_attr &attr = default_attr(), |
5641 | bool allow_empty = false) |
5642 | : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc, |
5643 | diff_dst_desc, strides, nullptr, padding_l, padding_r, |
5644 | hint_fwd_pd, attr, allow_empty) {} |
5645 | |
5646 | /// Constructs a primitive descriptor for a deconvolution backward |
5647 | /// propagation primitive. |
5648 | /// |
5649 | /// @note |
5650 | /// All the memory descriptors may be initialized with the |
5651 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5652 | /// |
5653 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
5654 | /// contain values for spatial dimensions only and hence must have the |
5655 | /// same number of elements as there are spatial dimensions. The order |
5656 | /// of values is the same as in the tensor: depth (for 3D tensors), |
5657 | /// height (for 3D and 2D tensors), and width. |
5658 | /// |
5659 | /// @param aengine Engine to use. |
5660 | /// @param aalgorithm Deconvolution algorithm |
5661 | /// (#dnnl::algorithm::convolution_direct, |
5662 | /// #dnnl::algorithm::convolution_winograd). |
5663 | /// @param diff_src_desc Diff source memory descriptor. |
5664 | /// @param weights_desc Weights memory descriptor. |
5665 | /// @param diff_dst_desc Diff destination memory descriptor. |
5666 | /// @param strides Strides for each spatial dimension. |
5667 | /// @param dilates Dilations for each spatial dimension. A zero value |
5668 | /// means no dilation in the corresponding dimension. |
5669 | /// @param padding_l Vector of padding values for low indices for each |
5670 | /// spatial dimension `([[front,] top,] left)`. |
5671 | /// @param padding_r Vector of padding values for high indices for |
5672 | /// each spatial dimension `([[back,] bottom,] right)`. |
5673 | /// @param hint_fwd_pd Primitive descriptor for a deconvolution |
5674 | /// forward propagation primitive. It is used as a hint for |
5675 | /// deciding which memory format to use. |
5676 | /// @param attr Primitive attributes to use. Attributes are optional |
5677 | /// and default to empty attributes. |
5678 | /// @param allow_empty A flag signifying whether construction is |
5679 | /// allowed to fail without throwing an exception. In this case an |
5680 | /// empty object will be produced. This flag is optional and |
5681 | /// defaults to false. |
5682 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5683 | const memory::desc &diff_src_desc, |
5684 | const memory::desc &weights_desc, |
5685 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5686 | const memory::dims &dilates, const memory::dims &padding_l, |
5687 | const memory::dims &padding_r, |
5688 | const deconvolution_forward::primitive_desc &hint_fwd_pd, |
5689 | const primitive_attr &attr = default_attr(), |
5690 | bool allow_empty = false) |
5691 | : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc, |
5692 | diff_dst_desc, strides, &dilates, padding_l, padding_r, |
5693 | hint_fwd_pd, attr, allow_empty) {} |
5694 | |
5695 | /// Constructs a primitive descriptor for a deconvolution backward |
5696 | /// propagation primitive from a C API primitive descriptor that must |
5697 | /// have a matching kind. |
5698 | /// |
5699 | /// @param pd C API primitive descriptor for a deconvolution backward |
5700 | /// propagation primitive. |
5701 | primitive_desc(dnnl_primitive_desc_t pd) |
5702 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution, |
5703 | dnnl::prop_kind::backward_data) {} |
5704 | |
5705 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
5706 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
5707 | |
5708 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
5709 | memory::desc weights_desc() const { return base::weights_desc(0); } |
5710 | |
5711 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
5712 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
5713 | |
5714 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
5715 | algorithm get_algorithm() const { return base::get_algorithm(); } |
5716 | |
5717 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
5718 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
5719 | |
5720 | /// @copydoc dnnl::primitive_desc_base::get_strides()const |
5721 | memory::dims get_strides() const { return base::get_strides(); } |
5722 | |
5723 | /// @copydoc dnnl::primitive_desc_base::get_dilations()const |
5724 | memory::dims get_dilations() const { return base::get_dilations(); } |
5725 | |
5726 | /// @copydoc dnnl::primitive_desc_base::get_padding_l()const |
5727 | memory::dims get_padding_l() const { return base::get_padding_l(); } |
5728 | |
5729 | /// @copydoc dnnl::primitive_desc_base::get_padding_r()const |
5730 | memory::dims get_padding_r() const { return base::get_padding_r(); } |
5731 | |
5732 | private: |
5733 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5734 | const memory::desc &diff_src_desc, |
5735 | const memory::desc &weights_desc, |
5736 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5737 | const memory::dims *dilates, const memory::dims &padding_l, |
5738 | const memory::dims &padding_r, |
5739 | const deconvolution_forward::primitive_desc &hint_fwd_pd, |
5740 | const primitive_attr &attr, bool allow_empty) { |
5741 | |
5742 | memory::validate_dims(strides, diff_src_desc.get_ndims() - 2); |
5743 | memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2); |
5744 | memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2); |
5745 | |
5746 | if (dilates) |
5747 | memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2); |
5748 | |
5749 | dnnl_primitive_desc_t pd = nullptr; |
5750 | dnnl_status_t status |
5751 | = dnnl_deconvolution_backward_data_primitive_desc_create( |
5752 | &pd, aengine.get(), convert_to_c(aalgorithm), |
5753 | diff_src_desc.get(), weights_desc.get(), |
5754 | diff_dst_desc.get(), &strides[0], |
5755 | optional_arg(dilates), &padding_l[0], &padding_r[0], |
5756 | hint_fwd_pd.get(), attr.get()); |
5757 | if (!allow_empty) |
5758 | error::wrap_c_api(status, |
5759 | "could not create a primitive descriptor for a " |
5760 | "deconvolution backward propagation primitive" ); |
5761 | reset(pd); |
5762 | } |
5763 | }; |
5764 | |
5765 | /// Default constructor. Produces an empty object. |
5766 | deconvolution_backward_data() = default; |
5767 | |
5768 | /// Constructs a deconvolution backward propagation primitive. |
5769 | /// @param pd Primitive descriptor for a deconvolution backward propagation |
5770 | /// primitive. |
5771 | deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {} |
5772 | |
5773 | /// Constructs a deconvolution backward propagation primitive from a cache |
5774 | /// blob. |
5775 | /// @param pd Primitive descriptor for a deconvolution backward propagation |
5776 | /// primitive. |
5777 | /// @param cache_blob Cache blob. |
5778 | deconvolution_backward_data( |
5779 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
5780 | : primitive(pd, cache_blob) {} |
5781 | }; |
5782 | |
5783 | /// Deconvolution weights gradient primitive. |
5784 | struct deconvolution_backward_weights : public primitive { |
5785 | /// Primitive descriptor for a deconvolution weights gradient primitive. |
5786 | struct primitive_desc : public dnnl::primitive_desc { |
5787 | /// Default constructor. Produces an empty object. |
5788 | primitive_desc() = default; |
5789 | |
5790 | /// Constructs a primitive descriptor for a deconvolution weights |
5791 | /// gradient primitive with bias. |
5792 | /// |
5793 | /// @note |
5794 | /// All the memory descriptors may be initialized with the |
5795 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5796 | /// |
5797 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
5798 | /// for spatial dimensions only and hence must have the same number of |
5799 | /// elements as there are spatial dimensions. The order of values is |
5800 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
5801 | /// and 2D tensors), and width. |
5802 | /// |
5803 | /// @param aengine Engine to use. |
5804 | /// @param aalgorithm Deconvolution algorithm. Possible values are |
5805 | /// #dnnl::algorithm::deconvolution_direct, and |
5806 | /// #dnnl::algorithm::deconvolution_winograd. |
5807 | /// @param src_desc Source memory descriptor. |
5808 | /// @param diff_weights_desc Diff weights memory descriptor. |
5809 | /// @param diff_bias_desc Diff bias memory descriptor. Passing zero |
5810 | /// memory descriptor disables the bias term. |
5811 | /// @param diff_dst_desc Diff destination memory descriptor. |
5812 | /// @param strides Strides for each spatial dimension. |
5813 | /// @param padding_l Vector of padding values for low indices for each |
5814 | /// spatial dimension `([[front,] top,] left)`. |
5815 | /// @param padding_r Vector of padding values for high indices for |
5816 | /// each spatial dimension `([[back,] bottom,] right)`. |
5817 | /// @param hint_fwd_pd Primitive descriptor for a deconvolution |
5818 | /// forward propagation primitive. It is used as a hint for |
5819 | /// deciding which memory format to use. |
5820 | /// @param attr Primitive attributes to use. Attributes are optional |
5821 | /// and default to empty attributes. |
5822 | /// @param allow_empty A flag signifying whether construction is |
5823 | /// allowed to fail without throwing an exception. In this case an |
5824 | /// empty object will be produced. This flag is optional and |
5825 | /// defaults to false. |
5826 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5827 | const memory::desc &src_desc, |
5828 | const memory::desc &diff_weights_desc, |
5829 | const memory::desc &diff_bias_desc, |
5830 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5831 | const memory::dims &padding_l, const memory::dims &padding_r, |
5832 | const deconvolution_forward::primitive_desc &hint_fwd_pd, |
5833 | const primitive_attr &attr = default_attr(), |
5834 | bool allow_empty = false) |
5835 | : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, |
5836 | &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l, |
5837 | padding_r, hint_fwd_pd, attr, allow_empty) {} |
5838 | |
5839 | /// Constructs a primitive descriptor for a deconvolution weights |
5840 | /// gradient primitive without bias. |
5841 | /// |
5842 | /// @note |
5843 | /// All the memory descriptors may be initialized with the |
5844 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5845 | /// |
5846 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values |
5847 | /// for spatial dimensions only and hence must have the same number of |
5848 | /// elements as there are spatial dimensions. The order of values is |
5849 | /// the same as in the tensor: depth (for 3D tensors), height (for 3D |
5850 | /// and 2D tensors), and width. |
5851 | /// |
5852 | /// @param aengine Engine to use. |
5853 | /// @param aalgorithm Deconvolution algorithm. Possible values are |
5854 | /// #dnnl::algorithm::deconvolution_direct, and |
5855 | /// #dnnl::algorithm::deconvolution_winograd. |
5856 | /// @param src_desc Source memory descriptor. |
5857 | /// @param diff_weights_desc Diff weights memory descriptor. |
5858 | /// @param diff_dst_desc Diff destination memory descriptor. |
5859 | /// @param strides Strides for each spatial dimension. |
5860 | /// @param padding_l Vector of padding values for low indices for each |
5861 | /// spatial dimension `([[front,] top,] left)`. |
5862 | /// @param padding_r Vector of padding values for high indices for |
5863 | /// each spatial dimension `([[back,] bottom,] right)`. |
5864 | /// @param hint_fwd_pd Primitive descriptor for a deconvolution |
5865 | /// forward propagation primitive. It is used as a hint for |
5866 | /// deciding which memory format to use. |
5867 | /// @param attr Primitive attributes to use. Attributes are optional |
5868 | /// and default to empty attributes. |
5869 | /// @param allow_empty A flag signifying whether construction is |
5870 | /// allowed to fail without throwing an exception. In this case an |
5871 | /// empty object will be produced. This flag is optional and |
5872 | /// defaults to false. |
5873 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5874 | const memory::desc &src_desc, |
5875 | const memory::desc &diff_weights_desc, |
5876 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5877 | const memory::dims &padding_l, const memory::dims &padding_r, |
5878 | const deconvolution_forward::primitive_desc &hint_fwd_pd, |
5879 | const primitive_attr &attr = default_attr(), |
5880 | bool allow_empty = false) |
5881 | : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, |
5882 | nullptr, diff_dst_desc, strides, nullptr, padding_l, |
5883 | padding_r, hint_fwd_pd, attr, allow_empty) {} |
5884 | |
5885 | /// Constructs a primitive descriptor for a deconvolution weights |
5886 | /// gradient primitive with bias. |
5887 | /// |
5888 | /// @note |
5889 | /// All the memory descriptors may be initialized with the |
5890 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5891 | /// |
5892 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
5893 | /// contain values for spatial dimensions only and hence must have the |
5894 | /// same number of elements as there are spatial dimensions. The order |
5895 | /// of values is the same as in the tensor: depth (for 3D tensors), |
5896 | /// height (for 3D and 2D tensors), and width. |
5897 | /// |
5898 | /// @param aengine Engine to use. |
5899 | /// @param aalgorithm Deconvolution algorithm. Possible values are |
5900 | /// #dnnl::algorithm::deconvolution_direct, and |
5901 | /// #dnnl::algorithm::deconvolution_winograd. |
5902 | /// @param src_desc Source memory descriptor. |
5903 | /// @param diff_weights_desc Diff weights memory descriptor. |
5904 | /// @param diff_bias_desc Diff bias memory descriptor. Passing zero |
5905 | /// memory descriptor disables the bias term. |
5906 | /// @param diff_dst_desc Diff destination memory descriptor. |
5907 | /// @param strides Strides for each spatial dimension. |
5908 | /// @param dilates Dilations for each spatial dimension. A zero value |
5909 | /// means no dilation in the corresponding dimension. |
5910 | /// @param padding_l Vector of padding values for low indices for each |
5911 | /// spatial dimension `([[front,] top,] left)`. |
5912 | /// @param padding_r Vector of padding values for high indices for |
5913 | /// each spatial dimension `([[back,] bottom,] right)`. |
5914 | /// @param hint_fwd_pd Primitive descriptor for a deconvolution |
5915 | /// forward propagation primitive. It is used as a hint for |
5916 | /// deciding which memory format to use. |
5917 | /// @param attr Primitive attributes to use. Attributes are optional |
5918 | /// and default to empty attributes. |
5919 | /// @param allow_empty A flag signifying whether construction is |
5920 | /// allowed to fail without throwing an exception. In this case an |
5921 | /// empty object will be produced. This flag is optional and |
5922 | /// defaults to false. |
5923 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5924 | const memory::desc &src_desc, |
5925 | const memory::desc &diff_weights_desc, |
5926 | const memory::desc &diff_bias_desc, |
5927 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5928 | const memory::dims &dilates, const memory::dims &padding_l, |
5929 | const memory::dims &padding_r, |
5930 | const deconvolution_forward::primitive_desc &hint_fwd_pd, |
5931 | const primitive_attr &attr = default_attr(), |
5932 | bool allow_empty = false) |
5933 | : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, |
5934 | &diff_bias_desc, diff_dst_desc, strides, &dilates, |
5935 | padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {} |
5936 | |
5937 | /// Constructs a primitive descriptor for a deconvolution weights |
5938 | /// gradient primitive without bias. |
5939 | /// |
5940 | /// @note |
5941 | /// All the memory descriptors may be initialized with the |
5942 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
5943 | /// |
5944 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r |
5945 | /// contain values for spatial dimensions only and hence must have the |
5946 | /// same number of elements as there are spatial dimensions. The order |
5947 | /// of values is the same as in the tensor: depth (for 3D tensors), |
5948 | /// height (for 3D and 2D tensors), and width. |
5949 | /// |
5950 | /// @param aengine Engine to use. |
5951 | /// @param aalgorithm Deconvolution algorithm. Possible values are |
5952 | /// #dnnl::algorithm::deconvolution_direct, and |
5953 | /// #dnnl::algorithm::deconvolution_winograd. |
5954 | /// @param src_desc Source memory descriptor. |
5955 | /// @param diff_weights_desc Diff weights memory descriptor. |
5956 | /// @param diff_dst_desc Diff destination memory descriptor. |
5957 | /// @param strides Strides for each spatial dimension. |
5958 | /// @param dilates Dilations for each spatial dimension. A zero value |
5959 | /// means no dilation in the corresponding dimension. |
5960 | /// @param padding_l Vector of padding values for low indices for each |
5961 | /// spatial dimension `([[front,] top,] left)`. |
5962 | /// @param padding_r Vector of padding values for high indices for |
5963 | /// each spatial dimension `([[back,] bottom,] right)`. |
5964 | /// @param hint_fwd_pd Primitive descriptor for a deconvolution |
5965 | /// forward propagation primitive. It is used as a hint for |
5966 | /// deciding which memory format to use. |
5967 | /// @param attr Primitive attributes to use. Attributes are optional |
5968 | /// and default to empty attributes. |
5969 | /// @param allow_empty A flag signifying whether construction is |
5970 | /// allowed to fail without throwing an exception. In this case an |
5971 | /// empty object will be produced. This flag is optional and |
5972 | /// defaults to false. |
5973 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
5974 | const memory::desc &src_desc, |
5975 | const memory::desc &diff_weights_desc, |
5976 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
5977 | const memory::dims &dilates, const memory::dims &padding_l, |
5978 | const memory::dims &padding_r, |
5979 | const deconvolution_forward::primitive_desc &hint_fwd_pd, |
5980 | const primitive_attr &attr = default_attr(), |
5981 | bool allow_empty = false) |
5982 | : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, |
5983 | nullptr, diff_dst_desc, strides, &dilates, padding_l, |
5984 | padding_r, hint_fwd_pd, attr, allow_empty) {} |
5985 | |
5986 | /// Constructs a primitive descriptor for a deconvolution weights |
5987 | /// gradient primitive from a C API primitive descriptor that must |
5988 | /// have a matching kind. |
5989 | /// |
5990 | /// @param pd C API primitive descriptor for a deconvolution weights |
5991 | /// gradient primitive. |
5992 | primitive_desc(dnnl_primitive_desc_t pd) |
5993 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution, |
5994 | dnnl::prop_kind::backward_weights) {} |
5995 | |
5996 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
5997 | memory::desc src_desc() const { return base::src_desc(0); } |
5998 | |
5999 | /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const |
6000 | memory::desc diff_weights_desc() const { |
6001 | return base::diff_weights_desc(0); |
6002 | } |
6003 | |
6004 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
6005 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
6006 | |
6007 | /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const |
6008 | memory::desc diff_bias_desc() const { |
6009 | return base::diff_weights_desc(1); |
6010 | } |
6011 | |
6012 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
6013 | algorithm get_algorithm() const { return base::get_algorithm(); } |
6014 | |
6015 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
6016 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
6017 | |
6018 | /// @copydoc dnnl::primitive_desc_base::get_strides()const |
6019 | memory::dims get_strides() const { return base::get_strides(); } |
6020 | |
6021 | /// @copydoc dnnl::primitive_desc_base::get_dilations()const |
6022 | memory::dims get_dilations() const { return base::get_dilations(); } |
6023 | |
6024 | /// @copydoc dnnl::primitive_desc_base::get_padding_l()const |
6025 | memory::dims get_padding_l() const { return base::get_padding_l(); } |
6026 | |
6027 | /// @copydoc dnnl::primitive_desc_base::get_padding_r()const |
6028 | memory::dims get_padding_r() const { return base::get_padding_r(); } |
6029 | |
6030 | private: |
6031 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
6032 | const memory::desc &src_desc, |
6033 | const memory::desc &diff_weights_desc, |
6034 | const memory::desc *diff_bias_desc, |
6035 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
6036 | const memory::dims *dilates, const memory::dims &padding_l, |
6037 | const memory::dims &padding_r, |
6038 | const deconvolution_forward::primitive_desc &hint_fwd_pd, |
6039 | const primitive_attr &attr, bool allow_empty) { |
6040 | |
6041 | memory::validate_dims(strides, src_desc.get_ndims() - 2); |
6042 | memory::validate_dims(padding_l, src_desc.get_ndims() - 2); |
6043 | memory::validate_dims(padding_r, src_desc.get_ndims() - 2); |
6044 | |
6045 | if (dilates) |
6046 | memory::validate_dims(*dilates, src_desc.get_ndims() - 2); |
6047 | |
6048 | dnnl_primitive_desc_t pd = nullptr; |
6049 | dnnl_status_t status |
6050 | = dnnl_deconvolution_backward_weights_primitive_desc_create( |
6051 | &pd, aengine.get(), convert_to_c(aalgorithm), |
6052 | src_desc.get(), diff_weights_desc.get(), |
6053 | optional_arg(diff_bias_desc), diff_dst_desc.get(), |
6054 | &strides[0], optional_arg(dilates), &padding_l[0], |
6055 | &padding_r[0], hint_fwd_pd.get(), attr.get()); |
6056 | if (!allow_empty) |
6057 | error::wrap_c_api(status, |
6058 | "could not create a primitive descriptor for a " |
6059 | "deconvolution weights update primitive" ); |
6060 | reset(pd); |
6061 | } |
6062 | }; |
6063 | |
6064 | /// Default constructor. Produces an empty object. |
6065 | deconvolution_backward_weights() = default; |
6066 | |
6067 | /// Constructs a deconvolution weights gradient primitive. |
6068 | /// @param pd Primitive descriptor for a deconvolution weights gradient |
6069 | /// primitive. |
6070 | deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {} |
6071 | |
6072 | /// Constructs a deconvolution weights gradient primitive from a cache |
6073 | /// blob. |
6074 | /// @param pd Primitive descriptor for a deconvolution weights gradient |
6075 | /// primitive. |
6076 | /// @param cache_blob Cache blob. |
6077 | deconvolution_backward_weights( |
6078 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
6079 | : primitive(pd, cache_blob) {} |
6080 | }; |
6081 | |
6082 | /// @} dnnl_api_deconvolution |
6083 | |
6084 | /// @addtogroup dnnl_api_lrn LRN |
6085 | /// |
6086 | /// A primitive to perform local response normalization (LRN) across or within |
6087 | /// channels. |
6088 | /// |
6089 | /// @sa @ref dev_guide_lrn in developer guide |
6090 | /// |
6091 | /// @{ |
6092 | |
6093 | /// Local response normalization (LRN) forward propagation primitive. |
6094 | struct lrn_forward : public primitive { |
6095 | /// Primitive descriptor for an LRN forward propagation primitive. |
6096 | struct primitive_desc : public dnnl::primitive_desc { |
6097 | /// Default constructor. Produces an empty object. |
6098 | primitive_desc() = default; |
6099 | |
6100 | /// Constructs a primitive descriptor for an LRN forward propagation |
6101 | /// primitive. |
6102 | /// |
6103 | /// @param aengine Engine to use. |
6104 | /// @param aprop_kind Propagation kind. Possible values are |
6105 | /// #dnnl::prop_kind::forward_training, and |
6106 | /// #dnnl::prop_kind::forward_inference. |
6107 | /// @param aalgorithm LRN algorithm kind: either |
6108 | /// #dnnl::algorithm::lrn_across_channels, or |
6109 | /// #dnnl::algorithm::lrn_within_channel. |
6110 | /// @param src_desc Source memory descriptor. |
6111 | /// @param dst_desc Destination memory descriptor. |
6112 | /// @param local_size Regularization local size. |
6113 | /// @param alpha The alpha regularization parameter. |
6114 | /// @param beta The beta regularization parameter. |
6115 | /// @param k The k regularization parameter. |
6116 | /// @param attr Primitive attributes to use. Attributes are optional |
6117 | /// and default to empty attributes. |
6118 | /// @param allow_empty A flag signifying whether construction is |
6119 | /// allowed to fail without throwing an exception. In this case an |
6120 | /// empty object will be produced. This flag is optional and |
6121 | /// defaults to false. |
6122 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
6123 | algorithm aalgorithm, const memory::desc &src_desc, |
6124 | const memory::desc &dst_desc, memory::dim local_size, |
6125 | float alpha, float beta, float k, |
6126 | const primitive_attr &attr = default_attr(), |
6127 | bool allow_empty = false) { |
6128 | |
6129 | dnnl_primitive_desc_t pd = nullptr; |
6130 | dnnl_status_t status = dnnl_lrn_forward_primitive_desc_create(&pd, |
6131 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
6132 | convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(), |
6133 | local_size, alpha, beta, k, attr.get()); |
6134 | |
6135 | if (!allow_empty) |
6136 | error::wrap_c_api(status, |
6137 | "could not create a primitive descriptor for a lrn " |
6138 | "forward propagation primitive" ); |
6139 | reset(pd); |
6140 | } |
6141 | |
6142 | /// Constructs a primitive descriptor for an LRN forward propagation |
6143 | /// primitive from a C API primitive descriptor that must have a |
6144 | /// matching kind. |
6145 | /// |
6146 | /// @param pd C API primitive descriptor for an LRN forward |
6147 | /// propagation primitive. |
6148 | primitive_desc(dnnl_primitive_desc_t pd) |
6149 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn, |
6150 | dnnl::prop_kind::forward_training, |
6151 | dnnl::prop_kind::forward_inference) {} |
6152 | |
6153 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
6154 | memory::desc src_desc() const { return base::src_desc(0); } |
6155 | |
6156 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
6157 | memory::desc dst_desc() const { return base::dst_desc(0); } |
6158 | |
6159 | /// @copydoc dnnl::primitive_desc_base::workspace_desc()const |
6160 | memory::desc workspace_desc() const { return base::workspace_desc(); } |
6161 | |
6162 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
6163 | algorithm get_algorithm() const { return base::get_algorithm(); } |
6164 | |
6165 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
6166 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
6167 | |
6168 | /// @copydoc dnnl::primitive_desc_base::get_alpha()const |
6169 | float get_alpha() const { return base::get_alpha(); } |
6170 | |
6171 | /// @copydoc dnnl::primitive_desc_base::get_beta()const |
6172 | float get_beta() const { return base::get_beta(); } |
6173 | |
6174 | /// @copydoc dnnl::primitive_desc_base::get_local_size()const |
6175 | memory::dim get_local_size() const { return base::get_local_size(); } |
6176 | |
6177 | /// @copydoc dnnl::primitive_desc_base::get_k()const |
6178 | float get_k() const { return base::get_k(); } |
6179 | }; |
6180 | |
6181 | /// Default constructor. Produces an empty object. |
6182 | lrn_forward() = default; |
6183 | |
6184 | /// Constructs an LRN forward propagation primitive. |
6185 | /// @param pd Primitive descriptor for an LRN forward propagation |
6186 | /// primitive. |
6187 | lrn_forward(const primitive_desc &pd) : primitive(pd) {} |
6188 | |
6189 | /// Constructs an LRN forward propagation primitive from a cache blob. |
6190 | /// @param pd Primitive descriptor for an LRN forward propagation |
6191 | /// primitive. |
6192 | /// @param cache_blob Cache blob. |
6193 | lrn_forward( |
6194 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
6195 | : primitive(pd, cache_blob) {} |
6196 | }; |
6197 | |
6198 | /// Local response normalization (LRN) backward propagation primitive. |
6199 | struct lrn_backward : public primitive { |
6200 | /// Primitive descriptor for an LRN backward propagation primitive. |
6201 | struct primitive_desc : public dnnl::primitive_desc { |
6202 | /// Default constructor. Produces an empty object. |
6203 | primitive_desc() = default; |
6204 | |
6205 | /// Constructs a primitive descriptor for an LRN backward propagation |
6206 | /// primitive. |
6207 | /// |
6208 | /// @param aengine Engine to use. |
6209 | /// @param aalgorithm LRN algorithm kind: either |
6210 | /// #dnnl::algorithm::lrn_across_channels, or |
6211 | /// #dnnl::algorithm::lrn_within_channel. |
6212 | /// @param diff_src_desc Diff source memory descriptor. |
6213 | /// @param diff_dst_desc Diff destination memory descriptor. |
6214 | /// @param src_desc Source memory descriptor. |
6215 | /// @param local_size Regularization local size. |
6216 | /// @param alpha The alpha regularization parameter. |
6217 | /// @param beta The beta regularization parameter. |
6218 | /// @param k The k regularization parameter. |
6219 | /// @param hint_fwd_pd Primitive descriptor for an LRN forward |
6220 | /// propagation primitive. It is used as a hint for deciding which |
6221 | /// memory format to use. |
6222 | /// @param attr Primitive attributes to use. Attributes are optional |
6223 | /// and default to empty attributes. |
6224 | /// @param allow_empty A flag signifying whether construction is |
6225 | /// allowed to fail without throwing an exception. In this case an |
6226 | /// empty object will be produced. This flag is optional and |
6227 | /// defaults to false. |
6228 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
6229 | const memory::desc &diff_src_desc, |
6230 | const memory::desc &diff_dst_desc, const memory::desc &src_desc, |
6231 | memory::dim local_size, float alpha, float beta, float k, |
6232 | const lrn_forward::primitive_desc &hint_fwd_pd, |
6233 | const primitive_attr &attr = default_attr(), |
6234 | bool allow_empty = false) { |
6235 | |
6236 | dnnl_primitive_desc_t pd = nullptr; |
6237 | dnnl_status_t status = dnnl_lrn_backward_primitive_desc_create(&pd, |
6238 | aengine.get(), convert_to_c(aalgorithm), |
6239 | diff_src_desc.get(), diff_dst_desc.get(), src_desc.get(), |
6240 | local_size, alpha, beta, k, hint_fwd_pd.get(), attr.get()); |
6241 | |
6242 | if (!allow_empty) |
6243 | error::wrap_c_api(status, |
6244 | "could not create a primitive descriptor for a lrn " |
6245 | "backward propagation primitive" ); |
6246 | reset(pd); |
6247 | } |
6248 | |
6249 | /// Constructs a primitive descriptor for an LRN backward propagation |
6250 | /// primitive from a C API primitive descriptor that must have a |
6251 | /// matching kind. |
6252 | /// |
6253 | /// @param pd C API primitive descriptor for an LRN backward |
6254 | /// propagation primitive. |
6255 | primitive_desc(dnnl_primitive_desc_t pd) |
6256 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn, |
6257 | dnnl::prop_kind::backward_data) {} |
6258 | |
6259 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
6260 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
6261 | |
6262 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
6263 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
6264 | |
6265 | /// @copydoc dnnl::primitive_desc_base::workspace_desc()const |
6266 | memory::desc workspace_desc() const { return base::workspace_desc(); } |
6267 | |
6268 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
6269 | algorithm get_algorithm() const { return base::get_algorithm(); } |
6270 | |
6271 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
6272 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
6273 | |
6274 | /// @copydoc dnnl::primitive_desc_base::get_alpha()const |
6275 | float get_alpha() const { return base::get_alpha(); } |
6276 | |
6277 | /// @copydoc dnnl::primitive_desc_base::get_beta()const |
6278 | float get_beta() const { return base::get_beta(); } |
6279 | |
6280 | /// @copydoc dnnl::primitive_desc_base::get_local_size()const |
6281 | memory::dim get_local_size() const { return base::get_local_size(); } |
6282 | |
6283 | /// @copydoc dnnl::primitive_desc_base::get_k()const |
6284 | float get_k() const { return base::get_k(); } |
6285 | }; |
6286 | |
6287 | /// Default constructor. Produces an empty object. |
6288 | lrn_backward() = default; |
6289 | |
6290 | /// Constructs an LRN backward propagation primitive. |
6291 | /// @param pd Primitive descriptor for an LRN backward propagation |
6292 | /// primitive. |
6293 | lrn_backward(const primitive_desc &pd) : primitive(pd) {} |
6294 | |
6295 | /// Constructs an LRN backward propagation primitive from a cache blob. |
6296 | /// @param pd Primitive descriptor for an LRN backward propagation |
6297 | /// primitive. |
6298 | /// @param cache_blob Cache blob. |
6299 | lrn_backward( |
6300 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
6301 | : primitive(pd, cache_blob) {} |
6302 | }; |
6303 | |
6304 | /// @} dnnl_api_lrn |
6305 | |
6306 | /// @addtogroup dnnl_api_eltwise Eltwise |
6307 | /// |
6308 | /// A primitive to perform elementwise operations such as the |
6309 | /// rectifier linear unit (ReLU). |
6310 | /// |
6311 | /// Both forward and backward propagation primitives support in-place |
6312 | /// operation; that is, src and dst can refer to the same memory for forward |
6313 | /// propagation, and diff_dst and diff_src can refer to the same memory for |
6314 | /// backward propagation. |
6315 | /// |
6316 | /// @warning |
6317 | /// Because the original source data is required for backward propagation, |
6318 | /// in-place forward propagation is not generally supported in the |
6319 | /// training mode. However, for algorithms supporting destination as input |
6320 | /// memory, dst can be used for the backward propagation, which makes it |
6321 | /// possible to get performance benefit even in the training mode. |
6322 | /// |
6323 | /// @sa @ref dev_guide_eltwise in developer guide |
6324 | /// |
6325 | /// @{ |
6326 | |
6327 | /// Elementwise unary operation forward propagation primitive. |
6328 | struct eltwise_forward : public primitive { |
6329 | /// Primitive descriptor for an elementwise forward propagation primitive. |
6330 | struct primitive_desc : public dnnl::primitive_desc { |
6331 | /// Default constructor. Produces an empty object. |
6332 | primitive_desc() = default; |
6333 | |
6334 | /// Constructs a primitive descriptor for an elementwise forward |
6335 | /// propagation primitive. |
6336 | /// |
6337 | /// @param aengine Engine to use. |
6338 | /// @param aprop_kind Propagation kind. Possible values are |
6339 | /// #dnnl::prop_kind::forward_training, and |
6340 | /// #dnnl::prop_kind::forward_inference. |
6341 | /// @param aalgorithm Elementwise algorithm kind. |
6342 | /// @param src_desc Source memory descriptor. |
6343 | /// @param dst_desc Destination memory descriptor. |
6344 | /// @param attr Primitive attributes to use. Attributes are optional |
6345 | /// and default to empty attributes. |
6346 | /// @param allow_empty A flag signifying whether construction is |
6347 | /// allowed to fail without throwing an exception. In this case an |
6348 | /// empty object will be produced. This flag is optional and |
6349 | /// defaults to false. |
6350 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
6351 | algorithm aalgorithm, const memory::desc &src_desc, |
6352 | const memory::desc &dst_desc, |
6353 | const primitive_attr &attr = default_attr(), |
6354 | bool allow_empty = false) |
6355 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
6356 | dst_desc, nullptr, nullptr, attr, allow_empty) {} |
6357 | |
6358 | /// Constructs a primitive descriptor for an elementwise forward |
6359 | /// propagation primitive with an alpha parameter. |
6360 | /// |
6361 | /// @param aengine Engine to use. |
6362 | /// @param aprop_kind Propagation kind. Possible values are |
6363 | /// #dnnl::prop_kind::forward_training, and |
6364 | /// #dnnl::prop_kind::forward_inference. |
6365 | /// @param aalgorithm Elementwise algorithm kind. |
6366 | /// @param src_desc Source memory descriptor. |
6367 | /// @param dst_desc Destination memory descriptor. |
6368 | /// @param alpha The alpha parameter for the elementwise operation. |
6369 | /// Specific meaning depends on the algorithm. |
6370 | /// @param attr Primitive attributes to use. Attributes are optional |
6371 | /// and default to empty attributes. |
6372 | /// @param allow_empty A flag signifying whether construction is |
6373 | /// allowed to fail without throwing an exception. In this case an |
6374 | /// empty object will be produced. This flag is optional and |
6375 | /// defaults to false. |
6376 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
6377 | algorithm aalgorithm, const memory::desc &src_desc, |
6378 | const memory::desc &dst_desc, float alpha, |
6379 | const primitive_attr &attr = default_attr(), |
6380 | bool allow_empty = false) |
6381 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
6382 | dst_desc, &alpha, nullptr, attr, allow_empty) {} |
6383 | |
6384 | /// Constructs a primitive descriptor for an elementwise forward |
6385 | /// propagation primitive with an alpha and beta parameters. |
6386 | /// |
6387 | /// @param aengine Engine to use. |
6388 | /// @param aprop_kind Propagation kind. Possible values are |
6389 | /// #dnnl::prop_kind::forward_training, and |
6390 | /// #dnnl::prop_kind::forward_inference. |
6391 | /// @param aalgorithm Elementwise algorithm kind. |
6392 | /// @param src_desc Source memory descriptor. |
6393 | /// @param dst_desc Destination memory descriptor. |
6394 | /// @param alpha The alpha parameter for the elementwise operation. |
6395 | /// Specific meaning depends on the algorithm. |
6396 | /// @param beta The beta parameter for the elementwise operation. |
6397 | /// Specific meaning depends on the algorithm. |
6398 | /// @param attr Primitive attributes to use. Attributes are optional |
6399 | /// and default to empty attributes. |
6400 | /// @param allow_empty A flag signifying whether construction is |
6401 | /// allowed to fail without throwing an exception. In this case an |
6402 | /// empty object will be produced. This flag is optional and |
6403 | /// defaults to false. |
6404 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
6405 | algorithm aalgorithm, const memory::desc &src_desc, |
6406 | const memory::desc &dst_desc, float alpha, float beta, |
6407 | const primitive_attr &attr = default_attr(), |
6408 | bool allow_empty = false) |
6409 | : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, |
6410 | dst_desc, &alpha, &beta, attr, allow_empty) {} |
6411 | |
6412 | /// Constructs a primitive descriptor for an eltwise forward |
6413 | /// propagation primitive from a C API primitive descriptor that must |
6414 | /// have a matching kind. |
6415 | /// |
6416 | /// @param pd C API primitive descriptor for an eltwise forward |
6417 | /// propagation primitive. |
6418 | primitive_desc(dnnl_primitive_desc_t pd) |
6419 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise, |
6420 | dnnl::prop_kind::forward_training, |
6421 | dnnl::prop_kind::forward_inference) {} |
6422 | |
6423 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
6424 | memory::desc src_desc() const { return base::src_desc(0); } |
6425 | |
6426 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
6427 | memory::desc dst_desc() const { return base::dst_desc(0); } |
6428 | |
6429 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
6430 | dnnl::algorithm get_algorithm() const { return base::get_algorithm(); } |
6431 | |
6432 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
6433 | dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
6434 | |
6435 | /// @copydoc dnnl::primitive_desc_base::get_alpha()const |
6436 | float get_alpha() const { return base::get_alpha(); } |
6437 | |
6438 | /// @copydoc dnnl::primitive_desc_base::get_beta()const |
6439 | float get_beta() const { return base::get_beta(); } |
6440 | |
6441 | private: |
6442 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
6443 | algorithm aalgorithm, const memory::desc &src_desc, |
6444 | const memory::desc &dst_desc, const float *alpha, |
6445 | const float *beta, const primitive_attr &attr, |
6446 | bool allow_empty) { |
6447 | |
6448 | dnnl_primitive_desc_t pd = nullptr; |
6449 | dnnl_status_t status = dnnl_eltwise_forward_primitive_desc_create( |
6450 | &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), |
6451 | dnnl::convert_to_c(aalgorithm), src_desc.get(), |
6452 | dst_desc.get(), alpha ? *alpha : 0.0f, beta ? *beta : 0.0f, |
6453 | attr.get()); |
6454 | |
6455 | if (!allow_empty) |
6456 | error::wrap_c_api(status, |
6457 | "could not create a primitive descriptor for an " |
6458 | "eltwise forward propagation primitive" ); |
6459 | reset(pd); |
6460 | } |
6461 | }; |
6462 | |
6463 | /// Default constructor. Produces an empty object. |
6464 | eltwise_forward() = default; |
6465 | |
6466 | /// Constructs an eltwise forward propagation primitive. |
6467 | /// @param pd Primitive descriptor for an eltwise forward propagation |
6468 | /// primitive. |
6469 | eltwise_forward(const primitive_desc &pd) : primitive(pd) {} |
6470 | |
6471 | /// Constructs an eltwise forward propagation primitive from a cache blob. |
6472 | /// @param pd Primitive descriptor for an eltwise forward propagation |
6473 | /// primitive. |
6474 | /// @param cache_blob Cache blob. |
6475 | eltwise_forward( |
6476 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
6477 | : primitive(pd, cache_blob) {} |
6478 | }; |
6479 | |
6480 | /// Elementwise unary operation backward propagation primitive. |
6481 | struct eltwise_backward : public primitive { |
6482 | /// Primitive descriptor for eltwise backward propagation. |
6483 | struct primitive_desc : public dnnl::primitive_desc { |
6484 | /// Default constructor. Produces an empty object. |
6485 | primitive_desc() = default; |
6486 | |
6487 | /// Constructs a primitive descriptor for an elementwise backward |
6488 | /// propagation primitive with an alpha parameter. |
6489 | /// |
6490 | /// @param aengine Engine to use. |
6491 | /// @param aalgorithm Elementwise algorithm kind. |
6492 | /// @param diff_src_desc Diff source memory descriptor. |
6493 | /// @param diff_dst_desc Diff destination memory descriptor. |
6494 | /// @param data_desc Destination memory descriptor if one of the |
6495 | /// "use_dst_for_bwd" algorithms are used (such as |
6496 | /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor |
6497 | /// otherwise. |
6498 | /// @param hint_fwd_pd Primitive descriptor for an elementwise |
6499 | /// forward propagation primitive. It is used as a hint for |
6500 | /// deciding which memory format to use. |
6501 | /// @param attr Primitive attributes to use. Attributes are optional |
6502 | /// and default to empty attributes. |
6503 | /// @param allow_empty A flag signifying whether construction is |
6504 | /// allowed to fail without throwing an exception. In this case an |
6505 | /// empty object will be produced. This flag is optional and |
6506 | /// defaults to false. |
6507 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
6508 | const memory::desc &diff_src_desc, |
6509 | const memory::desc &diff_dst_desc, |
6510 | const memory::desc &data_desc, |
6511 | const eltwise_forward::primitive_desc &hint_fwd_pd, |
6512 | const primitive_attr &attr = default_attr(), |
6513 | bool allow_empty = false) |
6514 | : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc, |
6515 | data_desc, nullptr, nullptr, hint_fwd_pd, attr, |
6516 | allow_empty) {} |
6517 | |
6518 | /// Constructs a primitive descriptor for an elementwise backward |
6519 | /// propagation primitive with an alpha parameter. |
6520 | /// |
6521 | /// @param aengine Engine to use. |
6522 | /// @param aalgorithm Elementwise algorithm kind. |
6523 | /// @param diff_src_desc Diff source memory descriptor. |
6524 | /// @param diff_dst_desc Diff destination memory descriptor. |
6525 | /// @param data_desc Destination memory descriptor if one of the |
6526 | /// "use_dst_for_bwd" algorithms are used (such as |
6527 | /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor |
6528 | /// otherwise. |
6529 | /// @param alpha The alpha parameter for the elementwise operation. |
6530 | /// Specific meaning depends on the algorithm. |
6531 | /// @param hint_fwd_pd Primitive descriptor for an elementwise |
6532 | /// forward propagation primitive. It is used as a hint for |
6533 | /// deciding which memory format to use. |
6534 | /// @param attr Primitive attributes to use. Attributes are optional |
6535 | /// and default to empty attributes. |
6536 | /// @param allow_empty A flag signifying whether construction is |
6537 | /// allowed to fail without throwing an exception. In this case an |
6538 | /// empty object will be produced. This flag is optional and |
6539 | /// defaults to false. |
6540 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
6541 | const memory::desc &diff_src_desc, |
6542 | const memory::desc &diff_dst_desc, |
6543 | const memory::desc &data_desc, float alpha, |
6544 | const eltwise_forward::primitive_desc &hint_fwd_pd, |
6545 | const primitive_attr &attr = default_attr(), |
6546 | bool allow_empty = false) |
6547 | : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc, |
6548 | data_desc, &alpha, nullptr, hint_fwd_pd, attr, |
6549 | allow_empty) {} |
6550 | |
6551 | /// Constructs a primitive descriptor for an elementwise backward |
6552 | /// propagation primitive with an alpha and beta parameters. |
6553 | /// |
6554 | /// @param aengine Engine to use. |
6555 | /// @param aalgorithm Elementwise algorithm kind. |
6556 | /// @param diff_src_desc Diff source memory descriptor. |
6557 | /// @param diff_dst_desc Diff destination memory descriptor. |
6558 | /// @param data_desc Destination memory descriptor if one of the |
6559 | /// "use_dst_for_bwd" algorithms are used (such as |
6560 | /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor |
6561 | /// otherwise. |
6562 | /// @param alpha The alpha parameter for the elementwise operation. |
6563 | /// Specific meaning depends on the algorithm. |
6564 | /// @param beta The beta parameter for the elementwise operation. |
6565 | /// Specific meaning depends on the algorithm. |
6566 | /// @param hint_fwd_pd Primitive descriptor for an elementwise |
6567 | /// forward propagation primitive. It is used as a hint for |
6568 | /// deciding which memory format to use. |
6569 | /// @param attr Primitive attributes to use. Attributes are optional |
6570 | /// and default to empty attributes. |
6571 | /// @param allow_empty A flag signifying whether construction is |
6572 | /// allowed to fail without throwing an exception. In this case an |
6573 | /// empty object will be produced. This flag is optional and |
6574 | /// defaults to false. |
6575 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
6576 | const memory::desc &diff_src_desc, |
6577 | const memory::desc &diff_dst_desc, |
6578 | const memory::desc &data_desc, float alpha, float beta, |
6579 | const eltwise_forward::primitive_desc &hint_fwd_pd, |
6580 | const primitive_attr &attr = default_attr(), |
6581 | bool allow_empty = false) |
6582 | : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc, |
6583 | data_desc, &alpha, &beta, hint_fwd_pd, attr, allow_empty) {} |
6584 | |
6585 | /// Constructs a primitive descriptor for an eltwise backward |
6586 | /// propagation primitive from a C API primitive descriptor that must |
6587 | /// have a matching kind. |
6588 | /// |
6589 | /// @param pd C API primitive descriptor for an eltwise backward |
6590 | /// propagation primitive. |
6591 | primitive_desc(dnnl_primitive_desc_t pd) |
6592 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise, |
6593 | dnnl::prop_kind::backward_data) {} |
6594 | |
6595 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
6596 | memory::desc src_desc() const { return base::src_desc(0); } |
6597 | |
6598 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
6599 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
6600 | |
6601 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
6602 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
6603 | |
6604 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
6605 | dnnl::algorithm get_algorithm() const { return base::get_algorithm(); } |
6606 | |
6607 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
6608 | dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
6609 | |
6610 | /// @copydoc dnnl::primitive_desc_base::get_alpha()const |
6611 | float get_alpha() const { return base::get_alpha(); } |
6612 | |
6613 | /// @copydoc dnnl::primitive_desc_base::get_beta()const |
6614 | float get_beta() const { return base::get_beta(); } |
6615 | |
6616 | private: |
6617 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
6618 | const memory::desc &diff_src_desc, |
6619 | const memory::desc &diff_dst_desc, |
6620 | const memory::desc &data_desc, const float *alpha, |
6621 | const float *beta, |
6622 | const eltwise_forward::primitive_desc &hint_fwd_pd, |
6623 | const primitive_attr &attr, bool allow_empty) { |
6624 | |
6625 | dnnl_primitive_desc_t pd = nullptr; |
6626 | dnnl_status_t status = dnnl_eltwise_backward_primitive_desc_create( |
6627 | &pd, aengine.get(), dnnl::convert_to_c(aalgorithm), |
6628 | diff_src_desc.get(), diff_dst_desc.get(), data_desc.get(), |
6629 | alpha ? *alpha : 0.0f, beta ? *beta : 0.0f, |
6630 | hint_fwd_pd.get(), attr.get()); |
6631 | |
6632 | if (!allow_empty) |
6633 | error::wrap_c_api(status, |
6634 | "could not create a primitive descriptor for an " |
6635 | "eltwise backward propagation primitive" ); |
6636 | reset(pd); |
6637 | } |
6638 | }; |
6639 | |
6640 | /// Default constructor. Produces an empty object. |
6641 | eltwise_backward() = default; |
6642 | |
6643 | /// Constructs an eltwise backward propagation primitive. |
6644 | /// @param pd Primitive descriptor for an eltwise backward propagation |
6645 | /// primitive. |
6646 | eltwise_backward(const primitive_desc &pd) : primitive(pd) {} |
6647 | |
6648 | /// Constructs an eltwise backward propagation primitive from a cache blob. |
6649 | /// @param pd Primitive descriptor for an eltwise backward propagation |
6650 | /// primitive. |
6651 | /// @param cache_blob Cache blob. |
6652 | eltwise_backward( |
6653 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
6654 | : primitive(pd, cache_blob) {} |
6655 | }; |
6656 | |
6657 | /// @} dnnl_api_eltwise |
6658 | |
6659 | /// @addtogroup dnnl_api_softmax Softmax |
6660 | /// |
6661 | /// A primitive to perform softmax. |
6662 | /// |
6663 | /// @sa @ref dev_guide_softmax in developer guide |
6664 | /// |
6665 | /// @{ |
6666 | |
6667 | /// Softmax forward propagation primitive. |
6668 | struct softmax_forward : public primitive { |
6669 | /// Primitive descriptor for a softmax forward propagation primitive. |
6670 | struct primitive_desc : public dnnl::primitive_desc { |
6671 | /// Default constructor. Produces an empty object. |
6672 | primitive_desc() = default; |
6673 | |
6674 | /// Constructs a primitive descriptor for a softmax forward propagation |
6675 | /// primitive. |
6676 | /// |
6677 | /// @param aengine Engine to use. |
6678 | /// @param aprop_kind Propagation kind. Possible values are |
6679 | /// #dnnl::prop_kind::forward_training, and |
6680 | /// #dnnl::prop_kind::forward_inference. |
6681 | /// @param aalgorithm Softmax algorithm kind: either |
6682 | /// #dnnl::algorithm::softmax_accurate, |
6683 | /// or #dnnl::algorithm::softmax_log. |
6684 | /// @param src_desc Source memory descriptor. |
6685 | /// @param dst_desc Destination memory descriptor. |
6686 | /// @param axis Axis over which softmax is computed. |
6687 | /// @param attr Primitive attributes to use. Attributes are optional |
6688 | /// and default to empty attributes. |
6689 | /// @param allow_empty A flag signifying whether construction is |
6690 | /// allowed to fail without throwing an exception. In this case an |
6691 | /// empty object will be produced. This flag is optional and |
6692 | /// defaults to false. |
6693 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
6694 | algorithm aalgorithm, const memory::desc &src_desc, |
6695 | const memory::desc &dst_desc, int axis, |
6696 | const primitive_attr &attr = default_attr(), |
6697 | bool allow_empty = false) { |
6698 | |
6699 | dnnl_primitive_desc_t pd = nullptr; |
6700 | dnnl_status_t status = dnnl_softmax_forward_primitive_desc_create( |
6701 | &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), |
6702 | dnnl::convert_to_c(aalgorithm), src_desc.get(), |
6703 | dst_desc.get(), axis, attr.get()); |
6704 | |
6705 | if (!allow_empty) |
6706 | error::wrap_c_api(status, |
6707 | "could not create a primitive descriptor for a softmax " |
6708 | "forward propagation primitive" ); |
6709 | reset(pd); |
6710 | } |
6711 | |
6712 | /// Constructs a primitive descriptor for a softmax forward |
6713 | /// propagation primitive from a C API primitive descriptor that must |
6714 | /// have a matching kind. |
6715 | /// |
6716 | /// @param pd C API primitive descriptor for a softmax forward |
6717 | /// propagation primitive. |
6718 | primitive_desc(dnnl_primitive_desc_t pd) |
6719 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax, |
6720 | dnnl::prop_kind::forward_training, |
6721 | dnnl::prop_kind::forward_inference) {} |
6722 | |
6723 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
6724 | memory::desc src_desc() const { return base::src_desc(0); } |
6725 | |
6726 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
6727 | memory::desc dst_desc() const { return base::dst_desc(0); } |
6728 | |
6729 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
6730 | dnnl::algorithm get_algorithm() const { return base::get_algorithm(); } |
6731 | |
6732 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
6733 | dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
6734 | |
6735 | /// @copydoc dnnl::primitive_desc_base::get_axis()const |
6736 | int get_axis() const { return base::get_axis(); } |
6737 | }; |
6738 | |
6739 | /// Default constructor. Produces an empty object. |
6740 | softmax_forward() = default; |
6741 | |
6742 | /// Constructs a softmax forward propagation primitive. |
6743 | /// @param pd Primitive descriptor for a softmax forward propagation |
6744 | /// primitive. |
6745 | softmax_forward(const primitive_desc &pd) : primitive(pd) {} |
6746 | |
6747 | /// Constructs a softmax forward propagation primitive from a cache blob. |
6748 | /// @param pd Primitive descriptor for a softmax forward propagation |
6749 | /// primitive. |
6750 | /// @param cache_blob Cache blob. |
6751 | softmax_forward( |
6752 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
6753 | : primitive(pd, cache_blob) {} |
6754 | }; |
6755 | |
6756 | /// Softmax backward propagation primitive. |
6757 | struct softmax_backward : public primitive { |
6758 | /// Primitive descriptor for a softmax backward propagation primitive. |
6759 | struct primitive_desc : public dnnl::primitive_desc { |
6760 | /// Default constructor. Produces an empty object. |
6761 | primitive_desc() = default; |
6762 | |
6763 | /// Constructs a primitive descriptor for a softmax backward propagation |
6764 | /// primitive. |
6765 | /// |
6766 | /// @param aengine Engine to use. |
6767 | /// @param aalgorithm Softmax algorithm kind: either |
6768 | /// #dnnl::algorithm::softmax_accurate, |
6769 | /// or #dnnl::algorithm::softmax_log. |
6770 | /// @param diff_src_desc Diff source memory descriptor. |
6771 | /// @param diff_dst_desc Diff destination memory descriptor. |
6772 | /// @param dst_desc Destination memory descriptor. |
6773 | /// @param axis Axis over which softmax is computed. |
6774 | /// @param hint_fwd_pd Primitive descriptor for a softmax |
6775 | /// forward propagation primitive. It is used as a hint for |
6776 | /// deciding which memory format to use. |
6777 | /// @param attr Primitive attributes to use. Attributes are optional |
6778 | /// and default to empty attributes. |
6779 | /// @param allow_empty A flag signifying whether construction is |
6780 | /// allowed to fail without throwing an exception. In this case an |
6781 | /// empty object will be produced. This flag is optional and |
6782 | /// defaults to false. |
6783 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
6784 | const memory::desc &diff_src_desc, |
6785 | const memory::desc &diff_dst_desc, const memory::desc &dst_desc, |
6786 | int axis, const softmax_forward::primitive_desc &hint_fwd_pd, |
6787 | const primitive_attr &attr = default_attr(), |
6788 | bool allow_empty = false) { |
6789 | |
6790 | dnnl_primitive_desc_t pd = nullptr; |
6791 | dnnl_status_t status = dnnl_softmax_backward_primitive_desc_create( |
6792 | &pd, aengine.get(), dnnl::convert_to_c(aalgorithm), |
6793 | diff_src_desc.get(), diff_dst_desc.get(), dst_desc.get(), |
6794 | axis, hint_fwd_pd.get(), attr.get()); |
6795 | |
6796 | if (!allow_empty) |
6797 | error::wrap_c_api(status, |
6798 | "could not create a primitive descriptor for a softmax " |
6799 | "backward propagation primitive" ); |
6800 | reset(pd); |
6801 | } |
6802 | |
6803 | /// Constructs a primitive descriptor for a softmax backward |
6804 | /// propagation primitive from a C API primitive descriptor that must |
6805 | /// have a matching kind. |
6806 | /// |
6807 | /// @param pd C API primitive descriptor for a softmax backward |
6808 | /// propagation primitive. |
6809 | primitive_desc(dnnl_primitive_desc_t pd) |
6810 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax, |
6811 | dnnl::prop_kind::backward_data) {} |
6812 | |
6813 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
6814 | memory::desc dst_desc() const { return base::dst_desc(0); } |
6815 | |
6816 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
6817 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
6818 | |
6819 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
6820 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
6821 | |
6822 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
6823 | dnnl::algorithm get_algorithm() const { return base::get_algorithm(); } |
6824 | |
6825 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
6826 | dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
6827 | |
6828 | /// @copydoc dnnl::primitive_desc_base::get_axis()const |
6829 | int get_axis() const { return base::get_axis(); } |
6830 | }; |
6831 | |
6832 | /// Default constructor. Produces an empty object. |
6833 | softmax_backward() = default; |
6834 | |
6835 | /// Constructs a softmax backward propagation primitive. |
6836 | /// @param pd Primitive descriptor for a softmax backward propagation |
6837 | /// primitive. |
6838 | softmax_backward(const primitive_desc &pd) : primitive(pd) {} |
6839 | |
6840 | /// Constructs a softmax backward propagation primitive from a cache blob. |
6841 | /// @param pd Primitive descriptor for a softmax backward propagation |
6842 | /// primitive. |
6843 | /// @param cache_blob Cache blob. |
6844 | softmax_backward( |
6845 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
6846 | : primitive(pd, cache_blob) {} |
6847 | }; |
6848 | |
6849 | /// @} dnnl_api_softmax |
6850 | |
6851 | /// @addtogroup dnnl_api_batch_normalization Batch Normalization |
6852 | /// |
6853 | /// A primitive to perform batch normalization. |
6854 | /// |
6855 | /// Both forward and backward propagation primitives support in-place |
6856 | /// operation; that is, src and dst can refer to the same memory for forward |
6857 | /// propagation, and diff_dst and diff_src can refer to the same memory for |
6858 | /// backward propagation. |
6859 | /// |
6860 | /// The batch normalization primitives computations can be controlled by |
6861 | /// specifying different @ref dnnl::normalization_flags values. For example, |
6862 | /// batch normalization forward propagation can be configured to either |
6863 | /// compute the mean and variance or take them as arguments. It can either |
6864 | /// perform scaling and shifting using gamma and beta parameters or not. |
6865 | /// Optionally, it can also perform a fused ReLU, which in case of training |
6866 | /// would also require a workspace. |
6867 | /// |
6868 | /// @sa @ref dev_guide_batch_normalization in developer guide |
6869 | /// |
6870 | /// @{ |
6871 | |
6872 | /// Batch normalization forward propagation primitive. |
6873 | struct batch_normalization_forward : public primitive { |
6874 | /// Primitive descriptor for a batch normalization forward propagation |
6875 | /// primitive. |
6876 | struct primitive_desc : public dnnl::primitive_desc { |
6877 | /// Default constructor. Produces an empty object. |
6878 | primitive_desc() = default; |
6879 | |
6880 | /// Constructs a primitive descriptor for a batch normalization forward |
6881 | /// propagation primitive. |
6882 | /// |
6883 | /// @note |
6884 | /// In-place operation is supported: the dst can refer to the same |
6885 | /// memory as the src. |
6886 | /// |
6887 | /// @param aengine Engine to use. |
6888 | /// @param aprop_kind Propagation kind. Possible values are |
6889 | /// #dnnl::prop_kind::forward_training and |
6890 | /// #dnnl::prop_kind::forward_inference. |
6891 | /// @param src_desc Source memory descriptor. |
6892 | /// @param dst_desc Destination memory descriptor. |
6893 | /// @param epsilon Batch normalization epsilon parameter. |
6894 | /// @param flags Batch normalization flags (@ref |
6895 | /// dnnl::normalization_flags). |
6896 | /// @param attr Primitive attributes to use. Attributes are optional |
6897 | /// and default to empty attributes. |
6898 | /// @param allow_empty A flag signifying whether construction is |
6899 | /// allowed to fail without throwing an exception. In this case an |
6900 | /// empty object will be produced. This flag is optional and |
6901 | /// defaults to false. |
6902 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
6903 | const memory::desc &src_desc, const memory::desc &dst_desc, |
6904 | float epsilon, normalization_flags flags, |
6905 | const primitive_attr &attr = default_attr(), |
6906 | bool allow_empty = false) { |
6907 | dnnl_primitive_desc_t pd = nullptr; |
6908 | dnnl_status_t status |
6909 | = dnnl_batch_normalization_forward_primitive_desc_create( |
6910 | &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), |
6911 | src_desc.get(), dst_desc.get(), epsilon, |
6912 | convert_to_c(flags), attr.get()); |
6913 | |
6914 | if (!allow_empty) |
6915 | error::wrap_c_api(status, |
6916 | "could not create a primitive descriptor for a batch " |
6917 | "normalization forward propagation primitive" ); |
6918 | reset(pd); |
6919 | } |
6920 | |
6921 | /// Constructs a primitive descriptor for a batch normalization |
6922 | /// forward propagation primitive from a C API primitive descriptor |
6923 | /// that must have a matching kind. |
6924 | /// |
6925 | /// @param pd C API primitive descriptor for a batch normalization |
6926 | /// forward propagation primitive. |
6927 | primitive_desc(dnnl_primitive_desc_t pd) |
6928 | : dnnl::primitive_desc(pd, |
6929 | dnnl::primitive::kind::batch_normalization, |
6930 | dnnl::prop_kind::forward_training, |
6931 | dnnl::prop_kind::forward_inference) {} |
6932 | |
6933 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
6934 | memory::desc src_desc() const { return base::src_desc(0); } |
6935 | |
6936 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
6937 | memory::desc dst_desc() const { return base::dst_desc(0); } |
6938 | |
6939 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
6940 | memory::desc weights_desc() const { return base::weights_desc(0); } |
6941 | |
6942 | /// @copydoc dnnl::primitive_desc_base::workspace_desc()const |
6943 | memory::desc workspace_desc() const { return base::workspace_desc(); } |
6944 | |
6945 | /// Returns memory descriptor for mean. |
6946 | /// @returns Memory descriptor for mean. |
6947 | memory::desc mean_desc() const { return stat_desc(mean); } |
6948 | |
6949 | /// Returns memory descriptor for variance. |
6950 | /// @returns Memory descriptor for variance. |
6951 | memory::desc variance_desc() const { return stat_desc(var); } |
6952 | |
6953 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
6954 | dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
6955 | |
6956 | /// @copydoc dnnl::primitive_desc_base::get_epsilon()const |
6957 | float get_epsilon() const { return base::get_epsilon(); } |
6958 | |
6959 | /// Returns normalization flags. |
6960 | /// @return Normalization flags. |
6961 | normalization_flags get_flags() const { |
6962 | return base::get_flags<normalization_flags>(); |
6963 | } |
6964 | |
6965 | private: |
6966 | enum { |
6967 | mean = 1, |
6968 | var = 2, |
6969 | }; |
6970 | memory::desc stat_desc(int kind) const { |
6971 | const bool use_global_stats |
6972 | = (get_flags() & normalization_flags::use_global_stats) |
6973 | != normalization_flags::none; |
6974 | return query_md( |
6975 | use_global_stats ? query::src_md : query::dst_md, kind); |
6976 | } |
6977 | }; |
6978 | |
6979 | /// Default constructor. Produces an empty object. |
6980 | batch_normalization_forward() = default; |
6981 | |
6982 | /// Constructs a batch normalization forward propagation primitive. |
6983 | /// @param pd Primitive descriptor for a batch normalization forward |
6984 | /// propagation primitive. |
6985 | batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {} |
6986 | |
6987 | /// Constructs a batch normalization forward propagation primitive from |
6988 | /// a cache blob. |
6989 | /// @param pd Primitive descriptor for a batch normalization forward |
6990 | /// propagation primitive. |
6991 | /// @param cache_blob Cache blob. |
6992 | batch_normalization_forward( |
6993 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
6994 | : primitive(pd, cache_blob) {} |
6995 | }; |
6996 | |
6997 | /// Batch normalization backward propagation primitive. |
6998 | struct batch_normalization_backward : public primitive { |
6999 | /// Primitive descriptor for a batch normalization backward propagation |
7000 | /// primitive. |
7001 | struct primitive_desc : public dnnl::primitive_desc { |
7002 | /// Default constructor. Produces an empty object. |
7003 | primitive_desc() = default; |
7004 | |
7005 | /// Constructs a primitive descriptor for a batch normalization backward |
7006 | /// propagation primitive. |
7007 | /// |
7008 | /// @param aengine Engine to use. |
7009 | /// @param aprop_kind Propagation kind. Possible values are |
7010 | /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward |
7011 | /// (diffs for all parameters are computed in this case). |
7012 | /// @param diff_src_desc Diff source memory descriptor. |
7013 | /// @param diff_dst_desc Diff destination memory descriptor. |
7014 | /// @param src_desc Source memory descriptor. |
7015 | /// @param epsilon Batch normalization epsilon parameter. |
7016 | /// @param flags Batch normalization flags (@ref |
7017 | /// dnnl::normalization_flags). |
7018 | /// @param hint_fwd_pd Primitive descriptor for a batch normalization |
7019 | /// forward propagation primitive. It is used as a hint for |
7020 | /// deciding which memory format to use. |
7021 | /// @param attr Primitive attributes to use. Attributes are optional |
7022 | /// and default to empty attributes. |
7023 | /// @param allow_empty A flag signifying whether construction is |
7024 | /// allowed to fail without throwing an exception. In this case an |
7025 | /// empty object will be produced. This flag is optional and |
7026 | /// defaults to false. |
7027 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7028 | const memory::desc &diff_src_desc, |
7029 | const memory::desc &diff_dst_desc, const memory::desc &src_desc, |
7030 | float epsilon, normalization_flags flags, |
7031 | const batch_normalization_forward::primitive_desc &hint_fwd_pd, |
7032 | const primitive_attr &attr = default_attr(), |
7033 | bool allow_empty = false) { |
7034 | dnnl_primitive_desc_t pd = nullptr; |
7035 | dnnl_status_t status |
7036 | = dnnl_batch_normalization_backward_primitive_desc_create( |
7037 | &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), |
7038 | diff_src_desc.get(), diff_dst_desc.get(), |
7039 | src_desc.get(), epsilon, convert_to_c(flags), |
7040 | hint_fwd_pd.get(), attr.get()); |
7041 | |
7042 | if (!allow_empty) |
7043 | error::wrap_c_api(status, |
7044 | "could not create a primitive descriptor for a batch " |
7045 | "normalization backward propagation primitive" ); |
7046 | reset(pd); |
7047 | } |
7048 | |
7049 | /// Constructs a primitive descriptor for a batch normalization |
7050 | /// backward propagation primitive from a C API primitive descriptor |
7051 | /// that must have a matching kind. |
7052 | /// |
7053 | /// @param pd C API primitive descriptor for a batch normalization |
7054 | /// backward propagation primitive. |
7055 | primitive_desc(dnnl_primitive_desc_t pd) |
7056 | : dnnl::primitive_desc(pd, |
7057 | dnnl::primitive::kind::batch_normalization, |
7058 | dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) { |
7059 | } |
7060 | |
7061 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
7062 | memory::desc src_desc() const { return base::src_desc(0); } |
7063 | |
7064 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
7065 | memory::desc weights_desc() const { return base::weights_desc(0); } |
7066 | |
7067 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
7068 | memory::desc dst_desc() const { return base::dst_desc(0); } |
7069 | |
7070 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
7071 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
7072 | |
7073 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
7074 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
7075 | |
7076 | /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const |
7077 | memory::desc diff_weights_desc() const { |
7078 | return base::diff_weights_desc(0); |
7079 | } |
7080 | |
7081 | /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const |
7082 | memory::desc mean_desc() const { return query_md(query::src_md, 1); } |
7083 | |
7084 | /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const |
7085 | memory::desc variance_desc() const { |
7086 | return query_md(query::src_md, 2); |
7087 | } |
7088 | |
7089 | /// @copydoc dnnl::primitive_desc_base::workspace_desc()const |
7090 | memory::desc workspace_desc() const { return base::workspace_desc(); } |
7091 | |
7092 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
7093 | dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
7094 | |
7095 | /// @copydoc dnnl::primitive_desc_base::get_epsilon()const |
7096 | float get_epsilon() const { return base::get_epsilon(); } |
7097 | |
7098 | /// Returns normalization flags. |
7099 | /// @return Normalization flags. |
7100 | normalization_flags get_flags() const { |
7101 | return base::get_flags<normalization_flags>(); |
7102 | } |
7103 | }; |
7104 | |
7105 | /// Default constructor. Produces an empty object. |
7106 | batch_normalization_backward() = default; |
7107 | |
7108 | /// Constructs a batch normalization backward propagation primitive. |
7109 | /// @param pd Primitive descriptor for a batch normalization backward |
7110 | /// propagation primitive. |
7111 | batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {} |
7112 | |
7113 | /// Constructs a batch normalization backward propagation primitive from |
7114 | /// a cache blob. |
7115 | /// @param pd Primitive descriptor for a batch normalization backward |
7116 | /// propagation primitive. |
7117 | /// @param cache_blob Cache blob. |
7118 | batch_normalization_backward( |
7119 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
7120 | : primitive(pd, cache_blob) {} |
7121 | }; |
7122 | |
7123 | /// @} dnnl_api_batch_normalization |
7124 | |
7125 | /// @addtogroup dnnl_api_layer_normalization Layer Normalization |
7126 | /// |
7127 | /// A primitive to perform layer normalization. Normalization is performed |
7128 | /// within the last logical dimension of data tensor. |
7129 | /// |
7130 | /// Both forward and backward propagation primitives support in-place |
7131 | /// operation; that is, src and dst can refer to the same memory for forward |
7132 | /// propagation, and diff_dst and diff_src can refer to the same memory for |
7133 | /// backward propagation. |
7134 | /// |
7135 | /// The layer normalization primitives computations can be controlled by |
7136 | /// specifying different @ref dnnl::normalization_flags values. For example, |
7137 | /// layer normalization forward propagation can be configured to either |
7138 | /// compute the mean and variance or take them as arguments. It can either |
7139 | /// perform scaling and shifting using gamma and beta parameters or not. |
7140 | /// |
7141 | /// @sa @ref dev_guide_layer_normalization in developer guide |
7142 | /// |
7143 | /// @{ |
7144 | |
7145 | /// Layer normalization forward propagation primitive. |
7146 | struct layer_normalization_forward : public primitive { |
7147 | /// Primitive descriptor for a layer normalization forward propagation |
7148 | /// primitive. |
7149 | struct primitive_desc : public dnnl::primitive_desc { |
7150 | /// Default constructor. Produces an empty object. |
7151 | primitive_desc() = default; |
7152 | |
7153 | /// Constructs a primitive descriptor for a layer normalization forward |
7154 | /// propagation primitive. |
7155 | /// |
7156 | /// @param aengine Engine to use. |
7157 | /// @param aprop_kind Propagation kind. Possible values are |
7158 | /// #dnnl::prop_kind::forward_training, and |
7159 | /// #dnnl::prop_kind::forward_inference. |
7160 | /// @param src_desc Source memory descriptor. |
7161 | /// @param dst_desc Destination memory descriptor. |
7162 | /// @param stat_desc Statistics memory descriptors. |
7163 | /// @param epsilon Layer normalization epsilon parameter. |
7164 | /// @param flags Layer normalization flags (@ref |
7165 | /// dnnl::normalization_flags). |
7166 | /// @param attr Primitive attributes to use. Attributes are optional |
7167 | /// and default to empty attributes. |
7168 | /// @param allow_empty A flag signifying whether construction is |
7169 | /// allowed to fail without throwing an exception. In this case an |
7170 | /// empty object will be produced. This flag is optional and |
7171 | /// defaults to false. |
7172 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7173 | const memory::desc &src_desc, const memory::desc &dst_desc, |
7174 | const memory::desc &stat_desc, float epsilon, |
7175 | normalization_flags flags, |
7176 | const primitive_attr &attr = default_attr(), |
7177 | bool allow_empty = false) |
7178 | : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, |
7179 | &stat_desc, epsilon, flags, attr, allow_empty) {} |
7180 | |
7181 | /// Constructs a primitive descriptor for a layer normalization forward |
7182 | /// propagation primitive. |
7183 | /// |
7184 | /// @param aengine Engine to use. |
7185 | /// @param aprop_kind Propagation kind. Possible values are |
7186 | /// #dnnl::prop_kind::forward_training, and |
7187 | /// #dnnl::prop_kind::forward_inference. |
7188 | /// @param src_desc Source memory descriptor. |
7189 | /// @param dst_desc Destination memory descriptor. |
7190 | /// @param epsilon Layer normalization epsilon parameter. |
7191 | /// @param flags Layer normalization flags (@ref |
7192 | /// dnnl::normalization_flags). |
7193 | /// @param attr Primitive attributes to use. Attributes are optional |
7194 | /// and default to empty attributes. |
7195 | /// @param allow_empty A flag signifying whether construction is |
7196 | /// allowed to fail without throwing an exception. In this case an |
7197 | /// empty object will be produced. This flag is optional and |
7198 | /// defaults to false. |
7199 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7200 | const memory::desc &src_desc, const memory::desc &dst_desc, |
7201 | float epsilon, normalization_flags flags, |
7202 | const primitive_attr &attr = default_attr(), |
7203 | bool allow_empty = false) |
7204 | : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr, |
7205 | epsilon, flags, attr, allow_empty) {} |
7206 | |
7207 | /// Constructs a primitive descriptor for a layer normalization |
7208 | /// forward propagation primitive from a C API primitive descriptor |
7209 | /// that must have a matching kind. |
7210 | /// |
7211 | /// @param pd C API primitive descriptor for a layer normalization |
7212 | /// forward propagation primitive. |
7213 | primitive_desc(dnnl_primitive_desc_t pd) |
7214 | : dnnl::primitive_desc(pd, |
7215 | dnnl::primitive::kind::layer_normalization, |
7216 | dnnl::prop_kind::forward_training, |
7217 | dnnl::prop_kind::forward_inference) {} |
7218 | |
7219 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
7220 | memory::desc src_desc() const { return base::src_desc(0); } |
7221 | |
7222 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
7223 | memory::desc dst_desc() const { return base::dst_desc(0); } |
7224 | |
7225 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
7226 | memory::desc weights_desc() const { return base::weights_desc(0); } |
7227 | |
7228 | /// @copydoc dnnl::primitive_desc_base::workspace_desc()const |
7229 | memory::desc workspace_desc() const { return base::workspace_desc(); } |
7230 | |
7231 | /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const |
7232 | memory::desc mean_desc() const { return stat_desc(mean); } |
7233 | |
7234 | /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const |
7235 | memory::desc variance_desc() const { return stat_desc(var); } |
7236 | |
7237 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
7238 | dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
7239 | |
7240 | /// @copydoc dnnl::primitive_desc_base::get_epsilon()const |
7241 | float get_epsilon() const { return base::get_epsilon(); } |
7242 | |
7243 | /// Returns normalization flags. |
7244 | /// @return Normalization flags. |
7245 | normalization_flags get_flags() const { |
7246 | return base::get_flags<normalization_flags>(); |
7247 | } |
7248 | |
7249 | private: |
7250 | enum { |
7251 | mean = 1, |
7252 | var = 2, |
7253 | }; |
7254 | memory::desc stat_desc(int kind) const { |
7255 | const bool use_global_stats |
7256 | = (get_flags() & normalization_flags::use_global_stats) |
7257 | != normalization_flags::none; |
7258 | return query_md( |
7259 | use_global_stats ? query::src_md : query::dst_md, kind); |
7260 | } |
7261 | |
7262 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7263 | const memory::desc &src_desc, const memory::desc &dst_desc, |
7264 | const memory::desc *stat_desc, float epsilon, |
7265 | normalization_flags flags, const primitive_attr &attr, |
7266 | bool allow_empty) { |
7267 | |
7268 | dnnl_primitive_desc_t pd = nullptr; |
7269 | dnnl_status_t status |
7270 | = dnnl_layer_normalization_forward_primitive_desc_create( |
7271 | &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), |
7272 | src_desc.get(), dst_desc.get(), |
7273 | optional_arg(stat_desc), epsilon, |
7274 | convert_to_c(flags), attr.get()); |
7275 | |
7276 | if (!allow_empty) |
7277 | error::wrap_c_api(status, |
7278 | "could not create a primitive descriptor for a layer " |
7279 | "normalization forward propagation primitive" ); |
7280 | reset(pd); |
7281 | } |
7282 | }; |
7283 | |
7284 | /// Default constructor. Produces an empty object. |
7285 | layer_normalization_forward() = default; |
7286 | |
7287 | /// Constructs a layer normalization forward propagation primitive. |
7288 | /// @param pd Primitive descriptor for a layer normalization forward |
7289 | /// propagation primitive. |
7290 | layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {} |
7291 | |
7292 | /// Constructs a layer normalization forward propagation primitive from |
7293 | /// a cache blob. |
7294 | /// @param pd Primitive descriptor for a layer normalization forward |
7295 | /// propagation primitive. |
7296 | /// @param cache_blob Cache blob. |
7297 | layer_normalization_forward( |
7298 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
7299 | : primitive(pd, cache_blob) {} |
7300 | }; |
7301 | |
7302 | /// Layer normalization backward propagation primitive. |
7303 | struct layer_normalization_backward : public primitive { |
7304 | /// Primitive descriptor for a layer normalization backward propagation |
7305 | /// primitive. |
7306 | struct primitive_desc : public dnnl::primitive_desc { |
7307 | /// Default constructor. Produces an empty object. |
7308 | primitive_desc() = default; |
7309 | |
7310 | /// Constructs a primitive descriptor for a layer normalization backward |
7311 | /// propagation primitive. |
7312 | /// |
7313 | /// @param aengine Engine to use. |
7314 | /// @param aprop_kind Propagation kind. Possible values are |
7315 | /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward |
7316 | /// (diffs for all parameters are computed in this case). |
7317 | /// @param diff_src_desc Diff source memory descriptor. |
7318 | /// @param diff_dst_desc Diff destination memory descriptor. |
7319 | /// @param src_desc Source memory descriptor. |
7320 | /// @param stat_desc Statistics memory descriptors. |
7321 | /// @param epsilon Layer normalization epsilon parameter. |
7322 | /// @param flags Layer normalization flags (@ref |
7323 | /// dnnl::normalization_flags). |
7324 | /// @param attr Primitive attributes to use. Attributes are optional |
7325 | /// and default to empty attributes. |
7326 | /// @param hint_fwd_pd Primitive descriptor for a layer normalization |
7327 | /// forward propagation primitive. It is used as a hint for |
7328 | /// deciding which memory format to use. |
7329 | /// @param allow_empty A flag signifying whether construction is |
7330 | /// allowed to fail without throwing an exception. In this case an |
7331 | /// empty object will be produced. This flag is optional and |
7332 | /// defaults to false. |
7333 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7334 | const memory::desc &diff_src_desc, |
7335 | const memory::desc &diff_dst_desc, const memory::desc &src_desc, |
7336 | const memory::desc &stat_desc, float epsilon, |
7337 | normalization_flags flags, |
7338 | const layer_normalization_forward::primitive_desc &hint_fwd_pd, |
7339 | const primitive_attr &attr = default_attr(), |
7340 | bool allow_empty = false) |
7341 | : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc, |
7342 | src_desc, &stat_desc, epsilon, flags, hint_fwd_pd, attr, |
7343 | allow_empty) {} |
7344 | |
7345 | /// Constructs a primitive descriptor for a layer normalization backward |
7346 | /// propagation primitive. |
7347 | /// |
7348 | /// @param aengine Engine to use. |
7349 | /// @param aprop_kind Propagation kind. Possible values are |
7350 | /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward |
7351 | /// (diffs for all parameters are computed in this case). |
7352 | /// @param diff_src_desc Diff source memory descriptor. |
7353 | /// @param diff_dst_desc Diff destination memory descriptor. |
7354 | /// @param src_desc Source memory descriptor. |
7355 | /// @param epsilon Layer normalization epsilon parameter. |
7356 | /// @param flags Layer normalization flags (@ref |
7357 | /// dnnl::normalization_flags). |
7358 | /// @param attr Primitive attributes to use. Attributes are optional |
7359 | /// and default to empty attributes. |
7360 | /// @param hint_fwd_pd Primitive descriptor for a layer normalization |
7361 | /// forward propagation primitive. It is used as a hint for |
7362 | /// deciding which memory format to use. |
7363 | /// @param allow_empty A flag signifying whether construction is |
7364 | /// allowed to fail without throwing an exception. In this case an |
7365 | /// empty object will be produced. This flag is optional and |
7366 | /// defaults to false. |
7367 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7368 | const memory::desc &diff_src_desc, |
7369 | const memory::desc &diff_dst_desc, const memory::desc &src_desc, |
7370 | float epsilon, normalization_flags flags, |
7371 | const layer_normalization_forward::primitive_desc &hint_fwd_pd, |
7372 | const primitive_attr &attr = default_attr(), |
7373 | bool allow_empty = false) |
7374 | : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc, |
7375 | src_desc, nullptr, epsilon, flags, hint_fwd_pd, attr, |
7376 | allow_empty) {} |
7377 | |
7378 | /// Constructs a primitive descriptor for a layer normalization |
7379 | /// backward propagation primitive from a C API primitive descriptor |
7380 | /// that must have a matching kind. |
7381 | /// |
7382 | /// @param pd C API primitive descriptor for a layer normalization |
7383 | /// backward propagation primitive. |
7384 | primitive_desc(dnnl_primitive_desc_t pd) |
7385 | : dnnl::primitive_desc(pd, |
7386 | dnnl::primitive::kind::layer_normalization, |
7387 | dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) { |
7388 | } |
7389 | |
7390 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
7391 | memory::desc src_desc() const { return base::src_desc(0); } |
7392 | |
7393 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
7394 | memory::desc weights_desc() const { return base::weights_desc(0); } |
7395 | |
7396 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
7397 | memory::desc dst_desc() const { return base::dst_desc(0); } |
7398 | |
7399 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
7400 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
7401 | |
7402 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
7403 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
7404 | |
7405 | /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const |
7406 | memory::desc diff_weights_desc() const { |
7407 | return base::diff_weights_desc(0); |
7408 | } |
7409 | |
7410 | /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const |
7411 | memory::desc mean_desc() const { return query_md(query::src_md, 1); } |
7412 | |
7413 | /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const |
7414 | memory::desc variance_desc() const { |
7415 | return query_md(query::src_md, 2); |
7416 | } |
7417 | |
7418 | /// @copydoc dnnl::primitive_desc_base::workspace_desc()const |
7419 | memory::desc workspace_desc() const { return base::workspace_desc(); } |
7420 | |
7421 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
7422 | dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
7423 | |
7424 | /// @copydoc dnnl::primitive_desc_base::get_epsilon()const |
7425 | float get_epsilon() const { return base::get_epsilon(); } |
7426 | |
7427 | /// Returns normalization flags. |
7428 | /// @return Normalization flags. |
7429 | normalization_flags get_flags() const { |
7430 | return base::get_flags<normalization_flags>(); |
7431 | } |
7432 | |
7433 | private: |
7434 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7435 | const memory::desc &diff_src_desc, |
7436 | const memory::desc &diff_dst_desc, const memory::desc &src_desc, |
7437 | const memory::desc *stat_desc, float epsilon, |
7438 | normalization_flags flags, |
7439 | const layer_normalization_forward::primitive_desc &hint_fwd_pd, |
7440 | const primitive_attr &attr, bool allow_empty) { |
7441 | |
7442 | dnnl_primitive_desc_t pd = nullptr; |
7443 | dnnl_status_t status |
7444 | = dnnl_layer_normalization_backward_primitive_desc_create( |
7445 | &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), |
7446 | diff_src_desc.get(), diff_dst_desc.get(), |
7447 | src_desc.get(), optional_arg(stat_desc), epsilon, |
7448 | convert_to_c(flags), hint_fwd_pd.get(), attr.get()); |
7449 | |
7450 | if (!allow_empty) |
7451 | error::wrap_c_api(status, |
7452 | "could not create a primitive descriptor for a layer " |
7453 | "normalization backward propagation primitive" ); |
7454 | reset(pd); |
7455 | } |
7456 | }; |
7457 | |
7458 | /// Default constructor. Produces an empty object. |
7459 | layer_normalization_backward() = default; |
7460 | |
7461 | /// Constructs a layer normalization backward propagation primitive. |
7462 | /// @param pd Primitive descriptor for a layer normalization backward |
7463 | /// propagation primitive. |
7464 | layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {} |
7465 | |
7466 | /// Constructs a layer normalization backward propagation primitive from |
7467 | /// a cache blob. |
7468 | /// @param pd Primitive descriptor for a layer normalization backward |
7469 | /// propagation primitive. |
7470 | /// @param cache_blob Cache blob. |
7471 | layer_normalization_backward( |
7472 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
7473 | : primitive(pd, cache_blob) {} |
7474 | }; |
7475 | |
7476 | /// @} dnnl_api_layer_normalization |
7477 | |
7478 | /// @addtogroup dnnl_api_inner_product Inner Product |
7479 | /// |
7480 | /// A primitive to compute an inner product. |
7481 | /// |
7482 | /// @sa @ref dev_guide_inner_product in developer guide |
7483 | /// |
7484 | /// @{ |
7485 | |
7486 | /// Inner product forward propagation primitive. |
7487 | struct inner_product_forward : public primitive { |
7488 | /// Primitive descriptor for an inner product forward propagation primitive. |
7489 | struct primitive_desc : public dnnl::primitive_desc { |
7490 | /// Default constructor. Produces an empty object. |
7491 | primitive_desc() = default; |
7492 | |
7493 | /// Constructs a primitive descriptor for an inner product forward |
7494 | /// propagation primitive with bias. |
7495 | /// |
7496 | /// @note |
7497 | /// All the memory descriptors may be initialized with the |
7498 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
7499 | /// |
7500 | /// @param aengine Engine to use. |
7501 | /// @param aprop_kind Propagation kind. Possible values are |
7502 | /// #dnnl::prop_kind::forward_training, and |
7503 | /// #dnnl::prop_kind::forward_inference. |
7504 | /// @param src_desc Memory descriptor for src. |
7505 | /// @param weights_desc Memory descriptor for weights. |
7506 | /// @param bias_desc Memory descriptor for bias. |
7507 | /// @param dst_desc Memory descriptor for dst. |
7508 | /// @param attr Primitive attributes to use. Attributes are optional |
7509 | /// and default to empty attributes. |
7510 | /// @param allow_empty A flag signifying whether construction is |
7511 | /// allowed to fail without throwing an exception. In this case an |
7512 | /// empty object will be produced. This flag is optional and |
7513 | /// defaults to false. |
7514 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7515 | const memory::desc &src_desc, const memory::desc &weights_desc, |
7516 | const memory::desc &bias_desc, const memory::desc &dst_desc, |
7517 | const primitive_attr &attr = default_attr(), |
7518 | bool allow_empty = false) |
7519 | : primitive_desc(aengine, aprop_kind, src_desc, weights_desc, |
7520 | &bias_desc, dst_desc, attr, allow_empty) {} |
7521 | |
7522 | /// Constructs a primitive descriptor for an inner product forward |
7523 | /// propagation primitive. |
7524 | /// |
7525 | /// @note |
7526 | /// All the memory descriptors may be initialized with the |
7527 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
7528 | /// |
7529 | /// @param aengine Engine to use. |
7530 | /// @param aprop_kind Propagation kind. Possible values are |
7531 | /// #dnnl::prop_kind::forward_training, and |
7532 | /// #dnnl::prop_kind::forward_inference. |
7533 | /// @param src_desc Memory descriptor for src. |
7534 | /// @param weights_desc Memory descriptor for weights. |
7535 | /// @param dst_desc Memory descriptor for dst. |
7536 | /// @param attr Primitive attributes to use. Attributes are optional |
7537 | /// and default to empty attributes. |
7538 | /// @param allow_empty A flag signifying whether construction is |
7539 | /// allowed to fail without throwing an exception. In this case an |
7540 | /// empty object will be produced. This flag is optional and |
7541 | /// defaults to false. |
7542 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7543 | const memory::desc &src_desc, const memory::desc &weights_desc, |
7544 | const memory::desc &dst_desc, |
7545 | const primitive_attr &attr = default_attr(), |
7546 | bool allow_empty = false) |
7547 | : primitive_desc(aengine, aprop_kind, src_desc, weights_desc, |
7548 | nullptr, dst_desc, attr, allow_empty) {} |
7549 | |
7550 | /// Constructs a primitive descriptor for an inner product forward |
7551 | /// propagation primitive from a C API primitive descriptor that must |
7552 | /// have a matching kind. |
7553 | /// |
7554 | /// @param pd C API primitive descriptor for an inner product forward |
7555 | /// propagation primitive. |
7556 | primitive_desc(dnnl_primitive_desc_t pd) |
7557 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product, |
7558 | dnnl::prop_kind::forward_training, |
7559 | dnnl::prop_kind::forward_inference) {} |
7560 | |
7561 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
7562 | memory::desc src_desc() const { return base::src_desc(0); } |
7563 | |
7564 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
7565 | memory::desc weights_desc() const { return base::weights_desc(0); } |
7566 | |
7567 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
7568 | memory::desc dst_desc() const { return base::dst_desc(0); } |
7569 | |
7570 | /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const |
7571 | memory::desc bias_desc() const { return base::weights_desc(1); } |
7572 | |
7573 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
7574 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
7575 | |
7576 | private: |
7577 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
7578 | const memory::desc &src_desc, const memory::desc &weights_desc, |
7579 | const memory::desc *bias_desc, const memory::desc &dst_desc, |
7580 | const primitive_attr &attr, bool allow_empty) { |
7581 | |
7582 | dnnl_primitive_desc_t pd = nullptr; |
7583 | dnnl_status_t status |
7584 | = dnnl_inner_product_forward_primitive_desc_create(&pd, |
7585 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
7586 | src_desc.get(), weights_desc.get(), |
7587 | optional_arg(bias_desc), dst_desc.get(), |
7588 | attr.get()); |
7589 | |
7590 | if (!allow_empty) |
7591 | error::wrap_c_api(status, |
7592 | "could not create a primitive descriptor for an inner " |
7593 | "product forward propagation primitive" ); |
7594 | reset(pd); |
7595 | } |
7596 | }; |
7597 | |
7598 | /// Default constructor. Produces an empty object. |
7599 | inner_product_forward() = default; |
7600 | |
7601 | /// Constructs an inner product forward propagation primitive. |
7602 | /// @param pd Primitive descriptor for an inner product forward |
7603 | /// propagation primitive. |
7604 | inner_product_forward(const primitive_desc &pd) : primitive(pd) {} |
7605 | |
7606 | /// Constructs an inner product forward propagation primitive from |
7607 | /// a cache blob. |
7608 | /// @param pd Primitive descriptor for an inner product forward |
7609 | /// propagation primitive. |
7610 | /// @param cache_blob Cache blob. |
7611 | inner_product_forward( |
7612 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
7613 | : primitive(pd, cache_blob) {} |
7614 | }; |
7615 | |
7616 | /// Inner product backward propagation primitive. |
7617 | struct inner_product_backward_data : public primitive { |
7618 | /// Primitive descriptor for an inner product backward propagation |
7619 | /// primitive. |
7620 | struct primitive_desc : public dnnl::primitive_desc { |
7621 | /// Default constructor. Produces an empty object. |
7622 | primitive_desc() = default; |
7623 | |
7624 | /// Constructs a primitive descriptor for an inner product backward |
7625 | /// propagation primitive. |
7626 | /// |
7627 | /// @note |
7628 | /// All the memory descriptors may be initialized with the |
7629 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
7630 | /// |
7631 | /// @param aengine Engine to use. |
7632 | /// @param diff_src_desc Memory descriptor for diff src. |
7633 | /// @param weights_desc Memory descriptor for weights. |
7634 | /// @param diff_dst_desc Memory descriptor for diff dst. |
7635 | /// @param hint_fwd_pd Primitive descriptor for an inner product |
7636 | /// forward propagation primitive. It is used as a hint for |
7637 | /// deciding which memory format to use. |
7638 | /// @param attr Primitive attributes to use. Attributes are optional |
7639 | /// and default to empty attributes. |
7640 | /// @param allow_empty A flag signifying whether construction is |
7641 | /// allowed to fail without throwing an exception. In this case an |
7642 | /// empty object will be produced. This flag is optional and |
7643 | /// defaults to false. |
7644 | primitive_desc(const engine &aengine, const memory::desc &diff_src_desc, |
7645 | const memory::desc &weights_desc, |
7646 | const memory::desc &diff_dst_desc, |
7647 | const inner_product_forward::primitive_desc &hint_fwd_pd, |
7648 | const primitive_attr &attr = default_attr(), |
7649 | bool allow_empty = false) { |
7650 | dnnl_primitive_desc_t pd = nullptr; |
7651 | dnnl_status_t status |
7652 | = dnnl_inner_product_backward_data_primitive_desc_create( |
7653 | &pd, aengine.get(), diff_src_desc.get(), |
7654 | weights_desc.get(), diff_dst_desc.get(), |
7655 | hint_fwd_pd.get(), attr.get()); |
7656 | |
7657 | if (!allow_empty) |
7658 | error::wrap_c_api(status, |
7659 | "could not create a primitive descriptor for an inner " |
7660 | "product backward propagation primitive" ); |
7661 | reset(pd); |
7662 | } |
7663 | |
7664 | /// Constructs a primitive descriptor for an inner product backward |
7665 | /// propagation primitive from a C API primitive descriptor that must |
7666 | /// have a matching kind. |
7667 | /// |
7668 | /// @param pd C API primitive descriptor for an inner product backward |
7669 | /// propagation primitive. |
7670 | primitive_desc(dnnl_primitive_desc_t pd) |
7671 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product, |
7672 | dnnl::prop_kind::backward_data) {} |
7673 | |
7674 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
7675 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
7676 | |
7677 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
7678 | memory::desc weights_desc() const { return base::weights_desc(0); } |
7679 | |
7680 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
7681 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
7682 | |
7683 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
7684 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
7685 | }; |
7686 | |
7687 | /// Default constructor. Produces an empty object. |
7688 | inner_product_backward_data() = default; |
7689 | |
7690 | /// Constructs an inner product backward propagation primitive. |
7691 | /// @param pd Primitive descriptor for an inner product backward |
7692 | /// propagation primitive. |
7693 | inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {} |
7694 | |
7695 | /// Constructs an inner product backward propagation primitive from |
7696 | /// a cache blob. |
7697 | /// @param pd Primitive descriptor for an inner product backward |
7698 | /// propagation primitive. |
7699 | /// @param cache_blob Cache blob. |
7700 | inner_product_backward_data( |
7701 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
7702 | : primitive(pd, cache_blob) {} |
7703 | }; |
7704 | |
7705 | /// Inner product weights gradient primitive. |
7706 | struct inner_product_backward_weights : public primitive { |
7707 | /// Primitive descriptor for an inner product weights gradient primitive. |
7708 | struct primitive_desc : public dnnl::primitive_desc { |
7709 | /// Default constructor. Produces an empty object. |
7710 | primitive_desc() = default; |
7711 | |
7712 | /// Constructs a primitive descriptor for an inner product weights |
7713 | /// update primitive with bias. |
7714 | /// |
7715 | /// @note |
7716 | /// All the memory descriptors may be initialized with the |
7717 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
7718 | /// |
7719 | /// @param aengine Engine to use. |
7720 | /// @param src_desc Memory descriptor for src. |
7721 | /// @param diff_weights_desc Memory descriptor for diff weights. |
7722 | /// @param diff_bias_desc Memory descriptor for diff bias. |
7723 | /// @param diff_dst_desc Memory descriptor for diff dst. |
7724 | /// @param hint_fwd_pd Primitive descriptor for an inner product |
7725 | /// forward propagation primitive. It is used as a hint for |
7726 | /// deciding which memory format to use. |
7727 | /// @param attr Primitive attributes to use. Attributes are optional |
7728 | /// and default to empty attributes. |
7729 | /// @param allow_empty A flag signifying whether construction is |
7730 | /// allowed to fail without throwing an exception. In this case an |
7731 | /// empty object will be produced. This flag is optional and |
7732 | /// defaults to false. |
7733 | primitive_desc(const engine &aengine, const memory::desc &src_desc, |
7734 | const memory::desc &diff_weights_desc, |
7735 | const memory::desc &diff_bias_desc, |
7736 | const memory::desc &diff_dst_desc, |
7737 | const inner_product_forward::primitive_desc &hint_fwd_pd, |
7738 | const primitive_attr &attr = default_attr(), |
7739 | bool allow_empty = false) |
7740 | : primitive_desc(aengine, src_desc, diff_weights_desc, |
7741 | &diff_bias_desc, diff_dst_desc, hint_fwd_pd, attr, |
7742 | allow_empty) {} |
7743 | |
7744 | /// Constructs a primitive descriptor for an inner product weights |
7745 | /// update primitive. |
7746 | /// |
7747 | /// @note |
7748 | /// All the memory descriptors may be initialized with the |
7749 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
7750 | /// |
7751 | /// @param aengine Engine to use. |
7752 | /// @param src_desc Memory descriptor for src. |
7753 | /// @param diff_weights_desc Memory descriptor for diff weights. |
7754 | /// @param diff_dst_desc Memory descriptor for diff dst. |
7755 | /// @param attr Primitive attributes to use. Attributes are optional |
7756 | /// and default to empty attributes. |
7757 | /// @param hint_fwd_pd Primitive descriptor for an inner product |
7758 | /// forward propagation primitive. It is used as a hint for |
7759 | /// deciding which memory format to use. |
7760 | /// @param allow_empty A flag signifying whether construction is |
7761 | /// allowed to fail without throwing an exception. In this case an |
7762 | /// empty object will be produced. This flag is optional and |
7763 | /// defaults to false. |
7764 | primitive_desc(const engine &aengine, const memory::desc &src_desc, |
7765 | const memory::desc &diff_weights_desc, |
7766 | const memory::desc &diff_dst_desc, |
7767 | const inner_product_forward::primitive_desc &hint_fwd_pd, |
7768 | const primitive_attr &attr = default_attr(), |
7769 | bool allow_empty = false) |
7770 | : primitive_desc(aengine, src_desc, diff_weights_desc, nullptr, |
7771 | diff_dst_desc, hint_fwd_pd, attr, allow_empty) {} |
7772 | |
7773 | /// Constructs a primitive descriptor for an inner product weights |
7774 | /// update primitive from a C API primitive descriptor that must |
7775 | /// have a matching kind. |
7776 | /// |
7777 | /// @param pd C API primitive descriptor for an inner product weights |
7778 | /// gradient primitive. |
7779 | primitive_desc(dnnl_primitive_desc_t pd) |
7780 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product, |
7781 | dnnl::prop_kind::backward_weights) {} |
7782 | |
7783 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
7784 | memory::desc src_desc() const { return base::src_desc(0); } |
7785 | |
7786 | /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const |
7787 | memory::desc diff_weights_desc() const { |
7788 | return base::diff_weights_desc(0); |
7789 | } |
7790 | |
7791 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
7792 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
7793 | |
7794 | /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const |
7795 | memory::desc diff_bias_desc() const { |
7796 | return base::diff_weights_desc(1); |
7797 | } |
7798 | |
7799 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
7800 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
7801 | |
7802 | private: |
7803 | primitive_desc(const engine &aengine, const memory::desc &src_desc, |
7804 | const memory::desc &diff_weights_desc, |
7805 | const memory::desc *diff_bias_desc, |
7806 | const memory::desc &diff_dst_desc, |
7807 | const inner_product_forward::primitive_desc &hint_fwd_pd, |
7808 | const primitive_attr &attr, bool allow_empty) { |
7809 | |
7810 | dnnl_primitive_desc_t pd = nullptr; |
7811 | dnnl_status_t status |
7812 | = dnnl_inner_product_backward_weights_primitive_desc_create( |
7813 | &pd, aengine.get(), src_desc.get(), |
7814 | diff_weights_desc.get(), |
7815 | optional_arg(diff_bias_desc), diff_dst_desc.get(), |
7816 | hint_fwd_pd.get(), attr.get()); |
7817 | |
7818 | if (!allow_empty) |
7819 | error::wrap_c_api(status, |
7820 | "could not create a primitive descriptor for an inner " |
7821 | "product weights gradient primitive" ); |
7822 | reset(pd); |
7823 | } |
7824 | }; |
7825 | |
7826 | /// Default constructor. Produces an empty object. |
7827 | inner_product_backward_weights() = default; |
7828 | |
7829 | /// Constructs an inner product weights gradient primitive. |
7830 | /// @param pd Primitive descriptor for an inner product weights gradient |
7831 | /// primitive. |
7832 | inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {} |
7833 | |
7834 | /// Constructs an inner product weights gradient primitive from a cache |
7835 | /// blob. |
7836 | /// @param pd Primitive descriptor for an inner product weights gradient |
7837 | /// primitive. |
7838 | /// @param cache_blob Cache blob. |
7839 | inner_product_backward_weights( |
7840 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
7841 | : primitive(pd, cache_blob) {} |
7842 | }; |
7843 | |
7844 | /// @} dnnl_api_inner_product |
7845 | |
7846 | /// @addtogroup dnnl_api_rnn RNN |
7847 | /// |
7848 | /// A primitive to compute recurrent neural network layers. |
7849 | /// |
7850 | /// @sa @ref dev_guide_rnn in developer guide |
7851 | /// |
7852 | /// @{ |
7853 | |
7854 | /// Base class for primitive descriptors for RNN primitives. |
7855 | struct rnn_primitive_desc_base : public primitive_desc { |
7856 | using primitive_desc::primitive_desc; |
7857 | |
7858 | /// Default constructor. Produces an empty object. |
7859 | rnn_primitive_desc_base() = default; |
7860 | |
7861 | /// Constructs an RNN primitive descriptor base from a C API primitive |
7862 | /// descriptor while checking that it actually describes the expected |
7863 | /// primitive by comparing propagation and primitive kinds. |
7864 | /// |
7865 | /// @param pd C API primitive descriptor. |
7866 | /// @param aprop_kind Expected propagation kind. |
7867 | /// @param cell_kind Expected cell kind. |
7868 | rnn_primitive_desc_base(dnnl_primitive_desc_t pd, |
7869 | dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind) |
7870 | : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {} |
7871 | |
7872 | /// Returns source layer memory descriptor. |
7873 | /// @returns Source layer memory descriptor. |
7874 | memory::desc src_layer_desc() const { |
7875 | return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER); |
7876 | } |
7877 | |
7878 | /// Returns AUGRU attention memory descriptor. |
7879 | /// @returns AUGRU attention memory descriptor. |
7880 | memory::desc augru_attention_desc() const { |
7881 | return base::query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION); |
7882 | } |
7883 | |
7884 | /// Returns source iteration memory descriptor. |
7885 | /// @returns Source iteration memory descriptor. |
7886 | /// @returns A zero memory descriptor if the primitive does not have a |
7887 | /// source iteration parameter. |
7888 | memory::desc src_iter_desc() const { |
7889 | return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER); |
7890 | } |
7891 | |
7892 | /// Returns source recurrent cell state memory descriptor. |
7893 | /// @returns Source recurrent cell state memory descriptor. |
7894 | memory::desc src_iter_c_desc() const { |
7895 | return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C); |
7896 | } |
7897 | |
7898 | /// Returns weights layer memory descriptor. |
7899 | /// @returns Weights layer memory descriptor. |
7900 | memory::desc weights_layer_desc() const { |
7901 | return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER); |
7902 | } |
7903 | |
7904 | /// Returns weights iteration memory descriptor. |
7905 | /// @returns Weights iteration memory descriptor. |
7906 | memory::desc weights_iter_desc() const { |
7907 | return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER); |
7908 | } |
7909 | |
7910 | /// Returns weights peephole memory descriptor. |
7911 | /// @returns Weights peephole memory descriptor. |
7912 | memory::desc weights_peephole_desc() const { |
7913 | return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE); |
7914 | } |
7915 | |
7916 | /// Returns weights projection memory descriptor. |
7917 | /// @returns Weights projection memory descriptor. |
7918 | memory::desc weights_projection_desc() const { |
7919 | return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION); |
7920 | } |
7921 | |
7922 | /// Returns bias memory descriptor. |
7923 | /// @returns Bias memory descriptor. |
7924 | /// @returns A zero memory descriptor if the primitive does not have a |
7925 | /// bias parameter. |
7926 | memory::desc bias_desc() const { |
7927 | return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS); |
7928 | } |
7929 | |
7930 | /// Returns destination layer memory descriptor. |
7931 | /// @returns Destination layer memory descriptor. |
7932 | memory::desc dst_layer_desc() const { |
7933 | return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER); |
7934 | } |
7935 | |
7936 | /// Returns destination iteration memory descriptor. |
7937 | /// @returns Destination iteration memory descriptor. |
7938 | /// @returns A zero memory descriptor if the primitive does not have a |
7939 | /// destination iteration parameter. |
7940 | memory::desc dst_iter_desc() const { |
7941 | return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER); |
7942 | } |
7943 | |
7944 | /// Returns destination recurrent cell state memory descriptor. |
7945 | /// @returns Destination recurrent cell state memory descriptor. |
7946 | memory::desc dst_iter_c_desc() const { |
7947 | return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C); |
7948 | } |
7949 | |
7950 | /// Returns diff source layer memory descriptor. |
7951 | /// @returns Diff source layer memory descriptor. |
7952 | memory::desc diff_src_layer_desc() const { |
7953 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER); |
7954 | } |
7955 | |
7956 | /// Returns diff AUGRU attention memory descriptor. |
7957 | /// @returns Diff AUGRU attention memory descriptor. |
7958 | memory::desc diff_augru_attention_desc() const { |
7959 | return base::query_md( |
7960 | query::exec_arg_md, DNNL_ARG_DIFF_AUGRU_ATTENTION); |
7961 | } |
7962 | |
7963 | /// Returns diff source iteration memory descriptor. |
7964 | /// @returns Diff source iteration memory descriptor. |
7965 | /// @returns A zero memory descriptor if the primitive does not have a |
7966 | /// diff source iteration parameter. |
7967 | memory::desc diff_src_iter_desc() const { |
7968 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER); |
7969 | } |
7970 | |
7971 | /// Returns diff source recurrent cell state memory descriptor. |
7972 | /// @returns Diff source recurrent cell state memory descriptor. |
7973 | memory::desc diff_src_iter_c_desc() const { |
7974 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C); |
7975 | } |
7976 | |
7977 | /// Returns diff weights layer memory descriptor. |
7978 | /// @returns Diff weights layer memory descriptor. |
7979 | memory::desc diff_weights_layer_desc() const { |
7980 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER); |
7981 | } |
7982 | |
7983 | /// Returns diff weights iteration memory descriptor. |
7984 | /// @returns Diff weights iteration memory descriptor. |
7985 | memory::desc diff_weights_iter_desc() const { |
7986 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER); |
7987 | } |
7988 | |
7989 | /// Returns diff weights peephole memory descriptor. |
7990 | /// @returns Diff weights peephole memory descriptor. |
7991 | memory::desc diff_weights_peephole_desc() const { |
7992 | return base::query_md( |
7993 | query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE); |
7994 | } |
7995 | |
7996 | /// Returns diff weights projection memory descriptor. |
7997 | /// @returns Diff weights projection memory descriptor. |
7998 | memory::desc diff_weights_projection_desc() const { |
7999 | return base::query_md( |
8000 | query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION); |
8001 | } |
8002 | |
8003 | /// Returns diff bias memory descriptor. |
8004 | /// @returns Diff bias memory descriptor. |
8005 | /// @returns A zero memory descriptor if the primitive does not have a |
8006 | /// diff bias parameter. |
8007 | memory::desc diff_bias_desc() const { |
8008 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS); |
8009 | } |
8010 | |
8011 | /// Returns diff destination layer memory descriptor. |
8012 | /// @returns Diff destination layer memory descriptor. |
8013 | memory::desc diff_dst_layer_desc() const { |
8014 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER); |
8015 | } |
8016 | |
8017 | /// Returns diff destination iteration memory descriptor. |
8018 | /// @returns Diff destination iteration memory descriptor. |
8019 | /// @returns A zero memory descriptor if the primitive does not have a |
8020 | /// diff destination iteration parameter. |
8021 | memory::desc diff_dst_iter_desc() const { |
8022 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER); |
8023 | } |
8024 | |
8025 | /// Returns diff destination recurrent cell state memory descriptor. |
8026 | /// @returns Diff destination recurrent cell state memory descriptor. |
8027 | memory::desc diff_dst_iter_c_desc() const { |
8028 | return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C); |
8029 | } |
8030 | |
8031 | protected: |
8032 | using rnn_base = rnn_primitive_desc_base; |
8033 | |
8034 | // (Deliberately not using doxygen comments) |
8035 | // |
8036 | // Constructs an RNN primitive descriptor base from a C API primitive |
8037 | // descriptor while checking that it actually describes the expected |
8038 | // primitive by comparing propagation and primitive kinds. Caller can |
8039 | // pass two options propagation kinds. This is typically used to check |
8040 | // that propagation kind is inference or training forward propagation. |
8041 | // |
8042 | // @param pd C API primitive descriptor. |
8043 | // @param prop_kind1 Expected propagation kind. |
8044 | // @param prop_kind2 Expected propagation kind. |
8045 | // @param cell_kind Expected cell kind. |
8046 | rnn_primitive_desc_base(dnnl_primitive_desc_t pd, |
8047 | dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2, |
8048 | dnnl::algorithm cell_kind) { |
8049 | |
8050 | dnnl_status_t rc; |
8051 | |
8052 | dnnl_primitive_kind_t q_primitive_kind; |
8053 | rc = dnnl_primitive_desc_query( |
8054 | pd, dnnl_query_primitive_kind, 0, &q_primitive_kind); |
8055 | error::wrap_c_api(rc, |
8056 | "could not retrieve a primitive kind from a primitive " |
8057 | "descriptor for an RNN primitive" ); |
8058 | |
8059 | dnnl_prop_kind_t q_prop_kind; |
8060 | rc = dnnl_primitive_desc_query( |
8061 | pd, dnnl_query_prop_kind, 0, &q_prop_kind); |
8062 | error::wrap_c_api(rc, |
8063 | "could not retrieve a propagation kind from a primitive " |
8064 | "descriptor for an RNN primitive" ); |
8065 | |
8066 | dnnl_alg_kind_t q_cell_kind; |
8067 | rc = dnnl_primitive_desc_query( |
8068 | pd, dnnl_query_cell_kind, 0, &q_cell_kind); |
8069 | error::wrap_c_api(rc, |
8070 | "could not retrieve a cell kind from a primitive descriptor " |
8071 | "for an RNN primitive" ); |
8072 | |
8073 | dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1); |
8074 | dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2); |
8075 | dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind); |
8076 | |
8077 | bool ok = q_primitive_kind == dnnl_rnn |
8078 | && (q_prop_kind == c_prop_kind1 || q_prop_kind == c_prop_kind2) |
8079 | && q_cell_kind == c_cell_kind; |
8080 | |
8081 | if (!ok) |
8082 | DNNL_THROW_ERROR(dnnl_invalid_arguments, |
8083 | "mismatch between expected and provided descriptors for an " |
8084 | "RNN primitive" ); |
8085 | |
8086 | reset_with_clone(pd); |
8087 | } |
8088 | |
8089 | // Constructs an RNN forward propagation primitive descriptor base for |
8090 | // any cell kind. |
8091 | rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind, |
8092 | prop_kind aprop_kind, algorithm activation, rnn_direction direction, |
8093 | const memory::desc &src_layer_desc, |
8094 | const memory::desc &src_iter_desc, |
8095 | const memory::desc *src_iter_c_desc, |
8096 | const memory::desc *attention_desc, |
8097 | const memory::desc &weights_layer_desc, |
8098 | const memory::desc &weights_iter_desc, |
8099 | const memory::desc *weights_peephole_desc, |
8100 | const memory::desc *weights_projection_desc, |
8101 | const memory::desc &bias_desc, const memory::desc &dst_layer_desc, |
8102 | const memory::desc &dst_iter_desc, |
8103 | const memory::desc *dst_iter_c_desc, rnn_flags flags, float alpha, |
8104 | float beta, const primitive_attr &attr, bool allow_empty) { |
8105 | |
8106 | dnnl_status_t status = dnnl_success; |
8107 | const char *msg |
8108 | = "could not create a primitive descriptor for a requested " |
8109 | "cell kind" ; |
8110 | |
8111 | dnnl_primitive_desc_t pd = nullptr; |
8112 | switch (cell_kind) { |
8113 | case algorithm::vanilla_rnn: |
8114 | status = dnnl_vanilla_rnn_forward_primitive_desc_create(&pd, |
8115 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8116 | dnnl::convert_to_c(activation), |
8117 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8118 | src_iter_desc.get(), weights_layer_desc.get(), |
8119 | weights_iter_desc.get(), bias_desc.get(), |
8120 | dst_layer_desc.get(), dst_iter_desc.get(), |
8121 | convert_to_c(flags), alpha, beta, attr.get()); |
8122 | msg = "could not create a primitive descriptor for a vanilla " |
8123 | "RNN forward propagation primitive" ; |
8124 | break; |
8125 | case algorithm::vanilla_lstm: |
8126 | status = dnnl_lstm_forward_primitive_desc_create(&pd, |
8127 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8128 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8129 | src_iter_desc.get(), optional_arg(src_iter_c_desc), |
8130 | weights_layer_desc.get(), weights_iter_desc.get(), |
8131 | optional_arg(weights_peephole_desc), |
8132 | optional_arg(weights_projection_desc), bias_desc.get(), |
8133 | dst_layer_desc.get(), dst_iter_desc.get(), |
8134 | optional_arg(dst_iter_c_desc), convert_to_c(flags), |
8135 | attr.get()); |
8136 | msg = "could not create a primitive descriptor for an LSTM " |
8137 | "forward propagation primitive" ; |
8138 | break; |
8139 | case algorithm::vanilla_gru: |
8140 | status = dnnl_gru_forward_primitive_desc_create(&pd, |
8141 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8142 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8143 | src_iter_desc.get(), weights_layer_desc.get(), |
8144 | weights_iter_desc.get(), bias_desc.get(), |
8145 | dst_layer_desc.get(), dst_iter_desc.get(), |
8146 | convert_to_c(flags), attr.get()); |
8147 | msg = "could not create a primitive descriptor for a GRU " |
8148 | "forward propagation primitive" ; |
8149 | break; |
8150 | case algorithm::lbr_gru: |
8151 | status = dnnl_lbr_gru_forward_primitive_desc_create(&pd, |
8152 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8153 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8154 | src_iter_desc.get(), weights_layer_desc.get(), |
8155 | weights_iter_desc.get(), bias_desc.get(), |
8156 | dst_layer_desc.get(), dst_iter_desc.get(), |
8157 | convert_to_c(flags), attr.get()); |
8158 | msg = "could not create a primitive descriptor for an LBR GRU " |
8159 | "forward propagation primitive" ; |
8160 | break; |
8161 | case algorithm::vanilla_augru: |
8162 | status = dnnl_augru_forward_primitive_desc_create(&pd, |
8163 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8164 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8165 | src_iter_desc.get(), optional_arg(attention_desc), |
8166 | weights_layer_desc.get(), weights_iter_desc.get(), |
8167 | bias_desc.get(), dst_layer_desc.get(), |
8168 | dst_iter_desc.get(), convert_to_c(flags), attr.get()); |
8169 | msg = "could not create a primitive descriptor for an AUGRU " |
8170 | "forward propagation primitive" ; |
8171 | break; |
8172 | case algorithm::lbr_augru: |
8173 | status = dnnl_lbr_augru_forward_primitive_desc_create(&pd, |
8174 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8175 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8176 | src_iter_desc.get(), optional_arg(attention_desc), |
8177 | weights_layer_desc.get(), weights_iter_desc.get(), |
8178 | bias_desc.get(), dst_layer_desc.get(), |
8179 | dst_iter_desc.get(), convert_to_c(flags), attr.get()); |
8180 | msg = "could not create a primitive descriptor for an LBR " |
8181 | "AUGRU forward propagation primitive" ; |
8182 | break; |
8183 | default: status = dnnl_unimplemented; |
8184 | } |
8185 | |
8186 | if (!allow_empty) error::wrap_c_api(status, msg); |
8187 | reset(pd); |
8188 | } |
8189 | |
8190 | // Constructs an RNN backward propagation primitive descriptor base for |
8191 | // any cell kind. |
8192 | rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind, |
8193 | prop_kind aprop_kind, algorithm activation, rnn_direction direction, |
8194 | const memory::desc &src_layer_desc, |
8195 | const memory::desc &src_iter_desc, |
8196 | const memory::desc *src_iter_c_desc, |
8197 | const memory::desc *attention_desc, |
8198 | const memory::desc &weights_layer_desc, |
8199 | const memory::desc &weights_iter_desc, |
8200 | const memory::desc *weights_peephole_desc, |
8201 | const memory::desc *weights_projection_desc, |
8202 | const memory::desc &bias_desc, const memory::desc &dst_layer_desc, |
8203 | const memory::desc &dst_iter_desc, |
8204 | const memory::desc *dst_iter_c_desc, |
8205 | const memory::desc &diff_src_layer_desc, |
8206 | const memory::desc &diff_src_iter_desc, |
8207 | const memory::desc *diff_src_iter_c_desc, |
8208 | const memory::desc *diff_attention_desc, |
8209 | const memory::desc &diff_weights_layer_desc, |
8210 | const memory::desc &diff_weights_iter_desc, |
8211 | const memory::desc *diff_weights_peephole_desc, |
8212 | const memory::desc *diff_weights_projection_desc, |
8213 | const memory::desc &diff_bias_desc, |
8214 | const memory::desc &diff_dst_layer_desc, |
8215 | const memory::desc &diff_dst_iter_desc, |
8216 | const memory::desc *diff_dst_iter_c_desc, rnn_flags flags, |
8217 | float alpha, float beta, const rnn_primitive_desc_base &hint_fwd_pd, |
8218 | const primitive_attr &attr, bool allow_empty) { |
8219 | |
8220 | dnnl_status_t status = dnnl_success; |
8221 | const char *msg = "" ; |
8222 | |
8223 | dnnl_primitive_desc_t pd = nullptr; |
8224 | switch (cell_kind) { |
8225 | case algorithm::vanilla_rnn: |
8226 | status = dnnl_vanilla_rnn_backward_primitive_desc_create(&pd, |
8227 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8228 | dnnl::convert_to_c(activation), |
8229 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8230 | src_iter_desc.get(), weights_layer_desc.get(), |
8231 | weights_iter_desc.get(), bias_desc.get(), |
8232 | dst_layer_desc.get(), dst_iter_desc.get(), |
8233 | diff_src_layer_desc.get(), diff_src_iter_desc.get(), |
8234 | diff_weights_layer_desc.get(), |
8235 | diff_weights_iter_desc.get(), diff_bias_desc.get(), |
8236 | diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), |
8237 | convert_to_c(flags), alpha, beta, hint_fwd_pd.get(), |
8238 | attr.get()); |
8239 | msg = "could not create a primitive descriptor for a vanilla " |
8240 | "RNN backward propagation primitive" ; |
8241 | break; |
8242 | case algorithm::vanilla_lstm: |
8243 | status = dnnl_lstm_backward_primitive_desc_create(&pd, |
8244 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8245 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8246 | src_iter_desc.get(), optional_arg(src_iter_c_desc), |
8247 | weights_layer_desc.get(), weights_iter_desc.get(), |
8248 | optional_arg(weights_peephole_desc), |
8249 | optional_arg(weights_projection_desc), bias_desc.get(), |
8250 | dst_layer_desc.get(), dst_iter_desc.get(), |
8251 | optional_arg(dst_iter_c_desc), |
8252 | diff_src_layer_desc.get(), diff_src_iter_desc.get(), |
8253 | optional_arg(diff_src_iter_c_desc), |
8254 | diff_weights_layer_desc.get(), |
8255 | diff_weights_iter_desc.get(), |
8256 | optional_arg(diff_weights_peephole_desc), |
8257 | optional_arg(diff_weights_projection_desc), |
8258 | diff_bias_desc.get(), diff_dst_layer_desc.get(), |
8259 | diff_dst_iter_desc.get(), |
8260 | optional_arg(diff_dst_iter_c_desc), convert_to_c(flags), |
8261 | hint_fwd_pd.get(), attr.get()); |
8262 | msg = "could not create a primitive descriptor for an LSTM " |
8263 | "backward propagation primitive" ; |
8264 | break; |
8265 | case algorithm::vanilla_gru: |
8266 | status = dnnl_gru_backward_primitive_desc_create(&pd, |
8267 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8268 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8269 | src_iter_desc.get(), weights_layer_desc.get(), |
8270 | weights_iter_desc.get(), bias_desc.get(), |
8271 | dst_layer_desc.get(), dst_iter_desc.get(), |
8272 | diff_src_layer_desc.get(), diff_src_iter_desc.get(), |
8273 | diff_weights_layer_desc.get(), |
8274 | diff_weights_iter_desc.get(), diff_bias_desc.get(), |
8275 | diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), |
8276 | convert_to_c(flags), hint_fwd_pd.get(), attr.get()); |
8277 | msg = "could not create a primitive descriptor for a GRU " |
8278 | "backward propagation primitive" ; |
8279 | break; |
8280 | case algorithm::lbr_gru: |
8281 | status = dnnl_lbr_gru_backward_primitive_desc_create(&pd, |
8282 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8283 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8284 | src_iter_desc.get(), weights_layer_desc.get(), |
8285 | weights_iter_desc.get(), bias_desc.get(), |
8286 | dst_layer_desc.get(), dst_iter_desc.get(), |
8287 | diff_src_layer_desc.get(), diff_src_iter_desc.get(), |
8288 | diff_weights_layer_desc.get(), |
8289 | diff_weights_iter_desc.get(), diff_bias_desc.get(), |
8290 | diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), |
8291 | convert_to_c(flags), hint_fwd_pd.get(), attr.get()); |
8292 | msg = "could not create a primitive descriptor for an LBR GRU " |
8293 | "backward propagation primitive" ; |
8294 | break; |
8295 | case algorithm::vanilla_augru: |
8296 | status = dnnl_augru_backward_primitive_desc_create(&pd, |
8297 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8298 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8299 | src_iter_desc.get(), optional_arg(attention_desc), |
8300 | weights_layer_desc.get(), weights_iter_desc.get(), |
8301 | bias_desc.get(), dst_layer_desc.get(), |
8302 | dst_iter_desc.get(), diff_src_layer_desc.get(), |
8303 | diff_src_iter_desc.get(), |
8304 | optional_arg(diff_attention_desc), |
8305 | diff_weights_layer_desc.get(), |
8306 | diff_weights_iter_desc.get(), diff_bias_desc.get(), |
8307 | diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), |
8308 | convert_to_c(flags), hint_fwd_pd.get(), attr.get()); |
8309 | msg = "could not create a primitive descriptor for an AUGRU " |
8310 | "backward propagation primitive" ; |
8311 | break; |
8312 | case algorithm::lbr_augru: |
8313 | status = dnnl_lbr_augru_backward_primitive_desc_create(&pd, |
8314 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
8315 | dnnl::convert_to_c(direction), src_layer_desc.get(), |
8316 | src_iter_desc.get(), optional_arg(attention_desc), |
8317 | weights_layer_desc.get(), weights_iter_desc.get(), |
8318 | bias_desc.get(), dst_layer_desc.get(), |
8319 | dst_iter_desc.get(), diff_src_layer_desc.get(), |
8320 | diff_src_iter_desc.get(), |
8321 | optional_arg(diff_attention_desc), |
8322 | diff_weights_layer_desc.get(), |
8323 | diff_weights_iter_desc.get(), diff_bias_desc.get(), |
8324 | diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), |
8325 | convert_to_c(flags), hint_fwd_pd.get(), attr.get()); |
8326 | msg = "could not create a primitive descriptor for an LBR " |
8327 | "AUGRU backward propagation primitive" ; |
8328 | break; |
8329 | default: status = dnnl_unimplemented; |
8330 | } |
8331 | if (!allow_empty) error::wrap_c_api(status, msg); |
8332 | reset(pd); |
8333 | } |
8334 | }; |
8335 | |
8336 | /// Vanilla RNN forward propagation primitive. |
8337 | struct vanilla_rnn_forward : public primitive { |
8338 | /// Primitive descriptor for a vanilla RNN forward propagation primitive. |
8339 | struct primitive_desc : public rnn_primitive_desc_base { |
8340 | /// Default constructor. Produces an empty object. |
8341 | primitive_desc() = default; |
8342 | |
8343 | /// Constructs a primitive descriptor for a vanilla RNN forward |
8344 | /// propagation primitive. |
8345 | /// |
8346 | /// The following arguments may point to a zero memory descriptor: |
8347 | /// - @p src_iter_desc, |
8348 | /// - @p bias_desc, |
8349 | /// - @p dst_iter_desc. |
8350 | /// |
8351 | /// This would then indicate that the RNN forward propagation primitive |
8352 | /// should not use them and should default to zero values instead. |
8353 | /// |
8354 | /// @note |
8355 | /// All memory descriptors except @p src_iter_desc can be |
8356 | /// initialized with an #dnnl::memory::format_tag::any value of @p |
8357 | /// format_tag. |
8358 | /// |
8359 | /// @param aengine Engine to use. |
8360 | /// @param aprop_kind Propagation kind. Possible values are |
8361 | /// #dnnl::prop_kind::forward_training, and |
8362 | /// #dnnl::prop_kind::forward_inference. |
8363 | /// @param activation Activation kind. Possible values are |
8364 | /// #dnnl::algorithm::eltwise_relu, |
8365 | /// #dnnl::algorithm::eltwise_tanh, or |
8366 | /// #dnnl::algorithm::eltwise_logistic. |
8367 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
8368 | /// more info. |
8369 | /// @param src_layer_desc Memory descriptor for the input vector. |
8370 | /// @param src_iter_desc Memory descriptor for the input recurrent |
8371 | /// hidden state vector. |
8372 | /// @param weights_layer_desc Memory descriptor for the weights |
8373 | /// applied to the layer input. |
8374 | /// @param weights_iter_desc Memory descriptor for the weights applied |
8375 | /// to the recurrent input. |
8376 | /// @param bias_desc Bias memory descriptor. |
8377 | /// @param dst_layer_desc Memory descriptor for the output vector. |
8378 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
8379 | /// hidden state vector. |
8380 | /// @param attr Primitive attributes to use. Attributes are optional |
8381 | /// and default to empty attributes. |
8382 | /// @param allow_empty A flag signifying whether construction is |
8383 | /// allowed to fail without throwing an exception. In this case an |
8384 | /// empty object will be produced. This flag is optional and |
8385 | /// defaults to false. |
8386 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
8387 | algorithm activation, rnn_direction direction, |
8388 | const memory::desc &src_layer_desc, |
8389 | const memory::desc &src_iter_desc, |
8390 | const memory::desc &weights_layer_desc, |
8391 | const memory::desc &weights_iter_desc, |
8392 | const memory::desc &bias_desc, |
8393 | const memory::desc &dst_layer_desc, |
8394 | const memory::desc &dst_iter_desc, |
8395 | const primitive_attr &attr = default_attr(), |
8396 | bool allow_empty = false) |
8397 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn, |
8398 | aprop_kind, activation, direction, src_layer_desc, |
8399 | src_iter_desc, nullptr, nullptr, weights_layer_desc, |
8400 | weights_iter_desc, nullptr, nullptr, bias_desc, |
8401 | dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, |
8402 | 0.0f, 0.0f, attr, allow_empty) {} |
8403 | |
8404 | /// Constructs a primitive descriptor for a vanilla RNN forward |
8405 | /// propagation primitive with alpha parameter. |
8406 | /// |
8407 | /// The following arguments may point to a zero memory descriptor: |
8408 | /// - @p src_iter_desc, |
8409 | /// - @p bias_desc, |
8410 | /// - @p dst_iter_desc. |
8411 | /// |
8412 | /// This would then indicate that the RNN forward propagation primitive |
8413 | /// should not use them and should default to zero values instead. |
8414 | /// |
8415 | /// @note |
8416 | /// All memory descriptors except @p src_iter_desc can be |
8417 | /// initialized with an #dnnl::memory::format_tag::any value of @p |
8418 | /// format_tag. |
8419 | /// |
8420 | /// @param aengine Engine to use. |
8421 | /// @param aprop_kind Propagation kind. Possible values are |
8422 | /// #dnnl::prop_kind::forward_training, and |
8423 | /// #dnnl::prop_kind::forward_inference. |
8424 | /// @param activation Activation kind. Possible values are |
8425 | /// #dnnl::algorithm::eltwise_relu, |
8426 | /// #dnnl::algorithm::eltwise_tanh, or |
8427 | /// #dnnl::algorithm::eltwise_logistic. |
8428 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
8429 | /// more info. |
8430 | /// @param src_layer_desc Memory descriptor for the input vector. |
8431 | /// @param src_iter_desc Memory descriptor for the input recurrent |
8432 | /// hidden state vector. |
8433 | /// @param weights_layer_desc Memory descriptor for the weights |
8434 | /// applied to the layer input. |
8435 | /// @param weights_iter_desc Memory descriptor for the weights applied |
8436 | /// to the recurrent input. |
8437 | /// @param bias_desc Bias memory descriptor. |
8438 | /// @param dst_layer_desc Memory descriptor for the output vector. |
8439 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
8440 | /// hidden state vector. |
8441 | /// @param alpha Negative slope if activation is |
8442 | /// #dnnl::algorithm::eltwise_relu. |
8443 | /// @param attr Primitive attributes to use. Attributes are optional |
8444 | /// and default to empty attributes. |
8445 | /// @param allow_empty A flag signifying whether construction is |
8446 | /// allowed to fail without throwing an exception. In this case an |
8447 | /// empty object will be produced. This flag is optional and |
8448 | /// defaults to false. |
8449 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
8450 | algorithm activation, rnn_direction direction, |
8451 | const memory::desc &src_layer_desc, |
8452 | const memory::desc &src_iter_desc, |
8453 | const memory::desc &weights_layer_desc, |
8454 | const memory::desc &weights_iter_desc, |
8455 | const memory::desc &bias_desc, |
8456 | const memory::desc &dst_layer_desc, |
8457 | const memory::desc &dst_iter_desc, float alpha, |
8458 | const primitive_attr &attr = default_attr(), |
8459 | bool allow_empty = false) |
8460 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn, |
8461 | aprop_kind, activation, direction, src_layer_desc, |
8462 | src_iter_desc, nullptr, nullptr, weights_layer_desc, |
8463 | weights_iter_desc, nullptr, nullptr, bias_desc, |
8464 | dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, |
8465 | alpha, 0.0f, attr, allow_empty) {} |
8466 | |
8467 | /// Constructs a primitive descriptor for a vanilla RNN forward |
8468 | /// propagation primitive from a C API primitive descriptor that must |
8469 | /// have a matching kind. |
8470 | /// |
8471 | /// @param pd C API primitive descriptor for a vanilla RNN forward |
8472 | /// propagation primitive. |
8473 | primitive_desc(dnnl_primitive_desc_t pd) |
8474 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, |
8475 | dnnl::prop_kind::forward_inference, |
8476 | dnnl::algorithm::vanilla_rnn) {} |
8477 | |
8478 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
8479 | memory::desc src_layer_desc() const { |
8480 | return rnn_base::src_layer_desc(); |
8481 | } |
8482 | |
8483 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
8484 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
8485 | |
8486 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
8487 | memory::desc weights_layer_desc() const { |
8488 | return rnn_base::weights_layer_desc(); |
8489 | } |
8490 | |
8491 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
8492 | memory::desc weights_iter_desc() const { |
8493 | return rnn_base::weights_iter_desc(); |
8494 | } |
8495 | |
8496 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
8497 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
8498 | |
8499 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
8500 | memory::desc dst_layer_desc() const { |
8501 | return rnn_base::dst_layer_desc(); |
8502 | } |
8503 | |
8504 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
8505 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
8506 | |
8507 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
8508 | memory::desc workspace_desc() const { |
8509 | return rnn_base::workspace_desc(); |
8510 | } |
8511 | |
8512 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
8513 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
8514 | |
8515 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
8516 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
8517 | |
8518 | /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const |
8519 | algorithm get_activation_kind() const { |
8520 | return base::get_activation_kind(); |
8521 | } |
8522 | |
8523 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
8524 | rnn_direction get_direction() const { return base::get_direction(); } |
8525 | |
8526 | /// @copydoc dnnl::primitive_desc_base::get_alpha()const |
8527 | float get_alpha() const { return base::get_alpha(); } |
8528 | |
8529 | /// @copydoc dnnl::primitive_desc_base::get_beta()const |
8530 | float get_beta() const { return base::get_beta(); } |
8531 | }; |
8532 | |
8533 | /// Default constructor. Produces an empty object. |
8534 | vanilla_rnn_forward() = default; |
8535 | |
8536 | /// Constructs a vanilla RNN forward propagation primitive. |
8537 | /// @param pd Primitive descriptor for a vanilla RNN forward |
8538 | /// propagation primitive. |
8539 | vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {} |
8540 | |
8541 | /// Constructs a vanilla RNN forward propagation primitive from |
8542 | /// a cache blob. |
8543 | /// @param pd Primitive descriptor for a vanilla RNN forward |
8544 | /// propagation primitive. |
8545 | /// @param cache_blob Cache blob. |
8546 | vanilla_rnn_forward( |
8547 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
8548 | : primitive(pd, cache_blob) {} |
8549 | }; |
8550 | |
8551 | /// Vanilla RNN backward propagation primitive. |
8552 | struct vanilla_rnn_backward : public primitive { |
8553 | /// Primitive descriptor for an RNN backward propagation primitive. |
8554 | struct primitive_desc : public rnn_primitive_desc_base { |
8555 | /// Default constructor. Produces an empty object. |
8556 | primitive_desc() = default; |
8557 | |
8558 | /// Constructs a primitive descriptor for a vanilla RNN backward |
8559 | /// propagation primitive. |
8560 | /// |
8561 | /// The following arguments may point to a zero memory descriptor: |
8562 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
8563 | /// - @p bias_desc together with @p diff_bias_desc, |
8564 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
8565 | /// |
8566 | /// This would then indicate that the RNN backward propagation |
8567 | /// primitive should not use the respective data and should use zero |
8568 | /// values instead. |
8569 | /// |
8570 | /// @note |
8571 | /// All the memory descriptors may be initialized with the |
8572 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
8573 | /// |
8574 | /// @param aengine Engine to use. |
8575 | /// @param aprop_kind Propagation kind. Must be |
8576 | /// #dnnl::prop_kind::backward. |
8577 | /// @param activation Activation kind. Possible values are |
8578 | /// #dnnl::algorithm::eltwise_relu, |
8579 | /// #dnnl::algorithm::eltwise_tanh, or |
8580 | /// #dnnl::algorithm::eltwise_logistic. |
8581 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
8582 | /// more info. |
8583 | /// @param src_layer_desc Memory descriptor for the input vector. |
8584 | /// @param src_iter_desc Memory descriptor for the input recurrent |
8585 | /// hidden state vector. |
8586 | /// @param weights_layer_desc Memory descriptor for the weights |
8587 | /// applied to the layer input. |
8588 | /// @param weights_iter_desc Memory descriptor for the weights applied |
8589 | /// to the recurrent input. |
8590 | /// @param bias_desc Bias memory descriptor. |
8591 | /// @param dst_layer_desc Memory descriptor for the output vector. |
8592 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
8593 | /// hidden state vector. |
8594 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
8595 | /// vector. |
8596 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
8597 | /// recurrent hidden state vector. |
8598 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
8599 | /// weights applied to the layer input. |
8600 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
8601 | /// weights applied to the recurrent input. |
8602 | /// @param diff_bias_desc Diff bias memory descriptor. |
8603 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
8604 | /// output vector. |
8605 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
8606 | /// recurrent hidden state vector. |
8607 | /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN |
8608 | /// forward propagation primitive. It is used as a hint for |
8609 | /// deciding which memory format to use. |
8610 | /// @param attr Primitive attributes to use. Attributes are optional |
8611 | /// and default to empty attributes. |
8612 | /// @param allow_empty A flag signifying whether construction is |
8613 | /// allowed to fail without throwing an exception. In this case an |
8614 | /// empty object will be produced. This flag is optional and |
8615 | /// defaults to false. |
8616 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
8617 | algorithm activation, rnn_direction direction, |
8618 | const memory::desc &src_layer_desc, |
8619 | const memory::desc &src_iter_desc, |
8620 | const memory::desc &weights_layer_desc, |
8621 | const memory::desc &weights_iter_desc, |
8622 | const memory::desc &bias_desc, |
8623 | const memory::desc &dst_layer_desc, |
8624 | const memory::desc &dst_iter_desc, |
8625 | const memory::desc &diff_src_layer_desc, |
8626 | const memory::desc &diff_src_iter_desc, |
8627 | const memory::desc &diff_weights_layer_desc, |
8628 | const memory::desc &diff_weights_iter_desc, |
8629 | const memory::desc &diff_bias_desc, |
8630 | const memory::desc &diff_dst_layer_desc, |
8631 | const memory::desc &diff_dst_iter_desc, |
8632 | const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, |
8633 | const primitive_attr &attr = default_attr(), |
8634 | bool allow_empty = false) |
8635 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn, |
8636 | aprop_kind, activation, direction, src_layer_desc, |
8637 | src_iter_desc, nullptr, nullptr, weights_layer_desc, |
8638 | weights_iter_desc, nullptr, nullptr, bias_desc, |
8639 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
8640 | diff_src_iter_desc, nullptr, nullptr, |
8641 | diff_weights_layer_desc, diff_weights_iter_desc, nullptr, |
8642 | nullptr, diff_bias_desc, diff_dst_layer_desc, |
8643 | diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, |
8644 | hint_fwd_pd, attr, allow_empty) {} |
8645 | |
8646 | /// Constructs a primitive descriptor for a vanilla RNN backward |
8647 | /// propagation primitive with an alpha parameter. |
8648 | /// |
8649 | /// The following arguments may point to a zero memory descriptor: |
8650 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
8651 | /// - @p bias_desc together with @p diff_bias_desc, |
8652 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
8653 | /// |
8654 | /// This would then indicate that the RNN backward propagation |
8655 | /// primitive should not use the respective data and should use zero |
8656 | /// values instead. |
8657 | /// |
8658 | /// @note |
8659 | /// All the memory descriptors may be initialized with the |
8660 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
8661 | /// |
8662 | /// @param aengine Engine to use. |
8663 | /// @param aprop_kind Propagation kind. Must be |
8664 | /// #dnnl::prop_kind::backward. |
8665 | /// @param activation Activation kind. Possible values are |
8666 | /// #dnnl::algorithm::eltwise_relu, |
8667 | /// #dnnl::algorithm::eltwise_tanh, or |
8668 | /// #dnnl::algorithm::eltwise_logistic. |
8669 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
8670 | /// more info. |
8671 | /// @param src_layer_desc Memory descriptor for the input vector. |
8672 | /// @param src_iter_desc Memory descriptor for the input recurrent |
8673 | /// hidden state vector. |
8674 | /// @param weights_layer_desc Memory descriptor for the weights |
8675 | /// applied to the layer input. |
8676 | /// @param weights_iter_desc Memory descriptor for the weights applied |
8677 | /// to the recurrent input. |
8678 | /// @param bias_desc Bias memory descriptor. |
8679 | /// @param dst_layer_desc Memory descriptor for the output vector. |
8680 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
8681 | /// hidden state vector. |
8682 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
8683 | /// vector. |
8684 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
8685 | /// recurrent hidden state vector. |
8686 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
8687 | /// weights applied to the layer input. |
8688 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
8689 | /// weights applied to the recurrent input. |
8690 | /// @param diff_bias_desc Diff bias memory descriptor. |
8691 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
8692 | /// output vector. |
8693 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
8694 | /// recurrent hidden state vector. |
8695 | /// @param alpha Negative slope if activation is |
8696 | /// #dnnl::algorithm::eltwise_relu. |
8697 | /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN |
8698 | /// forward propagation primitive. It is used as a hint for |
8699 | /// deciding which memory format to use. |
8700 | /// @param attr Primitive attributes to use. Attributes are optional |
8701 | /// and default to empty attributes. |
8702 | /// @param allow_empty A flag signifying whether construction is |
8703 | /// allowed to fail without throwing an exception. In this case an |
8704 | /// empty object will be produced. This flag is optional and |
8705 | /// defaults to false. |
8706 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
8707 | algorithm activation, rnn_direction direction, |
8708 | const memory::desc &src_layer_desc, |
8709 | const memory::desc &src_iter_desc, |
8710 | const memory::desc &weights_layer_desc, |
8711 | const memory::desc &weights_iter_desc, |
8712 | const memory::desc &bias_desc, |
8713 | const memory::desc &dst_layer_desc, |
8714 | const memory::desc &dst_iter_desc, |
8715 | const memory::desc &diff_src_layer_desc, |
8716 | const memory::desc &diff_src_iter_desc, |
8717 | const memory::desc &diff_weights_layer_desc, |
8718 | const memory::desc &diff_weights_iter_desc, |
8719 | const memory::desc &diff_bias_desc, |
8720 | const memory::desc &diff_dst_layer_desc, |
8721 | const memory::desc &diff_dst_iter_desc, float alpha, |
8722 | const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, |
8723 | const primitive_attr &attr = default_attr(), |
8724 | bool allow_empty = false) |
8725 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn, |
8726 | aprop_kind, activation, direction, src_layer_desc, |
8727 | src_iter_desc, nullptr, nullptr, weights_layer_desc, |
8728 | weights_iter_desc, nullptr, nullptr, bias_desc, |
8729 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
8730 | diff_src_iter_desc, nullptr, nullptr, |
8731 | diff_weights_layer_desc, diff_weights_iter_desc, nullptr, |
8732 | nullptr, diff_bias_desc, diff_dst_layer_desc, |
8733 | diff_dst_iter_desc, nullptr, rnn_flags::undef, alpha, 0.0f, |
8734 | hint_fwd_pd, attr, allow_empty) {} |
8735 | |
8736 | /// Constructs a primitive descriptor for a vanilla RNN backward |
8737 | /// propagation primitive from a C API primitive descriptor that must |
8738 | /// have a matching kind. |
8739 | /// |
8740 | /// @param pd C API primitive descriptor for a vanilla RNN backward |
8741 | /// propagation primitive. |
8742 | primitive_desc(dnnl_primitive_desc_t pd) |
8743 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, |
8744 | dnnl::algorithm::vanilla_rnn) {} |
8745 | |
8746 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
8747 | memory::desc src_layer_desc() const { |
8748 | return rnn_base::src_layer_desc(); |
8749 | } |
8750 | |
8751 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
8752 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
8753 | |
8754 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
8755 | memory::desc weights_layer_desc() const { |
8756 | return rnn_base::weights_layer_desc(); |
8757 | } |
8758 | |
8759 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
8760 | memory::desc weights_iter_desc() const { |
8761 | return rnn_base::weights_iter_desc(); |
8762 | } |
8763 | |
8764 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
8765 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
8766 | |
8767 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
8768 | memory::desc dst_layer_desc() const { |
8769 | return rnn_base::dst_layer_desc(); |
8770 | } |
8771 | |
8772 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
8773 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
8774 | |
8775 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
8776 | memory::desc workspace_desc() const { |
8777 | return rnn_base::workspace_desc(); |
8778 | } |
8779 | |
8780 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const |
8781 | memory::desc diff_src_layer_desc() const { |
8782 | return rnn_base::diff_src_layer_desc(); |
8783 | } |
8784 | |
8785 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const |
8786 | memory::desc diff_src_iter_desc() const { |
8787 | return rnn_base::diff_src_iter_desc(); |
8788 | } |
8789 | |
8790 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const |
8791 | memory::desc diff_weights_layer_desc() const { |
8792 | return rnn_base::diff_weights_layer_desc(); |
8793 | } |
8794 | |
8795 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const |
8796 | memory::desc diff_weights_iter_desc() const { |
8797 | return rnn_base::diff_weights_iter_desc(); |
8798 | } |
8799 | |
8800 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const |
8801 | memory::desc diff_bias_desc() const { |
8802 | return rnn_base::diff_bias_desc(); |
8803 | } |
8804 | |
8805 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const |
8806 | memory::desc diff_dst_layer_desc() const { |
8807 | return rnn_base::diff_dst_layer_desc(); |
8808 | } |
8809 | |
8810 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const |
8811 | memory::desc diff_dst_iter_desc() const { |
8812 | return rnn_base::diff_dst_iter_desc(); |
8813 | } |
8814 | |
8815 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
8816 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
8817 | |
8818 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
8819 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
8820 | |
8821 | /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const |
8822 | algorithm get_activation_kind() const { |
8823 | return base::get_activation_kind(); |
8824 | } |
8825 | |
8826 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
8827 | rnn_direction get_direction() const { return base::get_direction(); } |
8828 | |
8829 | /// @copydoc dnnl::primitive_desc_base::get_alpha()const |
8830 | float get_alpha() const { return base::get_alpha(); } |
8831 | |
8832 | /// @copydoc dnnl::primitive_desc_base::get_beta()const |
8833 | float get_beta() const { return base::get_beta(); } |
8834 | }; |
8835 | |
8836 | /// Default constructor. Produces an empty object. |
8837 | vanilla_rnn_backward() = default; |
8838 | |
8839 | /// Constructs a vanilla RNN backward propagation primitive. |
8840 | /// @param pd Primitive descriptor for a vanilla RNN backward |
8841 | /// propagation primitive. |
8842 | vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {} |
8843 | |
8844 | /// Constructs a vanilla RNN backward propagation primitive from |
8845 | /// a cache blob. |
8846 | /// @param pd Primitive descriptor for a vanilla RNN backward |
8847 | /// propagation primitive. |
8848 | /// @param cache_blob Cache blob. |
8849 | vanilla_rnn_backward( |
8850 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
8851 | : primitive(pd, cache_blob) {} |
8852 | }; |
8853 | |
8854 | /// LSTM forward propagation primitive. |
8855 | struct lstm_forward : public primitive { |
8856 | /// Primitive descriptor for an LSTM forward propagation primitive. |
8857 | struct primitive_desc : public rnn_primitive_desc_base { |
8858 | /// Default constructor. Produces an empty object. |
8859 | primitive_desc() = default; |
8860 | |
8861 | /// Constructs a primitive descriptor for an LSTM (with or without |
8862 | /// peephole and with or without projection) forward propagation |
8863 | /// primitive. |
8864 | /// |
8865 | /// The following arguments may point to a zero memory descriptor: |
8866 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
8867 | /// - @p weights_peephole_desc, |
8868 | /// - @p bias_desc, |
8869 | /// - @p dst_iter_desc together with @p dst_iter_c_desc. |
8870 | /// |
8871 | /// This would then indicate that the LSTM forward propagation |
8872 | /// primitive should not use them and should default to zero values |
8873 | /// instead. |
8874 | /// |
8875 | /// The @p weights_projection_desc may point to a zero memory |
8876 | /// descriptor. This would then indicate that the LSTM doesn't have |
8877 | /// recurrent projection layer. |
8878 | /// |
8879 | /// @note |
8880 | /// All memory descriptors can be initialized with an |
8881 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
8882 | /// |
8883 | /// @param aengine Engine to use. |
8884 | /// @param aprop_kind Propagation kind. Possible values are |
8885 | /// #dnnl::prop_kind::forward_training, and |
8886 | /// #dnnl::prop_kind::forward_inference. |
8887 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
8888 | /// more info. |
8889 | /// @param src_layer_desc Memory descriptor for the input vector. |
8890 | /// @param src_iter_desc Memory descriptor for the input recurrent |
8891 | /// hidden state vector. |
8892 | /// @param src_iter_c_desc Memory descriptor for the input recurrent |
8893 | /// cell state vector. |
8894 | /// @param weights_layer_desc Memory descriptor for the weights |
8895 | /// applied to the layer input. |
8896 | /// @param weights_iter_desc Memory descriptor for the weights applied |
8897 | /// to the recurrent input. |
8898 | /// @param weights_peephole_desc Memory descriptor for the weights |
8899 | /// applied to the cell states (according to the Peephole LSTM |
8900 | /// formula). |
8901 | /// @param weights_projection_desc Memory descriptor for the weights |
8902 | /// applied to the hidden states to get the recurrent projection |
8903 | /// (according to the Projection LSTM formula). |
8904 | /// @param bias_desc Bias memory descriptor. |
8905 | /// @param dst_layer_desc Memory descriptor for the output vector. |
8906 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
8907 | /// hidden state vector. |
8908 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent |
8909 | /// cell state vector. |
8910 | /// @param attr Primitive attributes to use. Attributes are optional |
8911 | /// and default to empty attributes. |
8912 | /// @param allow_empty A flag signifying whether construction is |
8913 | /// allowed to fail without throwing an exception. In this case an |
8914 | /// empty object will be produced. This flag is optional and |
8915 | /// defaults to false. |
8916 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
8917 | rnn_direction direction, const memory::desc &src_layer_desc, |
8918 | const memory::desc &src_iter_desc, |
8919 | const memory::desc &src_iter_c_desc, |
8920 | const memory::desc &weights_layer_desc, |
8921 | const memory::desc &weights_iter_desc, |
8922 | const memory::desc &weights_peephole_desc, |
8923 | const memory::desc &weights_projection_desc, |
8924 | const memory::desc &bias_desc, |
8925 | const memory::desc &dst_layer_desc, |
8926 | const memory::desc &dst_iter_desc, |
8927 | const memory::desc &dst_iter_c_desc, |
8928 | const primitive_attr &attr = default_attr(), |
8929 | bool allow_empty = false) |
8930 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, |
8931 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
8932 | src_iter_desc, &src_iter_c_desc, nullptr, |
8933 | weights_layer_desc, weights_iter_desc, |
8934 | &weights_peephole_desc, &weights_projection_desc, bias_desc, |
8935 | dst_layer_desc, dst_iter_desc, &dst_iter_c_desc, |
8936 | rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {} |
8937 | |
8938 | /// Constructs a primitive descriptor for an LSTM (with or without |
8939 | /// peephole) forward propagation primitive. |
8940 | /// |
8941 | /// The following arguments may point to a zero memory descriptor: |
8942 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
8943 | /// - @p weights_peephole_desc, |
8944 | /// - @p bias_desc, |
8945 | /// - @p dst_iter_desc together with @p dst_iter_c_desc. |
8946 | /// |
8947 | /// This would then indicate that the LSTM forward propagation |
8948 | /// primitive should not use them and should default to zero values |
8949 | /// instead. |
8950 | /// |
8951 | /// @note |
8952 | /// All memory descriptors can be initialized with an |
8953 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
8954 | /// |
8955 | /// @param aengine Engine to use. |
8956 | /// @param aprop_kind Propagation kind. Possible values are |
8957 | /// #dnnl::prop_kind::forward_training, and |
8958 | /// #dnnl::prop_kind::forward_inference. |
8959 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
8960 | /// more info. |
8961 | /// @param src_layer_desc Memory descriptor for the input vector. |
8962 | /// @param src_iter_desc Memory descriptor for the input recurrent |
8963 | /// hidden state vector. |
8964 | /// @param src_iter_c_desc Memory descriptor for the input recurrent |
8965 | /// cell state vector. |
8966 | /// @param weights_layer_desc Memory descriptor for the weights |
8967 | /// applied to the layer input. |
8968 | /// @param weights_iter_desc Memory descriptor for the weights applied |
8969 | /// to the recurrent input. |
8970 | /// @param weights_peephole_desc Memory descriptor for the weights |
8971 | /// applied to the cell states (according to the Peephole LSTM |
8972 | /// formula). |
8973 | /// @param bias_desc Bias memory descriptor. |
8974 | /// @param dst_layer_desc Memory descriptor for the output vector. |
8975 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
8976 | /// hidden state vector. |
8977 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent |
8978 | /// cell state vector. |
8979 | /// @param attr Primitive attributes to use. Attributes are optional |
8980 | /// and default to empty attributes. |
8981 | /// @param allow_empty A flag signifying whether construction is |
8982 | /// allowed to fail without throwing an exception. In this case an |
8983 | /// empty object will be produced. This flag is optional and |
8984 | /// defaults to false. |
8985 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
8986 | rnn_direction direction, const memory::desc &src_layer_desc, |
8987 | const memory::desc &src_iter_desc, |
8988 | const memory::desc &src_iter_c_desc, |
8989 | const memory::desc &weights_layer_desc, |
8990 | const memory::desc &weights_iter_desc, |
8991 | const memory::desc &weights_peephole_desc, |
8992 | const memory::desc &bias_desc, |
8993 | const memory::desc &dst_layer_desc, |
8994 | const memory::desc &dst_iter_desc, |
8995 | const memory::desc &dst_iter_c_desc, |
8996 | const primitive_attr &attr = default_attr(), |
8997 | bool allow_empty = false) |
8998 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, |
8999 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
9000 | src_iter_desc, &src_iter_c_desc, nullptr, |
9001 | weights_layer_desc, weights_iter_desc, |
9002 | &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc, |
9003 | dst_iter_desc, &dst_iter_c_desc, rnn_flags::undef, 0.0f, |
9004 | 0.0f, attr, allow_empty) {} |
9005 | |
9006 | /// Constructs a primitive descriptor for an LSTM forward propagation |
9007 | /// primitive. |
9008 | /// |
9009 | /// The following arguments may point to a zero memory descriptor: |
9010 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
9011 | /// - @p bias_desc, |
9012 | /// - @p dst_iter_desc together with @p dst_iter_c_desc. |
9013 | /// |
9014 | /// This would then indicate that the LSTM forward propagation |
9015 | /// primitive should not use them and should default to zero values |
9016 | /// instead. |
9017 | /// |
9018 | /// @note |
9019 | /// All memory descriptors can be initialized with an |
9020 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
9021 | /// |
9022 | /// @param aengine Engine to use. |
9023 | /// @param aprop_kind Propagation kind. Possible values are |
9024 | /// #dnnl::prop_kind::forward_training, and |
9025 | /// #dnnl::prop_kind::forward_inference. |
9026 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
9027 | /// more info. |
9028 | /// @param src_layer_desc Memory descriptor for the input vector. |
9029 | /// @param src_iter_desc Memory descriptor for the input recurrent |
9030 | /// hidden state vector. |
9031 | /// @param src_iter_c_desc Memory descriptor for the input recurrent |
9032 | /// cell state vector. |
9033 | /// @param weights_layer_desc Memory descriptor for the weights |
9034 | /// applied to the layer input. |
9035 | /// @param weights_iter_desc Memory descriptor for the weights applied |
9036 | /// to the recurrent input. |
9037 | /// @param bias_desc Bias memory descriptor. |
9038 | /// @param dst_layer_desc Memory descriptor for the output vector. |
9039 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
9040 | /// hidden state vector. |
9041 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent |
9042 | /// cell state vector. |
9043 | /// @param attr Primitive attributes to use. Attributes are optional |
9044 | /// and default to empty attributes. |
9045 | /// @param allow_empty A flag signifying whether construction is |
9046 | /// allowed to fail without throwing an exception. In this case an |
9047 | /// empty object will be produced. This flag is optional and |
9048 | /// defaults to false. |
9049 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
9050 | rnn_direction direction, const memory::desc &src_layer_desc, |
9051 | const memory::desc &src_iter_desc, |
9052 | const memory::desc &src_iter_c_desc, |
9053 | const memory::desc &weights_layer_desc, |
9054 | const memory::desc &weights_iter_desc, |
9055 | const memory::desc &bias_desc, |
9056 | const memory::desc &dst_layer_desc, |
9057 | const memory::desc &dst_iter_desc, |
9058 | const memory::desc &dst_iter_c_desc, |
9059 | const primitive_attr &attr = default_attr(), |
9060 | bool allow_empty = false) |
9061 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, |
9062 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
9063 | src_iter_desc, &src_iter_c_desc, nullptr, |
9064 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, |
9065 | bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc, |
9066 | rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {} |
9067 | |
9068 | /// Constructs a primitive descriptor for an LSTM forward propagation |
9069 | /// primitive from a C API primitive descriptor that must have a |
9070 | /// matching kind. |
9071 | /// |
9072 | /// @param pd C API primitive descriptor for an LSTM forward |
9073 | /// propagation primitive. |
9074 | primitive_desc(dnnl_primitive_desc_t pd) |
9075 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, |
9076 | dnnl::prop_kind::forward_inference, |
9077 | dnnl::algorithm::vanilla_lstm) {} |
9078 | |
9079 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
9080 | memory::desc src_layer_desc() const { |
9081 | return rnn_base::src_layer_desc(); |
9082 | } |
9083 | |
9084 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
9085 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
9086 | |
9087 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
9088 | memory::desc src_iter_c_desc() const { |
9089 | return rnn_base::src_iter_c_desc(); |
9090 | } |
9091 | |
9092 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
9093 | memory::desc weights_layer_desc() const { |
9094 | return rnn_base::weights_layer_desc(); |
9095 | } |
9096 | |
9097 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
9098 | memory::desc weights_iter_desc() const { |
9099 | return rnn_base::weights_iter_desc(); |
9100 | } |
9101 | |
9102 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const |
9103 | memory::desc weights_peephole_desc() const { |
9104 | return rnn_base::weights_peephole_desc(); |
9105 | } |
9106 | |
9107 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const |
9108 | memory::desc weights_projection_desc() const { |
9109 | return rnn_base::weights_projection_desc(); |
9110 | } |
9111 | |
9112 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
9113 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
9114 | |
9115 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
9116 | memory::desc dst_layer_desc() const { |
9117 | return rnn_base::dst_layer_desc(); |
9118 | } |
9119 | |
9120 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
9121 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
9122 | |
9123 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
9124 | memory::desc dst_iter_c_desc() const { |
9125 | return rnn_base::dst_iter_c_desc(); |
9126 | } |
9127 | |
9128 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
9129 | memory::desc workspace_desc() const { |
9130 | return rnn_base::workspace_desc(); |
9131 | } |
9132 | |
9133 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
9134 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
9135 | |
9136 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
9137 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
9138 | |
9139 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
9140 | rnn_direction get_direction() const { return base::get_direction(); } |
9141 | }; |
9142 | |
9143 | /// Default constructor. Produces an empty object. |
9144 | lstm_forward() = default; |
9145 | |
9146 | /// Constructs an LSTM forward propagation primitive. |
9147 | /// @param pd Primitive descriptor for an LSTM forward propagation |
9148 | /// primitive. |
9149 | lstm_forward(const primitive_desc &pd) : primitive(pd) {} |
9150 | |
9151 | /// Constructs an LSTM forward propagation primitive from a cache blob. |
9152 | /// @param pd Primitive descriptor for an LSTM forward propagation |
9153 | /// primitive. |
9154 | /// @param cache_blob Cache blob. |
9155 | lstm_forward( |
9156 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
9157 | : primitive(pd, cache_blob) {} |
9158 | }; |
9159 | |
9160 | /// LSTM backward propagation primitive. |
9161 | struct lstm_backward : public primitive { |
9162 | /// Primitive descriptor for an LSTM backward propagation primitive. |
9163 | struct primitive_desc : public rnn_primitive_desc_base { |
9164 | /// Default constructor. Produces an empty object. |
9165 | primitive_desc() = default; |
9166 | |
9167 | /// Constructs an LSTM (with or without peephole and with or without |
9168 | /// projection) primitive descriptor for backward propagation |
9169 | /// using @p prop_kind, @p direction, and memory descriptors. |
9170 | /// |
9171 | /// The following arguments may point to a zero memory descriptor: |
9172 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
9173 | /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc, |
9174 | /// - @p weights_peephole_desc together with |
9175 | /// @p diff_weights_peephole_desc |
9176 | /// - @p bias_desc together with @p diff_bias_desc, |
9177 | /// - @p dst_iter_desc together with @p dst_iter_c_desc, |
9178 | /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc. |
9179 | /// |
9180 | /// This would then indicate that the LSTM backward propagation |
9181 | /// primitive should not use them and should default to zero values |
9182 | /// instead. |
9183 | /// |
9184 | /// The @p weights_projection_desc together with @p |
9185 | /// diff_weights_projection_desc may point to a zero memory descriptor. |
9186 | /// This would then indicate that the LSTM doesn't have recurrent |
9187 | /// projection layer. |
9188 | /// |
9189 | /// @note |
9190 | /// All memory descriptors can be initialized with |
9191 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
9192 | /// |
9193 | /// @param aengine Engine to use. |
9194 | /// @param aprop_kind Propagation kind. Must be |
9195 | /// #dnnl::prop_kind::backward. |
9196 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
9197 | /// more info. |
9198 | /// @param src_layer_desc Memory descriptor for the input vector. |
9199 | /// @param src_iter_desc Memory descriptor for the input recurrent |
9200 | /// hidden state vector. |
9201 | /// @param src_iter_c_desc Memory descriptor for the input recurrent |
9202 | /// cell state vector. |
9203 | /// @param weights_layer_desc Memory descriptor for the weights |
9204 | /// applied to the layer input. |
9205 | /// @param weights_iter_desc Memory descriptor for the weights applied |
9206 | /// to the recurrent input. |
9207 | /// @param weights_peephole_desc Memory descriptor for the weights |
9208 | /// applied to the cell states (according to the Peephole LSTM |
9209 | /// formula). |
9210 | /// @param weights_projection_desc Memory descriptor for the weights |
9211 | /// applied to the hidden states to get the recurrent projection |
9212 | /// (according to the Projection LSTM formula). |
9213 | /// @param bias_desc Bias memory descriptor. |
9214 | /// @param dst_layer_desc Memory descriptor for the output vector. |
9215 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
9216 | /// hidden state vector. |
9217 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent |
9218 | /// cell state vector. |
9219 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
9220 | /// vector. |
9221 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
9222 | /// recurrent hidden state vector. |
9223 | /// @param diff_src_iter_c_desc Memory descriptor for the diff of |
9224 | /// input recurrent cell state vector. |
9225 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
9226 | /// weights applied to the layer input. |
9227 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
9228 | /// weights applied to the recurrent input. |
9229 | /// @param diff_weights_peephole_desc Memory descriptor for the diff of |
9230 | /// weights applied to the cell states (according to the Peephole |
9231 | /// LSTM formula). |
9232 | /// @param diff_weights_projection_desc Memory descriptor for the diff |
9233 | /// of weights applied to the hidden states to get the recurrent |
9234 | /// projection (according to the Projection LSTM formula). |
9235 | /// @param diff_bias_desc Diff bias memory descriptor. |
9236 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
9237 | /// output vector. |
9238 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
9239 | /// recurrent hidden state vector. |
9240 | /// @param diff_dst_iter_c_desc Memory descriptor for the diff of |
9241 | /// output recurrent cell state vector. |
9242 | /// @param hint_fwd_pd Primitive descriptor for an LSTM |
9243 | /// forward propagation primitive. It is used as a hint for |
9244 | /// deciding which memory format to use. |
9245 | /// @param attr Primitive attributes to use. Attributes are optional |
9246 | /// and default to empty attributes. |
9247 | /// @param allow_empty A flag signifying whether construction is |
9248 | /// allowed to fail without throwing an exception. In this case an |
9249 | /// empty object will be produced. This flag is optional and |
9250 | /// defaults to false. |
9251 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
9252 | rnn_direction direction, const memory::desc &src_layer_desc, |
9253 | const memory::desc &src_iter_desc, |
9254 | const memory::desc &src_iter_c_desc, |
9255 | const memory::desc &weights_layer_desc, |
9256 | const memory::desc &weights_iter_desc, |
9257 | const memory::desc &weights_peephole_desc, |
9258 | const memory::desc &weights_projection_desc, |
9259 | const memory::desc &bias_desc, |
9260 | const memory::desc &dst_layer_desc, |
9261 | const memory::desc &dst_iter_desc, |
9262 | const memory::desc &dst_iter_c_desc, |
9263 | const memory::desc &diff_src_layer_desc, |
9264 | const memory::desc &diff_src_iter_desc, |
9265 | const memory::desc &diff_src_iter_c_desc, |
9266 | const memory::desc &diff_weights_layer_desc, |
9267 | const memory::desc &diff_weights_iter_desc, |
9268 | const memory::desc &diff_weights_peephole_desc, |
9269 | const memory::desc &diff_weights_projection_desc, |
9270 | const memory::desc &diff_bias_desc, |
9271 | const memory::desc &diff_dst_layer_desc, |
9272 | const memory::desc &diff_dst_iter_desc, |
9273 | const memory::desc &diff_dst_iter_c_desc, |
9274 | const lstm_forward::primitive_desc &hint_fwd_pd, |
9275 | const primitive_attr &attr = default_attr(), |
9276 | bool allow_empty = false) |
9277 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, |
9278 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
9279 | src_iter_desc, &src_iter_c_desc, nullptr, |
9280 | weights_layer_desc, weights_iter_desc, |
9281 | &weights_peephole_desc, &weights_projection_desc, bias_desc, |
9282 | dst_layer_desc, dst_iter_desc, &dst_iter_c_desc, |
9283 | diff_src_layer_desc, diff_src_iter_desc, |
9284 | &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc, |
9285 | diff_weights_iter_desc, &diff_weights_peephole_desc, |
9286 | &diff_weights_projection_desc, diff_bias_desc, |
9287 | diff_dst_layer_desc, diff_dst_iter_desc, |
9288 | &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f, |
9289 | hint_fwd_pd, attr, allow_empty) {} |
9290 | |
9291 | /// Constructs an LSTM (with or without peephole) primitive descriptor |
9292 | /// for backward propagation using @p prop_kind, @p direction, |
9293 | /// and memory descriptors. |
9294 | /// |
9295 | /// The following arguments may point to a zero memory descriptor: |
9296 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
9297 | /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc, |
9298 | /// - @p weights_peephole_desc together with |
9299 | /// @p diff_weights_peephole_desc |
9300 | /// - @p bias_desc together with @p diff_bias_desc, |
9301 | /// - @p dst_iter_desc together with @p dst_iter_c_desc, |
9302 | /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc. |
9303 | /// |
9304 | /// This would then indicate that the LSTM backward propagation |
9305 | /// primitive should not use them and should default to zero values |
9306 | /// instead. |
9307 | /// |
9308 | /// @note |
9309 | /// All memory descriptors may be initialized with |
9310 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
9311 | /// |
9312 | /// @param aengine Engine to use. |
9313 | /// @param aprop_kind Propagation kind. Must be |
9314 | /// #dnnl::prop_kind::backward. |
9315 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
9316 | /// more info. |
9317 | /// @param src_layer_desc Memory descriptor for the input vector. |
9318 | /// @param src_iter_desc Memory descriptor for the input recurrent |
9319 | /// hidden state vector. |
9320 | /// @param src_iter_c_desc Memory descriptor for the input recurrent |
9321 | /// cell state vector. |
9322 | /// @param weights_layer_desc Memory descriptor for the weights |
9323 | /// applied to the layer input. |
9324 | /// @param weights_iter_desc Memory descriptor for the weights applied |
9325 | /// to the recurrent input. |
9326 | /// @param weights_peephole_desc Memory descriptor for the weights |
9327 | /// applied to the cell states (according to the Peephole LSTM |
9328 | /// formula). |
9329 | /// @param bias_desc Bias memory descriptor. |
9330 | /// @param dst_layer_desc Memory descriptor for the output vector. |
9331 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
9332 | /// hidden state vector. |
9333 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent |
9334 | /// cell state vector. |
9335 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
9336 | /// vector. |
9337 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
9338 | /// recurrent hidden state vector. |
9339 | /// @param diff_src_iter_c_desc Memory descriptor for the diff of |
9340 | /// input recurrent cell state vector. |
9341 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
9342 | /// weights applied to the layer input. |
9343 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
9344 | /// weights applied to the recurrent input. |
9345 | /// @param diff_weights_peephole_desc Memory descriptor for the diff of |
9346 | /// weights applied to the cell states (according to the Peephole |
9347 | /// LSTM formula). |
9348 | /// @param diff_bias_desc Diff bias memory descriptor. |
9349 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
9350 | /// output vector. |
9351 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
9352 | /// recurrent hidden state vector. |
9353 | /// @param diff_dst_iter_c_desc Memory descriptor for the diff of |
9354 | /// output recurrent cell state vector. |
9355 | /// @param hint_fwd_pd Primitive descriptor for an LSTM |
9356 | /// forward propagation primitive. It is used as a hint for |
9357 | /// deciding which memory format to use. |
9358 | /// @param attr Primitive attributes to use. Attributes are optional |
9359 | /// and default to empty attributes. |
9360 | /// @param allow_empty A flag signifying whether construction is |
9361 | /// allowed to fail without throwing an exception. In this case an |
9362 | /// empty object will be produced. This flag is optional and |
9363 | /// defaults to false. |
9364 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
9365 | rnn_direction direction, const memory::desc &src_layer_desc, |
9366 | const memory::desc &src_iter_desc, |
9367 | const memory::desc &src_iter_c_desc, |
9368 | const memory::desc &weights_layer_desc, |
9369 | const memory::desc &weights_iter_desc, |
9370 | const memory::desc &weights_peephole_desc, |
9371 | const memory::desc &bias_desc, |
9372 | const memory::desc &dst_layer_desc, |
9373 | const memory::desc &dst_iter_desc, |
9374 | const memory::desc &dst_iter_c_desc, |
9375 | const memory::desc &diff_src_layer_desc, |
9376 | const memory::desc &diff_src_iter_desc, |
9377 | const memory::desc &diff_src_iter_c_desc, |
9378 | const memory::desc &diff_weights_layer_desc, |
9379 | const memory::desc &diff_weights_iter_desc, |
9380 | const memory::desc &diff_weights_peephole_desc, |
9381 | const memory::desc &diff_bias_desc, |
9382 | const memory::desc &diff_dst_layer_desc, |
9383 | const memory::desc &diff_dst_iter_desc, |
9384 | const memory::desc &diff_dst_iter_c_desc, |
9385 | const lstm_forward::primitive_desc &hint_fwd_pd, |
9386 | const primitive_attr &attr = default_attr(), |
9387 | bool allow_empty = false) |
9388 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, |
9389 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
9390 | src_iter_desc, &src_iter_c_desc, nullptr, |
9391 | weights_layer_desc, weights_iter_desc, |
9392 | &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc, |
9393 | dst_iter_desc, &dst_iter_c_desc, diff_src_layer_desc, |
9394 | diff_src_iter_desc, &diff_src_iter_c_desc, nullptr, |
9395 | diff_weights_layer_desc, diff_weights_iter_desc, |
9396 | &diff_weights_peephole_desc, nullptr, diff_bias_desc, |
9397 | diff_dst_layer_desc, diff_dst_iter_desc, |
9398 | &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f, |
9399 | hint_fwd_pd, attr, allow_empty) {} |
9400 | |
9401 | /// Constructs an LSTM primitive descriptor for backward propagation |
9402 | /// using @p prop_kind, @p direction, and memory descriptors. |
9403 | /// |
9404 | /// The following arguments may point to a zero memory descriptor: |
9405 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
9406 | /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc, |
9407 | /// - @p bias_desc together with @p diff_bias_desc, |
9408 | /// - @p dst_iter_desc together with @p dst_iter_c_desc, |
9409 | /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc. |
9410 | /// |
9411 | /// This would then indicate that the LSTM backward propagation |
9412 | /// primitive should not use them and should default to zero values |
9413 | /// instead. |
9414 | /// |
9415 | /// @note |
9416 | /// All memory descriptors may be initialized with |
9417 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
9418 | /// |
9419 | /// @param aengine Engine to use. |
9420 | /// @param aprop_kind Propagation kind. Must be |
9421 | /// #dnnl::prop_kind::backward. |
9422 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
9423 | /// more info. |
9424 | /// @param src_layer_desc Memory descriptor for the input vector. |
9425 | /// @param src_iter_desc Memory descriptor for the input recurrent |
9426 | /// hidden state vector. |
9427 | /// @param src_iter_c_desc Memory descriptor for the input recurrent |
9428 | /// cell state vector. |
9429 | /// @param weights_layer_desc Memory descriptor for the weights |
9430 | /// applied to the layer input. |
9431 | /// @param weights_iter_desc Memory descriptor for the weights applied |
9432 | /// to the recurrent input. |
9433 | /// @param bias_desc Bias memory descriptor. |
9434 | /// @param dst_layer_desc Memory descriptor for the output vector. |
9435 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
9436 | /// hidden state vector. |
9437 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent |
9438 | /// cell state vector. |
9439 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
9440 | /// vector. |
9441 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
9442 | /// recurrent hidden state vector. |
9443 | /// @param diff_src_iter_c_desc Memory descriptor for the diff of |
9444 | /// input recurrent cell state vector. |
9445 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
9446 | /// weights applied to the layer input. |
9447 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
9448 | /// weights applied to the recurrent input. |
9449 | /// @param diff_bias_desc Diff bias memory descriptor. |
9450 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
9451 | /// output vector. |
9452 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
9453 | /// recurrent hidden state vector. |
9454 | /// @param diff_dst_iter_c_desc Memory descriptor for the diff of |
9455 | /// output recurrent cell state vector. |
9456 | /// @param hint_fwd_pd Primitive descriptor for a convolution |
9457 | /// forward propagation primitive. It is used as a hint for |
9458 | /// deciding which memory format to use. |
9459 | /// @param attr Primitive attributes to use. Attributes are optional |
9460 | /// and default to empty attributes. |
9461 | /// @param allow_empty A flag signifying whether construction is |
9462 | /// allowed to fail without throwing an exception. In this case an |
9463 | /// empty object will be produced. This flag is optional and |
9464 | /// defaults to false. |
9465 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
9466 | rnn_direction direction, const memory::desc &src_layer_desc, |
9467 | const memory::desc &src_iter_desc, |
9468 | const memory::desc &src_iter_c_desc, |
9469 | const memory::desc &weights_layer_desc, |
9470 | const memory::desc &weights_iter_desc, |
9471 | const memory::desc &bias_desc, |
9472 | const memory::desc &dst_layer_desc, |
9473 | const memory::desc &dst_iter_desc, |
9474 | const memory::desc &dst_iter_c_desc, |
9475 | const memory::desc &diff_src_layer_desc, |
9476 | const memory::desc &diff_src_iter_desc, |
9477 | const memory::desc &diff_src_iter_c_desc, |
9478 | const memory::desc &diff_weights_layer_desc, |
9479 | const memory::desc &diff_weights_iter_desc, |
9480 | const memory::desc &diff_bias_desc, |
9481 | const memory::desc &diff_dst_layer_desc, |
9482 | const memory::desc &diff_dst_iter_desc, |
9483 | const memory::desc &diff_dst_iter_c_desc, |
9484 | const lstm_forward::primitive_desc &hint_fwd_pd, |
9485 | const primitive_attr &attr = default_attr(), |
9486 | bool allow_empty = false) |
9487 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, |
9488 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
9489 | src_iter_desc, &src_iter_c_desc, nullptr, |
9490 | weights_layer_desc, weights_iter_desc, nullptr, nullptr, |
9491 | bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc, |
9492 | diff_src_layer_desc, diff_src_iter_desc, |
9493 | &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc, |
9494 | diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc, |
9495 | diff_dst_layer_desc, diff_dst_iter_desc, |
9496 | &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f, |
9497 | hint_fwd_pd, attr, allow_empty) {} |
9498 | |
9499 | /// Constructs a primitive descriptor for an LSTM backward propagation |
9500 | /// primitive from a C API primitive descriptor that must have a |
9501 | /// matching kind. |
9502 | /// |
9503 | /// @param pd C API primitive descriptor for an LSTM backward |
9504 | /// propagation primitive. |
9505 | primitive_desc(dnnl_primitive_desc_t pd) |
9506 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, |
9507 | dnnl::algorithm::vanilla_lstm) {} |
9508 | |
9509 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
9510 | memory::desc src_layer_desc() const { |
9511 | return rnn_base::src_layer_desc(); |
9512 | } |
9513 | |
9514 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
9515 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
9516 | |
9517 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
9518 | memory::desc src_iter_c_desc() const { |
9519 | return rnn_base::src_iter_c_desc(); |
9520 | } |
9521 | |
9522 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
9523 | memory::desc weights_layer_desc() const { |
9524 | return rnn_base::weights_layer_desc(); |
9525 | } |
9526 | |
9527 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
9528 | memory::desc weights_iter_desc() const { |
9529 | return rnn_base::weights_iter_desc(); |
9530 | } |
9531 | |
9532 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const |
9533 | memory::desc weights_peephole_desc() const { |
9534 | return rnn_base::weights_peephole_desc(); |
9535 | } |
9536 | |
9537 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const |
9538 | memory::desc weights_projection_desc() const { |
9539 | return rnn_base::weights_projection_desc(); |
9540 | } |
9541 | |
9542 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
9543 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
9544 | |
9545 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
9546 | memory::desc dst_layer_desc() const { |
9547 | return rnn_base::dst_layer_desc(); |
9548 | } |
9549 | |
9550 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
9551 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
9552 | |
9553 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
9554 | memory::desc dst_iter_c_desc() const { |
9555 | return rnn_base::dst_iter_c_desc(); |
9556 | } |
9557 | |
9558 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
9559 | memory::desc workspace_desc() const { |
9560 | return rnn_base::workspace_desc(); |
9561 | } |
9562 | |
9563 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const |
9564 | memory::desc diff_src_layer_desc() const { |
9565 | return rnn_base::diff_src_layer_desc(); |
9566 | } |
9567 | |
9568 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const |
9569 | memory::desc diff_src_iter_desc() const { |
9570 | return rnn_base::diff_src_iter_desc(); |
9571 | } |
9572 | |
9573 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const |
9574 | memory::desc diff_src_iter_c_desc() const { |
9575 | return rnn_base::diff_src_iter_c_desc(); |
9576 | } |
9577 | |
9578 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const |
9579 | memory::desc diff_weights_layer_desc() const { |
9580 | return rnn_base::diff_weights_layer_desc(); |
9581 | } |
9582 | |
9583 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const |
9584 | memory::desc diff_weights_iter_desc() const { |
9585 | return rnn_base::diff_weights_iter_desc(); |
9586 | } |
9587 | |
9588 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const |
9589 | memory::desc diff_weights_peephole_desc() const { |
9590 | return rnn_base::diff_weights_peephole_desc(); |
9591 | } |
9592 | |
9593 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const |
9594 | memory::desc diff_weights_projection_desc() const { |
9595 | return rnn_base::diff_weights_projection_desc(); |
9596 | } |
9597 | |
9598 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const |
9599 | memory::desc diff_bias_desc() const { |
9600 | return rnn_base::diff_bias_desc(); |
9601 | } |
9602 | |
9603 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const |
9604 | memory::desc diff_dst_layer_desc() const { |
9605 | return rnn_base::diff_dst_layer_desc(); |
9606 | } |
9607 | |
9608 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const |
9609 | memory::desc diff_dst_iter_desc() const { |
9610 | return rnn_base::diff_dst_iter_desc(); |
9611 | } |
9612 | |
9613 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const |
9614 | memory::desc diff_dst_iter_c_desc() const { |
9615 | return rnn_base::diff_dst_iter_c_desc(); |
9616 | } |
9617 | |
9618 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
9619 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
9620 | |
9621 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
9622 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
9623 | |
9624 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
9625 | rnn_direction get_direction() const { return base::get_direction(); } |
9626 | }; |
9627 | |
9628 | /// Default constructor. Produces an empty object. |
9629 | lstm_backward() = default; |
9630 | |
9631 | /// Constructs an LSTM backward propagation primitive. |
9632 | /// @param pd Primitive descriptor for an LSTM backward propagation |
9633 | /// primitive. |
9634 | lstm_backward(const primitive_desc &pd) : primitive(pd) {} |
9635 | |
9636 | /// Constructs an LSTM backward propagation primitive from a cache blob. |
9637 | /// @param pd Primitive descriptor for an LSTM backward propagation |
9638 | /// primitive. |
9639 | /// @param cache_blob Cache blob. |
9640 | lstm_backward( |
9641 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
9642 | : primitive(pd, cache_blob) {} |
9643 | }; |
9644 | |
9645 | /// GRU forward propagation primitive. |
9646 | struct gru_forward : public primitive { |
9647 | /// Primitive descriptor for a GRU forward propagation primitive. |
9648 | struct primitive_desc : public rnn_primitive_desc_base { |
9649 | /// Default constructor. Produces an empty object. |
9650 | primitive_desc() = default; |
9651 | |
9652 | /// Constructs a primitive descriptor for a GRU forward propagation |
9653 | /// primitive. |
9654 | /// |
9655 | /// The following arguments may point to a zero memory descriptor: |
9656 | /// - @p src_iter_desc, |
9657 | /// - @p bias_desc, |
9658 | /// - @p dst_iter_desc. |
9659 | /// |
9660 | /// This would then indicate that the GRU forward propagation primitive |
9661 | /// should not use them and should default to zero values instead. |
9662 | /// |
9663 | /// @note |
9664 | /// All memory descriptors except @p src_iter_desc may be |
9665 | /// initialized with an #dnnl::memory::format_tag::any value of @p |
9666 | /// format_tag. |
9667 | /// |
9668 | /// @param aengine Engine to use. |
9669 | /// @param aprop_kind Propagation kind. Possible values are |
9670 | /// #dnnl::prop_kind::forward_training, and |
9671 | /// #dnnl::prop_kind::forward_inference. |
9672 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
9673 | /// more info. |
9674 | /// @param src_layer_desc Memory descriptor for the input vector. |
9675 | /// @param src_iter_desc Memory descriptor for the input recurrent |
9676 | /// hidden state vector. |
9677 | /// @param weights_layer_desc Memory descriptor for the weights |
9678 | /// applied to the layer input. |
9679 | /// @param weights_iter_desc Memory descriptor for the weights applied |
9680 | /// to the recurrent input. |
9681 | /// @param bias_desc Bias memory descriptor. |
9682 | /// @param dst_layer_desc Memory descriptor for the output vector. |
9683 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
9684 | /// hidden state vector. |
9685 | /// @param attr Primitive attributes to use. Attributes are optional |
9686 | /// and default to empty attributes. |
9687 | /// @param allow_empty A flag signifying whether construction is |
9688 | /// allowed to fail without throwing an exception. In this case an |
9689 | /// empty object will be produced. This flag is optional and |
9690 | /// defaults to false. |
9691 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
9692 | rnn_direction direction, const memory::desc &src_layer_desc, |
9693 | const memory::desc &src_iter_desc, |
9694 | const memory::desc &weights_layer_desc, |
9695 | const memory::desc &weights_iter_desc, |
9696 | const memory::desc &bias_desc, |
9697 | const memory::desc &dst_layer_desc, |
9698 | const memory::desc &dst_iter_desc, |
9699 | const primitive_attr &attr = default_attr(), |
9700 | bool allow_empty = false) |
9701 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru, |
9702 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
9703 | src_iter_desc, nullptr, nullptr, weights_layer_desc, |
9704 | weights_iter_desc, nullptr, nullptr, bias_desc, |
9705 | dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, |
9706 | 0.0f, 0.0f, attr, allow_empty) {} |
9707 | |
9708 | /// Constructs a primitive descriptor for a GRU forward propagation |
9709 | /// primitive from a C API primitive descriptor that must have a |
9710 | /// matching kind. |
9711 | /// |
9712 | /// @param pd C API primitive descriptor for a GRU forward |
9713 | /// propagation primitive. |
9714 | primitive_desc(dnnl_primitive_desc_t pd) |
9715 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, |
9716 | dnnl::prop_kind::forward_inference, |
9717 | dnnl::algorithm::vanilla_gru) {} |
9718 | |
9719 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
9720 | memory::desc src_layer_desc() const { |
9721 | return rnn_base::src_layer_desc(); |
9722 | } |
9723 | |
9724 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
9725 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
9726 | |
9727 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
9728 | memory::desc weights_layer_desc() const { |
9729 | return rnn_base::weights_layer_desc(); |
9730 | } |
9731 | |
9732 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
9733 | memory::desc weights_iter_desc() const { |
9734 | return rnn_base::weights_iter_desc(); |
9735 | } |
9736 | |
9737 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
9738 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
9739 | |
9740 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
9741 | memory::desc dst_layer_desc() const { |
9742 | return rnn_base::dst_layer_desc(); |
9743 | } |
9744 | |
9745 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
9746 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
9747 | |
9748 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
9749 | memory::desc workspace_desc() const { |
9750 | return rnn_base::workspace_desc(); |
9751 | } |
9752 | |
9753 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
9754 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
9755 | |
9756 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
9757 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
9758 | |
9759 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
9760 | rnn_direction get_direction() const { return base::get_direction(); } |
9761 | }; |
9762 | |
9763 | /// Default constructor. Produces an empty object. |
9764 | gru_forward() = default; |
9765 | |
9766 | /// Constructs a GRU forward propagation primitive. |
9767 | /// @param pd Primitive descriptor for a GRU forward propagation |
9768 | /// primitive. |
9769 | gru_forward(const primitive_desc &pd) : primitive(pd) {} |
9770 | |
9771 | /// Constructs a GRU forward propagation primitive from a cache blob. |
9772 | /// @param pd Primitive descriptor for a GRU forward propagation |
9773 | /// primitive. |
9774 | /// @param cache_blob Cache blob. |
9775 | gru_forward( |
9776 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
9777 | : primitive(pd, cache_blob) {} |
9778 | }; |
9779 | |
9780 | /// GRU backward propagation primitive. |
9781 | struct gru_backward : public primitive { |
9782 | /// Primitive descriptor for a GRU backward propagation primitive. |
9783 | struct primitive_desc : public rnn_primitive_desc_base { |
9784 | /// Default constructor. Produces an empty object. |
9785 | primitive_desc() = default; |
9786 | |
9787 | /// Constructs a primitive descriptor for a GRU backward propagation |
9788 | /// primitive. |
9789 | /// |
9790 | /// The following arguments may point to a zero memory descriptor: |
9791 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
9792 | /// - @p bias_desc together with @p diff_bias_desc, |
9793 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
9794 | /// |
9795 | /// This would then indicate that the GRU backward propagation |
9796 | /// primitive should not use them and should default to zero values |
9797 | /// instead. |
9798 | /// |
9799 | /// @note |
9800 | /// All memory descriptors may be initialized with |
9801 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
9802 | /// |
9803 | /// @param aengine Engine to use. |
9804 | /// @param aprop_kind Propagation kind. Must be |
9805 | /// #dnnl::prop_kind::backward. |
9806 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
9807 | /// more info. |
9808 | /// @param src_layer_desc Memory descriptor for the input vector. |
9809 | /// @param src_iter_desc Memory descriptor for the input recurrent |
9810 | /// hidden state vector. |
9811 | /// @param weights_layer_desc Memory descriptor for the weights |
9812 | /// applied to the layer input. |
9813 | /// @param weights_iter_desc Memory descriptor for the weights applied |
9814 | /// to the recurrent input. |
9815 | /// @param bias_desc Bias memory descriptor. |
9816 | /// @param dst_layer_desc Memory descriptor for the output vector. |
9817 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
9818 | /// hidden state vector. |
9819 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
9820 | /// vector. |
9821 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
9822 | /// recurrent hidden state vector. |
9823 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
9824 | /// weights applied to the layer input. |
9825 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
9826 | /// weights applied to the recurrent input. |
9827 | /// @param diff_bias_desc Diff bias memory descriptor. |
9828 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
9829 | /// output vector. |
9830 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
9831 | /// recurrent hidden state vector. |
9832 | /// @param hint_fwd_pd Primitive descriptor for a GRU |
9833 | /// forward propagation primitive. It is used as a hint for |
9834 | /// deciding which memory format to use. |
9835 | /// @param attr Primitive attributes to use. Attributes are optional |
9836 | /// and default to empty attributes. |
9837 | /// @param allow_empty A flag signifying whether construction is |
9838 | /// allowed to fail without throwing an exception. In this case an |
9839 | /// empty object will be produced. This flag is optional and |
9840 | /// defaults to false. |
9841 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
9842 | rnn_direction direction, const memory::desc &src_layer_desc, |
9843 | const memory::desc &src_iter_desc, |
9844 | const memory::desc &weights_layer_desc, |
9845 | const memory::desc &weights_iter_desc, |
9846 | const memory::desc &bias_desc, |
9847 | const memory::desc &dst_layer_desc, |
9848 | const memory::desc &dst_iter_desc, |
9849 | const memory::desc &diff_src_layer_desc, |
9850 | const memory::desc &diff_src_iter_desc, |
9851 | const memory::desc &diff_weights_layer_desc, |
9852 | const memory::desc &diff_weights_iter_desc, |
9853 | const memory::desc &diff_bias_desc, |
9854 | const memory::desc &diff_dst_layer_desc, |
9855 | const memory::desc &diff_dst_iter_desc, |
9856 | const gru_forward::primitive_desc &hint_fwd_pd, |
9857 | const primitive_attr &attr = default_attr(), |
9858 | bool allow_empty = false) |
9859 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru, |
9860 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
9861 | src_iter_desc, nullptr, nullptr, weights_layer_desc, |
9862 | weights_iter_desc, nullptr, nullptr, bias_desc, |
9863 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
9864 | diff_src_iter_desc, nullptr, nullptr, |
9865 | diff_weights_layer_desc, diff_weights_iter_desc, nullptr, |
9866 | nullptr, diff_bias_desc, diff_dst_layer_desc, |
9867 | diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, |
9868 | hint_fwd_pd, attr, allow_empty) {} |
9869 | |
9870 | /// Constructs a primitive descriptor for a GRU backward propagation |
9871 | /// primitive from a C API primitive descriptor that must have a |
9872 | /// matching kind. |
9873 | /// |
9874 | /// @param pd C API primitive descriptor for a GRU backward |
9875 | /// propagation primitive. |
9876 | primitive_desc(dnnl_primitive_desc_t pd) |
9877 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, |
9878 | dnnl::algorithm::vanilla_gru) {} |
9879 | |
9880 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
9881 | memory::desc src_layer_desc() const { |
9882 | return rnn_base::src_layer_desc(); |
9883 | } |
9884 | |
9885 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
9886 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
9887 | |
9888 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
9889 | memory::desc weights_layer_desc() const { |
9890 | return rnn_base::weights_layer_desc(); |
9891 | } |
9892 | |
9893 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
9894 | memory::desc weights_iter_desc() const { |
9895 | return rnn_base::weights_iter_desc(); |
9896 | } |
9897 | |
9898 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
9899 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
9900 | |
9901 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
9902 | memory::desc dst_layer_desc() const { |
9903 | return rnn_base::dst_layer_desc(); |
9904 | } |
9905 | |
9906 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
9907 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
9908 | |
9909 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
9910 | memory::desc workspace_desc() const { |
9911 | return rnn_base::workspace_desc(); |
9912 | } |
9913 | |
9914 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const |
9915 | memory::desc diff_src_layer_desc() const { |
9916 | return rnn_base::diff_src_layer_desc(); |
9917 | } |
9918 | |
9919 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const |
9920 | memory::desc diff_src_iter_desc() const { |
9921 | return rnn_base::diff_src_iter_desc(); |
9922 | } |
9923 | |
9924 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const |
9925 | memory::desc diff_weights_layer_desc() const { |
9926 | return rnn_base::diff_weights_layer_desc(); |
9927 | } |
9928 | |
9929 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const |
9930 | memory::desc diff_weights_iter_desc() const { |
9931 | return rnn_base::diff_weights_iter_desc(); |
9932 | } |
9933 | |
9934 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const |
9935 | memory::desc diff_bias_desc() const { |
9936 | return rnn_base::diff_bias_desc(); |
9937 | } |
9938 | |
9939 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const |
9940 | memory::desc diff_dst_layer_desc() const { |
9941 | return rnn_base::diff_dst_layer_desc(); |
9942 | } |
9943 | |
9944 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const |
9945 | memory::desc diff_dst_iter_desc() const { |
9946 | return rnn_base::diff_dst_iter_desc(); |
9947 | } |
9948 | |
9949 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
9950 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
9951 | |
9952 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
9953 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
9954 | |
9955 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
9956 | rnn_direction get_direction() const { return base::get_direction(); } |
9957 | }; |
9958 | |
9959 | /// Default constructor. Produces an empty object. |
9960 | gru_backward() = default; |
9961 | |
9962 | /// Constructs a GRU backward propagation primitive. |
9963 | /// @param pd Primitive descriptor for a GRU backward propagation |
9964 | /// primitive. |
9965 | gru_backward(const primitive_desc &pd) : primitive(pd) {} |
9966 | |
9967 | /// Constructs a GRU backward propagation primitive from a cache blob. |
9968 | /// @param pd Primitive descriptor for a GRU backward propagation |
9969 | /// primitive. |
9970 | /// @param cache_blob Cache blob. |
9971 | gru_backward( |
9972 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
9973 | : primitive(pd, cache_blob) {} |
9974 | }; |
9975 | |
9976 | /// LBR GRU forward propagation primitive. |
9977 | struct lbr_gru_forward : public primitive { |
9978 | /// Primitive descriptor for an LBR GRU forward propagation primitive. |
9979 | struct primitive_desc : public rnn_primitive_desc_base { |
9980 | /// Default constructor. Produces an empty object. |
9981 | primitive_desc() = default; |
9982 | |
9983 | /// Constructs a primitive descriptor for LBR GRU forward propagation |
9984 | /// primitive. |
9985 | /// |
9986 | /// The following arguments may point to a zero memory descriptor: |
9987 | /// - @p src_iter_desc, |
9988 | /// - @p bias_desc, |
9989 | /// - @p dst_iter_desc. |
9990 | /// |
9991 | /// This would then indicate that the LBR GRU forward propagation |
9992 | /// primitive should not use them and should default to zero values |
9993 | /// instead. |
9994 | /// |
9995 | /// @note |
9996 | /// All memory descriptors except @p src_iter_desc may be |
9997 | /// initialized with an #dnnl::memory::format_tag::any value of @p |
9998 | /// format_tag. |
9999 | /// |
10000 | /// @param aengine Engine to use. |
10001 | /// @param aprop_kind Propagation kind. Possible values are |
10002 | /// #dnnl::prop_kind::forward_training, and |
10003 | /// #dnnl::prop_kind::forward_inference. |
10004 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
10005 | /// more info. |
10006 | /// @param src_layer_desc Memory descriptor for the input vector. |
10007 | /// @param src_iter_desc Memory descriptor for the input recurrent |
10008 | /// hidden state vector. |
10009 | /// @param weights_layer_desc Memory descriptor for the weights |
10010 | /// applied to the layer input. |
10011 | /// @param weights_iter_desc Memory descriptor for the weights applied |
10012 | /// to the recurrent input. |
10013 | /// @param bias_desc Bias memory descriptor. |
10014 | /// @param dst_layer_desc Memory descriptor for the output vector. |
10015 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
10016 | /// hidden state vector. |
10017 | /// @param attr Primitive attributes to use. Attributes are optional |
10018 | /// and default to empty attributes. |
10019 | /// @param allow_empty A flag signifying whether construction is |
10020 | /// allowed to fail without throwing an exception. In this case an |
10021 | /// empty object will be produced. This flag is optional and |
10022 | /// defaults to false. |
10023 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
10024 | rnn_direction direction, const memory::desc &src_layer_desc, |
10025 | const memory::desc &src_iter_desc, |
10026 | const memory::desc &weights_layer_desc, |
10027 | const memory::desc &weights_iter_desc, |
10028 | const memory::desc &bias_desc, |
10029 | const memory::desc &dst_layer_desc, |
10030 | const memory::desc &dst_iter_desc, |
10031 | const primitive_attr &attr = default_attr(), |
10032 | bool allow_empty = false) |
10033 | : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind, |
10034 | algorithm::undef, direction, src_layer_desc, src_iter_desc, |
10035 | nullptr, nullptr, weights_layer_desc, weights_iter_desc, |
10036 | nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc, |
10037 | nullptr, rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {} |
10038 | |
10039 | /// Constructs a primitive descriptor for a LBR GRU forward propagation |
10040 | /// primitive from a C API primitive descriptor that must have a |
10041 | /// matching kind. |
10042 | /// |
10043 | /// @param pd C API primitive descriptor for a LBR GRU forward |
10044 | /// propagation primitive. |
10045 | primitive_desc(dnnl_primitive_desc_t pd) |
10046 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, |
10047 | dnnl::prop_kind::forward_inference, |
10048 | dnnl::algorithm::lbr_gru) {} |
10049 | |
10050 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
10051 | memory::desc src_layer_desc() const { |
10052 | return rnn_base::src_layer_desc(); |
10053 | } |
10054 | |
10055 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
10056 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
10057 | |
10058 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
10059 | memory::desc weights_layer_desc() const { |
10060 | return rnn_base::weights_layer_desc(); |
10061 | } |
10062 | |
10063 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
10064 | memory::desc weights_iter_desc() const { |
10065 | return rnn_base::weights_iter_desc(); |
10066 | } |
10067 | |
10068 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
10069 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
10070 | |
10071 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
10072 | memory::desc dst_layer_desc() const { |
10073 | return rnn_base::dst_layer_desc(); |
10074 | } |
10075 | |
10076 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
10077 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
10078 | |
10079 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
10080 | memory::desc workspace_desc() const { |
10081 | return rnn_base::workspace_desc(); |
10082 | } |
10083 | |
10084 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
10085 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
10086 | |
10087 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
10088 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
10089 | |
10090 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
10091 | rnn_direction get_direction() const { return base::get_direction(); } |
10092 | }; |
10093 | |
10094 | /// Default constructor. Produces an empty object. |
10095 | lbr_gru_forward() = default; |
10096 | |
10097 | /// Constructs an LBR GRU forward propagation primitive. |
10098 | /// @param pd Primitive descriptor for an LBR GRU forward propagation |
10099 | /// primitive. |
10100 | lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {} |
10101 | |
10102 | /// Constructs an LBR GRU forward propagation primitive from a cache blob. |
10103 | /// @param pd Primitive descriptor for an LBR GRU forward propagation |
10104 | /// primitive. |
10105 | /// @param cache_blob Cache blob. |
10106 | lbr_gru_forward( |
10107 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
10108 | : primitive(pd, cache_blob) {} |
10109 | }; |
10110 | |
10111 | /// LBR GRU backward propagation primitive. |
10112 | struct lbr_gru_backward : public primitive { |
10113 | /// Primitive descriptor for an LBR GRU backward propagation primitive. |
10114 | struct primitive_desc : public rnn_primitive_desc_base { |
10115 | /// Default constructor. Produces an empty object. |
10116 | primitive_desc() = default; |
10117 | |
10118 | /// Constructs a primitive descriptor for LBR GRU backward propagation |
10119 | /// primitive. |
10120 | /// |
10121 | /// The following arguments may point to a zero memory descriptor: |
10122 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
10123 | /// - @p bias_desc together with @p diff_bias_desc, |
10124 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
10125 | /// |
10126 | /// This would then indicate that the LBR GRU backward propagation |
10127 | /// primitive should not use them and should default to zero values |
10128 | /// instead. |
10129 | /// |
10130 | /// @note |
10131 | /// All memory descriptors may be initialized with |
10132 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
10133 | /// |
10134 | /// @param aengine Engine to use. |
10135 | /// @param aprop_kind Propagation kind. Must be |
10136 | /// #dnnl::prop_kind::backward. |
10137 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
10138 | /// more info. |
10139 | /// @param src_layer_desc Memory descriptor for the input vector. |
10140 | /// @param src_iter_desc Memory descriptor for the input recurrent |
10141 | /// hidden state vector. |
10142 | /// @param weights_layer_desc Memory descriptor for the weights |
10143 | /// applied to the layer input. |
10144 | /// @param weights_iter_desc Memory descriptor for the weights applied |
10145 | /// to the recurrent input. |
10146 | /// @param bias_desc Bias memory descriptor. |
10147 | /// @param dst_layer_desc Memory descriptor for the output vector. |
10148 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
10149 | /// hidden state vector. |
10150 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
10151 | /// vector. |
10152 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
10153 | /// recurrent hidden state vector. |
10154 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
10155 | /// weights applied to the layer input. |
10156 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
10157 | /// weights applied to the recurrent input. |
10158 | /// @param diff_bias_desc Diff bias memory descriptor. |
10159 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
10160 | /// output vector. |
10161 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
10162 | /// recurrent hidden state vector. |
10163 | /// @param hint_fwd_pd Primitive descriptor for an LBR GRU |
10164 | /// forward propagation primitive. It is used as a hint for |
10165 | /// deciding which memory format to use. |
10166 | /// @param attr Primitive attributes to use. Attributes are optional |
10167 | /// and default to empty attributes. |
10168 | /// @param allow_empty A flag signifying whether construction is |
10169 | /// allowed to fail without throwing an exception. In this case an |
10170 | /// empty object will be produced. This flag is optional and |
10171 | /// defaults to false. |
10172 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
10173 | rnn_direction direction, const memory::desc &src_layer_desc, |
10174 | const memory::desc &src_iter_desc, |
10175 | const memory::desc &weights_layer_desc, |
10176 | const memory::desc &weights_iter_desc, |
10177 | const memory::desc &bias_desc, |
10178 | const memory::desc &dst_layer_desc, |
10179 | const memory::desc &dst_iter_desc, |
10180 | const memory::desc &diff_src_layer_desc, |
10181 | const memory::desc &diff_src_iter_desc, |
10182 | const memory::desc &diff_weights_layer_desc, |
10183 | const memory::desc &diff_weights_iter_desc, |
10184 | const memory::desc &diff_bias_desc, |
10185 | const memory::desc &diff_dst_layer_desc, |
10186 | const memory::desc &diff_dst_iter_desc, |
10187 | const gru_forward::primitive_desc &hint_fwd_pd, |
10188 | const primitive_attr &attr = default_attr(), |
10189 | bool allow_empty = false) |
10190 | : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind, |
10191 | algorithm::undef, direction, src_layer_desc, src_iter_desc, |
10192 | nullptr, nullptr, weights_layer_desc, weights_iter_desc, |
10193 | nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc, |
10194 | nullptr, diff_src_layer_desc, diff_src_iter_desc, nullptr, |
10195 | nullptr, diff_weights_layer_desc, diff_weights_iter_desc, |
10196 | nullptr, nullptr, diff_bias_desc, diff_dst_layer_desc, |
10197 | diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, |
10198 | hint_fwd_pd, attr, allow_empty) {} |
10199 | |
10200 | /// Constructs a primitive descriptor for a LBR GRU backward propagation |
10201 | /// primitive from a C API primitive descriptor that must have a |
10202 | /// matching kind. |
10203 | /// |
10204 | /// @param pd C API primitive descriptor for a LBR GRU backward |
10205 | /// propagation primitive. |
10206 | primitive_desc(dnnl_primitive_desc_t pd) |
10207 | : rnn_primitive_desc_base( |
10208 | pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {} |
10209 | |
10210 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
10211 | memory::desc src_layer_desc() const { |
10212 | return rnn_base::src_layer_desc(); |
10213 | } |
10214 | |
10215 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
10216 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
10217 | |
10218 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
10219 | memory::desc weights_layer_desc() const { |
10220 | return rnn_base::weights_layer_desc(); |
10221 | } |
10222 | |
10223 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
10224 | memory::desc weights_iter_desc() const { |
10225 | return rnn_base::weights_iter_desc(); |
10226 | } |
10227 | |
10228 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
10229 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
10230 | |
10231 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
10232 | memory::desc dst_layer_desc() const { |
10233 | return rnn_base::dst_layer_desc(); |
10234 | } |
10235 | |
10236 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
10237 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
10238 | |
10239 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
10240 | memory::desc workspace_desc() const { |
10241 | return rnn_base::workspace_desc(); |
10242 | } |
10243 | |
10244 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const |
10245 | memory::desc diff_src_layer_desc() const { |
10246 | return rnn_base::diff_src_layer_desc(); |
10247 | } |
10248 | |
10249 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const |
10250 | memory::desc diff_src_iter_desc() const { |
10251 | return rnn_base::diff_src_iter_desc(); |
10252 | } |
10253 | |
10254 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const |
10255 | memory::desc diff_weights_layer_desc() const { |
10256 | return rnn_base::diff_weights_layer_desc(); |
10257 | } |
10258 | |
10259 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const |
10260 | memory::desc diff_weights_iter_desc() const { |
10261 | return rnn_base::diff_weights_iter_desc(); |
10262 | } |
10263 | |
10264 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const |
10265 | memory::desc diff_bias_desc() const { |
10266 | return rnn_base::diff_bias_desc(); |
10267 | } |
10268 | |
10269 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const |
10270 | memory::desc diff_dst_layer_desc() const { |
10271 | return rnn_base::diff_dst_layer_desc(); |
10272 | } |
10273 | |
10274 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const |
10275 | memory::desc diff_dst_iter_desc() const { |
10276 | return rnn_base::diff_dst_iter_desc(); |
10277 | } |
10278 | |
10279 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
10280 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
10281 | |
10282 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
10283 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
10284 | |
10285 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
10286 | rnn_direction get_direction() const { return base::get_direction(); } |
10287 | }; |
10288 | |
10289 | /// Default constructor. Produces an empty object. |
10290 | lbr_gru_backward() = default; |
10291 | |
10292 | /// Constructs an LBR GRU backward propagation primitive. |
10293 | /// @param pd Primitive descriptor for an LBR GRU backward propagation |
10294 | /// primitive. |
10295 | lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {} |
10296 | |
10297 | /// Constructs an LBR GRU backward propagation primitive from a cache blob. |
10298 | /// @param pd Primitive descriptor for an LBR GRU backward propagation |
10299 | /// primitive. |
10300 | /// @param cache_blob Cache blob. |
10301 | lbr_gru_backward( |
10302 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
10303 | : primitive(pd, cache_blob) {} |
10304 | }; |
10305 | |
10306 | /// AUGRU forward propagation primitive. |
10307 | struct augru_forward : public primitive { |
10308 | /// Primitive descriptor for an AUGRU forward propagation primitive. |
10309 | struct primitive_desc : public rnn_primitive_desc_base { |
10310 | /// Default constructor. Produces an empty object. |
10311 | primitive_desc() = default; |
10312 | |
10313 | /// Constructs a primitive descriptor for an AUGRU forward propagation |
10314 | /// primitive. |
10315 | /// |
10316 | /// The following arguments may point to a zero memory descriptor: |
10317 | /// - @p src_iter_desc, |
10318 | /// - @p bias_desc, |
10319 | /// - @p dst_iter_desc. |
10320 | /// |
10321 | /// This would then indicate that the AUGRU forward propagation |
10322 | /// primitive should not use them and should default to zero values |
10323 | /// instead. |
10324 | /// |
10325 | /// @note |
10326 | /// All memory descriptors except @p src_iter_desc may be |
10327 | /// initialized with an #dnnl::memory::format_tag::any value of @p |
10328 | /// format_tag. |
10329 | /// |
10330 | /// @param aengine Engine to use. |
10331 | /// @param aprop_kind Propagation kind. Possible values are |
10332 | /// #dnnl::prop_kind::forward_training, and |
10333 | /// #dnnl::prop_kind::forward_inference. |
10334 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
10335 | /// more info. |
10336 | /// @param src_layer_desc Memory descriptor for the input vector. |
10337 | /// @param src_iter_desc Memory descriptor for the input recurrent |
10338 | /// hidden state vector. |
10339 | /// @param attention_desc Memory descriptor for the attention vector. |
10340 | /// @param weights_layer_desc Memory descriptor for the weights |
10341 | /// applied to the layer input. |
10342 | /// @param weights_iter_desc Memory descriptor for the weights applied |
10343 | /// to the recurrent input. |
10344 | /// @param bias_desc Bias memory descriptor. |
10345 | /// @param dst_layer_desc Memory descriptor for the output vector. |
10346 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
10347 | /// hidden state vector. |
10348 | /// @param attr Primitive attributes to use. Attributes are optional |
10349 | /// and default to empty attributes. |
10350 | /// @param allow_empty A flag signifying whether construction is |
10351 | /// allowed to fail without throwing an exception. In this case an |
10352 | /// empty object will be produced. This flag is optional and |
10353 | /// defaults to false. |
10354 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
10355 | rnn_direction direction, const memory::desc &src_layer_desc, |
10356 | const memory::desc &src_iter_desc, |
10357 | const memory::desc &attention_desc, |
10358 | const memory::desc &weights_layer_desc, |
10359 | const memory::desc &weights_iter_desc, |
10360 | const memory::desc &bias_desc, |
10361 | const memory::desc &dst_layer_desc, |
10362 | const memory::desc &dst_iter_desc, |
10363 | const primitive_attr &attr = default_attr(), |
10364 | bool allow_empty = false) |
10365 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru, |
10366 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
10367 | src_iter_desc, nullptr, &attention_desc, weights_layer_desc, |
10368 | weights_iter_desc, nullptr, nullptr, bias_desc, |
10369 | dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, |
10370 | 0.0f, 0.0f, attr, allow_empty) {} |
10371 | |
10372 | /// Constructs a primitive descriptor for an AUGRU forward propagation |
10373 | /// primitive from a C API primitive descriptor that must have a |
10374 | /// matching kind. |
10375 | /// |
10376 | /// @param pd C API primitive descriptor for an AUGRU forward |
10377 | /// propagation primitive. |
10378 | primitive_desc(dnnl_primitive_desc_t pd) |
10379 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, |
10380 | dnnl::prop_kind::forward_inference, |
10381 | dnnl::algorithm::vanilla_augru) {} |
10382 | |
10383 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
10384 | memory::desc src_layer_desc() const { |
10385 | return rnn_base::src_layer_desc(); |
10386 | } |
10387 | |
10388 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
10389 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
10390 | |
10391 | /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const |
10392 | memory::desc attention_desc() const { |
10393 | return rnn_base::augru_attention_desc(); |
10394 | } |
10395 | |
10396 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
10397 | memory::desc weights_layer_desc() const { |
10398 | return rnn_base::weights_layer_desc(); |
10399 | } |
10400 | |
10401 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
10402 | memory::desc weights_iter_desc() const { |
10403 | return rnn_base::weights_iter_desc(); |
10404 | } |
10405 | |
10406 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
10407 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
10408 | |
10409 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
10410 | memory::desc dst_layer_desc() const { |
10411 | return rnn_base::dst_layer_desc(); |
10412 | } |
10413 | |
10414 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
10415 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
10416 | |
10417 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
10418 | memory::desc workspace_desc() const { |
10419 | return rnn_base::workspace_desc(); |
10420 | } |
10421 | |
10422 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
10423 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
10424 | |
10425 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
10426 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
10427 | |
10428 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
10429 | rnn_direction get_direction() const { return base::get_direction(); } |
10430 | }; |
10431 | |
10432 | /// Default constructor. Produces an empty object. |
10433 | augru_forward() = default; |
10434 | |
10435 | /// Constructs an AUGRU forward propagation primitive. |
10436 | /// @param pd Primitive descriptor for an AUGRU forward propagation |
10437 | /// primitive. |
10438 | augru_forward(const primitive_desc &pd) : primitive(pd) {} |
10439 | |
10440 | /// Constructs an AUGRU forward propagation primitive from a cache blob. |
10441 | /// @param pd Primitive descriptor for an AUGRU forward propagation |
10442 | /// primitive. |
10443 | /// @param cache_blob Cache blob. |
10444 | augru_forward( |
10445 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
10446 | : primitive(pd, cache_blob) {} |
10447 | }; |
10448 | |
10449 | /// AUGRU backward propagation primitive. |
10450 | struct augru_backward : public primitive { |
10451 | /// Descriptor for an AUGRU backward propagation primitive. |
10452 | /// Primitive descriptor for an AUGRU backward propagation primitive. |
10453 | struct primitive_desc : public rnn_primitive_desc_base { |
10454 | /// Default constructor. Produces an empty object. |
10455 | primitive_desc() = default; |
10456 | |
10457 | /// Constructs a primitive descriptor for an AUGRU backward propagation |
10458 | /// primitive. |
10459 | /// |
10460 | /// The following arguments may point to a zero memory descriptor: |
10461 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
10462 | /// - @p bias_desc together with @p diff_bias_desc, |
10463 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
10464 | /// |
10465 | /// This would then indicate that the AUGRU backward propagation |
10466 | /// primitive should not use them and should default to zero values |
10467 | /// instead. |
10468 | /// |
10469 | /// @note |
10470 | /// All memory descriptors may be initialized with |
10471 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
10472 | /// |
10473 | /// @param aengine Engine to use. |
10474 | /// @param aprop_kind Propagation kind. Must be |
10475 | /// #dnnl::prop_kind::backward. |
10476 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
10477 | /// more info. |
10478 | /// @param src_layer_desc Memory descriptor for the input vector. |
10479 | /// @param src_iter_desc Memory descriptor for the input recurrent |
10480 | /// hidden state vector. |
10481 | /// @param attention_desc Memory descriptor for the attention vector. |
10482 | /// @param weights_layer_desc Memory descriptor for the weights |
10483 | /// applied to the layer input. |
10484 | /// @param weights_iter_desc Memory descriptor for the weights applied |
10485 | /// to the recurrent input. |
10486 | /// @param bias_desc Bias memory descriptor. |
10487 | /// @param dst_layer_desc Memory descriptor for the output vector. |
10488 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
10489 | /// hidden state vector. |
10490 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
10491 | /// vector. |
10492 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
10493 | /// recurrent hidden state vector. |
10494 | /// @param diff_attention_desc Memory descriptor for the diff of |
10495 | /// attention vector. |
10496 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
10497 | /// weights applied to the layer input. |
10498 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
10499 | /// weights applied to the recurrent input. |
10500 | /// @param diff_bias_desc Diff bias memory descriptor. |
10501 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
10502 | /// output vector. |
10503 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
10504 | /// recurrent hidden state vector. |
10505 | /// @param hint_fwd_pd Primitive descriptor for an AUGRU |
10506 | /// forward propagation primitive. It is used as a hint for |
10507 | /// deciding which memory format to use. |
10508 | /// @param attr Primitive attributes to use. Attributes are optional |
10509 | /// and default to empty attributes. |
10510 | /// @param allow_empty A flag signifying whether construction is |
10511 | /// allowed to fail without throwing an exception. In this case an |
10512 | /// empty object will be produced. This flag is optional and |
10513 | /// defaults to false. |
10514 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
10515 | rnn_direction direction, const memory::desc &src_layer_desc, |
10516 | const memory::desc &src_iter_desc, |
10517 | const memory::desc &attention_desc, |
10518 | const memory::desc &weights_layer_desc, |
10519 | const memory::desc &weights_iter_desc, |
10520 | const memory::desc &bias_desc, |
10521 | const memory::desc &dst_layer_desc, |
10522 | const memory::desc &dst_iter_desc, |
10523 | const memory::desc &diff_src_layer_desc, |
10524 | const memory::desc &diff_src_iter_desc, |
10525 | const memory::desc &diff_attention_desc, |
10526 | const memory::desc &diff_weights_layer_desc, |
10527 | const memory::desc &diff_weights_iter_desc, |
10528 | const memory::desc &diff_bias_desc, |
10529 | const memory::desc &diff_dst_layer_desc, |
10530 | const memory::desc &diff_dst_iter_desc, |
10531 | const gru_forward::primitive_desc &hint_fwd_pd, |
10532 | const primitive_attr &attr = default_attr(), |
10533 | bool allow_empty = false) |
10534 | : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru, |
10535 | aprop_kind, algorithm::undef, direction, src_layer_desc, |
10536 | src_iter_desc, nullptr, &attention_desc, weights_layer_desc, |
10537 | weights_iter_desc, nullptr, nullptr, bias_desc, |
10538 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
10539 | diff_src_iter_desc, nullptr, &diff_attention_desc, |
10540 | diff_weights_layer_desc, diff_weights_iter_desc, nullptr, |
10541 | nullptr, diff_bias_desc, diff_dst_layer_desc, |
10542 | diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, |
10543 | hint_fwd_pd, attr, allow_empty) {} |
10544 | |
10545 | /// Constructs a primitive descriptor for an AUGRU backward propagation |
10546 | /// primitive from a C API primitive descriptor that must have a |
10547 | /// matching kind. |
10548 | /// |
10549 | /// @param pd C API primitive descriptor for an AUGRU backward |
10550 | /// propagation primitive. |
10551 | primitive_desc(dnnl_primitive_desc_t pd) |
10552 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, |
10553 | dnnl::algorithm::vanilla_augru) {} |
10554 | |
10555 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
10556 | memory::desc src_layer_desc() const { |
10557 | return rnn_base::src_layer_desc(); |
10558 | } |
10559 | |
10560 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
10561 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
10562 | |
10563 | /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const |
10564 | memory::desc attention_desc() const { |
10565 | return rnn_base::augru_attention_desc(); |
10566 | } |
10567 | |
10568 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
10569 | memory::desc weights_layer_desc() const { |
10570 | return rnn_base::weights_layer_desc(); |
10571 | } |
10572 | |
10573 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
10574 | memory::desc weights_iter_desc() const { |
10575 | return rnn_base::weights_iter_desc(); |
10576 | } |
10577 | |
10578 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
10579 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
10580 | |
10581 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
10582 | memory::desc dst_layer_desc() const { |
10583 | return rnn_base::dst_layer_desc(); |
10584 | } |
10585 | |
10586 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
10587 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
10588 | |
10589 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
10590 | memory::desc workspace_desc() const { |
10591 | return rnn_base::workspace_desc(); |
10592 | } |
10593 | |
10594 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const |
10595 | memory::desc diff_src_layer_desc() const { |
10596 | return rnn_base::diff_src_layer_desc(); |
10597 | } |
10598 | |
10599 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const |
10600 | memory::desc diff_src_iter_desc() const { |
10601 | return rnn_base::diff_src_iter_desc(); |
10602 | } |
10603 | |
10604 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const |
10605 | memory::desc diff_attention_desc() const { |
10606 | return rnn_base::diff_augru_attention_desc(); |
10607 | } |
10608 | |
10609 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const |
10610 | memory::desc diff_weights_layer_desc() const { |
10611 | return rnn_base::diff_weights_layer_desc(); |
10612 | } |
10613 | |
10614 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const |
10615 | memory::desc diff_weights_iter_desc() const { |
10616 | return rnn_base::diff_weights_iter_desc(); |
10617 | } |
10618 | |
10619 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const |
10620 | memory::desc diff_bias_desc() const { |
10621 | return rnn_base::diff_bias_desc(); |
10622 | } |
10623 | |
10624 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const |
10625 | memory::desc diff_dst_layer_desc() const { |
10626 | return rnn_base::diff_dst_layer_desc(); |
10627 | } |
10628 | |
10629 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const |
10630 | memory::desc diff_dst_iter_desc() const { |
10631 | return rnn_base::diff_dst_iter_desc(); |
10632 | } |
10633 | |
10634 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
10635 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
10636 | |
10637 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
10638 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
10639 | |
10640 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
10641 | rnn_direction get_direction() const { return base::get_direction(); } |
10642 | }; |
10643 | |
10644 | /// Default constructor. Produces an empty object. |
10645 | augru_backward() = default; |
10646 | |
10647 | /// Constructs an AUGRU backward propagation primitive. |
10648 | /// @param pd Primitive descriptor for an AUGRU backward propagation |
10649 | /// primitive. |
10650 | augru_backward(const primitive_desc &pd) : primitive(pd) {} |
10651 | |
10652 | /// Constructs an AUGRU backward propagation primitive from a cache blob. |
10653 | /// @param pd Primitive descriptor for an AUGRU backward propagation |
10654 | /// primitive. |
10655 | /// @param cache_blob Cache blob. |
10656 | augru_backward( |
10657 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
10658 | : primitive(pd, cache_blob) {} |
10659 | }; |
10660 | |
10661 | /// LBR AUGRU forward propagation primitive. |
10662 | struct lbr_augru_forward : public primitive { |
10663 | /// Descriptor for an LBR AUGRU forward propagation primitive. |
10664 | |
10665 | /// Primitive descriptor for an LBR AUGRU forward propagation primitive. |
10666 | struct primitive_desc : public rnn_primitive_desc_base { |
10667 | /// Default constructor. Produces an empty object. |
10668 | primitive_desc() = default; |
10669 | |
10670 | /// Constructs a primitive descriptor for LBR AUGRU forward propagation |
10671 | /// primitive. |
10672 | /// |
10673 | /// The following arguments may point to a zero memory descriptor: |
10674 | /// - @p src_iter_desc, |
10675 | /// - @p bias_desc, |
10676 | /// - @p dst_iter_desc. |
10677 | /// |
10678 | /// This would then indicate that the LBR AUGRU forward propagation |
10679 | /// primitive should not use them and should default to zero values |
10680 | /// instead. |
10681 | /// |
10682 | /// @note |
10683 | /// All memory descriptors except @p src_iter_desc may be |
10684 | /// initialized with an #dnnl::memory::format_tag::any value of @p |
10685 | /// format_tag. |
10686 | /// |
10687 | /// @param aengine Engine to use. |
10688 | /// @param aprop_kind Propagation kind. Possible values are |
10689 | /// #dnnl::prop_kind::forward_training, and |
10690 | /// #dnnl::prop_kind::forward_inference. |
10691 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
10692 | /// more info. |
10693 | /// @param src_layer_desc Memory descriptor for the input vector. |
10694 | /// @param src_iter_desc Memory descriptor for the input recurrent |
10695 | /// hidden state vector. |
10696 | /// @param attention_desc Memory descriptor for the attention vector. |
10697 | /// @param weights_layer_desc Memory descriptor for the weights |
10698 | /// applied to the layer input. |
10699 | /// @param weights_iter_desc Memory descriptor for the weights applied |
10700 | /// to the recurrent input. |
10701 | /// @param bias_desc Bias memory descriptor. |
10702 | /// @param dst_layer_desc Memory descriptor for the output vector. |
10703 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
10704 | /// hidden state vector. |
10705 | /// @param attr Primitive attributes to use. Attributes are optional |
10706 | /// and default to empty attributes. |
10707 | /// @param allow_empty A flag signifying whether construction is |
10708 | /// allowed to fail without throwing an exception. In this case an |
10709 | /// empty object will be produced. This flag is optional and |
10710 | /// defaults to false. |
10711 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
10712 | rnn_direction direction, const memory::desc &src_layer_desc, |
10713 | const memory::desc &src_iter_desc, |
10714 | const memory::desc &attention_desc, |
10715 | const memory::desc &weights_layer_desc, |
10716 | const memory::desc &weights_iter_desc, |
10717 | const memory::desc &bias_desc, |
10718 | const memory::desc &dst_layer_desc, |
10719 | const memory::desc &dst_iter_desc, |
10720 | const primitive_attr &attr = default_attr(), |
10721 | bool allow_empty = false) |
10722 | : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind, |
10723 | algorithm::undef, direction, src_layer_desc, src_iter_desc, |
10724 | nullptr, &attention_desc, weights_layer_desc, |
10725 | weights_iter_desc, nullptr, nullptr, bias_desc, |
10726 | dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, |
10727 | 0.0f, 0.0f, attr, allow_empty) {} |
10728 | |
10729 | /// Constructs a primitive descriptor for an LBR AUGRU forward propagation |
10730 | /// primitive from a C API primitive descriptor that must have a |
10731 | /// matching kind. |
10732 | /// |
10733 | /// @param pd C API primitive descriptor for an LBR AUGRU forward |
10734 | /// propagation primitive. |
10735 | primitive_desc(dnnl_primitive_desc_t pd) |
10736 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, |
10737 | dnnl::prop_kind::forward_inference, |
10738 | dnnl::algorithm::lbr_augru) {} |
10739 | |
10740 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
10741 | memory::desc src_layer_desc() const { |
10742 | return rnn_base::src_layer_desc(); |
10743 | } |
10744 | |
10745 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
10746 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
10747 | |
10748 | /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const |
10749 | memory::desc attention_desc() const { |
10750 | return rnn_base::augru_attention_desc(); |
10751 | } |
10752 | |
10753 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
10754 | memory::desc weights_layer_desc() const { |
10755 | return rnn_base::weights_layer_desc(); |
10756 | } |
10757 | |
10758 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
10759 | memory::desc weights_iter_desc() const { |
10760 | return rnn_base::weights_iter_desc(); |
10761 | } |
10762 | |
10763 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
10764 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
10765 | |
10766 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
10767 | memory::desc dst_layer_desc() const { |
10768 | return rnn_base::dst_layer_desc(); |
10769 | } |
10770 | |
10771 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
10772 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
10773 | |
10774 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
10775 | memory::desc workspace_desc() const { |
10776 | return rnn_base::workspace_desc(); |
10777 | } |
10778 | |
10779 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
10780 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
10781 | |
10782 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
10783 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
10784 | |
10785 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
10786 | rnn_direction get_direction() const { return base::get_direction(); } |
10787 | }; |
10788 | |
10789 | /// Default constructor. Produces an empty object. |
10790 | lbr_augru_forward() = default; |
10791 | |
10792 | /// Constructs an LBR AUGRU forward propagation primitive. |
10793 | /// @param pd Primitive descriptor for an LBR AUGRU forward propagation |
10794 | /// primitive. |
10795 | lbr_augru_forward(const primitive_desc &pd) : primitive(pd) {} |
10796 | |
10797 | /// Constructs an LBR AUGRU forward propagation primitive from a cache blob. |
10798 | /// @param pd Primitive descriptor for an LBR AUGRU forward propagation |
10799 | /// primitive. |
10800 | /// @param cache_blob Cache blob. |
10801 | lbr_augru_forward( |
10802 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
10803 | : primitive(pd, cache_blob) {} |
10804 | }; |
10805 | |
10806 | /// LBR AUGRU backward propagation primitive. |
10807 | struct lbr_augru_backward : public primitive { |
10808 | /// Primitive descriptor for an LBR AUGRU backward propagation primitive. |
10809 | struct primitive_desc : public rnn_primitive_desc_base { |
10810 | /// Default constructor. Produces an empty object. |
10811 | primitive_desc() = default; |
10812 | |
10813 | /// Constructs a primitive descriptor for LBR AUGRU backward propagation |
10814 | /// primitive. |
10815 | /// |
10816 | /// The following arguments may point to a zero memory descriptor: |
10817 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
10818 | /// - @p bias_desc together with @p diff_bias_desc, |
10819 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
10820 | /// |
10821 | /// This would then indicate that the LBR AUGRU backward propagation |
10822 | /// primitive should not use them and should default to zero values |
10823 | /// instead. |
10824 | /// |
10825 | /// @note |
10826 | /// All memory descriptors may be initialized with |
10827 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
10828 | /// |
10829 | /// @param aengine Engine to use. |
10830 | /// @param aprop_kind Propagation kind. Must be |
10831 | /// #dnnl::prop_kind::backward. |
10832 | /// @param direction RNN direction. See @ref dnnl::rnn_direction for |
10833 | /// more info. |
10834 | /// @param src_layer_desc Memory descriptor for the input vector. |
10835 | /// @param src_iter_desc Memory descriptor for the input recurrent |
10836 | /// hidden state vector. |
10837 | /// @param attention_desc Memory descriptor for the attention vector. |
10838 | /// @param weights_layer_desc Memory descriptor for the weights |
10839 | /// applied to the layer input. |
10840 | /// @param weights_iter_desc Memory descriptor for the weights applied |
10841 | /// to the recurrent input. |
10842 | /// @param bias_desc Bias memory descriptor. |
10843 | /// @param dst_layer_desc Memory descriptor for the output vector. |
10844 | /// @param dst_iter_desc Memory descriptor for the output recurrent |
10845 | /// hidden state vector. |
10846 | /// @param diff_src_layer_desc Memory descriptor for the diff of input |
10847 | /// vector. |
10848 | /// @param diff_src_iter_desc Memory descriptor for the diff of input |
10849 | /// recurrent hidden state vector. |
10850 | /// @param diff_attention_desc Memory descriptor for the diff of |
10851 | /// attention vector. |
10852 | /// @param diff_weights_layer_desc Memory descriptor for the diff of |
10853 | /// weights applied to the layer input. |
10854 | /// @param diff_weights_iter_desc Memory descriptor for the diff of |
10855 | /// weights applied to the recurrent input. |
10856 | /// @param diff_bias_desc Diff bias memory descriptor. |
10857 | /// @param diff_dst_layer_desc Memory descriptor for the diff of |
10858 | /// output vector. |
10859 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
10860 | /// recurrent hidden state vector. |
10861 | /// @param hint_fwd_pd Primitive descriptor for an LBR AUGRU |
10862 | /// forward propagation primitive. It is used as a hint for |
10863 | /// deciding which memory format to use. |
10864 | /// @param attr Primitive attributes to use. Attributes are optional |
10865 | /// and default to empty attributes. |
10866 | /// @param allow_empty A flag signifying whether construction is |
10867 | /// allowed to fail without throwing an exception. In this case an |
10868 | /// empty object will be produced. This flag is optional and |
10869 | /// defaults to false. |
10870 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
10871 | rnn_direction direction, const memory::desc &src_layer_desc, |
10872 | const memory::desc &src_iter_desc, |
10873 | const memory::desc &attention_desc, |
10874 | const memory::desc &weights_layer_desc, |
10875 | const memory::desc &weights_iter_desc, |
10876 | const memory::desc &bias_desc, |
10877 | const memory::desc &dst_layer_desc, |
10878 | const memory::desc &dst_iter_desc, |
10879 | const memory::desc &diff_src_layer_desc, |
10880 | const memory::desc &diff_src_iter_desc, |
10881 | const memory::desc &diff_attention_desc, |
10882 | const memory::desc &diff_weights_layer_desc, |
10883 | const memory::desc &diff_weights_iter_desc, |
10884 | const memory::desc &diff_bias_desc, |
10885 | const memory::desc &diff_dst_layer_desc, |
10886 | const memory::desc &diff_dst_iter_desc, |
10887 | const gru_forward::primitive_desc &hint_fwd_pd, |
10888 | const primitive_attr &attr = default_attr(), |
10889 | bool allow_empty = false) |
10890 | : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind, |
10891 | algorithm::undef, direction, src_layer_desc, src_iter_desc, |
10892 | nullptr, &attention_desc, weights_layer_desc, |
10893 | weights_iter_desc, nullptr, nullptr, bias_desc, |
10894 | dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, |
10895 | diff_src_iter_desc, nullptr, &diff_attention_desc, |
10896 | diff_weights_layer_desc, diff_weights_iter_desc, nullptr, |
10897 | nullptr, diff_bias_desc, diff_dst_layer_desc, |
10898 | diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, |
10899 | hint_fwd_pd, attr, allow_empty) {} |
10900 | |
10901 | /// Constructs a primitive descriptor for an LBR AUGRU backward |
10902 | /// propagation primitive from a C API primitive descriptor that must |
10903 | /// have a matching kind. |
10904 | /// |
10905 | /// @param pd C API primitive descriptor for an LBR AUGRU backward |
10906 | /// propagation primitive. |
10907 | primitive_desc(dnnl_primitive_desc_t pd) |
10908 | : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, |
10909 | dnnl::algorithm::lbr_augru) {} |
10910 | |
10911 | /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const |
10912 | memory::desc src_layer_desc() const { |
10913 | return rnn_base::src_layer_desc(); |
10914 | } |
10915 | |
10916 | /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const |
10917 | memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } |
10918 | |
10919 | /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const |
10920 | memory::desc attention_desc() const { |
10921 | return rnn_base::augru_attention_desc(); |
10922 | } |
10923 | |
10924 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const |
10925 | memory::desc weights_layer_desc() const { |
10926 | return rnn_base::weights_layer_desc(); |
10927 | } |
10928 | |
10929 | /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const |
10930 | memory::desc weights_iter_desc() const { |
10931 | return rnn_base::weights_iter_desc(); |
10932 | } |
10933 | |
10934 | /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const |
10935 | memory::desc bias_desc() const { return rnn_base::bias_desc(); } |
10936 | |
10937 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const |
10938 | memory::desc dst_layer_desc() const { |
10939 | return rnn_base::dst_layer_desc(); |
10940 | } |
10941 | |
10942 | /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const |
10943 | memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } |
10944 | |
10945 | /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const |
10946 | memory::desc workspace_desc() const { |
10947 | return rnn_base::workspace_desc(); |
10948 | } |
10949 | |
10950 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const |
10951 | memory::desc diff_src_layer_desc() const { |
10952 | return rnn_base::diff_src_layer_desc(); |
10953 | } |
10954 | |
10955 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const |
10956 | memory::desc diff_src_iter_desc() const { |
10957 | return rnn_base::diff_src_iter_desc(); |
10958 | } |
10959 | |
10960 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const |
10961 | memory::desc diff_attention_desc() const { |
10962 | return rnn_base::diff_augru_attention_desc(); |
10963 | } |
10964 | |
10965 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const |
10966 | memory::desc diff_weights_layer_desc() const { |
10967 | return rnn_base::diff_weights_layer_desc(); |
10968 | } |
10969 | |
10970 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const |
10971 | memory::desc diff_weights_iter_desc() const { |
10972 | return rnn_base::diff_weights_iter_desc(); |
10973 | } |
10974 | |
10975 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const |
10976 | memory::desc diff_bias_desc() const { |
10977 | return rnn_base::diff_bias_desc(); |
10978 | } |
10979 | |
10980 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const |
10981 | memory::desc diff_dst_layer_desc() const { |
10982 | return rnn_base::diff_dst_layer_desc(); |
10983 | } |
10984 | |
10985 | /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const |
10986 | memory::desc diff_dst_iter_desc() const { |
10987 | return rnn_base::diff_dst_iter_desc(); |
10988 | } |
10989 | |
10990 | /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const |
10991 | algorithm get_cell_kind() const { return base::get_cell_kind(); } |
10992 | |
10993 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
10994 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
10995 | |
10996 | /// @copydoc dnnl::primitive_desc_base::get_direction()const |
10997 | rnn_direction get_direction() const { return base::get_direction(); } |
10998 | }; |
10999 | |
11000 | /// Default constructor. Produces an empty object. |
11001 | lbr_augru_backward() = default; |
11002 | |
11003 | /// Constructs an LBR AUGRU backward propagation primitive. |
11004 | /// @param pd Primitive descriptor for an LBR AUGRU backward propagation |
11005 | /// primitive. |
11006 | lbr_augru_backward(const primitive_desc &pd) : primitive(pd) {} |
11007 | |
11008 | /// Constructs an LBR AUGRU backward propagation primitive from a cache blob. |
11009 | /// @param pd Primitive descriptor for an LBR AUGRU backward propagation |
11010 | /// primitive. |
11011 | /// @param cache_blob Cache blob. |
11012 | lbr_augru_backward( |
11013 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11014 | : primitive(pd, cache_blob) {} |
11015 | }; |
11016 | |
11017 | /// @} dnnl_api_rnn |
11018 | |
11019 | /// @addtogroup dnnl_api_shuffle Shuffle |
11020 | /// |
11021 | /// A primitive to shuffle tensor data along an axis. |
11022 | /// |
11023 | /// @sa @ref dev_guide_shuffle in developer guide |
11024 | /// |
11025 | /// @{ |
11026 | |
11027 | /// Shuffle forward propagation primitive. |
11028 | struct shuffle_forward : public primitive { |
11029 | /// Primitive descriptor for a shuffle forward propagation primitive. |
11030 | struct primitive_desc : public dnnl::primitive_desc { |
11031 | /// Default constructor. Produces an empty object. |
11032 | primitive_desc() = default; |
11033 | |
11034 | /// Constructs a primitive descriptor for a shuffle forward propagation |
11035 | /// primitive. |
11036 | /// |
11037 | /// @param aengine Engine to use. |
11038 | /// @param aprop_kind Propagation kind. Possible values are |
11039 | /// #dnnl::prop_kind::forward_training, and |
11040 | /// #dnnl::prop_kind::forward_inference. |
11041 | /// @param src_desc Source memory descriptor. |
11042 | /// @param dst_desc Destination memory descriptor. |
11043 | /// @param axis The axis along which the data is shuffled. |
11044 | /// @param group_size Shuffle group size. |
11045 | /// @param attr Primitive attributes to use. Attributes are optional |
11046 | /// and default to empty attributes. |
11047 | /// @param allow_empty A flag signifying whether construction is |
11048 | /// allowed to fail without throwing an exception. In this case an |
11049 | /// empty object will be produced. This flag is optional and |
11050 | /// defaults to false. |
11051 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
11052 | const memory::desc &src_desc, const memory::desc &dst_desc, |
11053 | int axis, int group_size, |
11054 | const primitive_attr &attr = default_attr(), |
11055 | bool allow_empty = false) { |
11056 | |
11057 | dnnl_primitive_desc_t pd = nullptr; |
11058 | dnnl_status_t status = dnnl_shuffle_forward_primitive_desc_create( |
11059 | &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), |
11060 | src_desc.get(), dst_desc.get(), axis, group_size, |
11061 | attr.get()); |
11062 | |
11063 | if (!allow_empty) |
11064 | error::wrap_c_api(status, |
11065 | "could not create a primitive descriptor for a shuffle " |
11066 | "forward propagation primitive" ); |
11067 | reset(pd); |
11068 | } |
11069 | |
11070 | /// Constructs a primitive descriptor for a shuffle forward propagation |
11071 | /// primitive from a C API primitive descriptor that must have a |
11072 | /// matching kind. |
11073 | /// |
11074 | /// @param pd C API primitive descriptor for a shuffle forward |
11075 | /// propagation primitive. |
11076 | primitive_desc(dnnl_primitive_desc_t pd) |
11077 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle, |
11078 | dnnl::prop_kind::forward_training, |
11079 | dnnl::prop_kind::forward_inference) {} |
11080 | |
11081 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
11082 | memory::desc src_desc() const { return base::src_desc(0); } |
11083 | |
11084 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
11085 | memory::desc dst_desc() const { return base::dst_desc(0); } |
11086 | |
11087 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
11088 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
11089 | |
11090 | /// @copydoc dnnl::primitive_desc_base::get_axis()const |
11091 | int get_axis() const { return base::get_axis(); } |
11092 | |
11093 | /// @copydoc dnnl::primitive_desc_base::get_group_size()const |
11094 | memory::dim get_group_size() const { return base::get_group_size(); } |
11095 | }; |
11096 | |
11097 | /// Default constructor. Produces an empty object. |
11098 | shuffle_forward() = default; |
11099 | |
11100 | /// Constructs a shuffle forward propagation primitive. |
11101 | /// @param pd Primitive descriptor for a shuffle forward propagation |
11102 | /// primitive. |
11103 | shuffle_forward(const primitive_desc &pd) : primitive(pd) {} |
11104 | |
11105 | /// Constructs a shuffle forward propagation primitive from a cache blob. |
11106 | /// @param pd Primitive descriptor for a shuffle forward propagation |
11107 | /// primitive. |
11108 | /// @param cache_blob Cache blob. |
11109 | shuffle_forward( |
11110 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11111 | : primitive(pd, cache_blob) {} |
11112 | }; |
11113 | |
11114 | /// Shuffle backward propagation primitive. |
11115 | struct shuffle_backward : public primitive { |
11116 | /// Primitive descriptor for a shuffle backward propagation primitive. |
11117 | struct primitive_desc : public dnnl::primitive_desc { |
11118 | /// Default constructor. Produces an empty object. |
11119 | primitive_desc() = default; |
11120 | |
11121 | /// Constructs a primitive descriptor for a shuffle backward propagation |
11122 | /// primitive. |
11123 | /// |
11124 | /// @param aengine Engine to use. |
11125 | /// @param diff_src_desc Diff source memory descriptor. |
11126 | /// @param diff_dst_desc Diff destination memory descriptor. |
11127 | /// @param axis The axis along which the data is shuffled. |
11128 | /// @param group_size Shuffle group size. |
11129 | /// @param hint_fwd_pd Primitive descriptor for a shuffle forward |
11130 | /// propagation primitive. It is used as a hint for deciding which |
11131 | /// memory format to use. |
11132 | /// @param attr Primitive attributes to use. Attributes are optional |
11133 | /// and default to empty attributes. |
11134 | /// @param allow_empty A flag signifying whether construction is |
11135 | /// allowed to fail without throwing an exception. In this case an |
11136 | /// empty object will be produced. This flag is optional and |
11137 | /// defaults to false. |
11138 | primitive_desc(const engine &aengine, const memory::desc &diff_src_desc, |
11139 | const memory::desc &diff_dst_desc, int axis, int group_size, |
11140 | const shuffle_forward::primitive_desc &hint_fwd_pd, |
11141 | const primitive_attr &attr = default_attr(), |
11142 | bool allow_empty = false) { |
11143 | |
11144 | dnnl_primitive_desc_t pd = nullptr; |
11145 | dnnl_status_t status = dnnl_shuffle_backward_primitive_desc_create( |
11146 | &pd, aengine.get(), diff_src_desc.get(), |
11147 | diff_dst_desc.get(), axis, group_size, hint_fwd_pd.get(), |
11148 | attr.get()); |
11149 | |
11150 | if (!allow_empty) |
11151 | error::wrap_c_api(status, |
11152 | "could not create a primitive descriptor for a shuffle " |
11153 | "backward propagation primitive" ); |
11154 | reset(pd); |
11155 | } |
11156 | |
11157 | /// Constructs a primitive descriptor for a shuffle backward |
11158 | /// propagation primitive from a C API primitive descriptor that must |
11159 | /// have a matching kind. |
11160 | /// |
11161 | /// @param pd C API primitive descriptor for a shuffle backward |
11162 | /// propagation primitive. |
11163 | primitive_desc(dnnl_primitive_desc_t pd) |
11164 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle, |
11165 | dnnl::prop_kind::backward_data) {} |
11166 | |
11167 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
11168 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
11169 | |
11170 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
11171 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
11172 | |
11173 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
11174 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
11175 | |
11176 | /// @copydoc dnnl::primitive_desc_base::get_axis()const |
11177 | int get_axis() const { return base::get_axis(); } |
11178 | |
11179 | /// @copydoc dnnl::primitive_desc_base::get_group_size()const |
11180 | memory::dim get_group_size() const { return base::get_group_size(); } |
11181 | }; |
11182 | |
11183 | /// Default constructor. Produces an empty object. |
11184 | shuffle_backward() = default; |
11185 | |
11186 | /// Constructs a shuffle backward propagation primitive. |
11187 | /// @param pd Primitive descriptor for a shuffle backward propagation |
11188 | /// primitive. |
11189 | shuffle_backward(const primitive_desc &pd) : primitive(pd) {} |
11190 | |
11191 | /// Constructs a shuffle backward propagation primitive from a cache blob. |
11192 | /// @param pd Primitive descriptor for a shuffle backward propagation |
11193 | /// primitive. |
11194 | /// @param cache_blob Cache blob. |
11195 | shuffle_backward( |
11196 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11197 | : primitive(pd, cache_blob) {} |
11198 | }; |
11199 | |
11200 | /// @} dnnl_api_shuffle |
11201 | |
11202 | /// @addtogroup dnnl_api_binary Binary |
11203 | /// |
11204 | /// A primitive to perform tensor operations over two tensors. |
11205 | /// |
11206 | /// @sa @ref dev_guide_binary in developer guide |
11207 | /// |
11208 | /// @{ |
11209 | |
11210 | /// Elementwise binary operator primitive. |
11211 | struct binary : public primitive { |
11212 | /// Primitive descriptor for an elementwise binary operator primitive. |
11213 | struct primitive_desc : public dnnl::primitive_desc { |
11214 | /// Default constructor. Produces an empty object. |
11215 | primitive_desc() = default; |
11216 | |
11217 | /// Constructs a primitive descriptor for an elementwise binary operator |
11218 | /// primitive. |
11219 | /// |
11220 | /// @param aengine Engine to use. |
11221 | /// @param aalgorithm Elementwise binary algorithm. |
11222 | /// @param src0 Memory descriptor for source tensor #0. |
11223 | /// @param src1 Memory descriptor for source tensor #1. |
11224 | /// @param dst Memory descriptor for destination tensor. |
11225 | /// @param attr Primitive attributes to use. Attributes are optional |
11226 | /// and default to empty attributes. |
11227 | /// @param allow_empty A flag signifying whether construction is |
11228 | /// allowed to fail without throwing an exception. In this case an |
11229 | /// empty object will be produced. This flag is optional and |
11230 | /// defaults to false. |
11231 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
11232 | const memory::desc &src0, const memory::desc &src1, |
11233 | const memory::desc &dst, |
11234 | const primitive_attr &attr = default_attr(), |
11235 | bool allow_empty = false) { |
11236 | |
11237 | dnnl_primitive_desc_t pd = nullptr; |
11238 | dnnl_status_t status = dnnl_binary_primitive_desc_create(&pd, |
11239 | aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(), |
11240 | src1.get(), dst.get(), attr.get()); |
11241 | |
11242 | if (!allow_empty) |
11243 | error::wrap_c_api(status, |
11244 | "could not create a primitive descriptor for a binary " |
11245 | "operation primitive" ); |
11246 | reset(pd); |
11247 | } |
11248 | |
11249 | /// Constructs a primitive descriptor for a binary primitive from a C |
11250 | /// API primitive descriptor that must have a matching kind. |
11251 | /// |
11252 | /// @param pd C API primitive descriptor for a binary primitive. |
11253 | primitive_desc(dnnl_primitive_desc_t pd) |
11254 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {} |
11255 | |
11256 | /// @copydoc dnnl::primitive_desc_base::src_desc(int)const |
11257 | memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); } |
11258 | |
11259 | /// Returns the memory descriptor for source #0. |
11260 | memory::desc src0_desc() const { return base::src_desc(0); } |
11261 | |
11262 | /// Returns the memory descriptor for source #1. |
11263 | memory::desc src1_desc() const { return base::src_desc(1); } |
11264 | |
11265 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
11266 | memory::desc dst_desc() const { return base::dst_desc(0); } |
11267 | |
11268 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
11269 | algorithm get_algorithm() const { return base::get_algorithm(); } |
11270 | }; |
11271 | |
11272 | /// Default constructor. Produces an empty object. |
11273 | binary() = default; |
11274 | |
11275 | /// Constructs an elementwise binary operation primitive. |
11276 | /// @param pd Primitive descriptor for an elementwise binary operation |
11277 | /// primitive. |
11278 | binary(const primitive_desc &pd) : primitive(pd) {} |
11279 | |
11280 | /// Constructs an elementwise binary operation primitive from a cache blob. |
11281 | /// @param pd Primitive descriptor for an elementwise binary operation |
11282 | /// primitive. |
11283 | /// @param cache_blob Cache blob. |
11284 | binary(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11285 | : primitive(pd, cache_blob) {} |
11286 | }; |
11287 | |
11288 | /// @} dnnl_api_binary |
11289 | |
11290 | /// @addtogroup dnnl_api_matmul Matrix Multiplication |
11291 | /// |
11292 | /// A primitive to perform matrix-matrix multiplication. The batched mode |
11293 | /// is supported with 3D tensors. |
11294 | /// |
11295 | /// @sa @ref dev_guide_matmul in developer guide |
11296 | /// |
11297 | /// |
11298 | /// @{ |
11299 | |
11300 | /// Matrix multiplication (matmul) primitive. |
11301 | struct matmul : public primitive { |
11302 | /// Primitive descriptor for a matmul primitive. |
11303 | struct primitive_desc : public dnnl::primitive_desc { |
11304 | /// Default constructor. Produces an empty object. |
11305 | primitive_desc() = default; |
11306 | |
11307 | /// Constructs a primitive descriptor for a matmul primitive |
11308 | /// without bias. |
11309 | /// |
11310 | /// @param aengine Engine to use. |
11311 | /// @param src_desc Memory descriptor for source (matrix A). |
11312 | /// @param weights_desc Memory descriptor for weights (matrix B). |
11313 | /// @param dst_desc Memory descriptor for destination (matrix C). |
11314 | /// @param attr Primitive attributes to use. Attributes are optional |
11315 | /// and default to empty attributes. |
11316 | /// @param allow_empty A flag signifying whether construction is |
11317 | /// allowed to fail without throwing an exception. In this case an |
11318 | /// empty object will be produced. This flag is optional and |
11319 | /// defaults to false. |
11320 | primitive_desc(const engine &aengine, const memory::desc &src_desc, |
11321 | const memory::desc &weights_desc, const memory::desc &dst_desc, |
11322 | const primitive_attr &attr = default_attr(), |
11323 | bool allow_empty = false) |
11324 | : primitive_desc(aengine, src_desc, weights_desc, nullptr, dst_desc, |
11325 | attr, allow_empty) {} |
11326 | |
11327 | /// Constructs a primitive descriptor for a matmul primitive with bias. |
11328 | /// |
11329 | /// @param aengine Engine to use. |
11330 | /// @param src_desc Memory descriptor for source (matrix A). |
11331 | /// @param weights_desc Memory descriptor for weights (matrix B). |
11332 | /// @param dst_desc Memory descriptor for destination (matrix C). |
11333 | /// @param bias_desc Memory descriptor for bias. |
11334 | /// @param attr Primitive attributes to use. Attributes are optional |
11335 | /// and default to empty attributes. |
11336 | /// @param allow_empty A flag signifying whether construction is |
11337 | /// allowed to fail without throwing an exception. In this case an |
11338 | /// empty object will be produced. This flag is optional and |
11339 | /// defaults to false. |
11340 | primitive_desc(const engine &aengine, const memory::desc &src_desc, |
11341 | const memory::desc &weights_desc, const memory::desc &bias_desc, |
11342 | const memory::desc &dst_desc, |
11343 | const primitive_attr &attr = default_attr(), |
11344 | bool allow_empty = false) |
11345 | : primitive_desc(aengine, src_desc, weights_desc, &bias_desc, |
11346 | dst_desc, attr, allow_empty) {} |
11347 | |
11348 | /// Constructs a primitive descriptor for a matmul primitive from a C |
11349 | /// API primitive descriptor that must have a matching kind. |
11350 | /// |
11351 | /// @param pd C API primitive descriptor for a matmul primitive. |
11352 | primitive_desc(dnnl_primitive_desc_t pd) |
11353 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {} |
11354 | |
11355 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
11356 | memory::desc src_desc() const { return query_md(query::src_md, 0); } |
11357 | |
11358 | /// @copydoc dnnl::primitive_desc_base::weights_desc()const |
11359 | memory::desc weights_desc() const { |
11360 | return query_md(query::weights_md, 0); |
11361 | } |
11362 | |
11363 | /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const |
11364 | memory::desc bias_desc() const { |
11365 | return query_md(query::weights_md, 1); |
11366 | } |
11367 | |
11368 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
11369 | memory::desc dst_desc() const { return query_md(query::dst_md, 0); } |
11370 | |
11371 | private: |
11372 | primitive_desc(const engine &aengine, const memory::desc &src_desc, |
11373 | const memory::desc &weights_desc, const memory::desc *bias_desc, |
11374 | const memory::desc &dst_desc, const primitive_attr &attr, |
11375 | bool allow_empty) { |
11376 | |
11377 | dnnl_primitive_desc_t pd = nullptr; |
11378 | dnnl_status_t status = dnnl_matmul_primitive_desc_create(&pd, |
11379 | aengine.get(), src_desc.get(), weights_desc.get(), |
11380 | optional_arg(bias_desc), dst_desc.get(), attr.get()); |
11381 | |
11382 | if (!allow_empty) |
11383 | error::wrap_c_api(status, |
11384 | "could not create a primitive descriptor for a matmul " |
11385 | "primitive" ); |
11386 | reset(pd); |
11387 | } |
11388 | }; |
11389 | |
11390 | /// Default constructor. Produces an empty object. |
11391 | matmul() = default; |
11392 | |
11393 | /// Constructs a matmul primitive. |
11394 | /// @param pd Primitive descriptor for a matmul primitive. |
11395 | matmul(const primitive_desc &pd) : primitive(pd) {} |
11396 | |
11397 | /// Constructs a matmul primitive from a cache blob. |
11398 | /// @param pd Primitive descriptor for a matmul primitive. |
11399 | /// @param cache_blob Cache blob. |
11400 | matmul(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11401 | : primitive(pd, cache_blob) {} |
11402 | }; |
11403 | |
11404 | /// @} dnnl_api_matmul |
11405 | |
11406 | /// @addtogroup dnnl_api_resampling Resampling |
11407 | /// |
11408 | /// A primitive to compute resampling operation on 1D, 2D or 3D data tensor |
11409 | /// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation |
11410 | /// method. |
11411 | /// |
11412 | /// @sa @ref dev_guide_resampling in developer guide |
11413 | /// |
11414 | /// @{ |
11415 | |
11416 | /// Resampling forward propagation. |
11417 | struct resampling_forward : public primitive { |
11418 | /// Primitive descriptor for a resampling forward propagation primitive. |
11419 | struct primitive_desc : public dnnl::primitive_desc { |
11420 | /// Default constructor. Produces an empty object. |
11421 | primitive_desc() = default; |
11422 | |
11423 | /// Constructs a primitive descriptor for a resampling forward |
11424 | /// propagation primitive using source and destination memory |
11425 | /// descriptors. |
11426 | /// |
11427 | /// @note |
11428 | /// Destination memory descriptor may be initialized with |
11429 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
11430 | /// |
11431 | /// @param aengine Engine to use. |
11432 | /// @param aprop_kind Propagation kind. Possible values are |
11433 | /// #dnnl::prop_kind::forward_training, and |
11434 | /// #dnnl::prop_kind::forward_inference. |
11435 | /// @param aalgorithm resampling algorithm kind: either |
11436 | /// #dnnl::algorithm::resampling_nearest, or |
11437 | /// #dnnl::algorithm::resampling_linear |
11438 | /// @param src_desc Source memory descriptor. |
11439 | /// @param dst_desc Destination memory descriptor. |
11440 | /// @param attr Primitive attributes to use. Attributes are optional |
11441 | /// and default to empty attributes. |
11442 | /// @param allow_empty A flag signifying whether construction is |
11443 | /// allowed to fail without throwing an exception. In this case an |
11444 | /// empty object will be produced. This flag is optional and |
11445 | /// defaults to false. |
11446 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
11447 | algorithm aalgorithm, const memory::desc &src_desc, |
11448 | const memory::desc &dst_desc, |
11449 | const primitive_attr &attr = default_attr(), |
11450 | bool allow_empty = false) |
11451 | : primitive_desc(aengine, aprop_kind, aalgorithm, nullptr, src_desc, |
11452 | &dst_desc, attr, allow_empty) {} |
11453 | |
11454 | /// Constructs a primitive descriptor for a resampling forward |
11455 | /// propagation primitive using source memory descriptor and |
11456 | /// factors. |
11457 | /// |
11458 | /// @param aengine Engine to use. |
11459 | /// @param aprop_kind Propagation kind. Possible values are |
11460 | /// #dnnl::prop_kind::forward_training, and |
11461 | /// #dnnl::prop_kind::forward_inference. |
11462 | /// @param aalgorithm resampling algorithm kind: either |
11463 | /// #dnnl::algorithm::resampling_nearest, or |
11464 | /// #dnnl::algorithm::resampling_linear |
11465 | /// @param factors Vector of scaling factors for spatial dimension. |
11466 | /// @param src_desc Source memory descriptor. |
11467 | /// @param attr Primitive attributes to use. Attributes are optional |
11468 | /// and default to empty attributes. |
11469 | /// @param allow_empty A flag signifying whether construction is |
11470 | /// allowed to fail without throwing an exception. In this case an |
11471 | /// empty object will be produced. This flag is optional and |
11472 | /// defaults to false. |
11473 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
11474 | algorithm aalgorithm, const std::vector<float> &factors, |
11475 | const memory::desc &src_desc, |
11476 | const primitive_attr &attr = default_attr(), |
11477 | bool allow_empty = false) |
11478 | : primitive_desc(aengine, aprop_kind, aalgorithm, &factors, |
11479 | src_desc, nullptr, attr, allow_empty) {} |
11480 | |
11481 | /// Constructs a primitive descriptor for a resampling forward |
11482 | /// propagation primitive. |
11483 | /// |
11484 | /// @note |
11485 | /// The destination memory descriptor may be initialized with |
11486 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
11487 | /// |
11488 | /// @param aengine Engine to use. |
11489 | /// @param aprop_kind Propagation kind. Possible values are |
11490 | /// #dnnl::prop_kind::forward_training, and |
11491 | /// #dnnl::prop_kind::forward_inference. |
11492 | /// @param aalgorithm resampling algorithm kind: either |
11493 | /// #dnnl::algorithm::resampling_nearest, or |
11494 | /// #dnnl::algorithm::resampling_linear |
11495 | /// @param factors Vector of scaling factors for spatial dimension. |
11496 | /// @param src_desc Source memory descriptor. |
11497 | /// @param dst_desc Destination memory descriptor. |
11498 | /// @param attr Primitive attributes to use. Attributes are optional |
11499 | /// and default to empty attributes. |
11500 | /// @param allow_empty A flag signifying whether construction is |
11501 | /// allowed to fail without throwing an exception. In this case an |
11502 | /// empty object will be produced. This flag is optional and |
11503 | /// defaults to false. |
11504 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
11505 | algorithm aalgorithm, const std::vector<float> &factors, |
11506 | const memory::desc &src_desc, const memory::desc &dst_desc, |
11507 | const primitive_attr &attr = default_attr(), |
11508 | bool allow_empty = false) |
11509 | : primitive_desc(aengine, aprop_kind, aalgorithm, &factors, |
11510 | src_desc, &dst_desc, attr, allow_empty) {} |
11511 | |
11512 | /// Constructs a primitive descriptor for a resampling forward |
11513 | /// propagation primitive from a C API primitive descriptor that must |
11514 | /// have a matching kind. |
11515 | /// |
11516 | /// @param pd C API primitive descriptor for a resampling forward |
11517 | /// propagation primitive. |
11518 | primitive_desc(dnnl_primitive_desc_t pd) |
11519 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling, |
11520 | dnnl::prop_kind::forward_training, |
11521 | dnnl::prop_kind::forward_inference) {} |
11522 | |
11523 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
11524 | memory::desc src_desc() const { return base::src_desc(0); } |
11525 | |
11526 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
11527 | memory::desc dst_desc() const { return base::dst_desc(0); } |
11528 | |
11529 | private: |
11530 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
11531 | algorithm aalgorithm, const std::vector<float> *factors, |
11532 | const memory::desc &src_desc, const memory::desc *dst_desc, |
11533 | const primitive_attr &attr, bool allow_empty) { |
11534 | |
11535 | if (factors) |
11536 | memory::validate_dims(*factors, src_desc.get_ndims() - 2); |
11537 | |
11538 | dnnl_primitive_desc_t pd = nullptr; |
11539 | dnnl_status_t status |
11540 | = dnnl_resampling_forward_primitive_desc_create(&pd, |
11541 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
11542 | convert_to_c(aalgorithm), optional_arg(factors), |
11543 | src_desc.get(), optional_arg(dst_desc), attr.get()); |
11544 | |
11545 | if (!allow_empty) |
11546 | error::wrap_c_api(status, |
11547 | "could not create a primitive descriptor for a " |
11548 | "resampling forward propagation primitive" ); |
11549 | reset(pd); |
11550 | } |
11551 | }; |
11552 | |
11553 | /// Default constructor. Produces an empty object. |
11554 | resampling_forward() = default; |
11555 | |
11556 | /// Constructs a resampling forward propagation primitive. |
11557 | /// @param pd Primitive descriptor for a resampling forward propagation |
11558 | /// primitive. |
11559 | resampling_forward(const primitive_desc &pd) : primitive(pd) {} |
11560 | |
11561 | /// Constructs a resampling forward propagation primitive from a cache |
11562 | /// blob. |
11563 | /// @param pd Primitive descriptor for a resampling forward propagation |
11564 | /// primitive. |
11565 | /// @param cache_blob Cache blob. |
11566 | resampling_forward( |
11567 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11568 | : primitive(pd, cache_blob) {} |
11569 | }; |
11570 | |
11571 | /// Resampling backward propagation primitive. |
11572 | struct resampling_backward : public primitive { |
11573 | /// Primitive descriptor for resampling backward propagation primitive. |
11574 | struct primitive_desc : public dnnl::primitive_desc { |
11575 | /// Default constructor. Produces an empty object. |
11576 | primitive_desc() = default; |
11577 | |
11578 | /// Constructs a primitive descriptor for a resampling backward |
11579 | /// propagation primitive using source and destination memory |
11580 | /// descriptors. |
11581 | /// |
11582 | /// @param aengine Engine to use. |
11583 | /// @param aalgorithm resampling algorithm kind: either |
11584 | /// #dnnl::algorithm::resampling_nearest, or |
11585 | /// #dnnl::algorithm::resampling_linear |
11586 | /// @param diff_src_desc Diff source memory descriptor. |
11587 | /// @param diff_dst_desc Diff destination memory descriptor. |
11588 | /// @param hint_fwd_pd Primitive descriptor for a resampling |
11589 | /// forward propagation primitive. It is used as a hint for |
11590 | /// deciding which memory format to use. |
11591 | /// @param attr Primitive attributes to use. Attributes are optional |
11592 | /// and default to empty attributes. |
11593 | /// @param allow_empty A flag signifying whether construction is |
11594 | /// allowed to fail without throwing an exception. In this case an |
11595 | /// empty object will be produced. This flag is optional and |
11596 | /// defaults to false. |
11597 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
11598 | const memory::desc &diff_src_desc, |
11599 | const memory::desc &diff_dst_desc, |
11600 | const resampling_forward::primitive_desc &hint_fwd_pd, |
11601 | const primitive_attr &attr = default_attr(), |
11602 | bool allow_empty = false) |
11603 | : primitive_desc(aengine, aalgorithm, nullptr, diff_src_desc, |
11604 | diff_dst_desc, hint_fwd_pd, attr, allow_empty) {} |
11605 | |
11606 | /// Constructs a primitive descriptor for resampling backward |
11607 | /// propagation primitive. |
11608 | /// |
11609 | /// @param aengine Engine to use. |
11610 | /// @param aalgorithm resampling algorithm kind: either |
11611 | /// #dnnl::algorithm::resampling_nearest, or |
11612 | /// #dnnl::algorithm::resampling_linear |
11613 | /// @param factors Vector of scaling factors for spatial dimension. |
11614 | /// @param diff_src_desc Diff source memory descriptor. |
11615 | /// @param diff_dst_desc Diff destination memory descriptor. |
11616 | /// @param hint_fwd_pd Primitive descriptor for a resampling |
11617 | /// forward propagation primitive. It is used as a hint for |
11618 | /// deciding which memory format to use. |
11619 | /// @param attr Primitive attributes to use. Attributes are optional |
11620 | /// and default to empty attributes. |
11621 | /// @param allow_empty A flag signifying whether construction is |
11622 | /// allowed to fail without throwing an exception. In this case an |
11623 | /// empty object will be produced. This flag is optional and |
11624 | /// defaults to false. |
11625 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
11626 | const std::vector<float> &factors, |
11627 | const memory::desc &diff_src_desc, |
11628 | const memory::desc &diff_dst_desc, |
11629 | const resampling_forward::primitive_desc &hint_fwd_pd, |
11630 | const primitive_attr &attr = default_attr(), |
11631 | bool allow_empty = false) |
11632 | : primitive_desc(aengine, aalgorithm, &factors, diff_src_desc, |
11633 | diff_dst_desc, hint_fwd_pd, attr, allow_empty) {} |
11634 | |
11635 | /// Constructs a primitive descriptor for a resampling backward |
11636 | /// propagation primitive from a C API primitive descriptor that must |
11637 | /// have a matching kind. |
11638 | /// |
11639 | /// @param pd C API primitive descriptor for a resampling backward |
11640 | /// propagation primitive. |
11641 | primitive_desc(dnnl_primitive_desc_t pd) |
11642 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling, |
11643 | dnnl::prop_kind::backward_data) {} |
11644 | |
11645 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
11646 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
11647 | |
11648 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
11649 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
11650 | |
11651 | private: |
11652 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
11653 | const std::vector<float> *factors, |
11654 | const memory::desc &diff_src_desc, |
11655 | const memory::desc &diff_dst_desc, |
11656 | const resampling_forward::primitive_desc &hint_fwd_pd, |
11657 | const primitive_attr &attr, bool allow_empty) { |
11658 | |
11659 | if (factors) |
11660 | memory::validate_dims(*factors, diff_src_desc.get_ndims() - 2); |
11661 | |
11662 | dnnl_primitive_desc_t pd = nullptr; |
11663 | dnnl_status_t status |
11664 | = dnnl_resampling_backward_primitive_desc_create(&pd, |
11665 | aengine.get(), convert_to_c(aalgorithm), |
11666 | optional_arg(factors), diff_src_desc.get(), |
11667 | diff_dst_desc.get(), hint_fwd_pd.get(), attr.get()); |
11668 | |
11669 | if (!allow_empty) |
11670 | error::wrap_c_api(status, |
11671 | "could not create a primitive descriptor for a " |
11672 | "resampling backward propagation primitive" ); |
11673 | reset(pd); |
11674 | } |
11675 | }; |
11676 | |
11677 | /// Default constructor. Produces an empty object. |
11678 | resampling_backward() = default; |
11679 | |
11680 | /// Constructs a resampling backward propagation primitive. |
11681 | /// @param pd Primitive descriptor for a resampling backward propagation |
11682 | /// primitive. |
11683 | resampling_backward(const primitive_desc &pd) : primitive(pd) {} |
11684 | |
11685 | /// Constructs a resampling backward propagation primitive from a cache |
11686 | /// blob. |
11687 | /// @param pd Primitive descriptor for a resampling backward propagation |
11688 | /// primitive. |
11689 | /// @param cache_blob Cache blob. |
11690 | resampling_backward( |
11691 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11692 | : primitive(pd, cache_blob) {} |
11693 | }; |
11694 | |
11695 | /// @} dnnl_api_resampling |
11696 | |
11697 | /// @addtogroup dnnl_api_pooling Pooling |
11698 | /// |
11699 | /// A primitive to perform max or average pooling with dilation. |
11700 | /// |
11701 | /// @sa @ref dev_guide_pooling in developer guide |
11702 | /// |
11703 | /// @{ |
11704 | |
11705 | /// Pooling forward propagation primitive. |
11706 | struct pooling_forward : public primitive { |
11707 | /// Primitive descriptor for a pooling forward propagation primitive. |
11708 | struct primitive_desc : public dnnl::primitive_desc { |
11709 | /// Default constructor. Produces an empty object. |
11710 | primitive_desc() = default; |
11711 | |
11712 | /// Constructs a primitive descriptor for pooling forward propagation |
11713 | /// primitive. |
11714 | /// |
11715 | /// Arrays @p strides, @p kernel, @p dilation, @p padding_l |
11716 | /// and @p padding_r contain values for spatial dimensions only and |
11717 | /// hence must have the same number of elements as there are spatial |
11718 | /// dimensions. The order of values is the same as in the tensor: |
11719 | /// depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
11720 | /// |
11721 | /// @param aengine Engine to use. |
11722 | /// @param aprop_kind Propagation kind. Possible values are |
11723 | /// #dnnl::prop_kind::forward_training, and |
11724 | /// #dnnl::prop_kind::forward_inference. |
11725 | /// @param aalgorithm Pooling algorithm kind: either |
11726 | /// #dnnl::algorithm::pooling_max, |
11727 | /// #dnnl::algorithm::pooling_avg_include_padding, |
11728 | /// or #dnnl::algorithm::pooling_avg_exclude_padding. |
11729 | /// @param src_desc Source memory descriptor. |
11730 | /// @param dst_desc Destination memory descriptor. |
11731 | /// @param strides Vector of strides for spatial dimension. |
11732 | /// @param kernel Vector of kernel spatial dimensions. |
11733 | /// @param dilation Array of dilations for spatial dimension. |
11734 | /// @param padding_l Vector of padding values for low indices for each |
11735 | /// spatial dimension `([[front,] top,] left)`. |
11736 | /// @param padding_r Vector of padding values for high indices for |
11737 | /// each spatial dimension `([[back,] bottom,] right)`. |
11738 | /// @param attr Primitive attributes to use. Attributes are optional |
11739 | /// and default to empty attributes. |
11740 | /// @param allow_empty A flag signifying whether construction is |
11741 | /// allowed to fail without throwing an exception. In this case an |
11742 | /// empty object will be produced. This flag is optional and |
11743 | /// defaults to false. |
11744 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
11745 | algorithm aalgorithm, const memory::desc &src_desc, |
11746 | const memory::desc &dst_desc, const memory::dims &strides, |
11747 | const memory::dims &kernel, const memory::dims &dilation, |
11748 | const memory::dims &padding_l, const memory::dims &padding_r, |
11749 | const primitive_attr &attr = default_attr(), |
11750 | bool allow_empty = false) { |
11751 | |
11752 | memory::validate_dims(strides, src_desc.get_ndims() - 2); |
11753 | memory::validate_dims(kernel, src_desc.get_ndims() - 2); |
11754 | memory::validate_dims(padding_l, src_desc.get_ndims() - 2); |
11755 | memory::validate_dims(padding_r, src_desc.get_ndims() - 2); |
11756 | memory::validate_dims(dilation, src_desc.get_ndims() - 2); |
11757 | |
11758 | dnnl_primitive_desc_t pd = nullptr; |
11759 | dnnl_status_t status = dnnl_pooling_forward_primitive_desc_create( |
11760 | &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), |
11761 | convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(), |
11762 | &strides[0], &kernel[0], &dilation[0], &padding_l[0], |
11763 | &padding_r[0], attr.get()); |
11764 | |
11765 | if (!allow_empty) |
11766 | error::wrap_c_api(status, |
11767 | "could not create a descriptor for a pooling forward " |
11768 | "propagation primitive" ); |
11769 | reset(pd); |
11770 | } |
11771 | |
11772 | /// Constructs a primitive descriptor for a pooling forward propagation |
11773 | /// primitive from a C API primitive descriptor that must have a |
11774 | /// matching kind. |
11775 | /// |
11776 | /// @param pd C API primitive descriptor for a pooling forward |
11777 | /// propagation primitive. |
11778 | primitive_desc(dnnl_primitive_desc_t pd) |
11779 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling, |
11780 | dnnl::prop_kind::forward_training, |
11781 | dnnl::prop_kind::forward_inference) {} |
11782 | |
11783 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
11784 | memory::desc src_desc() const { return base::src_desc(0); } |
11785 | |
11786 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
11787 | memory::desc dst_desc() const { return base::dst_desc(0); } |
11788 | |
11789 | /// @copydoc dnnl::primitive_desc_base::workspace_desc()const |
11790 | memory::desc workspace_desc() const { return base::workspace_desc(); } |
11791 | |
11792 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
11793 | algorithm get_algorithm() const { return base::get_algorithm(); } |
11794 | |
11795 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
11796 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
11797 | |
11798 | /// @copydoc dnnl::primitive_desc_base::get_strides()const |
11799 | memory::dims get_strides() const { return base::get_strides(); } |
11800 | |
11801 | /// @copydoc dnnl::primitive_desc_base::get_kernel()const |
11802 | memory::dims get_kernel() const { return base::get_kernel(); } |
11803 | |
11804 | /// @copydoc dnnl::primitive_desc_base::get_dilations()const |
11805 | memory::dims get_dilations() const { return base::get_dilations(); } |
11806 | |
11807 | /// @copydoc dnnl::primitive_desc_base::get_padding_l()const |
11808 | memory::dims get_padding_l() const { return base::get_padding_l(); } |
11809 | |
11810 | /// @copydoc dnnl::primitive_desc_base::get_padding_r()const |
11811 | memory::dims get_padding_r() const { return base::get_padding_r(); } |
11812 | }; |
11813 | |
11814 | /// Default constructor. Produces an empty object. |
11815 | pooling_forward() = default; |
11816 | |
11817 | /// Constructs a pooling forward propagation primitive. |
11818 | /// |
11819 | /// @param pd Primitive descriptor for a pooling forward propagation |
11820 | /// primitive. |
11821 | pooling_forward(const primitive_desc &pd) : primitive(pd) {} |
11822 | |
11823 | /// Constructs a pooling forward propagation primitive from a cache blob. |
11824 | /// |
11825 | /// @param pd Primitive descriptor for a pooling forward propagation |
11826 | /// primitive. |
11827 | /// @param cache_blob Cache blob. |
11828 | pooling_forward( |
11829 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11830 | : primitive(pd, cache_blob) {} |
11831 | }; |
11832 | |
11833 | /// Pooling backward propagation primitive. |
11834 | struct pooling_backward : public primitive { |
11835 | /// Primitive descriptor for a pooling backward propagation primitive. |
11836 | struct primitive_desc : public dnnl::primitive_desc { |
11837 | /// Default constructor. Produces an empty object. |
11838 | primitive_desc() = default; |
11839 | |
11840 | /// Constructs a primitive descriptor for a pooling backward propagation |
11841 | /// primitive. |
11842 | /// |
11843 | /// Arrays @p strides, @p kernel, @p dilation, @p padding_l |
11844 | /// and @p padding_r contain values for spatial dimensions only and |
11845 | /// hence must have the same number of elements as there are spatial |
11846 | /// dimensions. The order of values is the same as in the tensor: |
11847 | /// depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
11848 | /// |
11849 | /// @param aengine Engine to use. |
11850 | /// @param aalgorithm Pooling algorithm kind: either |
11851 | /// #dnnl::algorithm::pooling_max, |
11852 | /// #dnnl::algorithm::pooling_avg_include_padding, |
11853 | /// or #dnnl::algorithm::pooling_avg_exclude_padding. |
11854 | /// @param diff_src_desc Diff source memory descriptor. |
11855 | /// @param diff_dst_desc Diff destination memory descriptor. |
11856 | /// @param strides Vector of strides for spatial dimension. |
11857 | /// @param kernel Vector of kernel spatial dimensions. |
11858 | /// @param dilation Array of dilations for spatial dimension. |
11859 | /// @param padding_l Vector of padding values for low indices for each |
11860 | /// spatial dimension `([[front,] top,] left)`. |
11861 | /// @param padding_r Vector of padding values for high indices for |
11862 | /// each spatial dimension `([[back,] bottom,] right)`. |
11863 | /// @param hint_fwd_pd Primitive descriptor for a pooling |
11864 | /// forward propagation primitive. It is used as a hint for |
11865 | /// deciding which memory format to use. |
11866 | /// @param attr Primitive attributes to use. Attributes are optional |
11867 | /// and default to empty attributes. |
11868 | /// @param allow_empty A flag signifying whether construction is |
11869 | /// allowed to fail without throwing an exception. In this case an |
11870 | /// empty object will be produced. This flag is optional and |
11871 | /// defaults to false. |
11872 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
11873 | const memory::desc &diff_src_desc, |
11874 | const memory::desc &diff_dst_desc, const memory::dims &strides, |
11875 | const memory::dims &kernel, const memory::dims &dilation, |
11876 | const memory::dims &padding_l, const memory::dims &padding_r, |
11877 | const pooling_forward::primitive_desc &hint_fwd_pd, |
11878 | const primitive_attr &attr = default_attr(), |
11879 | bool allow_empty = false) { |
11880 | |
11881 | memory::validate_dims(strides, diff_src_desc.get_ndims() - 2); |
11882 | memory::validate_dims(kernel, diff_src_desc.get_ndims() - 2); |
11883 | memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2); |
11884 | memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2); |
11885 | memory::validate_dims(dilation, diff_src_desc.get_ndims() - 2); |
11886 | |
11887 | dnnl_primitive_desc_t pd = nullptr; |
11888 | dnnl_status_t status = dnnl_pooling_backward_primitive_desc_create( |
11889 | &pd, aengine.get(), convert_to_c(aalgorithm), |
11890 | diff_src_desc.get(), diff_dst_desc.get(), &strides[0], |
11891 | &kernel[0], &dilation[0], &padding_l[0], &padding_r[0], |
11892 | hint_fwd_pd.get(), attr.get()); |
11893 | if (!allow_empty) |
11894 | error::wrap_c_api(status, |
11895 | "could not create a descriptor for a pooling backward " |
11896 | "propagation primitive" ); |
11897 | reset(pd); |
11898 | } |
11899 | |
11900 | /// Constructs a primitive descriptor for a pooling backward propagation |
11901 | /// primitive from a C API primitive descriptor that must have a |
11902 | /// matching kind. |
11903 | /// |
11904 | /// @param pd C API primitive descriptor for a pooling backward |
11905 | /// propagation primitive. |
11906 | primitive_desc(dnnl_primitive_desc_t pd) |
11907 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling, |
11908 | dnnl::prop_kind::backward_data) {} |
11909 | |
11910 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
11911 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
11912 | |
11913 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
11914 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
11915 | |
11916 | /// @copydoc dnnl::primitive_desc_base::workspace_desc()const |
11917 | memory::desc workspace_desc() const { return base::workspace_desc(); } |
11918 | |
11919 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
11920 | algorithm get_algorithm() const { return base::get_algorithm(); } |
11921 | |
11922 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
11923 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
11924 | |
11925 | /// @copydoc dnnl::primitive_desc_base::get_strides()const |
11926 | memory::dims get_strides() const { return base::get_strides(); } |
11927 | |
11928 | /// @copydoc dnnl::primitive_desc_base::get_kernel()const |
11929 | memory::dims get_kernel() const { return base::get_kernel(); } |
11930 | |
11931 | /// @copydoc dnnl::primitive_desc_base::get_dilations()const |
11932 | memory::dims get_dilations() const { return base::get_dilations(); } |
11933 | |
11934 | /// @copydoc dnnl::primitive_desc_base::get_padding_l()const |
11935 | memory::dims get_padding_l() const { return base::get_padding_l(); } |
11936 | |
11937 | /// @copydoc dnnl::primitive_desc_base::get_padding_r()const |
11938 | memory::dims get_padding_r() const { return base::get_padding_r(); } |
11939 | }; |
11940 | |
11941 | /// Default constructor. Produces an empty object. |
11942 | pooling_backward() = default; |
11943 | |
11944 | /// Constructs a pooling backward propagation primitive. |
11945 | /// |
11946 | /// @param pd Primitive descriptor for a pooling backward propagation |
11947 | /// primitive. |
11948 | pooling_backward(const primitive_desc &pd) : primitive(pd) {} |
11949 | |
11950 | /// Constructs a pooling backward propagation primitive from a cache blob. |
11951 | /// |
11952 | /// @param pd Primitive descriptor for a pooling backward propagation |
11953 | /// primitive. |
11954 | /// @param cache_blob Cache blob. |
11955 | pooling_backward( |
11956 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
11957 | : primitive(pd, cache_blob) {} |
11958 | }; |
11959 | |
11960 | /// @} dnnl_api_pooling |
11961 | |
11962 | /// @addtogroup dnnl_api_prelu PReLU |
11963 | /// |
11964 | /// PReLU primitive |
11965 | /// A primitive to perform PReLU (leaky ReLU with trainable alpha parameter) |
11966 | /// |
11967 | /// @sa @ref dev_guide_prelu in developer guide |
11968 | /// |
11969 | /// @{ |
11970 | |
11971 | /// PReLU forward propagation primitive. |
11972 | struct prelu_forward : public primitive { |
11973 | /// Primitive descriptor for a PReLU forward propagation primitive. |
11974 | struct primitive_desc : public dnnl::primitive_desc { |
11975 | /// Default constructor. Produces an empty object. |
11976 | primitive_desc() = default; |
11977 | |
11978 | /// Constructs a primitive descriptor for a PReLU forward propagation |
11979 | /// primitive. |
11980 | /// |
11981 | /// @param aengine Engine to use. |
11982 | /// @param aprop_kind Propagation kind. Possible values are |
11983 | /// #dnnl::prop_kind::forward_training, and |
11984 | /// #dnnl::prop_kind::forward_inference. |
11985 | /// @param src_desc Source memory descriptor. |
11986 | /// @param weight_desc Alpha parameters memory descriptor. |
11987 | /// @param dst_desc Destination memory descriptor. |
11988 | /// @param attr Primitive attributes to use. Attributes are optional |
11989 | /// and default to empty attributes. |
11990 | /// @param allow_empty A flag signifying whether construction is |
11991 | /// allowed to fail without throwing an exception. In this case an |
11992 | /// empty object will be produced. This flag is optional and |
11993 | /// defaults to false. |
11994 | primitive_desc(const engine &aengine, prop_kind aprop_kind, |
11995 | const memory::desc &src_desc, const memory::desc &weight_desc, |
11996 | const memory::desc &dst_desc, |
11997 | const primitive_attr &attr = default_attr(), |
11998 | bool allow_empty = false) { |
11999 | |
12000 | dnnl_primitive_desc_t pd = nullptr; |
12001 | dnnl_status_t status = dnnl_prelu_forward_primitive_desc_create(&pd, |
12002 | aengine.get(), dnnl::convert_to_c(aprop_kind), |
12003 | src_desc.get(), weight_desc.get(), dst_desc.get(), |
12004 | attr.get()); |
12005 | |
12006 | if (!allow_empty) |
12007 | error::wrap_c_api(status, |
12008 | "could not create a primitive descriptor for a prelu " |
12009 | "forward propagation primitive" ); |
12010 | reset(pd); |
12011 | } |
12012 | |
12013 | /// Constructs a primitive descriptor for a prelu forward |
12014 | /// propagation primitive from a C API primitive descriptor that must |
12015 | /// have a matching kind. |
12016 | /// |
12017 | /// @param pd C API primitive descriptor for a prelu forward |
12018 | /// propagation primitive. |
12019 | primitive_desc(dnnl_primitive_desc_t pd) |
12020 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu, |
12021 | dnnl::prop_kind::forward_training, |
12022 | dnnl::prop_kind::forward_inference) {} |
12023 | |
12024 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
12025 | memory::desc src_desc() const { return base::src_desc(0); } |
12026 | |
12027 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
12028 | memory::desc dst_desc() const { return base::dst_desc(0); } |
12029 | |
12030 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
12031 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
12032 | }; |
12033 | |
12034 | /// Default constructor. Produces an empty object. |
12035 | prelu_forward() = default; |
12036 | |
12037 | /// Constructs a prelu forward propagation primitive. |
12038 | /// @param pd Primitive descriptor for a prelu forward propagation |
12039 | /// primitive. |
12040 | prelu_forward(const primitive_desc &pd) : primitive(pd) {} |
12041 | |
12042 | /// Constructs a prelu forward propagation primitive from a cache blob. |
12043 | /// @param pd Primitive descriptor for a prelu forward propagation |
12044 | /// primitive. |
12045 | /// @param cache_blob Cache blob. |
12046 | prelu_forward( |
12047 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
12048 | : primitive(pd, cache_blob) {} |
12049 | }; |
12050 | |
12051 | /// PReLU backward propagation primitive. |
12052 | struct prelu_backward : public primitive { |
12053 | /// Primitive descriptor for prelu backward propagation. |
12054 | struct primitive_desc : public dnnl::primitive_desc { |
12055 | /// Default constructor. Produces an empty object. |
12056 | primitive_desc() = default; |
12057 | |
12058 | /// Constructs a descriptor for a PReLU backward propagation |
12059 | /// primitive. |
12060 | /// |
12061 | /// @param aengine Engine to use. |
12062 | /// @param src_desc Source memory descriptor. |
12063 | /// @param weight_desc Alpha parameters memory descriptor. |
12064 | /// @param diff_src_desc Diff source memory descriptor. |
12065 | /// @param diff_weights_desc Diff alpha parameters memory descriptor. |
12066 | /// @param diff_dst_desc Diff destination memory descriptor. |
12067 | /// @param hint_fwd_pd Primitive descriptor for a PReLU |
12068 | /// forward propagation primitive. It is used as a hint for |
12069 | /// deciding which memory format to use. |
12070 | /// @param attr Primitive attributes to use. Attributes are optional |
12071 | /// and default to empty attributes. |
12072 | /// @param allow_empty A flag signifying whether construction is |
12073 | /// allowed to fail without throwing an exception. In this case an |
12074 | /// empty object will be produced. This flag is optional and |
12075 | /// defaults to false. |
12076 | primitive_desc(const engine &aengine, const memory::desc &src_desc, |
12077 | const memory::desc &weight_desc, |
12078 | const memory::desc &diff_src_desc, |
12079 | const memory::desc &diff_weights_desc, |
12080 | const memory::desc &diff_dst_desc, |
12081 | const prelu_forward::primitive_desc &hint_fwd_pd, |
12082 | const primitive_attr &attr = default_attr(), |
12083 | bool allow_empty = false) { |
12084 | |
12085 | dnnl_primitive_desc_t pd = nullptr; |
12086 | dnnl_status_t status = dnnl_prelu_backward_primitive_desc_create( |
12087 | &pd, aengine.get(), src_desc.get(), weight_desc.get(), |
12088 | diff_src_desc.get(), diff_weights_desc.get(), |
12089 | diff_dst_desc.get(), hint_fwd_pd.get(), attr.get()); |
12090 | |
12091 | if (!allow_empty) |
12092 | error::wrap_c_api(status, |
12093 | "could not create a primitive descriptor for a prelu " |
12094 | "backward propagation primitive" ); |
12095 | reset(pd); |
12096 | } |
12097 | |
12098 | /// Constructs a primitive descriptor for a prelu backward |
12099 | /// propagation primitive from a C API primitive descriptor that must |
12100 | /// have a matching kind. |
12101 | /// |
12102 | /// @param pd C API primitive descriptor for a prelu backward |
12103 | /// propagation primitive. |
12104 | primitive_desc(dnnl_primitive_desc_t pd) |
12105 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu, |
12106 | dnnl::prop_kind::backward) {} |
12107 | |
12108 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
12109 | memory::desc src_desc() const { return base::src_desc(0); } |
12110 | |
12111 | /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const |
12112 | memory::desc diff_src_desc() const { return base::diff_src_desc(0); } |
12113 | |
12114 | /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const |
12115 | memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } |
12116 | |
12117 | /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const |
12118 | prop_kind get_prop_kind() const { return base::get_prop_kind(); } |
12119 | }; |
12120 | |
12121 | /// Default constructor. Produces an empty object. |
12122 | prelu_backward() = default; |
12123 | |
12124 | /// Constructs a prelu backward propagation primitive. |
12125 | /// @param pd Primitive descriptor for a prelu backward propagation |
12126 | /// primitive. |
12127 | prelu_backward(const primitive_desc &pd) : primitive(pd) {} |
12128 | |
12129 | /// Constructs a prelu backward propagation primitive from a cache blob. |
12130 | /// @param pd Primitive descriptor for a prelu backward propagation |
12131 | /// primitive. |
12132 | /// @param cache_blob Cache blob. |
12133 | prelu_backward( |
12134 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
12135 | : primitive(pd, cache_blob) {} |
12136 | }; |
12137 | |
12138 | /// @} dnnl_api_prelu |
12139 | |
12140 | /// @addtogroup dnnl_api_reduction Reduction |
12141 | /// |
12142 | /// A primitive to compute reduction operation on data tensor |
12143 | /// using min, max, mul, sum, mean and norm_lp operations. |
12144 | /// |
12145 | /// @sa @ref dev_guide_reduction in developer guide |
12146 | /// |
12147 | /// @{ |
12148 | |
12149 | /// Reduction. |
12150 | struct reduction : public primitive { |
12151 | /// Primitive descriptor for a reduction primitive. |
12152 | struct primitive_desc : public dnnl::primitive_desc { |
12153 | /// Default constructor. Produces an empty object. |
12154 | primitive_desc() = default; |
12155 | |
12156 | /// Constructs a primitive descriptor for a reduction primitive using |
12157 | /// algorithm specific parameters, source and destination memory |
12158 | /// descriptors. |
12159 | /// |
12160 | /// @note |
12161 | /// Destination memory descriptor may be initialized with |
12162 | /// #dnnl::memory::format_tag::any value of @p format_tag. |
12163 | /// |
12164 | /// @param aengine Engine to use. |
12165 | /// @param aalgorithm reduction algorithm kind. Possible values: |
12166 | /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum, |
12167 | /// #dnnl_reduction_mul, #dnnl_reduction_mean, |
12168 | /// #dnnl_reduction_norm_lp_max, #dnnl_reduction_norm_lp_sum, |
12169 | /// #dnnl_reduction_norm_lp_power_p_max, |
12170 | /// #dnnl_reduction_norm_lp_power_p_sum. |
12171 | /// @param p algorithm specific parameter. |
12172 | /// @param eps algorithm specific parameter. |
12173 | /// @param src_desc Source memory descriptor. |
12174 | /// @param dst_desc Destination memory descriptor. |
12175 | /// @param attr Primitive attributes to use. Attributes are optional |
12176 | /// and default to empty attributes. |
12177 | /// @param allow_empty A flag signifying whether construction is |
12178 | /// allowed to fail without throwing an exception. In this case an |
12179 | /// empty object will be produced. This flag is optional and |
12180 | /// defaults to false. |
12181 | primitive_desc(const engine &aengine, algorithm aalgorithm, |
12182 | const memory::desc &src_desc, const memory::desc &dst_desc, |
12183 | float p, float eps, const primitive_attr &attr = default_attr(), |
12184 | bool allow_empty = false) { |
12185 | |
12186 | dnnl_primitive_desc_t pd = nullptr; |
12187 | dnnl_status_t status = dnnl_reduction_primitive_desc_create(&pd, |
12188 | aengine.get(), convert_to_c(aalgorithm), src_desc.get(), |
12189 | dst_desc.get(), p, eps, attr.get()); |
12190 | |
12191 | if (!allow_empty) |
12192 | error::wrap_c_api(status, |
12193 | "could not create a primitive descriptor for a " |
12194 | "reduction primitive descriptor" ); |
12195 | reset(pd); |
12196 | } |
12197 | |
12198 | /// Constructs a primitive descriptor for a reduction primitive from a C |
12199 | /// API primitive descriptor that must have a matching kind. |
12200 | /// |
12201 | /// @param pd C API primitive descriptor for a reduction primitive. |
12202 | primitive_desc(dnnl_primitive_desc_t pd) |
12203 | : dnnl::primitive_desc(pd, dnnl::primitive::kind::reduction) {} |
12204 | |
12205 | /// @copydoc dnnl::primitive_desc_base::src_desc()const |
12206 | memory::desc src_desc() const { return base::src_desc(0); } |
12207 | |
12208 | /// @copydoc dnnl::primitive_desc_base::dst_desc()const |
12209 | memory::desc dst_desc() const { return base::dst_desc(0); } |
12210 | |
12211 | /// @copydoc dnnl::primitive_desc_base::get_p()const |
12212 | float get_p() const { return base::get_p(); } |
12213 | |
12214 | /// @copydoc dnnl::primitive_desc_base::get_epsilon()const |
12215 | float get_epsilon() const { return base::get_epsilon(); } |
12216 | |
12217 | /// @copydoc dnnl::primitive_desc_base::get_algorithm()const |
12218 | algorithm get_algorithm() const { return base::get_algorithm(); } |
12219 | }; |
12220 | |
12221 | /// Default constructor. Produces an empty object. |
12222 | reduction() = default; |
12223 | |
12224 | /// Constructs a reduction primitive. |
12225 | /// @param pd Primitive descriptor for a reduction primitive. |
12226 | reduction(const primitive_desc &pd) : primitive(pd) {} |
12227 | |
12228 | /// Constructs a reduction primitive from a cache blob. |
12229 | /// @param pd Primitive descriptor for a reduction primitive. |
12230 | /// @param cache_blob Cache blob. |
12231 | reduction(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
12232 | : primitive(pd, cache_blob) {} |
12233 | }; |
12234 | |
12235 | /// @} dnnl_api_reduction |
12236 | |
12237 | /// @} dnnl_api_primitives |
12238 | |
12239 | /// @addtogroup dnnl_api_service Service |
12240 | /// |
12241 | /// A set of functions that aid in oneDNN debugging and profiling. |
12242 | /// |
12243 | /// @{ |
12244 | |
12245 | /// @copydoc dnnl_version_t |
12246 | using version_t = dnnl_version_t; |
12247 | |
12248 | /// Status values returned by the library functions. |
12249 | enum class status { |
12250 | /// @copydoc dnnl_success |
12251 | success = dnnl_success, |
12252 | /// @copydoc dnnl_out_of_memory |
12253 | out_of_memory = dnnl_out_of_memory, |
12254 | /// @copydoc dnnl_invalid_arguments |
12255 | invalid_arguments = dnnl_invalid_arguments, |
12256 | /// @copydoc dnnl_unimplemented |
12257 | unimplemented = dnnl_unimplemented, |
12258 | /// @copydoc dnnl_last_impl_reached |
12259 | last_impl_reached = dnnl_last_impl_reached, |
12260 | /// @copydoc dnnl_runtime_error |
12261 | runtime_error = dnnl_runtime_error, |
12262 | /// @copydoc dnnl_not_required |
12263 | not_required = dnnl_not_required, |
12264 | }; |
12265 | |
12266 | /// @copydoc dnnl_set_verbose() |
12267 | inline status set_verbose(int level) { |
12268 | return static_cast<status>(dnnl_set_verbose(level)); |
12269 | } |
12270 | |
12271 | /// @copydoc dnnl_version() |
12272 | inline const version_t *version() { |
12273 | return dnnl_version(); |
12274 | } |
12275 | |
12276 | /// Returns the floating-point math mode that will be used by default |
12277 | /// for all subsequently created primitives. |
12278 | /// |
12279 | /// @returns Output FP math mode. |
12280 | inline fpmath_mode get_default_fpmath_mode() { |
12281 | dnnl_fpmath_mode_t mode; |
12282 | error::wrap_c_api(dnnl_get_default_fpmath_mode(&mode), |
12283 | "could not get a default fpmath mode" ); |
12284 | return static_cast<fpmath_mode>(mode); |
12285 | } |
12286 | |
12287 | /// @copydoc dnnl_set_default_fpmath_mode() |
12288 | inline status set_default_fpmath_mode(fpmath_mode mode) { |
12289 | return static_cast<status>( |
12290 | dnnl_set_default_fpmath_mode(convert_to_c(mode))); |
12291 | } |
12292 | |
12293 | /// @copydoc dnnl_set_jit_dump() |
12294 | inline status set_jit_dump(int enable) { |
12295 | return static_cast<status>(dnnl_set_jit_dump(enable)); |
12296 | } |
12297 | |
12298 | /// @copydoc dnnl_set_jit_profiling_flags() |
12299 | inline status set_jit_profiling_flags(unsigned flags) { |
12300 | return static_cast<status>(dnnl_set_jit_profiling_flags(flags)); |
12301 | } |
12302 | |
12303 | /// @copydoc dnnl_set_jit_profiling_jitdumpdir() |
12304 | inline status set_jit_profiling_jitdumpdir(const std::string &dir) { |
12305 | return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str())); |
12306 | } |
12307 | |
12308 | /// @copydoc dnnl_cpu_isa_t |
12309 | enum class cpu_isa { |
12310 | /// @copydoc dnnl_cpu_isa_default |
12311 | isa_default = dnnl_cpu_isa_default, |
12312 | /// @copydoc dnnl_cpu_isa_sse41 |
12313 | sse41 = dnnl_cpu_isa_sse41, |
12314 | /// @copydoc dnnl_cpu_isa_avx |
12315 | avx = dnnl_cpu_isa_avx, |
12316 | /// @copydoc dnnl_cpu_isa_avx2 |
12317 | avx2 = dnnl_cpu_isa_avx2, |
12318 | /// @copydoc dnnl_cpu_isa_avx2_vnni |
12319 | avx2_vnni = dnnl_cpu_isa_avx2_vnni, |
12320 | /// @copydoc dnnl_cpu_isa_avx2_vnni_2 |
12321 | avx2_vnni_2 = dnnl_cpu_isa_avx2_vnni_2, |
12322 | /// @copydoc dnnl_cpu_isa_avx512_core |
12323 | avx512_core = dnnl_cpu_isa_avx512_core, |
12324 | /// @copydoc dnnl_cpu_isa_avx512_core_vnni |
12325 | avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni, |
12326 | /// @copydoc dnnl_cpu_isa_avx512_core_bf16 |
12327 | avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16, |
12328 | /// @copydoc dnnl_cpu_isa_avx512_core_fp16 |
12329 | avx512_core_fp16 = dnnl_cpu_isa_avx512_core_fp16, |
12330 | /// @copydoc dnnl_cpu_isa_avx512_core_amx |
12331 | avx512_core_amx = dnnl_cpu_isa_avx512_core_amx, |
12332 | /// @copydoc dnnl_cpu_isa_avx512_core_amx_fp16 |
12333 | avx512_core_amx_fp16 = dnnl_cpu_isa_avx512_core_amx_fp16, |
12334 | }; |
12335 | |
12336 | /// @copydoc dnnl_set_max_cpu_isa() |
12337 | inline status set_max_cpu_isa(cpu_isa isa) { |
12338 | return static_cast<status>( |
12339 | dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa))); |
12340 | } |
12341 | |
12342 | /// @copydoc dnnl_get_effective_cpu_isa() |
12343 | inline cpu_isa get_effective_cpu_isa() { |
12344 | return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa()); |
12345 | } |
12346 | |
12347 | /// @copydoc dnnl_cpu_isa_hints_t |
12348 | enum class cpu_isa_hints { |
12349 | /// @copydoc dnnl_cpu_isa_no_hints |
12350 | no_hints = dnnl_cpu_isa_no_hints, |
12351 | /// @copydoc dnnl_cpu_isa_prefer_ymm |
12352 | prefer_ymm = dnnl_cpu_isa_prefer_ymm, |
12353 | }; |
12354 | |
12355 | /// @copydoc dnnl_set_cpu_isa_hints() |
12356 | inline status set_cpu_isa_hints(cpu_isa_hints isa_hints) { |
12357 | return static_cast<status>(dnnl_set_cpu_isa_hints( |
12358 | static_cast<dnnl_cpu_isa_hints_t>(isa_hints))); |
12359 | } |
12360 | |
12361 | /// @copydoc dnnl_get_cpu_isa_hints() |
12362 | inline cpu_isa_hints get_cpu_isa_hints() { |
12363 | return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints()); |
12364 | } |
12365 | |
12366 | /// @} dnnl_api_service |
12367 | |
12368 | /// @addtogroup dnnl_api_primitive_cache Primitive Cache |
12369 | /// |
12370 | /// A set of functions that provide primitive cache control. |
12371 | /// |
12372 | /// @{ |
12373 | |
12374 | /// Returns the number of primitives that can be held in the primitive cache |
12375 | /// at the same time. |
12376 | inline int get_primitive_cache_capacity() { |
12377 | int result = 0; |
12378 | error::wrap_c_api(dnnl_get_primitive_cache_capacity(&result), |
12379 | "could not get primitive cache capacity" ); |
12380 | return result; |
12381 | } |
12382 | |
12383 | /// @copydoc dnnl_set_primitive_cache_capacity(int capacity) |
12384 | inline void set_primitive_cache_capacity(int capacity) { |
12385 | error::wrap_c_api(dnnl_set_primitive_cache_capacity(capacity), |
12386 | "could not set primitive cache capacity" ); |
12387 | } |
12388 | |
12389 | /// @} dnnl_api_primitive_cache |
12390 | |
12391 | /// @addtogroup dnnl_api_blas BLAS functions |
12392 | /// |
12393 | /// A subset of Basic Linear Algebra (BLAS) functions that perform |
12394 | /// matrix-matrix multiplication. |
12395 | /// |
12396 | /// @{ |
12397 | |
12398 | /// @copydoc dnnl_sgemm() |
12399 | inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, |
12400 | dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, |
12401 | const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) { |
12402 | return static_cast<status>(dnnl_sgemm( |
12403 | transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc)); |
12404 | } |
12405 | |
12406 | /// @copydoc dnnl_gemm_u8s8s32() |
12407 | inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, |
12408 | dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, |
12409 | dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, |
12410 | float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) { |
12411 | return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N, |
12412 | K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co)); |
12413 | } |
12414 | |
12415 | /// @copydoc dnnl_gemm_s8s8s32() |
12416 | inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, |
12417 | dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, |
12418 | dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, |
12419 | float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) { |
12420 | return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N, |
12421 | K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co)); |
12422 | } |
12423 | |
12424 | /// @} dnnl_api_blas |
12425 | |
12426 | // implementation section |
12427 | |
12428 | /// @cond DO_NOT_DOCUMENT_THIS |
12429 | inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) { |
12430 | dnnl_primitive_t result; |
12431 | error::wrap_c_api(dnnl_primitive_create(&result, c_pd), |
12432 | "could not create a primitive" ); |
12433 | reset(result); |
12434 | } |
12435 | |
12436 | inline primitive::primitive(const_dnnl_primitive_desc_t c_pd, |
12437 | const std::vector<uint8_t> &cache_blob) { |
12438 | dnnl_primitive_t result; |
12439 | size_t size = cache_blob.size(); |
12440 | const uint8_t *cache_blob_data = cache_blob.data(); |
12441 | error::wrap_c_api(dnnl_primitive_create_from_cache_blob( |
12442 | &result, c_pd, size, cache_blob_data), |
12443 | "could not create a primitive from a cache blob" ); |
12444 | reset(result); |
12445 | } |
12446 | |
12447 | inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {} |
12448 | inline primitive::primitive( |
12449 | const primitive_desc &pd, const std::vector<uint8_t> &cache_blob) |
12450 | : primitive(pd.get(), cache_blob) {} |
12451 | |
12452 | inline void primitive::execute(const stream &astream, |
12453 | const std::unordered_map<int, memory> &args) const { |
12454 | std::vector<dnnl_exec_arg_t> c_args; |
12455 | c_args.reserve(args.size()); |
12456 | for (const auto &a : args) |
12457 | c_args.push_back({a.first, a.second.get(true)}); |
12458 | |
12459 | error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(), |
12460 | (int)c_args.size(), c_args.data()), |
12461 | "could not execute a primitive" ); |
12462 | } |
12463 | |
12464 | /// @endcond |
12465 | |
12466 | #undef DNNL_DEFINE_BITMASK_OPS |
12467 | |
12468 | } // namespace dnnl |
12469 | |
12470 | /// oneAPI namespace |
12471 | |
12472 | /// The oneAPI namespace. |
12473 | /// Contains the oneapi::dnnl namespace as an alias to the ::dnnl namespace. |
12474 | namespace oneapi { |
12475 | // Note: without this guard, doxygen warns of potentially recursive namespace |
12476 | #ifndef DOXYGEN_SHOULD_SKIP_THIS |
12477 | /// oneDNN alias namespace |
12478 | namespace dnnl = ::dnnl; |
12479 | #endif |
12480 | } // namespace oneapi |
12481 | |
12482 | /// @} dnnl_api |
12483 | |
12484 | #endif /* ONEAPI_DNNL_DNNL_HPP */ |
12485 | |