1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_OPDESC_HPP |
18 | #define COMMON_OPDESC_HPP |
19 | |
20 | #include <vector> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/gemm_types.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | |
28 | struct reorder_desc_t { |
29 | primitive_kind_t primitive_kind; |
30 | const memory_desc_t *src_md; |
31 | const memory_desc_t *dst_md; |
32 | engine_kind_t src_engine_kind; |
33 | engine_kind_t dst_engine_kind; |
34 | bool is_cross_engine; |
35 | }; |
36 | |
37 | struct concat_desc_t { |
38 | concat_desc_t() = default; |
39 | concat_desc_t(primitive_kind_t primitive_kind, const memory_desc_t *dst_md, |
40 | dim_t n, dim_t concat_dimension, |
41 | const memory_desc_t *const *src_mds) |
42 | : primitive_kind(primitive_kind) |
43 | , dst_md(dst_md) |
44 | , n(n) |
45 | , concat_dimension(concat_dimension) { |
46 | for (dim_t i = 0; i < n; i++) |
47 | this->src_mds.push_back(src_mds[i]); |
48 | } |
49 | |
50 | primitive_kind_t primitive_kind; |
51 | const memory_desc_t *dst_md; |
52 | dim_t n; |
53 | dim_t concat_dimension; |
54 | std::vector<const memory_desc_t *> src_mds; |
55 | }; |
56 | |
57 | struct sum_desc_t { |
58 | sum_desc_t() = default; |
59 | sum_desc_t(primitive_kind_t primitive_kind, const memory_desc_t *dst_md, |
60 | dim_t n, const float *scales, const memory_desc_t *const *src_mds) |
61 | : primitive_kind(primitive_kind), dst_md(dst_md), n(n), scales(scales) { |
62 | for (dim_t i = 0; i < n; i++) |
63 | this->src_mds.push_back(src_mds[i]); |
64 | } |
65 | |
66 | primitive_kind_t primitive_kind; |
67 | const memory_desc_t *dst_md; |
68 | dim_t n; |
69 | const float *scales; |
70 | std::vector<const memory_desc_t *> src_mds; |
71 | }; |
72 | |
73 | struct zero_pad_desc_t { |
74 | primitive_kind_t primitive_kind; |
75 | }; |
76 | |
77 | struct inner_product_desc_t { |
78 | // The kind of primitive. Used for self-identifying the primitive |
79 | // descriptor. Must be #dnnl_inner_product. |
80 | primitive_kind_t primitive_kind; |
81 | // The kind of propagation. Possible values: forward_training, |
82 | // forward_inference, backward_data, |
83 | // backward_weights, and backward_bias. |
84 | prop_kind_t prop_kind; |
85 | // Source memory descriptor. |
86 | memory_desc_t src_desc; |
87 | // Source gradient memory descriptor. |
88 | memory_desc_t diff_src_desc; |
89 | // Weights memory descriptor. |
90 | memory_desc_t weights_desc; |
91 | // Weights gradient memory descriptor. |
92 | memory_desc_t diff_weights_desc; |
93 | // Bias memory descriptor. |
94 | memory_desc_t bias_desc; |
95 | // Bias gradient memory descriptor. |
96 | memory_desc_t diff_bias_desc; |
97 | // Destination memory descriptor. |
98 | memory_desc_t dst_desc; |
99 | // Destination gradient memory descriptor. |
100 | memory_desc_t diff_dst_desc; |
101 | // The accumulator data type. |
102 | data_type_t accum_data_type; |
103 | }; |
104 | |
105 | struct convolution_desc_t { |
106 | // The kind of primitive. Used for self-identifying the primitive |
107 | // descriptor. Must be #dnnl_convolution. |
108 | primitive_kind_t primitive_kind; |
109 | // The kind of propagation. Possible values: #dnnl_forward_training, |
110 | // #dnnl_forward_inference, #dnnl_backward_data, |
111 | // #dnnl_backward_weights, and #dnnl_backward_bias. |
112 | prop_kind_t prop_kind; |
113 | // The kind of the convolution algorithm. Possible values: |
114 | // #dnnl_convolution_direct. |
115 | alg_kind_t alg_kind; |
116 | // Source memory descriptor. |
117 | memory_desc_t src_desc; |
118 | // Source gradient memory descriptor. |
119 | memory_desc_t diff_src_desc; |
120 | // Weights memory descriptor. |
121 | memory_desc_t weights_desc; |
122 | // Weights gradient memory descriptor. |
123 | memory_desc_t diff_weights_desc; |
124 | // Bias memory descriptor. |
125 | memory_desc_t bias_desc; |
126 | // Bias gradient memory descriptor. |
127 | memory_desc_t diff_bias_desc; |
128 | // Destination memory descriptor. |
129 | memory_desc_t dst_desc; |
130 | // Destination gradient memory descriptor. |
131 | memory_desc_t diff_dst_desc; |
132 | // Convolution strides in each spatial dimension. |
133 | dims_t strides; |
134 | // Convolution dilates in each spatial dimension. |
135 | dims_t dilates; |
136 | // Padding in each spatial dimension. padding[0] is a padding in the |
137 | // beginning (@p padding_l), padding[1] is a padding in the end (@p |
138 | // padding_r). |
139 | dims_t padding[2]; |
140 | // The accumulator data type. Initialized automatically. |
141 | data_type_t accum_data_type; |
142 | }; |
143 | |
144 | // A descriptor of a deconvolution operation. |
145 | using deconvolution_desc_t = convolution_desc_t; |
146 | |
147 | // A descriptor of a shuffle operation. |
148 | struct shuffle_desc_t { |
149 | // The kind of primitive. Used for self-identifying the primitive |
150 | // descriptor. Must be #dnnl_shuffle. |
151 | primitive_kind_t primitive_kind; |
152 | // The kind of propagation. Possible values: #dnnl_forward_training, |
153 | // #dnnl_forward_inference, and #dnnl_backward_data. |
154 | prop_kind_t prop_kind; |
155 | // Source or source gradient memory descriptor. |
156 | memory_desc_t src_desc; |
157 | // Destination or destination gradient memory descriptor. |
158 | memory_desc_t dst_desc; |
159 | // Axis for shuffling. |
160 | int axis; |
161 | // Number of groups. |
162 | dim_t group_size; |
163 | }; |
164 | |
165 | // A descriptor of resampling operation. |
166 | struct resampling_desc_t { |
167 | // The kind of primitive. Used for self-identifying the primitive |
168 | // descriptor. Must be #dnnl_resampling. |
169 | primitive_kind_t primitive_kind; |
170 | // The kind of propagation. Possible values: #dnnl_forward_training, |
171 | // #dnnl_forward_inference, #dnnl_backward_data, |
172 | prop_kind_t prop_kind; |
173 | // The kind of the resampling algorithm. Possible values: |
174 | // #dnnl_resampling_nearest, #dnnl_resampling_linear. |
175 | alg_kind_t alg_kind; |
176 | // Source memory descriptor. |
177 | memory_desc_t src_desc; |
178 | // Source gradient memory descriptor. |
179 | memory_desc_t diff_src_desc; |
180 | // Destination memory descriptor. |
181 | memory_desc_t dst_desc; |
182 | // Destination gradient memory descriptor. |
183 | memory_desc_t diff_dst_desc; |
184 | // Resampling factor in each spatial dimension. |
185 | float factors[DNNL_MAX_NDIMS]; |
186 | }; |
187 | |
188 | // A descriptor of a matrix multiplication operation. |
189 | // |
190 | // 2D case: |
191 | // dst[m, n] = src[m, k] * weights[k, n] + bias[m, n] |
192 | // |
193 | // 3D case: |
194 | // dst[mb, m, n] = src[mb, m, k] * weights[mb, k, n] + bias[mb, m, n] |
195 | struct matmul_desc_t { |
196 | // The kind of primitive. Used for self-identifying the primitive |
197 | // descriptor. Must be #dnnl_matmul. |
198 | primitive_kind_t primitive_kind; |
199 | // Source memory descriptor. |
200 | memory_desc_t src_desc; |
201 | // Weights memory descriptor. |
202 | memory_desc_t weights_desc; |
203 | // Bias memory descriptor. |
204 | memory_desc_t bias_desc; |
205 | // Destination memory descriptor. |
206 | memory_desc_t dst_desc; |
207 | // The accumulator data type. Initialized automatically. |
208 | data_type_t accum_data_type; |
209 | }; |
210 | |
211 | // A descriptor of a element-wise operation. |
212 | struct eltwise_desc_t { |
213 | // The kind of primitive. Used for self-identifying the primitive |
214 | // descriptor. Must be #dnnl_eltwise. |
215 | primitive_kind_t primitive_kind; |
216 | // The kind of propagation. Possible values: #dnnl_forward_training, |
217 | // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. |
218 | prop_kind_t prop_kind; |
219 | // The kind of eltwise algorithm. Possible values: #dnnl_eltwise_relu, |
220 | // #dnnl_eltwise_tanh, #dnnl_eltwise_elu, #dnnl_eltwise_square, |
221 | // #dnnl_eltwise_abs, #dnnl_eltwise_sqrt, #dnnl_eltwise_linear, |
222 | // #dnnl_eltwise_soft_relu, #dnnl_eltwise_logistic, #dnnl_eltwise_exp, |
223 | // #dnnl_eltwise_gelu_tanh, #dnnl_eltwise_swish, #dnnl_eltwise_log, |
224 | // #dnnl_eltwise_clip, #dnnl_eltwise_clip_v2, #dnnl_eltwise_pow, |
225 | // #dnnl_eltwise_gelu_erf, #dnnl_eltwise_round, |
226 | // #dnnl_eltwise_mish, #dnnl_eltwise_hardswish, #dnnl_eltwise_hardsigmoid. |
227 | // Possible values for passing destination memory on backward: |
228 | // #dnnl_eltwise_relu_use_dst_for_bwd, #dnnl_eltwise_tanh_use_dst_for_bwd, |
229 | // #dnnl_eltwise_elu_use_dst_for_bwd, #dnnl_eltwise_sqrt_use_dst_for_bwd, |
230 | // #dnnl_eltwise_logistic_use_dst_for_bwd, |
231 | // #dnnl_eltwise_exp_use_dst_for_bwd, |
232 | // #dnnl_eltwise_clip_v2_use_dst_for_bwd. |
233 | alg_kind_t alg_kind; |
234 | // Source memory descriptor. |
235 | memory_desc_t src_desc; |
236 | // Destination memory descriptor. |
237 | memory_desc_t dst_desc; |
238 | // Source gradient memory descriptor. |
239 | memory_desc_t diff_src_desc; |
240 | // Destination gradient memory descriptor. |
241 | memory_desc_t diff_dst_desc; |
242 | // Algorithm specific parameter. |
243 | // Accordance table: |
244 | // - #dnnl_eltwise_relu: @p alpha -- negative slope, @p beta ignored |
245 | // - #dnnl_eltwise_tanh: @p alpha and @p beta ignored |
246 | // - #dnnl_eltwise_elu: @p alpha -- negative slope, @p beta ignored |
247 | // - #dnnl_eltwise_square: @p alpha and @p beta ignored |
248 | // - #dnnl_eltwise_abs: @p alpha and @p beta ignored |
249 | // - #dnnl_eltwise_sqrt: @p alpha and @p beta ignored |
250 | // - #dnnl_eltwise_linear: @p alpha -- scale, @p beta -- shift |
251 | // - #dnnl_eltwise_soft_relu: @p alpha -- soft_relu arg scaling, @p beta ignored |
252 | // - #dnnl_eltwise_logistic: @p alpha and @p beta ignored |
253 | // - #dnnl_eltwise_exp: @p alpha and @p beta ignored |
254 | // - #dnnl_eltwise_gelu_tanh: @p alpha and @p beta ignored |
255 | // - #dnnl_eltwise_swish: @p alpha -- sigmoid arg scaling, @p beta ignored |
256 | // - #dnnl_eltwise_log: @p alpha and @p beta ignored |
257 | // - #dnnl_eltwise_clip: @p alpha -- lower bound, @p beta -- upper bound |
258 | // - #dnnl_eltwise_clip_v2: @p alpha -- lower bound, @p beta -- upper bound |
259 | // - #dnnl_eltwise_pow: @p alpha -- scale, @p beta -- exponent |
260 | // - #dnnl_eltwise_gelu_erf: @p alpha and @p beta ignored |
261 | // - #dnnl_eltwise_round: @p alpha and @p beta ignored |
262 | // - #dnnl_eltwise_mish: @p alpha and @p beta ignored |
263 | // - #dnnl_eltwise_hardswish: @p alpha and @p beta ignored |
264 | // - #dnnl_eltwise_hardsigmoid: @p alpha -- scale, @p beta -- shift |
265 | float alpha, beta; |
266 | }; |
267 | |
268 | // A descriptor of a Batch Normalization operation. |
269 | struct batch_normalization_desc_t { |
270 | // The kind of primitive. Used for self-identifying the primitive |
271 | // descriptor. Must be #dnnl_batch_normalization. |
272 | primitive_kind_t primitive_kind; |
273 | // The kind of propagation. Possible values: #dnnl_forward_training, |
274 | // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. |
275 | prop_kind_t prop_kind; |
276 | // Source memory descriptor. |
277 | memory_desc_t src_desc; |
278 | // Destination memory descriptor. |
279 | memory_desc_t dst_desc; |
280 | // Source gradient memory descriptor. |
281 | memory_desc_t diff_src_desc; |
282 | // Destination gradient memory descriptor. |
283 | memory_desc_t diff_dst_desc; |
284 | // Scale and/or shift data and gradient memory descriptor. |
285 | // Scaleshift memory descriptor uses 1D #dnnl_x format[Channels]. |
286 | memory_desc_t scaleshift_desc; |
287 | memory_desc_t diff_scaleshift_desc; |
288 | // Statistics memory descriptor. |
289 | // |
290 | // Statistics (mean or variance) descriptor use 1D #dnnl_x format[Channels]. |
291 | memory_desc_t stat_desc; |
292 | // Batch normalization epsilon parameter. |
293 | float batch_norm_epsilon; |
294 | unsigned flags; |
295 | }; |
296 | |
297 | // A descriptor of a Layer Normalization operation. |
298 | struct layer_normalization_desc_t { |
299 | // The kind of primitive. Used for self-identifying the primitive |
300 | // descriptor. Must be #dnnl_layer_normalization. |
301 | primitive_kind_t primitive_kind; |
302 | // The kind of propagation. Possible values: #dnnl_forward_training, |
303 | // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. |
304 | prop_kind_t prop_kind; |
305 | // Source memory descriptor. |
306 | memory_desc_t src_desc; |
307 | // Source gradient memory descriptor. |
308 | memory_desc_t diff_src_desc; |
309 | // Scale and shift data and gradient memory descriptors. |
310 | // Scaleshift memory descriptor uses 1D #dnnl_x format[normalized_dim]. |
311 | // Normalized_dim is equal to the last logical dimension of the source |
312 | // tensor across which normalization is performed. |
313 | memory_desc_t data_scaleshift_desc; |
314 | memory_desc_t diff_data_scaleshift_desc; |
315 | // Mean and variance data memory descriptors. |
316 | // |
317 | // Statistics (mean and variance) memory descriptor is the k-dimensional tensor |
318 | // where k is equal to data_tensor_ndims - 1 and may have any plain |
319 | // (stride[last_dim] == 1) user-provided format. |
320 | memory_desc_t stat_desc; |
321 | // Layer normalization epsilon parameter. |
322 | float layer_norm_epsilon; |
323 | unsigned flags; |
324 | // Destination memory descriptor. |
325 | memory_desc_t dst_desc; |
326 | // Destination gradient memory descriptor. |
327 | memory_desc_t diff_dst_desc; |
328 | }; |
329 | |
330 | // A descriptor of a Local Response Normalization (LRN) operation. |
331 | struct lrn_desc_t { |
332 | // The kind of primitive. Used for self-identifying the primitive |
333 | // descriptor. Must be #dnnl_lrn. |
334 | primitive_kind_t primitive_kind; |
335 | // The kind of propagation. Possible values: #dnnl_forward_training, |
336 | // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. |
337 | prop_kind_t prop_kind; |
338 | // LRN algorithm. Possible values: #dnnl_lrn_within_channel and |
339 | // #dnnl_lrn_across_channels. |
340 | alg_kind_t alg_kind; |
341 | // Source memory descriptor. |
342 | memory_desc_t src_desc; |
343 | // Destination memory descriptor. |
344 | memory_desc_t dst_desc; |
345 | // Source gradient memory descriptor. |
346 | memory_desc_t diff_src_desc; |
347 | // Destination gradient memory descriptor. |
348 | memory_desc_t diff_dst_desc; |
349 | // The number of channels to sum over (for cross-channel LRN) or the side |
350 | // length of the square region to sum over (for within-channel LRN). |
351 | dim_t local_size; |
352 | // LRN alpha parameter. |
353 | float lrn_alpha; |
354 | // LRN beta parameter. |
355 | float lrn_beta; |
356 | // LRN k parameter. |
357 | float lrn_k; |
358 | }; |
359 | |
360 | // A descriptor of reduction operation. |
361 | struct reduction_desc_t { |
362 | // The kind of primitive. Used for self-identifying the primitive |
363 | // descriptor. Must be #dnnl_reduction. |
364 | primitive_kind_t primitive_kind; |
365 | // The kind of reduction algorithm. Possible values: |
366 | // #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum, |
367 | // #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max, |
368 | // #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max, |
369 | // #dnnl_reduction_norm_lp_power_p_sum. |
370 | alg_kind_t alg_kind; |
371 | // Source memory descriptor. |
372 | memory_desc_t src_desc; |
373 | // Destination memory descriptor. |
374 | memory_desc_t dst_desc; |
375 | // Algorithm specific parameters. |
376 | // Accordance table: |
377 | // #dnnl_reduction_max: @p p and @p eps are ignored |
378 | // #dnnl_reduction_min: @p p and @p eps are ignored |
379 | // #dnnl_reduction_norm_lp_max: @p p -- power, @p eps -- epsilon |
380 | // #dnnl_reduction_norm_lp_sum: @p p -- power, @p eps -- epsilon |
381 | // #dnnl_reduction_norm_lp_power_p_max: @p p -- power, @p eps -- epsilon |
382 | // #dnnl_reduction_norm_lp_power_p_sum: @p p -- power, @p eps -- epsilon |
383 | // #dnnl_reduction_sum: @p p and @p eps are ignored |
384 | // #dnnl_reduction_mul: @p p and @p eps are ignored |
385 | // #dnnl_reduction_mean: @p p and @p eps are ignored |
386 | float p, eps; |
387 | }; |
388 | |
389 | /// A descriptor of a Softmax operation. |
390 | struct softmax_desc_t { |
391 | // The kind of primitive. Used for self-identifying the primitive |
392 | // descriptor. Must be #dnnl_softmax. |
393 | primitive_kind_t primitive_kind; |
394 | // The kind of propagation. Possible values: #dnnl_forward_training, |
395 | // #dnnl_forward_inference, and #dnnl_backward_data. |
396 | prop_kind_t prop_kind; |
397 | // Source memory descriptor. |
398 | memory_desc_t src_desc; |
399 | // Source gradient memory descriptor. |
400 | memory_desc_t diff_src_desc; |
401 | // The axis along which to perform the softmax. |
402 | int softmax_axis; |
403 | // Softmax algorithm. Possible values: #dnnl_softmax_accurate and |
404 | // #dnnl_softmax_log. |
405 | alg_kind_t alg_kind; |
406 | // Destination memory descriptor. |
407 | memory_desc_t dst_desc; |
408 | // Destination gradient memory descriptor. |
409 | memory_desc_t diff_dst_desc; |
410 | }; |
411 | |
412 | // A descriptor of a binary operation. |
413 | struct binary_desc_t { |
414 | // The kind of primitive. Used for self-identifying the primitive |
415 | // descriptor. Must be #dnnl_binary. |
416 | primitive_kind_t primitive_kind; |
417 | // The kind of the binary algorithm. Possible values: |
418 | // #dnnl_binary_add, #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, |
419 | // #dnnl_binary_div and #dnnl_binary_sub. |
420 | alg_kind_t alg_kind; |
421 | // Source memory descriptors. |
422 | memory_desc_t src_desc[2]; |
423 | // Destination memory descriptor. |
424 | memory_desc_t dst_desc; |
425 | }; |
426 | |
427 | /// A descriptor of a PReLU operation. |
428 | struct prelu_desc_t { |
429 | // The kind of primitive. Used for self-identifying the primitive |
430 | // descriptor. Must be #dnnl_prelu. |
431 | primitive_kind_t primitive_kind; |
432 | // The kind of propagation. Possible values: #dnnl_forward_training, |
433 | // #dnnl_forward_inference, #dnnl_backward |
434 | prop_kind_t prop_kind; |
435 | // Source memory descriptor. |
436 | memory_desc_t src_desc; |
437 | // Learnable parameter alpha memory descriptor. |
438 | // Alpha describes negative slope. |
439 | memory_desc_t weights_desc; |
440 | // Destination memory descriptor. |
441 | memory_desc_t dst_desc; |
442 | // Source gradient memory descriptor. |
443 | memory_desc_t diff_src_desc; |
444 | // Learnable parameter alpha gradient memory descriptor. |
445 | memory_desc_t diff_weights_desc; |
446 | // Destination gradient memory descriptor. |
447 | memory_desc_t diff_dst_desc; |
448 | }; |
449 | |
450 | // A descriptor of a pooling operation. |
451 | struct pooling_desc_t { |
452 | // The kind of primitive. Used for self-identifying the primitive |
453 | // descriptor. Must be #dnnl_pooling. |
454 | primitive_kind_t primitive_kind; |
455 | // The kind of propagation. Possible values: #dnnl_forward_training, |
456 | // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. |
457 | prop_kind_t prop_kind; |
458 | // The kind of pooling algorithm. |
459 | // Possible values: #dnnl_pooling_max, |
460 | // #dnnl_pooling_avg_include_padding, and |
461 | // #dnnl_pooling_avg_exclude_padding. |
462 | alg_kind_t alg_kind; |
463 | // Source memory descriptor. |
464 | memory_desc_t src_desc; |
465 | // Source gradient memory descriptor. |
466 | memory_desc_t diff_src_desc; |
467 | // Destination memory descriptor. |
468 | memory_desc_t dst_desc; |
469 | // Destination gradient memory descriptor. |
470 | memory_desc_t diff_dst_desc; |
471 | // Pooling kernel strides for spatial dimensions. |
472 | dims_t strides; |
473 | // Pooling kernel spatial dimensions. |
474 | dims_t kernel; |
475 | // Padding in each spatial dimension. padding[0] is a padding in the |
476 | // beginning (@p padding_l), padding[1] is a padding in the end (@p |
477 | // padding_r). |
478 | dims_t padding[2]; |
479 | // The accumulator data type. Initialized automatically. |
480 | data_type_t accum_data_type; |
481 | // Pooling dilations for spatial dimensions. |
482 | dims_t dilation; |
483 | }; |
484 | |
485 | // A descriptor for an RNN operation. |
486 | struct rnn_desc_t { |
487 | // The kind of primitive. Used for self-identifying the primitive |
488 | // descriptor. Must be #dnnl_rnn. |
489 | dnnl_primitive_kind_t primitive_kind; |
490 | // The kind of propagation. Possible values: #dnnl_forward_training, |
491 | // #dnnl_forward_inference, and #dnnl_backward. |
492 | prop_kind_t prop_kind; |
493 | // RNN cell kind. Must be one of #dnnl_vanilla_rnn, |
494 | // #dnnl_vanilla_lstm, #dnnl_vanilla_gru, or #dnnl_lbr_gru. |
495 | alg_kind_t cell_kind; |
496 | // The direction of RNN primitive execution. |
497 | rnn_direction_t direction; |
498 | // Source layer memory descriptor. |
499 | memory_desc_t src_layer_desc; |
500 | // Source iteration memory descriptor for hidden state. |
501 | memory_desc_t src_iter_desc; |
502 | // Source iteration memory descriptor for cell state. |
503 | memory_desc_t src_iter_c_desc; |
504 | // Weights layer memory descriptor. |
505 | memory_desc_t weights_layer_desc; |
506 | // Weights iteration memory descriptor. |
507 | memory_desc_t weights_iter_desc; |
508 | // Bias memory descriptor. |
509 | memory_desc_t bias_desc; |
510 | // Destination layer memory descriptor. |
511 | memory_desc_t dst_layer_desc; |
512 | // Destination iter memory descriptor for hidden state. |
513 | memory_desc_t dst_iter_desc; |
514 | // Destination iter memory descriptor for cell state. |
515 | memory_desc_t dst_iter_c_desc; |
516 | // Weights peephole memory descriptor. |
517 | // This memory descriptor is equal to zero memory descriptor in case of |
518 | // non-peephole LSTMs and other non-LSTM RNNs. |
519 | memory_desc_t weights_peephole_desc; |
520 | // Weights projection memory descriptor. |
521 | // This memory descriptor is equal to zero memory descriptor in case of |
522 | // non-projection LSTMs and other non-LSTM RNNs. |
523 | memory_desc_t weights_projection_desc; |
524 | |
525 | // Source gradient layer memory descriptor. |
526 | memory_desc_t diff_src_layer_desc; |
527 | // Source gradient iter memory descriptor for hidden state. |
528 | memory_desc_t diff_src_iter_desc; |
529 | // Source gradient iter memory descriptor for cell state. |
530 | memory_desc_t diff_src_iter_c_desc; |
531 | // Weights gradient layer memory descriptor. |
532 | memory_desc_t diff_weights_layer_desc; |
533 | // Weights gradient iter memory descriptor. |
534 | memory_desc_t diff_weights_iter_desc; |
535 | // Bias gradient memory descriptor. |
536 | memory_desc_t diff_bias_desc; |
537 | // Destination gradient layer memory descriptor. |
538 | memory_desc_t diff_dst_layer_desc; |
539 | // Destination gradient iteration memory descriptor for hidden state. |
540 | memory_desc_t diff_dst_iter_desc; |
541 | // Destination gradient iteration memory descriptor for cell state. |
542 | memory_desc_t diff_dst_iter_c_desc; |
543 | // Weights gradient peephole memory descriptor. |
544 | // This memory descriptor is equal to zero memory descriptor in case of |
545 | // non-peephole LSTMs and other non-LSTM RNNs. |
546 | memory_desc_t diff_weights_peephole_desc; |
547 | // Weights gradient projection memory descriptor. |
548 | // This memory descriptor is equal to zero memory descriptor in case of |
549 | // non-projection LSTMs and other non-LSTM RNNs. |
550 | memory_desc_t diff_weights_projection_desc; |
551 | |
552 | // RNN cell flags |
553 | unsigned int flags; |
554 | // Activation function used for vanilla_rnn cell kind. |
555 | // Must be either #dnnl_eltwise_relu or #dnnl_eltwise_tanh. |
556 | alg_kind_t activation_kind; |
557 | float alpha; |
558 | float beta; |
559 | }; |
560 | |
561 | struct op_desc_t { |
562 | union { |
563 | primitive_kind_t kind; |
564 | convolution_desc_t convolution; |
565 | deconvolution_desc_t deconvolution; |
566 | shuffle_desc_t shuffle; |
567 | pooling_desc_t pooling; |
568 | prelu_desc_t prelu; |
569 | eltwise_desc_t eltwise; |
570 | softmax_desc_t softmax; |
571 | lrn_desc_t lrn; |
572 | batch_normalization_desc_t batch_normalization; |
573 | layer_normalization_desc_t layer_normalization; |
574 | inner_product_desc_t inner_product; |
575 | rnn_desc_t rnn; |
576 | gemm_desc_t gemm; |
577 | concat_desc_t concat; |
578 | reorder_desc_t reorder; |
579 | sum_desc_t sum; |
580 | binary_desc_t binary; |
581 | matmul_desc_t matmul; |
582 | resampling_desc_t resampling; |
583 | zero_pad_desc_t zero_pad; |
584 | reduction_desc_t reduction; |
585 | }; |
586 | |
587 | #define DECL_CTOR_AND_CONVERTERS(c_type) \ |
588 | op_desc_t(const c_type &) = delete; \ |
589 | static op_desc_t *convert_from_c(c_type *_) { \ |
590 | return reinterpret_cast<op_desc_t *>(_); \ |
591 | } \ |
592 | static const op_desc_t *convert_from_c(const c_type *_) { \ |
593 | return reinterpret_cast<const op_desc_t *>(_); \ |
594 | } |
595 | |
596 | DECL_CTOR_AND_CONVERTERS(convolution_desc_t); |
597 | DECL_CTOR_AND_CONVERTERS(shuffle_desc_t); |
598 | DECL_CTOR_AND_CONVERTERS(pooling_desc_t); |
599 | DECL_CTOR_AND_CONVERTERS(prelu_desc_t); |
600 | DECL_CTOR_AND_CONVERTERS(eltwise_desc_t); |
601 | DECL_CTOR_AND_CONVERTERS(softmax_desc_t); |
602 | DECL_CTOR_AND_CONVERTERS(lrn_desc_t); |
603 | DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t); |
604 | DECL_CTOR_AND_CONVERTERS(layer_normalization_desc_t); |
605 | DECL_CTOR_AND_CONVERTERS(inner_product_desc_t); |
606 | DECL_CTOR_AND_CONVERTERS(rnn_desc_t); |
607 | DECL_CTOR_AND_CONVERTERS(gemm_desc_t); |
608 | DECL_CTOR_AND_CONVERTERS(concat_desc_t); |
609 | DECL_CTOR_AND_CONVERTERS(reorder_desc_t); |
610 | DECL_CTOR_AND_CONVERTERS(sum_desc_t); |
611 | DECL_CTOR_AND_CONVERTERS(binary_desc_t); |
612 | DECL_CTOR_AND_CONVERTERS(matmul_desc_t); |
613 | DECL_CTOR_AND_CONVERTERS(resampling_desc_t); |
614 | DECL_CTOR_AND_CONVERTERS(zero_pad_desc_t); |
615 | DECL_CTOR_AND_CONVERTERS(reduction_desc_t); |
616 | |
617 | // concat_desc_t and sum_desc_t have data members which have non-trivial |
618 | // special member functions hence the default destructor is implicitly |
619 | // deleted by the compiler which causes a warning on Windows so we should |
620 | // delete the destructor explicitly. |
621 | ~op_desc_t() = delete; |
622 | |
623 | #undef DECL_CTOR_AND_CONVERTERS |
624 | }; |
625 | |
626 | } // namespace impl |
627 | } // namespace dnnl |
628 | |
629 | #endif |
630 | |