1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_OPDESC_HPP
18#define COMMON_OPDESC_HPP
19
20#include <vector>
21
22#include "common/c_types_map.hpp"
23#include "common/gemm_types.hpp"
24
25namespace dnnl {
26namespace impl {
27
28struct reorder_desc_t {
29 primitive_kind_t primitive_kind;
30 const memory_desc_t *src_md;
31 const memory_desc_t *dst_md;
32 engine_kind_t src_engine_kind;
33 engine_kind_t dst_engine_kind;
34 bool is_cross_engine;
35};
36
37struct concat_desc_t {
38 concat_desc_t() = default;
39 concat_desc_t(primitive_kind_t primitive_kind, const memory_desc_t *dst_md,
40 dim_t n, dim_t concat_dimension,
41 const memory_desc_t *const *src_mds)
42 : primitive_kind(primitive_kind)
43 , dst_md(dst_md)
44 , n(n)
45 , concat_dimension(concat_dimension) {
46 for (dim_t i = 0; i < n; i++)
47 this->src_mds.push_back(src_mds[i]);
48 }
49
50 primitive_kind_t primitive_kind;
51 const memory_desc_t *dst_md;
52 dim_t n;
53 dim_t concat_dimension;
54 std::vector<const memory_desc_t *> src_mds;
55};
56
57struct sum_desc_t {
58 sum_desc_t() = default;
59 sum_desc_t(primitive_kind_t primitive_kind, const memory_desc_t *dst_md,
60 dim_t n, const float *scales, const memory_desc_t *const *src_mds)
61 : primitive_kind(primitive_kind), dst_md(dst_md), n(n), scales(scales) {
62 for (dim_t i = 0; i < n; i++)
63 this->src_mds.push_back(src_mds[i]);
64 }
65
66 primitive_kind_t primitive_kind;
67 const memory_desc_t *dst_md;
68 dim_t n;
69 const float *scales;
70 std::vector<const memory_desc_t *> src_mds;
71};
72
73struct zero_pad_desc_t {
74 primitive_kind_t primitive_kind;
75};
76
77struct inner_product_desc_t {
78 // The kind of primitive. Used for self-identifying the primitive
79 // descriptor. Must be #dnnl_inner_product.
80 primitive_kind_t primitive_kind;
81 // The kind of propagation. Possible values: forward_training,
82 // forward_inference, backward_data,
83 // backward_weights, and backward_bias.
84 prop_kind_t prop_kind;
85 // Source memory descriptor.
86 memory_desc_t src_desc;
87 // Source gradient memory descriptor.
88 memory_desc_t diff_src_desc;
89 // Weights memory descriptor.
90 memory_desc_t weights_desc;
91 // Weights gradient memory descriptor.
92 memory_desc_t diff_weights_desc;
93 // Bias memory descriptor.
94 memory_desc_t bias_desc;
95 // Bias gradient memory descriptor.
96 memory_desc_t diff_bias_desc;
97 // Destination memory descriptor.
98 memory_desc_t dst_desc;
99 // Destination gradient memory descriptor.
100 memory_desc_t diff_dst_desc;
101 // The accumulator data type.
102 data_type_t accum_data_type;
103};
104
105struct convolution_desc_t {
106 // The kind of primitive. Used for self-identifying the primitive
107 // descriptor. Must be #dnnl_convolution.
108 primitive_kind_t primitive_kind;
109 // The kind of propagation. Possible values: #dnnl_forward_training,
110 // #dnnl_forward_inference, #dnnl_backward_data,
111 // #dnnl_backward_weights, and #dnnl_backward_bias.
112 prop_kind_t prop_kind;
113 // The kind of the convolution algorithm. Possible values:
114 // #dnnl_convolution_direct.
115 alg_kind_t alg_kind;
116 // Source memory descriptor.
117 memory_desc_t src_desc;
118 // Source gradient memory descriptor.
119 memory_desc_t diff_src_desc;
120 // Weights memory descriptor.
121 memory_desc_t weights_desc;
122 // Weights gradient memory descriptor.
123 memory_desc_t diff_weights_desc;
124 // Bias memory descriptor.
125 memory_desc_t bias_desc;
126 // Bias gradient memory descriptor.
127 memory_desc_t diff_bias_desc;
128 // Destination memory descriptor.
129 memory_desc_t dst_desc;
130 // Destination gradient memory descriptor.
131 memory_desc_t diff_dst_desc;
132 // Convolution strides in each spatial dimension.
133 dims_t strides;
134 // Convolution dilates in each spatial dimension.
135 dims_t dilates;
136 // Padding in each spatial dimension. padding[0] is a padding in the
137 // beginning (@p padding_l), padding[1] is a padding in the end (@p
138 // padding_r).
139 dims_t padding[2];
140 // The accumulator data type. Initialized automatically.
141 data_type_t accum_data_type;
142};
143
144// A descriptor of a deconvolution operation.
145using deconvolution_desc_t = convolution_desc_t;
146
147// A descriptor of a shuffle operation.
148struct shuffle_desc_t {
149 // The kind of primitive. Used for self-identifying the primitive
150 // descriptor. Must be #dnnl_shuffle.
151 primitive_kind_t primitive_kind;
152 // The kind of propagation. Possible values: #dnnl_forward_training,
153 // #dnnl_forward_inference, and #dnnl_backward_data.
154 prop_kind_t prop_kind;
155 // Source or source gradient memory descriptor.
156 memory_desc_t src_desc;
157 // Destination or destination gradient memory descriptor.
158 memory_desc_t dst_desc;
159 // Axis for shuffling.
160 int axis;
161 // Number of groups.
162 dim_t group_size;
163};
164
165// A descriptor of resampling operation.
166struct resampling_desc_t {
167 // The kind of primitive. Used for self-identifying the primitive
168 // descriptor. Must be #dnnl_resampling.
169 primitive_kind_t primitive_kind;
170 // The kind of propagation. Possible values: #dnnl_forward_training,
171 // #dnnl_forward_inference, #dnnl_backward_data,
172 prop_kind_t prop_kind;
173 // The kind of the resampling algorithm. Possible values:
174 // #dnnl_resampling_nearest, #dnnl_resampling_linear.
175 alg_kind_t alg_kind;
176 // Source memory descriptor.
177 memory_desc_t src_desc;
178 // Source gradient memory descriptor.
179 memory_desc_t diff_src_desc;
180 // Destination memory descriptor.
181 memory_desc_t dst_desc;
182 // Destination gradient memory descriptor.
183 memory_desc_t diff_dst_desc;
184 // Resampling factor in each spatial dimension.
185 float factors[DNNL_MAX_NDIMS];
186};
187
188// A descriptor of a matrix multiplication operation.
189//
190// 2D case:
191// dst[m, n] = src[m, k] * weights[k, n] + bias[m, n]
192//
193// 3D case:
194// dst[mb, m, n] = src[mb, m, k] * weights[mb, k, n] + bias[mb, m, n]
195struct matmul_desc_t {
196 // The kind of primitive. Used for self-identifying the primitive
197 // descriptor. Must be #dnnl_matmul.
198 primitive_kind_t primitive_kind;
199 // Source memory descriptor.
200 memory_desc_t src_desc;
201 // Weights memory descriptor.
202 memory_desc_t weights_desc;
203 // Bias memory descriptor.
204 memory_desc_t bias_desc;
205 // Destination memory descriptor.
206 memory_desc_t dst_desc;
207 // The accumulator data type. Initialized automatically.
208 data_type_t accum_data_type;
209};
210
211// A descriptor of a element-wise operation.
212struct eltwise_desc_t {
213 // The kind of primitive. Used for self-identifying the primitive
214 // descriptor. Must be #dnnl_eltwise.
215 primitive_kind_t primitive_kind;
216 // The kind of propagation. Possible values: #dnnl_forward_training,
217 // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data.
218 prop_kind_t prop_kind;
219 // The kind of eltwise algorithm. Possible values: #dnnl_eltwise_relu,
220 // #dnnl_eltwise_tanh, #dnnl_eltwise_elu, #dnnl_eltwise_square,
221 // #dnnl_eltwise_abs, #dnnl_eltwise_sqrt, #dnnl_eltwise_linear,
222 // #dnnl_eltwise_soft_relu, #dnnl_eltwise_logistic, #dnnl_eltwise_exp,
223 // #dnnl_eltwise_gelu_tanh, #dnnl_eltwise_swish, #dnnl_eltwise_log,
224 // #dnnl_eltwise_clip, #dnnl_eltwise_clip_v2, #dnnl_eltwise_pow,
225 // #dnnl_eltwise_gelu_erf, #dnnl_eltwise_round,
226 // #dnnl_eltwise_mish, #dnnl_eltwise_hardswish, #dnnl_eltwise_hardsigmoid.
227 // Possible values for passing destination memory on backward:
228 // #dnnl_eltwise_relu_use_dst_for_bwd, #dnnl_eltwise_tanh_use_dst_for_bwd,
229 // #dnnl_eltwise_elu_use_dst_for_bwd, #dnnl_eltwise_sqrt_use_dst_for_bwd,
230 // #dnnl_eltwise_logistic_use_dst_for_bwd,
231 // #dnnl_eltwise_exp_use_dst_for_bwd,
232 // #dnnl_eltwise_clip_v2_use_dst_for_bwd.
233 alg_kind_t alg_kind;
234 // Source memory descriptor.
235 memory_desc_t src_desc;
236 // Destination memory descriptor.
237 memory_desc_t dst_desc;
238 // Source gradient memory descriptor.
239 memory_desc_t diff_src_desc;
240 // Destination gradient memory descriptor.
241 memory_desc_t diff_dst_desc;
242 // Algorithm specific parameter.
243 // Accordance table:
244 // - #dnnl_eltwise_relu: @p alpha -- negative slope, @p beta ignored
245 // - #dnnl_eltwise_tanh: @p alpha and @p beta ignored
246 // - #dnnl_eltwise_elu: @p alpha -- negative slope, @p beta ignored
247 // - #dnnl_eltwise_square: @p alpha and @p beta ignored
248 // - #dnnl_eltwise_abs: @p alpha and @p beta ignored
249 // - #dnnl_eltwise_sqrt: @p alpha and @p beta ignored
250 // - #dnnl_eltwise_linear: @p alpha -- scale, @p beta -- shift
251 // - #dnnl_eltwise_soft_relu: @p alpha -- soft_relu arg scaling, @p beta ignored
252 // - #dnnl_eltwise_logistic: @p alpha and @p beta ignored
253 // - #dnnl_eltwise_exp: @p alpha and @p beta ignored
254 // - #dnnl_eltwise_gelu_tanh: @p alpha and @p beta ignored
255 // - #dnnl_eltwise_swish: @p alpha -- sigmoid arg scaling, @p beta ignored
256 // - #dnnl_eltwise_log: @p alpha and @p beta ignored
257 // - #dnnl_eltwise_clip: @p alpha -- lower bound, @p beta -- upper bound
258 // - #dnnl_eltwise_clip_v2: @p alpha -- lower bound, @p beta -- upper bound
259 // - #dnnl_eltwise_pow: @p alpha -- scale, @p beta -- exponent
260 // - #dnnl_eltwise_gelu_erf: @p alpha and @p beta ignored
261 // - #dnnl_eltwise_round: @p alpha and @p beta ignored
262 // - #dnnl_eltwise_mish: @p alpha and @p beta ignored
263 // - #dnnl_eltwise_hardswish: @p alpha and @p beta ignored
264 // - #dnnl_eltwise_hardsigmoid: @p alpha -- scale, @p beta -- shift
265 float alpha, beta;
266};
267
268// A descriptor of a Batch Normalization operation.
269struct batch_normalization_desc_t {
270 // The kind of primitive. Used for self-identifying the primitive
271 // descriptor. Must be #dnnl_batch_normalization.
272 primitive_kind_t primitive_kind;
273 // The kind of propagation. Possible values: #dnnl_forward_training,
274 // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data.
275 prop_kind_t prop_kind;
276 // Source memory descriptor.
277 memory_desc_t src_desc;
278 // Destination memory descriptor.
279 memory_desc_t dst_desc;
280 // Source gradient memory descriptor.
281 memory_desc_t diff_src_desc;
282 // Destination gradient memory descriptor.
283 memory_desc_t diff_dst_desc;
284 // Scale and/or shift data and gradient memory descriptor.
285 // Scaleshift memory descriptor uses 1D #dnnl_x format[Channels].
286 memory_desc_t scaleshift_desc;
287 memory_desc_t diff_scaleshift_desc;
288 // Statistics memory descriptor.
289 //
290 // Statistics (mean or variance) descriptor use 1D #dnnl_x format[Channels].
291 memory_desc_t stat_desc;
292 // Batch normalization epsilon parameter.
293 float batch_norm_epsilon;
294 unsigned flags;
295};
296
297// A descriptor of a Layer Normalization operation.
298struct layer_normalization_desc_t {
299 // The kind of primitive. Used for self-identifying the primitive
300 // descriptor. Must be #dnnl_layer_normalization.
301 primitive_kind_t primitive_kind;
302 // The kind of propagation. Possible values: #dnnl_forward_training,
303 // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data.
304 prop_kind_t prop_kind;
305 // Source memory descriptor.
306 memory_desc_t src_desc;
307 // Source gradient memory descriptor.
308 memory_desc_t diff_src_desc;
309 // Scale and shift data and gradient memory descriptors.
310 // Scaleshift memory descriptor uses 1D #dnnl_x format[normalized_dim].
311 // Normalized_dim is equal to the last logical dimension of the source
312 // tensor across which normalization is performed.
313 memory_desc_t data_scaleshift_desc;
314 memory_desc_t diff_data_scaleshift_desc;
315 // Mean and variance data memory descriptors.
316 //
317 // Statistics (mean and variance) memory descriptor is the k-dimensional tensor
318 // where k is equal to data_tensor_ndims - 1 and may have any plain
319 // (stride[last_dim] == 1) user-provided format.
320 memory_desc_t stat_desc;
321 // Layer normalization epsilon parameter.
322 float layer_norm_epsilon;
323 unsigned flags;
324 // Destination memory descriptor.
325 memory_desc_t dst_desc;
326 // Destination gradient memory descriptor.
327 memory_desc_t diff_dst_desc;
328};
329
330// A descriptor of a Local Response Normalization (LRN) operation.
331struct lrn_desc_t {
332 // The kind of primitive. Used for self-identifying the primitive
333 // descriptor. Must be #dnnl_lrn.
334 primitive_kind_t primitive_kind;
335 // The kind of propagation. Possible values: #dnnl_forward_training,
336 // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data.
337 prop_kind_t prop_kind;
338 // LRN algorithm. Possible values: #dnnl_lrn_within_channel and
339 // #dnnl_lrn_across_channels.
340 alg_kind_t alg_kind;
341 // Source memory descriptor.
342 memory_desc_t src_desc;
343 // Destination memory descriptor.
344 memory_desc_t dst_desc;
345 // Source gradient memory descriptor.
346 memory_desc_t diff_src_desc;
347 // Destination gradient memory descriptor.
348 memory_desc_t diff_dst_desc;
349 // The number of channels to sum over (for cross-channel LRN) or the side
350 // length of the square region to sum over (for within-channel LRN).
351 dim_t local_size;
352 // LRN alpha parameter.
353 float lrn_alpha;
354 // LRN beta parameter.
355 float lrn_beta;
356 // LRN k parameter.
357 float lrn_k;
358};
359
360// A descriptor of reduction operation.
361struct reduction_desc_t {
362 // The kind of primitive. Used for self-identifying the primitive
363 // descriptor. Must be #dnnl_reduction.
364 primitive_kind_t primitive_kind;
365 // The kind of reduction algorithm. Possible values:
366 // #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
367 // #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max,
368 // #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max,
369 // #dnnl_reduction_norm_lp_power_p_sum.
370 alg_kind_t alg_kind;
371 // Source memory descriptor.
372 memory_desc_t src_desc;
373 // Destination memory descriptor.
374 memory_desc_t dst_desc;
375 // Algorithm specific parameters.
376 // Accordance table:
377 // #dnnl_reduction_max: @p p and @p eps are ignored
378 // #dnnl_reduction_min: @p p and @p eps are ignored
379 // #dnnl_reduction_norm_lp_max: @p p -- power, @p eps -- epsilon
380 // #dnnl_reduction_norm_lp_sum: @p p -- power, @p eps -- epsilon
381 // #dnnl_reduction_norm_lp_power_p_max: @p p -- power, @p eps -- epsilon
382 // #dnnl_reduction_norm_lp_power_p_sum: @p p -- power, @p eps -- epsilon
383 // #dnnl_reduction_sum: @p p and @p eps are ignored
384 // #dnnl_reduction_mul: @p p and @p eps are ignored
385 // #dnnl_reduction_mean: @p p and @p eps are ignored
386 float p, eps;
387};
388
389/// A descriptor of a Softmax operation.
390struct softmax_desc_t {
391 // The kind of primitive. Used for self-identifying the primitive
392 // descriptor. Must be #dnnl_softmax.
393 primitive_kind_t primitive_kind;
394 // The kind of propagation. Possible values: #dnnl_forward_training,
395 // #dnnl_forward_inference, and #dnnl_backward_data.
396 prop_kind_t prop_kind;
397 // Source memory descriptor.
398 memory_desc_t src_desc;
399 // Source gradient memory descriptor.
400 memory_desc_t diff_src_desc;
401 // The axis along which to perform the softmax.
402 int softmax_axis;
403 // Softmax algorithm. Possible values: #dnnl_softmax_accurate and
404 // #dnnl_softmax_log.
405 alg_kind_t alg_kind;
406 // Destination memory descriptor.
407 memory_desc_t dst_desc;
408 // Destination gradient memory descriptor.
409 memory_desc_t diff_dst_desc;
410};
411
412// A descriptor of a binary operation.
413struct binary_desc_t {
414 // The kind of primitive. Used for self-identifying the primitive
415 // descriptor. Must be #dnnl_binary.
416 primitive_kind_t primitive_kind;
417 // The kind of the binary algorithm. Possible values:
418 // #dnnl_binary_add, #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min,
419 // #dnnl_binary_div and #dnnl_binary_sub.
420 alg_kind_t alg_kind;
421 // Source memory descriptors.
422 memory_desc_t src_desc[2];
423 // Destination memory descriptor.
424 memory_desc_t dst_desc;
425};
426
427/// A descriptor of a PReLU operation.
428struct prelu_desc_t {
429 // The kind of primitive. Used for self-identifying the primitive
430 // descriptor. Must be #dnnl_prelu.
431 primitive_kind_t primitive_kind;
432 // The kind of propagation. Possible values: #dnnl_forward_training,
433 // #dnnl_forward_inference, #dnnl_backward
434 prop_kind_t prop_kind;
435 // Source memory descriptor.
436 memory_desc_t src_desc;
437 // Learnable parameter alpha memory descriptor.
438 // Alpha describes negative slope.
439 memory_desc_t weights_desc;
440 // Destination memory descriptor.
441 memory_desc_t dst_desc;
442 // Source gradient memory descriptor.
443 memory_desc_t diff_src_desc;
444 // Learnable parameter alpha gradient memory descriptor.
445 memory_desc_t diff_weights_desc;
446 // Destination gradient memory descriptor.
447 memory_desc_t diff_dst_desc;
448};
449
450// A descriptor of a pooling operation.
451struct pooling_desc_t {
452 // The kind of primitive. Used for self-identifying the primitive
453 // descriptor. Must be #dnnl_pooling.
454 primitive_kind_t primitive_kind;
455 // The kind of propagation. Possible values: #dnnl_forward_training,
456 // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data.
457 prop_kind_t prop_kind;
458 // The kind of pooling algorithm.
459 // Possible values: #dnnl_pooling_max,
460 // #dnnl_pooling_avg_include_padding, and
461 // #dnnl_pooling_avg_exclude_padding.
462 alg_kind_t alg_kind;
463 // Source memory descriptor.
464 memory_desc_t src_desc;
465 // Source gradient memory descriptor.
466 memory_desc_t diff_src_desc;
467 // Destination memory descriptor.
468 memory_desc_t dst_desc;
469 // Destination gradient memory descriptor.
470 memory_desc_t diff_dst_desc;
471 // Pooling kernel strides for spatial dimensions.
472 dims_t strides;
473 // Pooling kernel spatial dimensions.
474 dims_t kernel;
475 // Padding in each spatial dimension. padding[0] is a padding in the
476 // beginning (@p padding_l), padding[1] is a padding in the end (@p
477 // padding_r).
478 dims_t padding[2];
479 // The accumulator data type. Initialized automatically.
480 data_type_t accum_data_type;
481 // Pooling dilations for spatial dimensions.
482 dims_t dilation;
483};
484
485// A descriptor for an RNN operation.
486struct rnn_desc_t {
487 // The kind of primitive. Used for self-identifying the primitive
488 // descriptor. Must be #dnnl_rnn.
489 dnnl_primitive_kind_t primitive_kind;
490 // The kind of propagation. Possible values: #dnnl_forward_training,
491 // #dnnl_forward_inference, and #dnnl_backward.
492 prop_kind_t prop_kind;
493 // RNN cell kind. Must be one of #dnnl_vanilla_rnn,
494 // #dnnl_vanilla_lstm, #dnnl_vanilla_gru, or #dnnl_lbr_gru.
495 alg_kind_t cell_kind;
496 // The direction of RNN primitive execution.
497 rnn_direction_t direction;
498 // Source layer memory descriptor.
499 memory_desc_t src_layer_desc;
500 // Source iteration memory descriptor for hidden state.
501 memory_desc_t src_iter_desc;
502 // Source iteration memory descriptor for cell state.
503 memory_desc_t src_iter_c_desc;
504 // Weights layer memory descriptor.
505 memory_desc_t weights_layer_desc;
506 // Weights iteration memory descriptor.
507 memory_desc_t weights_iter_desc;
508 // Bias memory descriptor.
509 memory_desc_t bias_desc;
510 // Destination layer memory descriptor.
511 memory_desc_t dst_layer_desc;
512 // Destination iter memory descriptor for hidden state.
513 memory_desc_t dst_iter_desc;
514 // Destination iter memory descriptor for cell state.
515 memory_desc_t dst_iter_c_desc;
516 // Weights peephole memory descriptor.
517 // This memory descriptor is equal to zero memory descriptor in case of
518 // non-peephole LSTMs and other non-LSTM RNNs.
519 memory_desc_t weights_peephole_desc;
520 // Weights projection memory descriptor.
521 // This memory descriptor is equal to zero memory descriptor in case of
522 // non-projection LSTMs and other non-LSTM RNNs.
523 memory_desc_t weights_projection_desc;
524
525 // Source gradient layer memory descriptor.
526 memory_desc_t diff_src_layer_desc;
527 // Source gradient iter memory descriptor for hidden state.
528 memory_desc_t diff_src_iter_desc;
529 // Source gradient iter memory descriptor for cell state.
530 memory_desc_t diff_src_iter_c_desc;
531 // Weights gradient layer memory descriptor.
532 memory_desc_t diff_weights_layer_desc;
533 // Weights gradient iter memory descriptor.
534 memory_desc_t diff_weights_iter_desc;
535 // Bias gradient memory descriptor.
536 memory_desc_t diff_bias_desc;
537 // Destination gradient layer memory descriptor.
538 memory_desc_t diff_dst_layer_desc;
539 // Destination gradient iteration memory descriptor for hidden state.
540 memory_desc_t diff_dst_iter_desc;
541 // Destination gradient iteration memory descriptor for cell state.
542 memory_desc_t diff_dst_iter_c_desc;
543 // Weights gradient peephole memory descriptor.
544 // This memory descriptor is equal to zero memory descriptor in case of
545 // non-peephole LSTMs and other non-LSTM RNNs.
546 memory_desc_t diff_weights_peephole_desc;
547 // Weights gradient projection memory descriptor.
548 // This memory descriptor is equal to zero memory descriptor in case of
549 // non-projection LSTMs and other non-LSTM RNNs.
550 memory_desc_t diff_weights_projection_desc;
551
552 // RNN cell flags
553 unsigned int flags;
554 // Activation function used for vanilla_rnn cell kind.
555 // Must be either #dnnl_eltwise_relu or #dnnl_eltwise_tanh.
556 alg_kind_t activation_kind;
557 float alpha;
558 float beta;
559};
560
561struct op_desc_t {
562 union {
563 primitive_kind_t kind;
564 convolution_desc_t convolution;
565 deconvolution_desc_t deconvolution;
566 shuffle_desc_t shuffle;
567 pooling_desc_t pooling;
568 prelu_desc_t prelu;
569 eltwise_desc_t eltwise;
570 softmax_desc_t softmax;
571 lrn_desc_t lrn;
572 batch_normalization_desc_t batch_normalization;
573 layer_normalization_desc_t layer_normalization;
574 inner_product_desc_t inner_product;
575 rnn_desc_t rnn;
576 gemm_desc_t gemm;
577 concat_desc_t concat;
578 reorder_desc_t reorder;
579 sum_desc_t sum;
580 binary_desc_t binary;
581 matmul_desc_t matmul;
582 resampling_desc_t resampling;
583 zero_pad_desc_t zero_pad;
584 reduction_desc_t reduction;
585 };
586
587#define DECL_CTOR_AND_CONVERTERS(c_type) \
588 op_desc_t(const c_type &) = delete; \
589 static op_desc_t *convert_from_c(c_type *_) { \
590 return reinterpret_cast<op_desc_t *>(_); \
591 } \
592 static const op_desc_t *convert_from_c(const c_type *_) { \
593 return reinterpret_cast<const op_desc_t *>(_); \
594 }
595
596 DECL_CTOR_AND_CONVERTERS(convolution_desc_t);
597 DECL_CTOR_AND_CONVERTERS(shuffle_desc_t);
598 DECL_CTOR_AND_CONVERTERS(pooling_desc_t);
599 DECL_CTOR_AND_CONVERTERS(prelu_desc_t);
600 DECL_CTOR_AND_CONVERTERS(eltwise_desc_t);
601 DECL_CTOR_AND_CONVERTERS(softmax_desc_t);
602 DECL_CTOR_AND_CONVERTERS(lrn_desc_t);
603 DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t);
604 DECL_CTOR_AND_CONVERTERS(layer_normalization_desc_t);
605 DECL_CTOR_AND_CONVERTERS(inner_product_desc_t);
606 DECL_CTOR_AND_CONVERTERS(rnn_desc_t);
607 DECL_CTOR_AND_CONVERTERS(gemm_desc_t);
608 DECL_CTOR_AND_CONVERTERS(concat_desc_t);
609 DECL_CTOR_AND_CONVERTERS(reorder_desc_t);
610 DECL_CTOR_AND_CONVERTERS(sum_desc_t);
611 DECL_CTOR_AND_CONVERTERS(binary_desc_t);
612 DECL_CTOR_AND_CONVERTERS(matmul_desc_t);
613 DECL_CTOR_AND_CONVERTERS(resampling_desc_t);
614 DECL_CTOR_AND_CONVERTERS(zero_pad_desc_t);
615 DECL_CTOR_AND_CONVERTERS(reduction_desc_t);
616
617 // concat_desc_t and sum_desc_t have data members which have non-trivial
618 // special member functions hence the default destructor is implicitly
619 // deleted by the compiler which causes a warning on Windows so we should
620 // delete the destructor explicitly.
621 ~op_desc_t() = delete;
622
623#undef DECL_CTOR_AND_CONVERTERS
624};
625
626} // namespace impl
627} // namespace dnnl
628
629#endif
630