1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @file
18/// C API
19
20#ifndef ONEAPI_DNNL_DNNL_H
21#define ONEAPI_DNNL_DNNL_H
22
23#include "oneapi/dnnl/dnnl_config.h"
24#include "oneapi/dnnl/dnnl_types.h"
25#include "oneapi/dnnl/dnnl_version.h"
26
27#ifdef __cplusplus
28extern "C" {
29#endif
30
31/// @addtogroup dnnl_api
32/// @{
33
34/// @addtogroup dnnl_api_primitives
35/// @{
36
37/// @addtogroup dnnl_api_primitives_common
38/// @{
39
40/// Creates a primitive descriptor iterator.
41///
42/// @param iterator Output primitive descriptor iterator.
43/// @param op_desc Operation descriptor.
44/// @param attr Primitive attributes (can be NULL).
45/// @param engine Engine to use.
46/// @param hint_forward_primitive_desc For backward propagation: primitive
47/// descriptor for a respective forward propagation primitive. Pass NULL
48/// for forward propagation.
49/// @returns #dnnl_success on success and a status describing the error
50/// otherwise.
51dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(
52 dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc,
53 const_dnnl_primitive_attr_t attr, dnnl_engine_t engine,
54 const_dnnl_primitive_desc_t hint_forward_primitive_desc);
55
56/// Advances the primitive descriptor iterator to point to the next available
57/// implementation.
58///
59/// @param iterator A primitive descriptor iterator to advance.
60/// @returns #dnnl_success on success and a status describing the error
61/// otherwise.
62/// @returns #dnnl_iterator_ends if no more implementations available.
63dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(
64 dnnl_primitive_desc_iterator_t iterator);
65
66/// Fetches the current primitive descriptor from a primitive descriptor
67/// iterator.
68///
69/// @note
70/// The user is responsible for deleting the resulting primitive
71/// descriptor using dnnl_primitive_desc_destroy().
72///
73/// @param iterator A primitive descriptor iterator.
74/// @returns A primitive descriptor.
75dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(
76 const_dnnl_primitive_desc_iterator_t iterator);
77
78/// Destroys a primitive descriptor iterator.
79///
80/// @param iterator Primitive descriptor iterator to destroy.
81/// @returns #dnnl_success on success and a status describing the error
82/// otherwise.
83dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(
84 dnnl_primitive_desc_iterator_t iterator);
85
86/// Creates a primitive descriptor. This function is equivalent to a sequence
87/// of #dnnl_primitive_desc_iterator_create() and
88/// #dnnl_primitive_desc_iterator_fetch(). In other words, the library will
89/// pick the first suitable implementation.
90///
91/// @param primitive_desc Output primitive descriptor.
92/// @param op_desc Operation descriptor.
93/// @param attr Primitive attributes (can be NULL).
94/// @param engine Engine to use.
95/// @param hint_forward_primitive_desc For backward propagation: primitive
96/// descriptor for a respective forward propagation primitive. Pass NULL
97/// for forward propagation.
98/// @returns #dnnl_success on success and a status describing the error
99/// otherwise.
100dnnl_status_t DNNL_API dnnl_primitive_desc_create(
101 dnnl_primitive_desc_t *primitive_desc, const_dnnl_op_desc_t op_desc,
102 const_dnnl_primitive_attr_t attr, dnnl_engine_t engine,
103 const_dnnl_primitive_desc_t hint_forward_primitive_desc);
104
105/// Clones a primitive descriptor. The resulting primitive descriptor must be
106/// destroyed separately.
107///
108/// @param primitive_desc Output primitive descriptor.
109/// @param existing_primitive_desc Primitive descriptor to clone.
110/// @returns #dnnl_success on success and a status describing the error
111/// otherwise.
112dnnl_status_t DNNL_API dnnl_primitive_desc_clone(
113 dnnl_primitive_desc_t *primitive_desc,
114 const_dnnl_primitive_desc_t existing_primitive_desc);
115
116/// Returns a constant reference to the attributes of a primitive descriptor.
117///
118/// @warning
119/// It is an error to destroy the resulting @p attr.
120///
121/// @warning
122/// The lifetime of an @p attr is the same as that of a @p
123/// primitive_desc, so it is an error to use the @p attr once the @p
124/// primitive_desc has been destroyed.
125///
126/// @param primitive_desc Primitive descriptor.
127/// @param attr Output primitive attributes.
128/// @returns #dnnl_success on success and a status describing the error
129/// otherwise.
130dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(
131 const_dnnl_primitive_desc_t primitive_desc,
132 const_dnnl_primitive_attr_t *attr);
133
134/// Destroys a primitive descriptor.
135///
136/// @param primitive_desc Primitive descriptor to destroy.
137/// @returns #dnnl_success on success and a status describing the error
138/// otherwise.
139dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(
140 dnnl_primitive_desc_t primitive_desc);
141
142/// Queries a primitive descriptor for various pieces of information.
143///
144/// The most common use case is to query a primitive descriptor, created with
145/// source, weights, and destination memory descriptors with format tags set
146/// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this
147/// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and
148/// #dnnl_query_dst_md respectively) so that it is possible to create memory
149/// objects and reorder primitives if necessary.
150///
151/// Another typical use case is to query a primitive descriptor for workspace
152/// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this
153/// query returns #dnnl_not_required status, then workspace memory is not
154/// required.
155///
156/// @note
157/// When querying for a memory descriptor for a scratchpad, a workspace,
158/// or an optional parameter, the query will return a pointer to a zero
159/// memory descriptor if the parameter is not needed.
160///
161/// A few other use cases:
162/// - query a primitive descriptor for the underlying operation descriptor
163/// (#dnnl_query_convolution_d, #dnnl_query_eltwise_d, #dnnl_query_rnn_d,
164/// etc.)
165/// - query a primitive descriptor for the implementation information string
166/// (#dnnl_query_impl_info_str)
167/// - query a primitive descriptor for the number of inputs and outputs
168/// (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32
169/// respectively)
170///
171/// @sa dnnl_query_t for more options
172///
173/// @param primitive_desc Primitive descriptor.
174/// @param what Parameter to query.
175/// @param index Index of the parameter to query for.
176/// @param result Output result. The type depends on the query. For example,
177/// it must be a @c dnnl_memory_desc_t* if querying for a memory
178/// descriptor.
179/// @returns #dnnl_success on success and a status describing the error
180/// otherwise.
181dnnl_status_t DNNL_API dnnl_primitive_desc_query(
182 const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
183 int index, void *result);
184
185/// Queries primitive descriptor for a memory descriptor.
186///
187/// @note
188/// This function is a convenience version of
189/// #dnnl_primitive_desc_query().
190///
191/// @param primitive_desc Primitive descriptor.
192/// @param what Kind of memory descriptor parameter to query for.
193/// @param index Index of the parameter to query.
194/// @returns A pointer to the requested memory descriptor.
195/// @returns A pointer to a zero memory descriptor if the parameter is not
196/// needed.
197/// @returns NULL in case of any error.
198///
199const dnnl_memory_desc_t DNNL_API *dnnl_primitive_desc_query_md(
200 const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
201 int index);
202
203/// Queries primitive descriptor for a signed 32bit int.
204///
205/// @note
206/// This function is a convenience version of
207/// #dnnl_primitive_desc_query().
208///
209/// @param primitive_desc Primitive descriptor.
210/// @param what Kind of the value to query for.
211/// @param index Index of the parameter to query.
212/// @returns The requested value.
213/// @returns 0 in case of any error (in particular if the queried entity is
214/// not of type int32_t). Note that 0 may also be the actual returned
215/// value.
216int DNNL_API dnnl_primitive_desc_query_s32(
217 const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
218 int index);
219
220/// Creates a primitive.
221///
222/// @param primitive Output primitive.
223/// @param primitive_desc Primitive descriptor used to create the primitive.
224/// @returns #dnnl_success on success and a status describing the error
225/// otherwise.
226dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive,
227 const_dnnl_primitive_desc_t primitive_desc);
228
229/// Creates a primitive from a cache blob.
230///
231/// @param primitive Output primitive.
232/// @param primitive_desc Primitive descriptor used to create the primitive.
233/// @param size Size of the cache blob in bytes.
234/// @param cache_blob Cache blob of size @p size.
235/// @returns #dnnl_success on success and a status describing the error
236/// otherwise.
237dnnl_status_t DNNL_API dnnl_primitive_create_from_cache_blob(
238 dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc,
239 size_t size, const uint8_t *cache_blob);
240
241/// Executes a primitive.
242///
243/// @param primitive Primitive to execute.
244/// @param stream Stream to use.
245/// @param nargs Number of arguments.
246/// @param args Array of arguments. Each argument is an
247/// <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
248/// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
249/// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
250/// descriptor as that returned by
251/// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
252/// @returns #dnnl_success on success and a status describing the error
253/// otherwise.
254
255/// @note If any argument in @param args is padded (padded_dims >
256/// dims), the primitive execution will assume properly zero-padded
257/// input arguments, and produce zero-padded output arguments.
258dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive,
259 dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args);
260
261/// Retrieves a constant reference to the primitive descriptor of a given
262/// primitive.
263///
264/// @warning
265/// It is an error to destroy the returned object. It is owned by the
266/// primitive. The @c const qualifier of the returned object prevents
267/// such attempts.
268///
269/// @param primitive Primitive to query for the primitive descriptor.
270/// @param primitive_desc Output primitive descriptor.
271/// @returns #dnnl_success on success and a status describing the error
272/// otherwise.
273dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(
274 const_dnnl_primitive_t primitive,
275 const_dnnl_primitive_desc_t *primitive_desc);
276
277/// Retrieves a cache blob associated with the given primitive.
278///
279/// @param primitive Primitive to query for the cache blob.
280/// @param size Size of the cache blob in bytes.
281/// @param cache_blob Cache blob of size @p size. If the @p cache_blob is
282/// nullptr then the size of the cache blob is returned in @p size.
283/// @returns #dnnl_success on success and a status describing the error
284/// otherwise.
285///
286/// @note The cache blob can be empty. It's the user's responsibility to check
287/// whether it's empty prior to passing it to
288/// #dnnl_primitive_create_from_cache_blob().
289dnnl_status_t DNNL_API dnnl_primitive_get_cache_blob(
290 const_dnnl_primitive_t primitive, size_t *size, uint8_t *cache_blob);
291
292/// Destroys a primitive.
293///
294/// @param primitive The primitive to destroy.
295/// @returns #dnnl_success on success and a status describing the error
296/// otherwise.
297dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive);
298
299/// @} dnnl_api_primitives_common
300
301/// @addtogroup dnnl_api_attributes
302/// @{
303
304/// Creates an empty (default) primitive attributes with all the parameters
305/// set to their default values.
306///
307/// Empty attributes are implied whenever the respective argument is NULL.
308///
309/// @param attr Output primitive attributes.
310/// @returns #dnnl_success on success and a status describing the error
311/// otherwise.
312dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr);
313
314/// Clones primitive attributes.
315///
316/// @param attr Output primitive attributes.
317/// @param existing_attr Primitive attributes to clone.
318/// @returns #dnnl_success on success and a status describing the error
319/// otherwise.
320dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
321 dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr);
322
323/// Destroys primitive attributes.
324///
325/// @param attr Primitive attributes to destroy.
326/// @returns #dnnl_success on success and a status describing the error
327/// otherwise.
328dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);
329
330/// Returns the floating-point math mode primitive attribute.
331///
332/// @param attr Primitive attributes.
333/// @param mode Output FP math mode.
334/// @returns #dnnl_success on success and a status describing the error
335/// otherwise.
336dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode(
337 const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode);
338
339/// Sets the floating-point math mode primitive attributes.
340///
341/// @param attr Primitive attributes.
342/// @param mode FP math mode. The possible values are:
343/// #dnnl_fpmath_mode_strict (default),
344/// #dnnl_fpmath_mode_bf16,
345/// #dnnl_fpmath_mode_f16,
346/// #dnnl_fpmath_mode_tf32,
347/// #dnnl_fpmath_mode_any.
348/// @returns #dnnl_success on success and a status describing the error
349/// otherwise.
350dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode(
351 dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode);
352
353/// Returns the primitive attributes scratchpad mode.
354///
355/// @param attr Primitive attributes.
356/// @param mode Output scratchpad mode.
357/// @returns #dnnl_success on success and a status describing the error
358/// otherwise.
359dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(
360 const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode);
361
362/// Sets primitive attributes scratchpad mode.
363///
364/// @param attr Primitive attributes.
365/// @param mode Scratchpad mode. The possible values are:
366/// #dnnl_scratchpad_mode_library (default) and
367/// #dnnl_scratchpad_mode_user.
368/// @returns #dnnl_success on success and a status describing the error
369/// otherwise.
370dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
371 dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode);
372
373/// Returns primitive attributes output scaling factors correspondence mask
374/// and values.
375///
376/// @warning
377/// The @p scales array is an internal part of the primitive attributes
378/// @p attr, so it is an error to modify or destroy the @p scales array.
379///
380/// @warning
381/// The lifetime of @p scales array is the same as that of the primitive
382/// attributes @p attr to which it belongs, so it is an error to use
383/// @p scales after @p attr is destroyed.
384///
385/// @param attr Primitive attributes.
386/// @param count Output length of the array of scaling factors @p scales.
387/// @param mask Output scaling factors correspondence mask that defines the
388/// correspondence between the output tensor dimensions and the @p scales
389/// vector. The set i-th bit indicates that a dedicated output scaling
390/// factor is used for each index along that dimension. The mask value of
391/// 0 implies a common output scaling factor for the whole output tensor.
392/// @param scales Output pointer to a constant array of scaling factors.
393/// @returns #dnnl_success on success and a status describing the error
394/// otherwise.
395dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(
396 const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
397 const float **scales);
398
399/// Sets output scaling factors correspondence mask and values.
400///
401/// @note
402/// The order of dimensions does not depend on how elements are laid
403/// out in memory. For example:
404/// - for a 2D CNN activations tensor the order is always (n, c)
405/// - for a 4D CNN activations tensor the order is always (n, c, h, w)
406/// - for a 5D CNN weights tensor the order is always
407/// (g, oc, ic, kh, kw)
408///
409/// Example usage:
410/// @code
411/// int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params
412/// float scales[oc] = { ... }; // unique output scales per output channel
413/// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
414///
415/// dnnl_convolution_desc_t conv_d; // create a convolution descriptor
416///
417/// dnnl_primitive_attr_t attr;
418/// dnnl_primitive_attr_create(&attr); // create primitive attributes
419/// dnnl_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales);
420///
421/// dnnl_primitive_desc_t conv_pd;
422/// dnnl_primitive_desc_create(&conv_pd, &conv_d, attr, engine, NULL);
423/// @endcode
424///
425/// @param attr Primitive attributes.
426/// @param count Length of the array of scaling factors @p scales.
427/// @param mask Scaling factors correspondence mask that defines the
428/// correspondence between the output tensor dimensions and the @p scales
429/// array. The set i-th bit indicates that a dedicated output scaling
430/// factor is used for each index along that dimension. The mask value of
431/// 0 implies a common output scaling factor for the whole output tensor.
432/// @param scales Array of output scaling factors. If the output scaling
433/// factors are known at the time of this call, this array must contain @p
434/// count values and the following equality must hold:
435/// \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
436/// Violations can only be detected when the attributes are used to create
437/// a primitive descriptor.
438/// If the output scaling factors are not known at the time of the call,
439/// this array must contain a single #DNNL_RUNTIME_F32_VAL value and the
440/// output scaling factors must be passed at execution time as an argument
441/// with index #DNNL_ARG_ATTR_OUTPUT_SCALES.
442/// @returns #dnnl_success on success and a status describing the error
443/// otherwise.
444dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(
445 dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
446 const float *scales);
447
448/// Returns primitive attributes scaling factors correspondence mask and values
449/// for a given memory argument.
450///
451/// @warning
452/// The output @p scales array is an internal part of the primitive
453/// attributes @p attr, so it is an error to modify or destroy the @p
454/// scales array.
455///
456/// @warning
457/// The lifetime of the @p scales array is the same as that of the primitive
458/// attributes @p attr to which it belongs, so it is an error to use @p
459/// scales after @p attr is destroyed.
460///
461///
462/// @param attr Primitive attributes.
463/// @param arg Parameter argument index as passed to the
464/// dnnl_primitive_execute() call.
465/// @param count Output length of the array of scaling factors @p scales.
466/// @param mask Output scaling factors correspondence mask that defines the
467/// correspondence between the output tensor dimensions and the @p
468/// scales array. The set i-th bit indicates that a dedicated output scaling
469/// factor is used for each index along that dimension. The mask value of 0
470/// implies a common scaling factor for the whole output tensor.
471/// @param scales Output pointer to a constant array of float scaling factors.
472/// @returns #dnnl_success on success and a status describing the error
473/// otherwise.
474dnnl_status_t DNNL_API dnnl_primitive_attr_get_scales(
475 dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask,
476 const float **scales);
477
478/// Sets primitive attributes scaling factors for primitive operations for a
479/// given memory argument.
480///
481/// @sa dnnl_primitive_attr_set_output_scales
482///
483///
484/// @param attr Primitive attributes.
485/// @param arg Parameter argument index as passed to the
486/// dnnl_primitive_execute() call.
487/// @param count Length of the array of scaling factors @p scales.
488/// @param mask Scaling factors correspondence mask that defines the
489/// correspondence between the tensor dimensions and the @p scales array.
490/// The set i-th bit indicates that a dedicated scaling factor is used for
491/// each index along that dimension. Set the mask to 0 to use a common
492/// scaling factor for the whole output tensor.
493/// @param scales Constant array of float scaling factors. This array must
494/// contain @p count scales and the following equality must hold:
495/// \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
496/// @returns #dnnl_success on success and a status describing the error
497/// otherwise.
498dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(
499 dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask,
500 const float *scales);
501
502/// Returns @p count, correspondence zero point @p mask, and a pointer to a
503/// constant int32_t array of @p zero_points for given @p attr and memory
504/// argument (index), previously set by dnnl_primitive_attr_set_zero_points.
505///
506/// @warning
507/// The output @p zero_points array is an internal part of the primitive
508/// attributes @p attr, so it is an error to modify or destroy the @p
509/// zero_points array.
510///
511/// @warning
512/// The lifetime of @p zero_points array is the same as that of the
513/// primitive attributes @p attr to which it belongs, so it is an error
514/// to use @p zero_points after @p attr is destroyed.
515///
516///
517/// @param attr Primitive attributes.
518/// @param arg Parameter argument index as passed to the
519/// dnnl_primitive_execute() call.
520/// @param count Output length of the array of zero points @p zero_points.
521/// @param mask Output zero points correspondence mask that defines the
522/// correspondence between the output tensor dimensions and the @p
523/// zero_points array. The set i-th bit indicates that a dedicated output
524/// zero point is used for each index along that dimension. The mask
525/// value of 0 implies a common zero point for the whole output tensor.
526/// @param zero_points Output pointer to a constant array of int32_t zero
527/// points.
528/// @returns #dnnl_success on success and a status describing the error
529/// otherwise.
530dnnl_status_t DNNL_API dnnl_primitive_attr_get_zero_points(
531 const_dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask,
532 const int32_t **zero_points);
533
534/// Sets primitive attributes zero points for primitive operations for a given
535/// memory argument.
536///
537/// @sa dnnl_primitive_attr_set_output_scales
538///
539///
540/// @param attr Primitive attributes.
541/// @param arg Parameter argument index as passed to the
542/// dnnl_primitive_execute() call.
543/// @param count Length of the array of zero points @p zero_points.
544/// @param mask Zero point correspondence mask that defines the
545/// correspondence between the tensor dimensions and the @p
546/// zero_points array. The set i-th bit indicates that a dedicated
547/// zero point is used for each index along that dimension. Set the
548/// mask to 0 to use a common zero point for the whole output tensor.
549/// @param zero_points Constant array of int32_t zero points. If the zero
550/// points are known at the time of this call, this array must contain @p
551/// count zero points and the following equality must hold:
552/// \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
553/// If the zero points are not known at the time of the call, this array
554/// must contain a single #DNNL_RUNTIME_S32_VAL and the zero points must
555/// be passed at execution time as an argument with index
556/// #DNNL_ARG_ATTR_ZERO_POINTS.
557/// @returns #dnnl_success on success and a status describing the error
558/// otherwise.
559dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(
560 dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask,
561 const int32_t *zero_points);
562
563/// Returns primitive attributes post-ops.
564///
565/// @warning
566/// The output @p post_ops points to the internal @p attr field, so it is
567/// an error to modify or destroy them. The lifetime of @p post_ops is
568/// the same as that of the @p attr it belongs to, so it is an error to
569/// use @p post_ops after @p attr has been destroyed.
570///
571/// @param attr Primitive attributes.
572/// @param post_ops Output post-ops.
573/// @returns #dnnl_success on success and a status describing the error
574/// otherwise.
575dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(
576 const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops);
577
578/// Sets primitive attributes post-ops.
579///
580/// @note
581/// There is no way to check whether the post-ops would be supported by
582/// the target primitive. Any error will be reported by the
583/// dnnl_primitive_desc_create() function call.
584///
585/// @param attr Primitive attributes.
586/// @param post_ops Post-ops to set.
587/// @returns #dnnl_success on success and a status describing the error
588/// otherwise.
589dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(
590 dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops);
591
592/// Creates empty post-ops sequence.
593///
594/// @param post_ops Output post-ops.
595/// @returns #dnnl_success on success and a status describing the error
596/// otherwise.
597dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops);
598
599/// Clones post-ops primitive attribute.
600///
601/// @param post_ops Output post-ops primitive attribute.
602/// @param existing_post_ops Post-ops primitive attribute to clone.
603/// @returns #dnnl_success on success and a status describing the error
604/// otherwise.
605dnnl_status_t DNNL_API dnnl_post_ops_clone(
606 dnnl_post_ops_t *post_ops, const_dnnl_post_ops_t existing_post_ops);
607
608/// Destroys post-ops.
609///
610/// @param post_ops Post-ops to destroy.
611/// @returns #dnnl_success on success and a status describing the error
612/// otherwise.
613dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops);
614
615/// Returns the length of post-ops.
616///
617/// @param post_ops Post-ops.
618/// @returns The number of post-ops entries.
619int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops);
620
621/// Returns the kind of a post-op entry.
622///
623/// @param post_ops Post-ops.
624/// @param index Post-op entry index.
625/// @returns The kind of the post-op with the specified index.
626/// @returns #dnnl_undefined_primitive if there is no post-op at the specified
627/// index.
628dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(
629 const_dnnl_post_ops_t post_ops, int index);
630
631/// Appends an accumulation (sum) to post-ops. Prior to accumulating the
632/// result, the previous value is multiplied by a scale.
633///
634/// The kind of this post-op is #dnnl_sum.
635///
636/// This feature may improve performance for cases like residual learning
637/// blocks, where the result of convolution is accumulated to the previously
638/// computed activations. The parameter @p scale may be used for the
639/// integer-based computations when the result and previous activations have
640/// different logical scaling factors.
641///
642/// In the simplest case where the accumulation is the only post-op, the
643/// computations will be:
644///
645/// dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...)
646///
647/// @note
648/// This post-op executes in-place and does not change the
649/// destination layout.
650///
651/// @param post_ops Post-ops.
652/// @param scale Accumulation scaling factor.
653/// @returns #dnnl_success on success and a status describing the error
654/// otherwise.
655dnnl_status_t DNNL_API dnnl_post_ops_append_sum(
656 dnnl_post_ops_t post_ops, float scale);
657
658/// Appends an accumulation v2 (sum) to post-ops. Prior to accumulating the
659/// result, the previous value is multiplied by a scale.
660///
661/// The kind of this post-op is #dnnl_sum.
662///
663/// This feature may improve performance for cases like residual learning
664/// blocks, where the result of convolution is accumulated to the previously
665/// computed activations. The parameter @p scale may be used for the
666/// integer-based computations when the result and previous activations have
667/// different logical scaling factors.
668///
669/// In the simplest case where the accumulation is the only post-op, the
670/// computations will be:
671///
672/// dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...)
673///
674/// If @p data_type is specified, original dst tensor will be reinterpreted
675/// as a tensor with provided data type. Since it is reinterpretation,
676/// data_type and dst data type should have the same size.
677/// As a result, computations will be:
678///
679/// dst[:] <- scale * as_data_type(dst[:]) + op(...)
680/// // instead of dst[:] <- op(...)
681/// @note
682/// This post-op executes in-place and does not change the
683/// destination layout.
684///
685/// @param post_ops Post-ops.
686/// @param scale Accumulation scaling factor.
687/// @param data_type Accumulation data_type.
688/// @returns #dnnl_success on success and a status describing the error
689/// otherwise.
690dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v2(
691 dnnl_post_ops_t post_ops, float scale, dnnl_data_type_t data_type);
692
693/// Appends an accumulation v3 (sum) to post-ops. Prior to accumulating the
694/// result, a zero point is subtracted from the previous value and is
695/// multiplied by the scale.
696///
697/// The kind of this post-op is #dnnl_sum.
698///
699/// This feature may improve performance for cases like dequantize the
700/// asymmetrically quantized sum's src1 tensor to f32 domain before performing
701/// the sum operation by subtracting the @p zero_point before the scaling.
702///
703/// In the simplest case where accumulation is the only post-op, the
704/// computations will be:
705///
706/// dst[:] <- scale * (dst[:] - zero_point) + op(...)
707/// // instead of dst[:] <- op(...)
708///
709/// If @p data_type is specified, original dst tensor will be reinterpreted
710/// as a tensor with provided data type. Since it is reinterpretation,
711/// data_type and dst data type should have the same size.
712/// As a result, computations will be:
713///
714/// dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...)
715/// // instead of dst[:] <- op(...)
716/// @note
717/// This post-op executes in-place and does not change the
718/// destination layout.
719///
720/// @param post_ops Post-ops.
721/// @param scale Accumulation scaling factor.
722/// @param zero_point Single scalar int32_t value of zero point.
723/// @param data_type Accumulation data_type.
724/// @returns #dnnl_success on success and a status describing the error
725/// otherwise.
726dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v3(dnnl_post_ops_t post_ops,
727 float scale, int32_t zero_point, dnnl_data_type_t data_type);
728
729/// Returns the parameters of an accumulation (sum) post-op.
730///
731/// @param post_ops Post-ops.
732/// @param index Index of the sum post-op.
733/// @param scale Output accumulation scaling factor.
734/// @returns #dnnl_success on success and a status describing the error
735/// otherwise.
736/// @returns #dnnl_invalid_arguments if @p index does not refer to a sum
737/// post-op.
738dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(
739 const_dnnl_post_ops_t post_ops, int index, float *scale);
740
741/// Returns the parameters of an accumulation (sum) post-op with
742/// a data type parameter.
743///
744/// @param post_ops Post-ops.
745/// @param index Index of the sum post-op.
746/// @param scale Output accumulation scaling factor.
747/// @param data_type Data type for accumulation.
748/// @returns #dnnl_success on success and a status describing the error
749/// otherwise.
750dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v2(
751 const_dnnl_post_ops_t post_ops, int index, float *scale,
752 dnnl_data_type_t *data_type);
753
754/// Returns the parameters of an accumulation (sum) post-op with
755/// zero point and data type parameter.
756///
757/// @param post_ops Post-ops.
758/// @param index Index of the sum post-op.
759/// @param scale Output accumulation scaling factor.
760/// @param zero_point Zero point.
761/// @param data_type Data type for accumulation.
762/// @returns #dnnl_success on success and a status describing the error
763/// otherwise.
764dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v3(
765 const_dnnl_post_ops_t post_ops, int index, float *scale,
766 int32_t *zero_point, dnnl_data_type_t *data_type);
767
768/// Appends an elementwise post-op.
769///
770/// The kind of this post operation is #dnnl_eltwise.
771///
772/// In the simplest case when the elementwise is the only post operation, the
773/// computations would be:
774///
775/// dst[:] <- scale * eltwise_op (op(...)) // instead of dst[:] <- op(...)
776///
777/// where eltwise_op is configured with the given parameters.
778///
779/// @param post_ops Post-ops.
780/// @param scale Scaling factor.
781/// @param alg_kind Elementwise algorithm for the post-op.
782/// @param alpha Alpha parameter for the elementwise algorithm.
783/// @param beta Beta parameter for the elementwise algorithm.
784/// @returns #dnnl_success on success and a status describing the error
785/// otherwise.
786dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops,
787 float scale, dnnl_alg_kind_t alg_kind, float alpha, float beta);
788
789/// Returns the parameters of an elementwise post-op.
790///
791/// @param post_ops Post-ops.
792/// @param index Index of the elementwise post-op.
793/// @param scale Output scaling factor.
794/// @param alg_kind Output elementwise algorithm kind.
795/// @param alpha Output alpha parameter for the elementwise algorithm.
796/// @param beta Output beta parameter for the elementwise algorithm.
797/// @returns #dnnl_success on success and a status describing the error
798/// otherwise.
799/// @returns #dnnl_invalid_arguments if @p index does not refer to an
800/// elementwise post-op.
801dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
802 const_dnnl_post_ops_t post_ops, int index, float *scale,
803 dnnl_alg_kind_t *alg_kind, float *alpha, float *beta);
804
805/// Appends a depthwise post-op convolution.
806///
807/// This post-op can only be fused with a 2D 1x1 convolution (convolution with
808/// weights spatial dimensions equal to 1 i.e., kh=kw=1).
809///
810/// The kind of this post-op is #dnnl_convolution.
811///
812/// The number of outputs for primitive with fusion is one. The output spatial
813/// size can be derived as below:
814///
815/// output_height = ceil(output_height_1x1_convolution, stride)
816/// output_width = ceil(output_width_1x1_convolution, stride)
817///
818/// See @ref dev_guide_attributes_post_ops_depthwise and
819/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
820///
821/// @param post_ops Post-ops.
822/// @param weights_data_type Weights data type of depthwise post-op
823/// @param bias_data_type Bias data type of depthwise post-op
824/// @param dst_data_type Output data type of depthwise post-op
825/// @param kernel_size Size of kernel of depthwise post-op
826/// @param stride_size Size of stride of depthwise post-op
827/// @param padding_l_size Size of left and top paddings of depthwise post-op
828/// @param count Output length of the array of scaling factors @p scales.
829/// @param mask Output scaling factors correspondence mask that defines the
830/// correspondence between the output tensor dimensions and the @p
831/// scales array. The set i-th bit indicates that a dedicated output scaling
832/// factor is used for each index along that dimension. The mask value of 0
833/// implies a common scaling factor for the whole output tensor.
834/// @param scales Output pointer to a constant array of float scaling factors.
835/// @returns #dnnl_success on success and a status describing the error
836/// otherwise
837dnnl_status_t DNNL_API dnnl_post_ops_append_dw(dnnl_post_ops_t post_ops,
838 dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
839 dnnl_data_type_t dst_data_type, dnnl_dim_t kernel_size,
840 dnnl_dim_t stride_size, dnnl_dim_t padding_l_size, dnnl_dim_t count,
841 int mask, const float *scales);
842
843/// Returns the parameters of an depthwise post-op.
844///
845/// @param post_ops Post-ops.
846/// @param index Index of the elementwise post-op.
847/// @param weights_data_type Weights data type of depthwise post-op
848/// @param bias_data_type Bias data type of depthwise post-op
849/// @param dst_data_type Output data type of depthwise post-op
850/// @param kernel_size Size of kernel of depthwise post-op
851/// @param stride_size Size of stride of depthwise post-op
852/// @param padding_l_size Size of left and top paddings of depthwise post-op
853/// @param count Output length of the array of scaling factors @p scales.
854/// @param mask Output scaling factors correspondence mask that defines the
855/// correspondence between the output tensor dimensions and the @p
856/// scales array. The set i-th bit indicates that a dedicated output scaling
857/// factor is used for each index along that dimension. The mask value of 0
858/// implies a common scaling factor for the whole output tensor.
859/// @param scales Output pointer to a constant array of float scaling factors.
860/// @returns #dnnl_success on success and a status describing the error
861/// otherwise
862dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
863 const_dnnl_post_ops_t post_ops, int index,
864 dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
865 dnnl_data_type_t *dst_data_type, dnnl_dim_t *kernel_size,
866 dnnl_dim_t *stride_size, dnnl_dim_t *padding_l_size, dnnl_dim_t *count,
867 int *mask, const float **scales);
868
869/// Appends a depthwise post-op convolution with stride 1.
870///
871/// This post-op can only be fused with a 2D 1x1 convolution (convolution with
872/// weights spatial dimension equal to 1 i.e., kh=kw=1).
873///
874/// The kind of this post-op is #dnnl_convolution.
875///
876/// The number of outputs for primitive remain same as before. The output size
877/// remain same as the original primitive due to stride=1.
878///
879/// The Post-op can be defined as:
880///
881/// dst[:] <- scales * (conv_dw(conv_1x1))
882///
883/// See @ref dev_guide_attributes_post_ops_depthwise and
884/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
885///
886/// @param post_ops Post-ops.
887/// @param weights_data_type Weights data type of depthwise post-op
888/// @param bias_data_type Bias data type of depthwise post-op
889/// @param dst_data_type Output data type of depthwise post-op
890/// @param count Output length of the array of scaling factors @p scales.
891/// @param mask Output scaling factors correspondence mask that defines the
892/// correspondence between the output tensor dimensions and the @p
893/// scales array. The set i-th bit indicates that a dedicated output scaling
894/// factor is used for each index along that dimension. The mask value of 0
895/// implies a common scaling factor for the whole output tensor.
896/// @param scales Output pointer to a constant array of float scaling factors.
897/// @returns #dnnl_success on success and a status describing the error
898/// otherwise
899dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s1p1(dnnl_post_ops_t post_ops,
900 dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
901 dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask,
902 const float *scales);
903
904/// Returns the parameters of an depthwise post-op with stride 1.
905///
906/// @param post_ops Post-ops.
907/// @param index Index of the elementwise post-op.
908/// @param weights_data_type Weights data type of depthwise post-op
909/// @param bias_data_type Bias data type of depthwise post-op
910/// @param dst_data_type Output data type of depthwise post-op
911/// @param count Output length of the array of scaling factors @p scales.
912/// @param mask Output scaling factors correspondence mask that defines the
913/// correspondence between the output tensor dimensions and the @p
914/// scales array. The set i-th bit indicates that a dedicated output scaling
915/// factor is used for each index along that dimension. The mask value of 0
916/// implies a common scaling factor for the whole output tensor.
917/// @param scales Output pointer to a constant array of float scaling factors.
918/// @returns #dnnl_success on success and a status describing the error
919/// otherwise
920dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s1p1(
921 const_dnnl_post_ops_t post_ops, int index,
922 dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
923 dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask,
924 const float **scales);
925
926/// Appends a depthwise post-op convolution with stride 2.
927///
928/// This post-op can only be fused with a 2D 1x1 convolution (convolution with
929/// weights spatial dimension equal to 1 i.e., kh=kw=1).
930///
931/// The kind of this post-op is #dnnl_convolution.
932///
933/// The number of outputs for primitive remain same as before. The output
934/// spatial size can be derived as below:
935///
936/// output_height = ceil(output_height_1x1_convolution, stride)
937/// output_width = ceil(output_width_1x1_convolution, stride)
938///
939/// The Post-op can be defined as:
940///
941/// dst[:] <- scales * (conv_dw(conv_1x1))
942///
943/// See @ref dev_guide_attributes_post_ops_depthwise and
944/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
945///
946/// @param post_ops Post-ops.
947/// @param weights_data_type Weights data type of depthwise post-op
948/// @param bias_data_type Bias data type of depthwise post-op
949/// @param dst_data_type Output data type of depthwise post-op
950/// @param count Output length of the array of scaling factors @p scales.
951/// @param mask Output scaling factors correspondence mask that defines the
952/// correspondence between the output tensor dimensions and the @p
953/// scales array. The set i-th bit indicates that a dedicated output scaling
954/// factor is used for each index along that dimension. The mask value of 0
955/// implies a common scaling factor for the whole output tensor.
956/// @param scales Output pointer to a constant array of float scaling factors.
957/// @returns #dnnl_success on success and a status describing the error
958/// otherwise
959dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s2p1(dnnl_post_ops_t post_ops,
960 dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
961 dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask,
962 const float *scales);
963
964/// Returns the parameters of an depthwise post-op with stride 2.
965///
966/// @param post_ops Post-ops.
967/// @param index Index of the elementwise post-op.
968/// @param weights_data_type Weights data type of depthwise post-op
969/// @param bias_data_type Bias data type of depthwise post-op
970/// @param dst_data_type Output data type of depthwise post-op
971/// @param count Output length of the array of scaling factors @p scales.
972/// @param mask Output scaling factors correspondence mask that defines the
973/// correspondence between the output tensor dimensions and the @p
974/// scales array. The set i-th bit indicates that a dedicated output scaling
975/// factor is used for each index along that dimension. The mask value of 0
976/// implies a common scaling factor for the whole output tensor.
977/// @param scales Output pointer to a constant array of float scaling factors.
978/// @returns #dnnl_success on success and a status describing the error
979/// otherwise
980dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s2p1(
981 const_dnnl_post_ops_t post_ops, int index,
982 dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
983 dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask,
984 const float **scales);
985
986/// Appends a binary post-op.
987///
988/// The kind of this post operation is #dnnl_binary.
989///
990/// In the simplest case when the binary is the only post operation, the
991/// computations would be:
992///
993/// dst[:] <- binary_op (dst[:], another_input[:])
994///
995/// where binary_op is configured with the given parameters. binary_op supports
996/// broadcast semantics for a second operand.
997///
998/// @param post_ops Post-ops.
999/// @param alg_kind Binary algorithm for the post-op.
1000/// @param src1_desc Memory descriptor of a second operand.
1001/// @returns #dnnl_success on success and a status describing the error
1002/// otherwise.
1003dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
1004 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src1_desc);
1005
1006/// Returns the parameters of a binary post-op.
1007///
1008/// @param post_ops Post-ops.
1009/// @param index Index of the binary post-op.
1010/// @param alg_kind Output binary algorithm kind.
1011/// @param src1_desc Output memory descriptor of a second operand.
1012/// @returns #dnnl_success on success and a status describing the error
1013/// otherwise.
1014/// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
1015/// post-op.
1016dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
1017 const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
1018 const dnnl_memory_desc_t **src1_desc);
1019
1020/// Appends a prelu forward post-op.
1021///
1022/// The kind of this post-op is #dnnl::primitive::kind::prelu.
1023///
1024/// The post-op can be defined as:
1025///
1026/// dst[:] <- prelu(dst[:], weights[:])
1027/// prelu:
1028/// dst[:] <- dst[:] if dst[:] > 0
1029/// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
1030///
1031///
1032/// @note
1033/// The order of dimensions does not depend on how elements are laid
1034/// out in memory. For example:
1035/// - for a 2D CNN activations tensor the order is always (n, c)
1036/// - for a 4D CNN activations tensor the order is always (n, c, h, w)
1037/// - for a 5D CNN weights tensor the order is always
1038/// (g, oc, ic, kh, kw)
1039///
1040/// Prelu weights tensor is passed in runtime execution phase. Prelu
1041/// weights tensor data type is implicitly assumed as f32 using plain
1042/// layout (a, ab, acb, acdb, acdeb)
1043///
1044/// @param post_ops Post-ops.
1045/// @param mask Defines the correspondence between the output tensor
1046/// dimensions and the prelu weights tensor. The set i-th bit indicates
1047/// that a dedicated weights value is used for each index along that
1048/// dimension. Set the mask to 0 to use a common weights value
1049/// for the whole output tensor.
1050/// @returns #dnnl_success on success and a status describing the error
1051/// otherwise.
1052dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
1053 dnnl_post_ops_t post_ops, int mask);
1054
1055/// Returns the parameters of a prelu post-op.
1056///
1057/// @param post_ops Post-ops.
1058/// @param index Index of the prelu post-op.
1059/// @param mask Mask of the prelu post-op.
1060/// @returns #dnnl_success on success and a status describing the error
1061/// otherwise.
1062dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
1063 const_dnnl_post_ops_t post_ops, int index, int *mask);
1064
1065/// @} dnnl_api_attributes
1066
1067/// @} dnnl_api_primitives
1068
1069/// @addtogroup dnnl_api_memory
1070/// @{
1071
1072/// Initializes a memory descriptor using dimensions and strides.
1073///
1074/// @note
1075/// As always, the logical order of dimensions corresponds to the `abc...`
1076/// format tag, and the physical meaning of the dimensions depends on both
1077/// the primitive that consumes the memory and the context of that
1078/// consumption.
1079///
1080/// @param memory_desc Output memory descriptor.
1081/// @param ndims Number of dimensions
1082/// @param dims Array of dimensions.
1083/// @param data_type Elements data type.
1084/// @param strides Strides in each dimension.
1085/// @returns #dnnl_success on success and a status describing the error
1086/// otherwise.
1087dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(
1088 dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
1089 dnnl_data_type_t data_type, const dnnl_dims_t strides);
1090
1091/// Initializes a memory descriptor using dimensions and memory format tag.
1092///
1093/// @note
1094/// As always, the logical order of dimensions corresponds to the `abc...`
1095/// format tag, and the physical meaning of the dimensions depends on both
1096/// the primitive that consumes the memory and the context of that
1097/// consumption.
1098///
1099/// @param memory_desc Output memory descriptor.
1100/// @param ndims Number of dimensions
1101/// @param dims Array of dimensions.
1102/// @param data_type Elements data type.
1103/// @param tag Memory format tag. Can be #dnnl_format_tag_any which would
1104/// allow a primitive to chose the final memory format. In this case the
1105/// format_kind field of the memory descriptor would be set to
1106/// #dnnl_format_kind_any.
1107/// @returns #dnnl_success on success and a status describing the error
1108/// otherwise.
1109dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(
1110 dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
1111 dnnl_data_type_t data_type, dnnl_format_tag_t tag);
1112
1113/// Initializes a memory descriptor for a region inside an area
1114/// described by an existing memory descriptor.
1115///
1116/// @warning
1117/// Some combinations of physical memory layout and/or offsets or dims may
1118/// result in a failure to create a submemory.
1119//
1120/// @param memory_desc Output memory descriptor.
1121/// @param parent_memory_desc An existing memory descriptor.
1122/// @param dims Sizes of the region.
1123/// @param offsets Offsets to the region from the encompassing
1124/// memory object in each dimension
1125/// @returns #dnnl_success on success and a status describing the error
1126/// otherwise.
1127dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(
1128 dnnl_memory_desc_t *memory_desc,
1129 const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims,
1130 const dnnl_dims_t offsets);
1131
1132/// Initializes a memory descriptor by reshaping an existing one. The new
1133/// memory descriptor inherits the data type. This operation is valid only for
1134/// memory descriptors that have format_kind set to #dnnl_blocked or
1135/// #dnnl_format_kind_any.
1136///
1137/// The operation ensures the transformation of the physical memory format
1138/// corresponds to the transformation of the logical dimensions. If such
1139/// transformation is impossible, the function returns #dnnl_invalid_arguments.
1140///
1141/// The reshape operation can be described as a combination of the following
1142/// basic operations:
1143/// 1. Add a dimension of size `1`. This is always possible.
1144/// 2. Remove a dimension of size `1`. This is possible only if the dimension
1145/// has no padding (i.e. `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
1146/// 3. Split a dimension into multiple ones. This is possible only if the size
1147/// of the dimension is exactly equal to the product of the split ones and
1148/// the dimension does not have padding (i.e.
1149/// `padded_dims[dim] = dims[dim]`).
1150/// 4. Joining multiple consecutive dimensions into a single one. As in the
1151/// cases above, this requires that the dimensions do not have padding and
1152/// that the memory format is such that in physical memory these dimensions
1153/// are dense and have the same order as their logical counterparts. This
1154/// also assumes that these dimensions are not blocked.
1155/// - Here, dense means:
1156/// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
1157/// - And same order means:
1158/// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
1159///
1160/// @warning
1161/// Some combinations of physical memory layout and/or offsets or
1162/// dimensions may result in a failure to make a reshape.
1163///
1164/// @param out_memory_desc Output memory descriptor.
1165/// @param in_memory_desc An existing memory descriptor. Must have format_kind
1166/// set to #dnnl_blocked or #dnnl_format_kind_any.
1167/// @param ndims Number of dimensions for the output memory descriptor.
1168/// @param dims Dimensions for the output memory descriptor.
1169/// @returns #dnnl_success on success and a status describing the error
1170/// otherwise.
1171dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
1172 dnnl_memory_desc_t *out_memory_desc,
1173 const dnnl_memory_desc_t *in_memory_desc, int ndims,
1174 const dnnl_dims_t dims);
1175
1176/// Initializes a memory descriptor by permuting axes in an existing one.
1177///
1178/// The physical memory layout representation is adjusted accordingly to
1179/// maintain the consistency between the logical and physical parts of the
1180/// memory descriptor.
1181///
1182/// The new memory descriptor inherits the data type. This operation is valid
1183/// only for memory descriptors that have format_kind set to #dnnl_blocked or
1184/// #dnnl_format_kind_any.
1185///
1186/// The logical axes will be permuted in the following manner:
1187/// ```
1188/// for (i: 0 .. in_memory_desc->ndims)
1189/// out_memory_desc->dims[permutation[i]] = in_memory_desc->dims[i];
1190/// ```
1191///
1192/// Example:
1193/// @code
1194/// dnnl_memory_desc_t in_md, out_md, expect_out_md;
1195///
1196/// const int permutation[] = {1, 0}; // swap the first and the second axes
1197///
1198/// dnnl_dims_t in_dims = {2, 3}, out_dims = {3, 2};
1199/// dnnl_format_tag_t in_tag = dnnl_ab, out_tag = dnnl_ba;
1200///
1201/// dnnl_memory_desc_init_by_tag(
1202/// &in_md, 2, in_dims, data_type, in_tag);
1203/// dnnl_memory_desc_init_by_tag(
1204/// &expect_out_md, 2, out_dims, data_type, out_tag);
1205///
1206/// dnnl_memory_desc_permute_axes(&out_md, in_md, permutation);
1207/// assert(dnnl_memory_desc_equal(&out_md, &expect_out_md));
1208/// @endcode
1209///
1210/// @param out_memory_desc Output memory descriptor.
1211/// @param in_memory_desc An existing memory descriptor. Must have format_kind
1212/// set to #dnnl_blocked or #dnnl_format_kind_any.
1213/// @param permutation Axes permutation (of size `in_memory_desc->ndims`).
1214/// @returns #dnnl_success on success and a status describing the error
1215/// otherwise.
1216dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
1217 dnnl_memory_desc_t *out_memory_desc,
1218 const dnnl_memory_desc_t *in_memory_desc, const int *permutation);
1219
1220/// Compares two memory descriptors.
1221///
1222/// Use this function to identify whether a reorder is required between the
1223/// two memories
1224///
1225/// @param lhs Left-hand side of the comparison.
1226/// @param rhs Right-hand side of the comparison.
1227/// @returns 1 if the descriptors are the same.
1228/// @returns 0 if the descriptors are different.
1229int DNNL_API dnnl_memory_desc_equal(
1230 const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs);
1231
1232/// Returns the size of a memory descriptor.
1233///
1234/// @param memory_desc Memory descriptor.
1235/// @returns The number of bytes required for memory described by a memory
1236/// descriptor.
1237size_t DNNL_API dnnl_memory_desc_get_size(
1238 const dnnl_memory_desc_t *memory_desc);
1239
1240/// Returns the size of data type.
1241///
1242/// @param data_type Data type.
1243/// @returns The number of bytes occupied by data type.
1244size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type);
1245
1246/// Creates a memory object.
1247///
1248/// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory
1249/// object will have the underlying buffer set. In this case, the buffer will
1250/// be initialized as if dnnl_memory_set_data_handle() had been called.
1251///
1252/// @sa dnnl_memory_set_data_handle()
1253///
1254/// @param memory Output memory object.
1255/// @param memory_desc Memory descriptor.
1256/// @param engine Engine to use.
1257/// @param handle Handle of the memory buffer to use as an underlying storage.
1258/// - A pointer to the user-allocated buffer. In this case the library
1259/// doesn't own the buffer.
1260/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
1261/// allocate the buffer for the memory object. In this case the library
1262/// owns the buffer.
1263/// - DNNL_MEMORY_NONE to create dnnl_memory without an underlying buffer.
1264/// @returns #dnnl_success on success and a status describing the error
1265/// otherwise.
1266dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory,
1267 const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine,
1268 void *handle);
1269
1270/// Returns the memory descriptor for a memory object.
1271///
1272/// @param memory Memory object.
1273/// @param memory_desc Output memory descriptor (a copy).
1274/// @returns #dnnl_success on success and a status describing the error
1275/// otherwise.
1276dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(
1277 const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc);
1278
1279/// Returns the engine of a memory object.
1280///
1281/// @param memory Memory object.
1282/// @param engine Output engine on which the memory is located.
1283/// @returns #dnnl_success on success and a status describing the error
1284/// otherwise.
1285dnnl_status_t DNNL_API dnnl_memory_get_engine(
1286 const_dnnl_memory_t memory, dnnl_engine_t *engine);
1287
1288/// Maps a memory object and returns a host-side pointer to a memory buffer
1289/// with a copy of its contents.
1290///
1291/// Mapping enables explicit direct access to memory contents for the engines
1292/// that do not support it implicitly.
1293///
1294/// Mapping is an exclusive operation - a memory object cannot be used in
1295/// other operations until this memory object is unmapped.
1296///
1297/// @note
1298/// Any primitives working with @p memory should be completed before
1299/// the memory is mapped. Use dnnl_stream_wait to synchronize the
1300/// corresponding execution stream.
1301///
1302/// @note
1303/// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
1304/// mainly provided for debug and testing purposes, and their performance
1305/// may be suboptimal.
1306///
1307/// @param memory Memory object.
1308/// @param mapped_ptr Output pointer to the mapped buffer.
1309/// @returns #dnnl_success on success and a status describing the error
1310/// otherwise.
1311dnnl_status_t DNNL_API dnnl_memory_map_data(
1312 const_dnnl_memory_t memory, void **mapped_ptr);
1313
1314/// Unmaps a memory object and writes back any changes made to the previously
1315/// mapped memory buffer. The pointer to the mapped buffer must be obtained
1316/// via the dnnl_memory_map_data() call.
1317///
1318/// @note
1319/// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
1320/// mainly provided for debug and testing purposes, and their performance
1321/// may be suboptimal.
1322///
1323/// @param memory Memory object.
1324/// @param mapped_ptr Pointer to the mapped buffer that must have been
1325/// obtained using the dnnl_memory_map_data() function.
1326/// @returns #dnnl_success on success and a status describing the error
1327/// otherwise.
1328dnnl_status_t DNNL_API dnnl_memory_unmap_data(
1329 const_dnnl_memory_t memory, void *mapped_ptr);
1330
1331/// Returns memory object's data handle.
1332///
1333/// @param memory Memory object.
1334/// @param handle Output data handle. For the CPU engine, the data handle is a
1335/// pointer to the actual data. For OpenCL it is a cl_mem.
1336/// @returns #dnnl_success on success and a status describing the error
1337/// otherwise.
1338dnnl_status_t DNNL_API dnnl_memory_get_data_handle(
1339 const_dnnl_memory_t memory, void **handle);
1340
1341/// Sets the underlying memory buffer.
1342///
1343/// See the description of dnnl_memory_set_data_handle_v2() for more details.
1344///
1345/// @param memory Memory object.
1346/// @param handle Data handle. For the CPU engine, the data handle is a
1347/// pointer to the actual data. For OpenCL it is a `cl_mem`.
1348/// @returns #dnnl_success on success and a status describing the error
1349/// otherwise.
1350dnnl_status_t DNNL_API dnnl_memory_set_data_handle(
1351 dnnl_memory_t memory, void *handle);
1352
1353/// Sets the underlying memory buffer.
1354///
1355/// @param memory Memory object.
1356/// @param handle Data handle. For the CPU engine, the data handle is a
1357/// pointer to the actual data. For OpenCL it is a `cl_mem`.
1358/// @param stream Stream to use to execute padding in.
1359/// @returns #dnnl_success on success and a status describing the error
1360/// otherwise.
1361dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(
1362 dnnl_memory_t memory, void *handle, dnnl_stream_t stream);
1363
1364/// Destroys a memory object.
1365///
1366/// @param memory Memory object to destroy.
1367/// @returns #dnnl_success on success and a status describing the error
1368/// otherwise.
1369dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);
1370
1371/// @} dnnl_api_memory
1372
1373/// @addtogroup dnnl_api_primitives
1374/// @{
1375
1376/// @addtogroup dnnl_api_reorder
1377/// @{
1378
1379/// Creates a primitive descriptor for a reorder primitive.
1380///
1381/// @param reorder_primitive_desc Output primitive descriptor.
1382/// @param src_desc Source memory descriptor.
1383/// @param src_engine Engine on which the source memory object will be
1384/// located.
1385/// @param dst_desc Destination memory descriptor.
1386/// @param dst_engine Engine on which the destination memory object
1387/// will be located.
1388/// @param attr Primitive attributes to use (can be NULL).
1389/// @returns #dnnl_success on success and a status describing the error
1390/// otherwise.
1391dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
1392 dnnl_primitive_desc_t *reorder_primitive_desc,
1393 const dnnl_memory_desc_t *src_desc, dnnl_engine_t src_engine,
1394 const dnnl_memory_desc_t *dst_desc, dnnl_engine_t dst_engine,
1395 const_dnnl_primitive_attr_t attr);
1396
1397/// @} dnnl_api_reorder
1398
1399/// @addtogroup dnnl_api_concat
1400/// @{
1401
1402/// Creates a primitive descriptor for an out-of-place concatenation
1403/// primitive.
1404///
1405/// @param concat_primitive_desc Output primitive descriptor.
1406/// @param dst_desc Destination memory descriptor.
1407/// @param n Number of source parameters.
1408/// @param concat_dimension Source tensors will be concatenated over
1409/// dimension with this index. Note that order of dimensions does
1410/// not depend on memory format.
1411/// @param src_descs Array of source memory descriptors with @p n elements.
1412/// @param attr Primitive attributes to use (can be NULL).
1413/// @param engine Engine to use.
1414/// @returns #dnnl_success on success and a status describing the error
1415/// otherwise.
1416dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
1417 dnnl_primitive_desc_t *concat_primitive_desc,
1418 const dnnl_memory_desc_t *dst_desc, int n, int concat_dimension,
1419 const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr,
1420 dnnl_engine_t engine);
1421
1422/// @} dnnl_api_concat
1423
1424/// @addtogroup dnnl_api_sum
1425/// @{
1426
1427/// Creates a primitive descriptor for an (out-of-place) sum primitive.
1428///
1429/// @param sum_primitive_desc Output primitive descriptor.
1430/// @param dst_desc Destination memory descriptor.
1431/// @param n Number of source parameters.
1432/// @param scales Vector of scales to multiply data in each source
1433/// memory by.
1434/// @param src_descs Array of source memory descriptors having @p n elements.
1435/// @param attr Primitive attributes to use (can be NULL).
1436/// @param engine Engine to use.
1437/// @returns #dnnl_success on success and a status describing the error
1438/// otherwise.
1439dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
1440 dnnl_primitive_desc_t *sum_primitive_desc,
1441 const dnnl_memory_desc_t *dst_desc, int n, const float *scales,
1442 const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr,
1443 dnnl_engine_t engine);
1444
1445/// @} dnnl_api_sum
1446
1447/// @addtogroup dnnl_api_binary
1448/// @{
1449
1450/// Initializes a descriptor for a binary primitive.
1451///
1452/// @note
1453/// Memory descriptor @p dst_desc is allowed to be initialized with
1454/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1455///
1456/// @note
1457/// Both memory descriptors must have the same number of dimensions.
1458/// Element broadcasting is supported for memory descriptor @p src1_desc
1459/// and are applied to @ src1_desc dimensions that have size equal to 1.
1460///
1461/// @param binary_desc Output descriptor for a binary primitive.
1462/// @param alg_kind Algorithm kind. Valid values are #dnnl_binary_add,
1463/// #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, #dnnl_binary_div,
1464/// #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, #dnnl_binary_le,
1465/// #dnnl_binary_lt, #dnnl_binary_eq and #dnnl_binary_ne.
1466/// @param src0_desc Source 0 memory descriptor.
1467/// @param src1_desc Source 1 memory descriptor.
1468/// @param dst_desc Destination memory descriptor.
1469/// @returns #dnnl_success on success and a status describing the error
1470/// otherwise.
1471dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc,
1472 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc,
1473 const dnnl_memory_desc_t *src1_desc,
1474 const dnnl_memory_desc_t *dst_desc);
1475
1476/// @} dnnl_api_binary
1477
1478/// @addtogroup dnnl_api_convolution
1479/// @{
1480
1481/// Initializes a descriptor for a convolution forward propagation primitive.
1482///
1483/// @note
1484/// Memory descriptors can be initialized with
1485/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1486///
1487/// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1488/// spatial dimensions only and hence must have the same number of elements as
1489/// there are spatial dimensions. The order of values is the same as in the
1490/// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1491///
1492/// @param conv_desc Output descriptor for a convolution primitive.
1493/// @param prop_kind Propagation kind. Possible values are
1494/// #dnnl_forward_training and #dnnl_forward_inference.
1495/// @param alg_kind Convolution algorithm. Possible values are
1496/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1497/// #dnnl_convolution_auto.
1498/// @param src_desc Source memory descriptor.
1499/// @param weights_desc Weights memory descriptor.
1500/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1501/// descriptor, or a memory descriptor with format_kind set to
1502/// #dnnl_format_kind_undef disables the bias term.
1503/// @param dst_desc Destination memory descriptor.
1504/// @param strides Array of strides for spatial dimension.
1505/// @param padding_l Array of padding values for low indices for each spatial
1506/// dimension `([[front,] top,] left)`.
1507/// @param padding_r Array of padding values for high indices for each spatial
1508/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1509/// padding is assumed to be symmetrical.
1510/// @returns #dnnl_success on success and a status describing the error
1511/// otherwise.
1512dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init(
1513 dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind,
1514 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1515 const dnnl_memory_desc_t *weights_desc,
1516 const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
1517 const dnnl_dims_t strides, const dnnl_dims_t padding_l,
1518 const dnnl_dims_t padding_r);
1519
1520/// Initializes a descriptor for a dilated convolution forward propagation
1521/// primitive.
1522///
1523/// @note
1524/// Memory descriptors can be initialized with
1525/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1526///
1527/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1528/// values for spatial dimensions only and hence must have the same number of
1529/// elements as there are spatial dimensions. The order of values is the same
1530/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1531/// and width.
1532///
1533/// @param conv_desc Output descriptor for a convolution primitive.
1534/// @param prop_kind Propagation kind. Possible values are
1535/// #dnnl_forward_training and #dnnl_forward_inference.
1536/// @param alg_kind Convolution algorithm. Possible values are
1537/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1538/// #dnnl_convolution_auto.
1539/// @param src_desc Source memory descriptor.
1540/// @param weights_desc Weights memory descriptor.
1541/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1542/// descriptor, or a memory descriptor with format_kind set to
1543/// #dnnl_format_kind_undef disables the bias term.
1544/// @param dst_desc Destination memory descriptor.
1545/// @param strides Array of strides for spatial dimension.
1546/// @param dilates Array of dilations for spatial dimension. A zero value
1547/// means no dilation in the corresponding dimension.
1548/// @param padding_l Array of padding values for low indices for each spatial
1549/// dimension `([[front,] top,] left)`.
1550/// @param padding_r Array of padding values for high indices for each spatial
1551/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1552/// padding is considered to be symmetrical.
1553/// @returns #dnnl_success on success and a status describing the error
1554/// otherwise.
1555dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init(
1556 dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind,
1557 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1558 const dnnl_memory_desc_t *weights_desc,
1559 const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
1560 const dnnl_dims_t strides, const dnnl_dims_t dilates,
1561 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1562
1563/// Initializes a descriptor for a convolution backward propagation primitive.
1564///
1565/// @note
1566/// Memory descriptors can be initialized with
1567/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1568///
1569/// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1570/// spatial dimensions only and hence must have the same number of elements as
1571/// there are spatial dimensions. The order of values is the same as in the
1572/// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1573///
1574/// @param conv_desc Output descriptor for a convolution primitive.
1575/// @param alg_kind Convolution algorithm. Possible values are
1576/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1577/// #dnnl_convolution_auto.
1578/// @param diff_src_desc Diff source memory descriptor.
1579/// @param weights_desc Weights memory descriptor.
1580/// @param diff_dst_desc Diff destination memory descriptor.
1581/// @param strides Array of strides for spatial dimension.
1582/// @param padding_l Array of padding values for low indices for each spatial
1583/// dimension `([[front,] top,] left)`.
1584/// @param padding_r Array of padding values for high indices for each spatial
1585/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1586/// padding is assumed to be symmetrical.
1587/// @returns #dnnl_success on success and a status describing the error
1588/// otherwise.
1589dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init(
1590 dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind,
1591 const dnnl_memory_desc_t *diff_src_desc,
1592 const dnnl_memory_desc_t *weights_desc,
1593 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1594 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1595
1596/// Initializes a descriptor for a dilated convolution backward propagation
1597/// primitive.
1598///
1599/// @note
1600/// Memory descriptors can be initialized with
1601/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1602///
1603/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1604/// values for spatial dimensions only and hence must have the same number of
1605/// elements as there are spatial dimensions. The order of values is the same
1606/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1607/// and width.
1608///
1609/// @param conv_desc Output descriptor for a convolution primitive.
1610/// @param alg_kind Convolution algorithm. Possible values are
1611/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1612/// #dnnl_convolution_auto.
1613/// @param diff_src_desc Diff source memory descriptor.
1614/// @param weights_desc Weights memory descriptor.
1615/// @param diff_dst_desc Diff destination memory descriptor.
1616/// @param strides Array of strides for spatial dimension.
1617/// @param dilates Array of dilations for spatial dimension. A zero value
1618/// means no dilation in the corresponding dimension.
1619/// @param padding_l Array of padding values for low indices for each spatial
1620/// dimension `([[front,] top,] left)`.
1621/// @param padding_r Array of padding values for high indices for each spatial
1622/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1623/// padding is considered to be symmetrical.
1624/// @returns #dnnl_success on success and a status describing the error
1625/// otherwise.
1626dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init(
1627 dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind,
1628 const dnnl_memory_desc_t *diff_src_desc,
1629 const dnnl_memory_desc_t *weights_desc,
1630 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1631 const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1632 const dnnl_dims_t padding_r);
1633
1634/// Initializes a descriptor for a convolution weights gradient primitive.
1635///
1636/// @note
1637/// Memory descriptors can be initialized with
1638/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1639///
1640/// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1641/// spatial dimensions only and hence must have the same number of elements as
1642/// there are spatial dimensions. The order of values is the same as in the
1643/// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1644///
1645/// @param conv_desc Output descriptor for a convolution primitive.
1646/// @param alg_kind Convolution algorithm. Possible values are
1647/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1648/// #dnnl_convolution_auto.
1649/// @param src_desc Source memory descriptor.
1650/// @param diff_weights_desc Diff weights memory descriptor.
1651/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1652/// memory descriptor, or a memory descriptor with format_kind set to
1653/// #dnnl_format_kind_undef disables the bias term.
1654/// @param diff_dst_desc Diff destination memory descriptor.
1655/// @param strides Array of strides for spatial dimension.
1656/// @param padding_l Array of padding values for low indices for each spatial
1657/// dimension `([[front,] top,] left)`.
1658/// @param padding_r Array of padding values for high indices for each spatial
1659/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1660/// padding is considered to be symmetrical.
1661/// @returns #dnnl_success on success and a status describing the error
1662/// otherwise.
1663dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init(
1664 dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind,
1665 const dnnl_memory_desc_t *src_desc,
1666 const dnnl_memory_desc_t *diff_weights_desc,
1667 const dnnl_memory_desc_t *diff_bias_desc,
1668 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1669 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1670
1671/// Initializes a descriptor for a dilated convolution weights gradient
1672/// primitive.
1673///
1674/// @note
1675/// Memory descriptors can be initialized with
1676/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1677///
1678/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1679/// values for spatial dimensions only and hence must have the same number of
1680/// elements as there are spatial dimensions. The order of values is the same
1681/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1682/// and width.
1683///
1684/// @param conv_desc Output descriptor for a convolution primitive.
1685/// @param alg_kind Convolution algorithm. Possible values are
1686/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1687/// #dnnl_convolution_auto.
1688/// @param src_desc Source memory descriptor.
1689/// @param diff_weights_desc Diff weights memory descriptor.
1690/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1691/// memory descriptor, or a memory descriptor with format_kind set to
1692/// #dnnl_format_kind_undef disables the bias term.
1693/// @param diff_dst_desc Diff destination memory descriptor.
1694/// @param strides Array of strides for spatial dimension.
1695/// @param dilates Array of dilations for spatial dimension. A zero value
1696/// means no dilation in the corresponding dimension.
1697/// @param padding_l Array of padding values for low indices for each spatial
1698/// dimension `([[front,] top,] left)`.
1699/// @param padding_r Array of padding values for high indices for each spatial
1700/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1701/// padding is considered to be symmetrical.
1702/// @returns #dnnl_success on success and a status describing the error
1703/// otherwise.
1704dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init(
1705 dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind,
1706 const dnnl_memory_desc_t *src_desc,
1707 const dnnl_memory_desc_t *diff_weights_desc,
1708 const dnnl_memory_desc_t *diff_bias_desc,
1709 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1710 const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1711 const dnnl_dims_t padding_r);
1712
1713/// @} dnnl_api_convolution
1714
1715/// @addtogroup dnnl_api_deconvolution
1716/// @{
1717
1718/// Initializes a descriptor for a deconvolution forward propagation primitive.
1719///
1720/// @note
1721/// Memory descriptors can be initialized with
1722/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1723///
1724/// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1725/// spatial dimensions only and hence must have the same number of elements as
1726/// there are spatial dimensions. The order of values is the same as in the
1727/// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1728///
1729/// @param deconv_desc Output descriptor for a deconvolution primitive.
1730/// @param prop_kind Propagation kind. Possible values are
1731/// #dnnl_forward_training and #dnnl_forward_inference.
1732/// @param alg_kind Deconvolution algorithm. Possible values are
1733/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1734/// @param src_desc Source memory descriptor.
1735/// @param weights_desc Weights memory descriptor.
1736/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1737/// descriptor, or a memory descriptor with format_kind set to
1738/// #dnnl_format_kind_undef disables the bias term.
1739/// @param dst_desc Destination memory descriptor.
1740/// @param strides Array of strides for spatial dimension.
1741/// @param padding_l Array of padding values for low indices for each spatial
1742/// dimension `([[front,] top,] left)`.
1743/// @param padding_r Array of padding values for high indices for each spatial
1744/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1745/// padding is considered to be symmetrical.
1746/// @returns #dnnl_success on success and a status describing the error
1747/// otherwise.
1748dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init(
1749 dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind,
1750 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1751 const dnnl_memory_desc_t *weights_desc,
1752 const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
1753 const dnnl_dims_t strides, const dnnl_dims_t padding_l,
1754 const dnnl_dims_t padding_r);
1755
1756/// Initializes a descriptor for a dilated deconvolution forward propagation
1757/// primitive.
1758///
1759/// @note
1760/// Memory descriptors can be initialized with
1761/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1762///
1763/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1764/// values for spatial dimensions only and hence must have the same number of
1765/// elements as there are spatial dimensions. The order of values is the same
1766/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1767/// and width.
1768///
1769/// @param deconv_desc Output descriptor for a deconvolution primitive.
1770/// @param prop_kind Propagation kind. Possible values are
1771/// #dnnl_forward_training and #dnnl_forward_inference.
1772/// @param alg_kind Deconvolution algorithm. Possible values are
1773/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1774/// @param src_desc Source memory descriptor.
1775/// @param weights_desc Weights memory descriptor.
1776/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1777/// descriptor, or a memory descriptor with format_kind set to
1778/// #dnnl_format_kind_undef disables the bias term.
1779/// @param dst_desc Destination memory descriptor.
1780/// @param strides Array of strides for spatial dimension.
1781/// @param dilates Array of dilations for spatial dimension. A zero value
1782/// means no dilation in the corresponding dimension.
1783/// @param padding_l Array of padding values for low indices for each spatial
1784/// dimension `([[front,] top,] left)`.
1785/// @param padding_r Array of padding values for high indices for each spatial
1786/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1787/// padding is considered to be symmetrical.
1788/// @returns #dnnl_success on success and a status describing the error
1789/// otherwise.
1790dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init(
1791 dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind,
1792 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1793 const dnnl_memory_desc_t *weights_desc,
1794 const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
1795 const dnnl_dims_t strides, const dnnl_dims_t dilates,
1796 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1797
1798/// Initializes a descriptor for a deconvolution backward propagation primitive.
1799///
1800/// @note
1801/// Memory descriptors can be initialized with
1802/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1803///
1804/// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1805/// spatial dimensions only and hence must have the same number of elements as
1806/// there are spatial dimensions. The order of values is the same as in the
1807/// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1808///
1809/// @param deconv_desc Output descriptor for a deconvolution primitive.
1810/// @param alg_kind Deconvolution algorithm. Possible values are
1811/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1812/// @param diff_src_desc Diff source memory descriptor.
1813/// @param weights_desc Weights memory descriptor.
1814/// @param diff_dst_desc Diff destination memory descriptor.
1815/// @param strides Array of strides for spatial dimension.
1816/// @param padding_l Array of padding values for low indices for each spatial
1817/// dimension `([[front,] top,] left)`.
1818/// @param padding_r Array of padding values for high indices for each spatial
1819/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1820/// padding is considered to be symmetrical.
1821/// @returns #dnnl_success on success and a status describing the error
1822/// otherwise.
1823dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init(
1824 dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind,
1825 const dnnl_memory_desc_t *diff_src_desc,
1826 const dnnl_memory_desc_t *weights_desc,
1827 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1828 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1829
1830/// Initializes a descriptor for a dilated deconvolution backward propagation
1831/// primitive.
1832///
1833/// @note
1834/// Memory descriptors can be initialized with
1835/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1836///
1837/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1838/// values for spatial dimensions only and hence must have the same number of
1839/// elements as there are spatial dimensions. The order of values is the same
1840/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1841/// and width.
1842///
1843/// @param deconv_desc Output descriptor for a deconvolution primitive.
1844/// @param alg_kind Deconvolution algorithm. Possible values are
1845/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1846/// @param diff_src_desc Diff source memory descriptor.
1847/// @param weights_desc Weights memory descriptor.
1848/// @param diff_dst_desc Diff destination memory descriptor.
1849/// @param strides Array of strides for spatial dimension.
1850/// @param dilates Array of dilations for spatial dimension. A zero value
1851/// means no dilation in the corresponding dimension.
1852/// @param padding_l Array of padding values for low indices for each spatial
1853/// dimension `([[front,] top,] left)`.
1854/// @param padding_r Array of padding values for high indices for each spatial
1855/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1856/// padding is considered to be symmetrical.
1857/// @returns #dnnl_success on success and a status describing the error
1858/// otherwise.
1859dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init(
1860 dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind,
1861 const dnnl_memory_desc_t *diff_src_desc,
1862 const dnnl_memory_desc_t *weights_desc,
1863 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1864 const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1865 const dnnl_dims_t padding_r);
1866
1867/// Initializes a descriptor for a deconvolution weights gradient primitive.
1868///
1869/// @note
1870/// Memory descriptors can be initialized with
1871/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1872///
1873/// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1874/// spatial dimensions only and hence must have the same number of elements as
1875/// there are spatial dimensions. The order of values is the same as in the
1876/// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1877///
1878/// @param deconv_desc Output descriptor for a deconvolution primitive.
1879/// @param alg_kind Deconvolution algorithm. Possible values are
1880/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1881/// @param src_desc Source memory descriptor.
1882/// @param diff_weights_desc Diff weights memory descriptor.
1883/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1884/// memory descriptor, or a memory descriptor with format_kind set to
1885/// #dnnl_format_kind_undef disables the bias term.
1886/// @param diff_dst_desc Diff destination memory descriptor.
1887/// @param strides Array of strides for spatial dimension.
1888/// @param padding_l Array of padding values for low indices for each spatial
1889/// dimension `([[front,] top,] left)`.
1890/// @param padding_r Array of padding values for high indices for each spatial
1891/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1892/// padding is considered to be symmetrical.
1893/// @returns #dnnl_success on success and a status describing the error
1894/// otherwise.
1895dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init(
1896 dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind,
1897 const dnnl_memory_desc_t *src_desc,
1898 const dnnl_memory_desc_t *diff_weights_desc,
1899 const dnnl_memory_desc_t *diff_bias_desc,
1900 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1901 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1902
1903/// Initializes a descriptor for a dilated deconvolution weights gradient
1904/// primitive.
1905///
1906/// @note
1907/// Memory descriptors can be initialized with
1908/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1909///
1910/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1911/// values for spatial dimensions only and hence must have the same number of
1912/// elements as there are spatial dimensions. The order of values is the same
1913/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1914/// and width.
1915///
1916/// @param deconv_desc Output descriptor for a deconvolution primitive.
1917/// @param alg_kind Deconvolution algorithm. Possible values are
1918/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1919/// @param src_desc Source memory descriptor.
1920/// @param diff_weights_desc Diff weights memory descriptor.
1921/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1922/// memory descriptor, or a memory descriptor with format_kind set to
1923/// #dnnl_format_kind_undef disables the bias term.
1924/// @param diff_dst_desc Diff destination memory descriptor.
1925/// @param strides Array of strides for spatial dimension.
1926/// @param dilates Array of dilations for spatial dimension. A zero value
1927/// means no dilation in the corresponding dimension.
1928/// @param padding_l Array of padding values for low indices for each spatial
1929/// dimension `([[front,] top,] left)`.
1930/// @param padding_r Array of padding values for high indices for each spatial
1931/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1932/// padding is considered to be symmetrical.
1933/// @returns #dnnl_success on success and a status describing the error
1934/// otherwise.
1935dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init(
1936 dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind,
1937 const dnnl_memory_desc_t *src_desc,
1938 const dnnl_memory_desc_t *diff_weights_desc,
1939 const dnnl_memory_desc_t *diff_bias_desc,
1940 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1941 const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1942 const dnnl_dims_t padding_r);
1943
1944/// @} dnnl_api_deconvolution
1945
1946/// @addtogroup dnnl_api_shuffle
1947/// @{
1948
1949/// Initializes a descriptor for shuffle forward propagation primitive.
1950///
1951/// @param shuffle_desc Output descriptor for a shuffle primitive.
1952/// @param prop_kind Propagation kind. Possible values are
1953/// #dnnl_forward_training and #dnnl_forward_inference.
1954/// @param data_desc Source and destination memory descriptor.
1955/// @param axis The axis along which the data is shuffled.
1956/// @param group_size Shuffle group size.
1957/// @returns #dnnl_success on success and a status describing the error
1958/// otherwise.
1959dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init(
1960 dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind,
1961 const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size);
1962
1963/// Initializes a descriptor for shuffle backward propagation primitive.
1964///
1965/// @param shuffle_desc Output descriptor for a shuffle primitive.
1966/// @param diff_data_desc Diff source and diff destination memory descriptor.
1967/// @param axis The axis along which the data is shuffled.
1968/// @param group_size Shuffle group size.
1969/// @returns #dnnl_success on success and a status describing the error
1970/// otherwise.
1971dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init(
1972 dnnl_shuffle_desc_t *shuffle_desc,
1973 const dnnl_memory_desc_t *diff_data_desc, int axis,
1974 dnnl_dim_t group_size);
1975
1976/// @} dnnl_api_shuffle
1977
1978/// @addtogroup dnnl_api_eltwise
1979/// @{
1980
1981/// Initializes a descriptor for eltwise forward propagation primitive.
1982///
1983/// @param eltwise_desc Output descriptor for an eltwise primitive.
1984/// @param prop_kind Propagation kind. Possible values are
1985/// #dnnl_forward_training and #dnnl_forward_inference.
1986/// @param alg_kind Elementwise algorithm kind.
1987/// @param data_desc Source and destination memory descriptor.
1988/// @param alpha The alpha parameter for the elementwise operation. Specific
1989/// meaning depends on the algorithm.
1990/// @param beta The beta parameter for the elementwise operation. Specific
1991/// meaning depends on the algorithm.
1992/// @returns #dnnl_success on success and a status describing the error
1993/// otherwise.
1994dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init(
1995 dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind,
1996 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc,
1997 float alpha, float beta);
1998
1999/// Initializes a descriptor for eltwise backward propagation primitive.
2000///
2001/// @param eltwise_desc Output descriptor for an eltwise primitive.
2002/// @param alg_kind Elementwise algorithm kind.
2003/// @param diff_data_desc Diff source and diff destination memory descriptors.
2004/// @param data_desc Source and destination memory descriptor.
2005/// @param alpha The alpha parameter for the elementwise operation. Specific
2006/// meaning depends on the algorithm.
2007/// @param beta The beta parameter for the elementwise operation. Specific
2008/// meaning depends on the algorithm.
2009/// @returns #dnnl_success on success and a status describing the error
2010/// otherwise.
2011dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init(
2012 dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind,
2013 const dnnl_memory_desc_t *diff_data_desc,
2014 const dnnl_memory_desc_t *data_desc, float alpha, float beta);
2015
2016/// @} dnnl_api_eltwise
2017
2018/// @addtogroup dnnl_api_softmax
2019/// @{
2020
2021/// Initializes a descriptor for softmax forward propagation primitive.
2022///
2023/// @param softmax_desc Output descriptor for a softmax primitive.
2024/// @param prop_kind Propagation kind. Possible values are
2025/// #dnnl_forward_training and #dnnl_forward_inference.
2026/// @param data_desc Source and destination memory descriptor.
2027/// @param softmax_axis Axis over which softmax is computed.
2028/// @returns #dnnl_success on success and a status describing the error
2029/// otherwise.
2030dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init(
2031 dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind,
2032 const dnnl_memory_desc_t *data_desc, int softmax_axis);
2033
2034/// Initializes a descriptor for softmax backward propagation primitive.
2035///
2036/// @param softmax_desc Output descriptor for a softmax primitive.
2037/// @param diff_data_desc Diff source and diff destination memory descriptors.
2038/// @param data_desc Destination memory descriptor.
2039/// @param softmax_axis Axis over which softmax is computed.
2040/// @returns #dnnl_success on success and a status describing the error
2041/// otherwise.
2042dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init(
2043 dnnl_softmax_desc_t *softmax_desc,
2044 const dnnl_memory_desc_t *diff_data_desc,
2045 const dnnl_memory_desc_t *data_desc, int softmax_axis);
2046
2047/// @} dnnl_api_softmax
2048
2049/// @addtogroup dnnl_api_softmax_v2
2050/// @{
2051
2052/// Initializes a descriptor for softmax v2 forward propagation primitive.
2053///
2054/// @param softmax_desc Output descriptor for a softmax primitive.
2055/// @param prop_kind Propagation kind. Possible values are
2056/// #dnnl_forward_training and #dnnl_forward_inference.
2057/// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
2058/// #dnnl_softmax_log.
2059/// @param src_desc Source memory descriptor.
2060/// @param dst_desc Destination memory descriptor.
2061/// @param softmax_axis Axis over which softmax is computed.
2062/// @returns #dnnl_success on success and a status describing the error
2063/// otherwise.
2064dnnl_status_t DNNL_API dnnl_softmax_v2_forward_desc_init(
2065 dnnl_softmax_v2_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind,
2066 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
2067 const dnnl_memory_desc_t *dst_desc, int softmax_axis);
2068
2069/// Initializes a descriptor for softmax v2 backward propagation primitive.
2070///
2071/// @param softmax_desc Output descriptor for a softmax primitive.
2072/// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
2073/// #dnnl_softmax_log.
2074/// @param diff_src_desc Diff source memory descriptor.
2075/// @param diff_dst_desc Diff destination memory descriptor.
2076/// @param dst_desc Destination memory descriptor.
2077/// @param softmax_axis Axis over which softmax is computed.
2078/// @returns #dnnl_success on success and a status describing the error
2079/// otherwise.
2080dnnl_status_t DNNL_API dnnl_softmax_v2_backward_desc_init(
2081 dnnl_softmax_v2_desc_t *softmax_desc, dnnl_alg_kind_t alg_kind,
2082 const dnnl_memory_desc_t *diff_src_desc,
2083 const dnnl_memory_desc_t *diff_dst_desc,
2084 const dnnl_memory_desc_t *dst_desc, int softmax_axis);
2085
2086/// @} dnnl_api_softmax_v2
2087
2088/// @addtogroup dnnl_api_logsoftmax
2089/// @{
2090
2091/// Initializes a descriptor for logsoftmax forward propagation primitive.
2092///
2093/// @param logsoftmax_desc Output descriptor for a logsoftmax primitive.
2094/// @param prop_kind Propagation kind. Possible values are
2095/// #dnnl_forward_training and #dnnl_forward_inference.
2096/// @param data_desc Source and destination memory descriptor.
2097/// @param logsoftmax_axis Axis over which logsoftmax is computed.
2098/// @returns #dnnl_success on success and a status describing the error
2099/// otherwise.
2100dnnl_status_t DNNL_API dnnl_logsoftmax_forward_desc_init(
2101 dnnl_logsoftmax_desc_t *logsoftmax_desc, dnnl_prop_kind_t prop_kind,
2102 const dnnl_memory_desc_t *data_desc, int logsoftmax_axis);
2103
2104/// Initializes a descriptor for logsoftmax backward propagation primitive.
2105///
2106/// @param logsoftmax_desc Output descriptor for a logsoftmax primitive.
2107/// @param diff_data_desc Diff source and diff destination memory descriptors.
2108/// @param data_desc Destination memory descriptor.
2109/// @param logsoftmax_axis Axis over which softmax is computed.
2110/// @returns #dnnl_success on success and a status describing the error
2111/// otherwise.
2112dnnl_status_t DNNL_API dnnl_logsoftmax_backward_desc_init(
2113 dnnl_logsoftmax_desc_t *logsoftmax_desc,
2114 const dnnl_memory_desc_t *diff_data_desc,
2115 const dnnl_memory_desc_t *data_desc, int logsoftmax_axis);
2116
2117/// @} dnnl_api_logsoftmax
2118
2119/// @addtogroup dnnl_api_pooling
2120/// @{
2121
2122/// Initializes a descriptor for pooling forward propagation primitive.
2123///
2124/// Arrays @p strides, @p kernel, @p padding_l, and @p padding_r contain values
2125/// for spatial dimensions only and hence must have the same number of elements
2126/// as there are spatial dimensions. The order of values is the same as in the
2127/// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
2128///
2129/// @param pool_desc Output descriptor for a pooling primitive.
2130/// @param prop_kind Propagation kind. Possible values are
2131/// #dnnl_forward_training and #dnnl_forward_inference.
2132/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
2133/// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as
2134/// #dnnl_pooling_avg_exclude_padding).
2135/// @param src_desc Source memory descriptor.
2136/// @param dst_desc Destination memory descriptor.
2137/// @param strides Array of strides for spatial dimension.
2138/// @param kernel Array of kernel spatial dimensions.
2139/// @param padding_l Array of padding values for low indices for each spatial
2140/// dimension `([[front,] top,] left)`.
2141/// @param padding_r Array of padding values for high indices for each spatial
2142/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
2143/// padding is considered to be symmetrical.
2144/// @returns #dnnl_success on success and a status describing the error
2145/// otherwise.
2146dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init(
2147 dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind,
2148 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
2149 const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides,
2150 const dnnl_dims_t kernel, const dnnl_dims_t padding_l,
2151 const dnnl_dims_t padding_r);
2152
2153/// Initializes a descriptor for pooling backward propagation primitive.
2154///
2155/// Arrays @p strides, @p kernel, @p padding_l, and @p padding_r contain values
2156/// for spatial dimensions only and hence must have the same number of elements
2157/// as there are spatial dimensions. The order of values is the same as in the
2158/// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
2159///
2160/// @param pool_desc Output descriptor for a pooling primitive.
2161/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
2162/// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as
2163/// #dnnl_pooling_avg_exclude_padding).
2164/// @param diff_src_desc Diff source memory descriptor.
2165/// @param diff_dst_desc Diff destination memory descriptor.
2166/// @param strides Array of strides for spatial dimension.
2167/// @param kernel Array of kernel spatial dimensions.
2168/// @param padding_l Array of padding values for low indices for each spatial
2169/// dimension `([[front,] top,] left)`.
2170/// @param padding_r Array of padding values for high indices for each spatial
2171/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
2172/// padding is considered to be symmetrical.
2173/// @returns #dnnl_success on success and a status describing the error
2174/// otherwise.
2175dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init(
2176 dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind,
2177 const dnnl_memory_desc_t *diff_src_desc,
2178 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
2179 const dnnl_dims_t kernel, const dnnl_dims_t padding_l,
2180 const dnnl_dims_t padding_r);
2181
2182/// @} dnnl_api_pooling
2183
2184/// @addtogroup dnnl_api_pooling_v2
2185/// @{
2186
2187/// Initializes a descriptor for pooling v2 (pooling with dilation support)
2188/// forward propagation primitive.
2189///
2190/// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
2191/// contain values for spatial dimensions only and hence must have the same
2192/// number of elements as there are spatial dimensions. The order of values
2193/// is the same as in the tensor: depth (for 3D tensors),
2194/// height (for 3D and 2D tensors), and width.
2195///
2196/// @param pool_desc Output descriptor for a pooling primitive.
2197/// @param prop_kind Propagation kind. Possible values are
2198/// #dnnl_forward_training and #dnnl_forward_inference.
2199/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
2200/// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as
2201/// #dnnl_pooling_avg_exclude_padding).
2202/// @param src_desc Source memory descriptor.
2203/// @param dst_desc Destination memory descriptor.
2204/// @param strides Array of strides for spatial dimension.
2205/// @param kernel Array of kernel spatial dimensions.
2206/// @param dilation Array of dilations for spatial dimension.
2207/// @param padding_l Array of padding values for low indices for each spatial
2208/// dimension `([[front,] top,] left)`.
2209/// @param padding_r Array of padding values for high indices for each spatial
2210/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
2211/// padding is considered to be symmetrical.
2212/// @returns #dnnl_success on success and a status describing the error
2213/// otherwise.
2214dnnl_status_t DNNL_API dnnl_pooling_v2_forward_desc_init(
2215 dnnl_pooling_v2_desc_t *pool_desc, dnnl_prop_kind_t prop_kind,
2216 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
2217 const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides,
2218 const dnnl_dims_t kernel, const dnnl_dims_t dilation,
2219 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
2220
2221/// Initializes a descriptor for pooling v2 (pooling with dilation support)
2222/// backward propagation primitive.
2223///
2224/// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
2225/// contain values for spatial dimensions only and hence must have the same
2226/// number of elements as there are spatial dimensions. The order of values
2227/// is the same as in the tensor: depth (for 3D tensors),
2228/// height (for 3D and 2D tensors), and width.
2229///
2230/// @param pool_desc Output descriptor for a pooling primitive.
2231/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
2232/// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as
2233/// #dnnl_pooling_avg_exclude_padding).
2234/// @param diff_src_desc Diff source memory descriptor.
2235/// @param diff_dst_desc Diff destination memory descriptor.
2236/// @param strides Array of strides for spatial dimension.
2237/// @param kernel Array of kernel spatial dimensions.
2238/// @param dilation Array of dilations for spatial dimension.
2239/// @param padding_l Array of padding values for low indices for each spatial
2240/// dimension `([[front,] top,] left)`.
2241/// @param padding_r Array of padding values for high indices for each spatial
2242/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
2243/// padding is considered to be symmetrical.
2244/// @returns #dnnl_success on success and a status describing the error
2245/// otherwise.
2246dnnl_status_t DNNL_API dnnl_pooling_v2_backward_desc_init(
2247 dnnl_pooling_v2_desc_t *pool_desc, dnnl_alg_kind_t alg_kind,
2248 const dnnl_memory_desc_t *diff_src_desc,
2249 const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
2250 const dnnl_dims_t kernel, const dnnl_dims_t dilation,
2251 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
2252
2253/// @} dnnl_api_pooling_v2
2254
2255/// @addtogroup dnnl_api_prelu
2256/// @{
2257
2258/// Initializes a descriptor for PReLU
2259/// (leaky ReLU with trainable alpha parameter)
2260/// forward propagation primitive.
2261///
2262/// @note
2263/// weights descriptor is allowed to be initialized with
2264/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2265///
2266/// @param prelu_desc Output descriptor for a prelu primitive.
2267/// @param prop_kind Propagation kind. Possible values are
2268/// #dnnl_forward_training and #dnnl_forward_inference.
2269/// @param data_desc Source and destination memory descriptor.
2270/// @param weights_desc Alpha parameters memory descriptor.
2271/// @returns #dnnl_success on success and a status describing the error
2272/// otherwise.
2273dnnl_status_t DNNL_API dnnl_prelu_forward_desc_init(
2274 dnnl_prelu_desc_t *prelu_desc, dnnl_prop_kind_t prop_kind,
2275 const dnnl_memory_desc_t *data_desc,
2276 const dnnl_memory_desc_t *weights_desc);
2277
2278/// Initializes a descriptor for PReLU
2279/// (leaky ReLU with trainable alpha parameter)
2280/// backward propagation primitive.
2281///
2282/// @note
2283/// weights descriptor and diff_weights descriptor are allowed
2284/// to be initialized with #dnnl_format_tag_any or with format_kind
2285/// set to #dnnl_format_kind_any.
2286///
2287/// @param prelu_desc Output descriptor for a prelu primitive.
2288/// @param data_desc Source and destination memory descriptor.
2289/// @param weights_desc Alpha parameters memory descriptor.
2290/// @param diff_data_desc Diff source and destination memory descriptor.
2291/// @param diff_weights_desc Diff alpha parameters memory descriptor.
2292/// @returns #dnnl_success on success and a status describing the error
2293/// otherwise.
2294dnnl_status_t DNNL_API dnnl_prelu_backward_desc_init(
2295 dnnl_prelu_desc_t *prelu_desc, const dnnl_memory_desc_t *data_desc,
2296 const dnnl_memory_desc_t *weights_desc,
2297 const dnnl_memory_desc_t *diff_data_desc,
2298 const dnnl_memory_desc_t *diff_weights_desc);
2299
2300/// @} dnnl_api_prelu
2301
2302/// @addtogroup dnnl_api_lrn
2303/// @{
2304
2305/// Initializes a descriptor for LRN forward propagation primitive.
2306///
2307/// @param lrn_desc Output descriptor for a LRN primitive.
2308/// @param prop_kind Propagation kind. Possible values are
2309/// #dnnl_forward_training and #dnnl_forward_inference.
2310/// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
2311/// #dnnl_lrn_within_channel.
2312/// @param data_desc Source and destination memory descriptor.
2313/// @param local_size Regularization local size.
2314/// @param alpha The alpha regularization parameter.
2315/// @param beta The beta regularization parameter.
2316/// @param k The k regularization parameter.
2317/// @returns #dnnl_success on success and a status describing the error
2318/// otherwise.
2319dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc,
2320 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
2321 const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha,
2322 float beta, float k);
2323
2324/// Initializes a descriptor for LRN backward propagation primitive.
2325///
2326/// @param lrn_desc Output descriptor for a LRN primitive.
2327/// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
2328/// #dnnl_lrn_within_channel.
2329/// @param diff_data_desc Diff source and diff destination memory descriptor.
2330/// @param data_desc Source memory descriptor.
2331/// @param local_size Regularization local size.
2332/// @param alpha The alpha regularization parameter.
2333/// @param beta The beta regularization parameter.
2334/// @param k The k regularization parameter.
2335/// @returns #dnnl_success on success and a status describing the error
2336/// otherwise.
2337dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc,
2338 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc,
2339 const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha,
2340 float beta, float k);
2341
2342/// @} dnnl_api_lrn
2343
2344/// @addtogroup dnnl_api_batch_normalization
2345/// @{
2346
2347/// Initializes a descriptor for a batch normalization forward propagation
2348/// primitive.
2349///
2350/// @note
2351/// In-place operation is supported: the dst can refer to the same memory
2352/// as the src.
2353///
2354/// @param bnrm_desc Output descriptor for batch normalization primitive.
2355/// @param prop_kind Propagation kind. Possible values are
2356/// #dnnl_forward_training and #dnnl_forward_inference.
2357/// @param data_desc Source and destination memory descriptor.
2358/// @param epsilon Batch normalization epsilon parameter.
2359/// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
2360/// @returns #dnnl_success on success and a status describing the error
2361/// otherwise.
2362dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init(
2363 dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind,
2364 const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags);
2365
2366/// Initializes a descriptor for a batch normalization backward propagation
2367/// primitive.
2368///
2369/// @note
2370/// In-place operation is supported: the diff_dst can refer to the same
2371/// memory as the diff_src.
2372///
2373/// @param bnrm_desc Output descriptor for batch normalization primitive.
2374/// @param prop_kind Propagation kind. Possible values are
2375/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
2376/// computed in this case).
2377/// @param diff_data_desc Diff source and diff destination memory descriptor.
2378/// @param data_desc Source memory descriptor.
2379/// @param epsilon Batch normalization epsilon parameter.
2380/// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
2381/// @returns #dnnl_success on success and a status describing the error
2382/// otherwise.
2383dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init(
2384 dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind,
2385 const dnnl_memory_desc_t *diff_data_desc,
2386 const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags);
2387
2388/// @} dnnl_api_batch_normalization
2389
2390/// @addtogroup dnnl_api_layer_normalization
2391/// @{
2392
2393/// Initializes a descriptor for layer normalization forward propagation
2394/// primitive.
2395///
2396/// @note
2397/// In-place operation is supported: the dst can refer to the same memory
2398/// as the src.
2399///
2400/// @param lnrm_desc Output descriptor for layer normalization primitive.
2401/// @param prop_kind Propagation kind. Possible values are
2402/// #dnnl_forward_training and #dnnl_forward_inference.
2403/// @param data_desc Source and destination memory descriptor.
2404/// @param stat_desc Memory descriptor for mean and variance. If this
2405/// parameter is NULL, a zero memory descriptor, or a memory descriptor
2406/// with format_kind set to #dnnl_format_kind_undef, then the memory
2407/// descriptor for stats is derived from @p data_desc by removing the last
2408/// dimension.
2409/// @param epsilon Layer normalization epsilon parameter.
2410/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
2411/// @returns #dnnl_success on success and a status describing the error
2412/// otherwise.
2413dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init(
2414 dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind,
2415 const dnnl_memory_desc_t *data_desc,
2416 const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags);
2417
2418/// Initializes a descriptor for a layer normalization backward propagation
2419/// primitive.
2420///
2421/// @note
2422/// In-place operation is supported: the diff_dst can refer to the same
2423/// memory as the diff_src.
2424///
2425/// @param lnrm_desc Output descriptor for layer normalization primitive.
2426/// @param prop_kind Propagation kind. Possible values are
2427/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
2428/// computed in this case).
2429/// @param diff_data_desc Diff source and diff destination memory descriptor.
2430/// @param data_desc Source memory descriptor.
2431/// @param stat_desc Memory descriptor for mean and variance. If this
2432/// parameter is NULL, a zero memory descriptor, or a memory descriptor
2433/// with format_kind set to #dnnl_format_kind_undef, then the memory
2434/// descriptor for stats is derived from @p data_desc by removing the last
2435/// dimension.
2436/// @param epsilon Layer normalization epsilon parameter.
2437/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
2438/// @returns #dnnl_success on success and a status describing the error
2439/// otherwise.
2440dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init(
2441 dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind,
2442 const dnnl_memory_desc_t *diff_data_desc,
2443 const dnnl_memory_desc_t *data_desc,
2444 const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags);
2445
2446/// @} dnnl_api_layer_normalization
2447
2448/// @addtogroup dnnl_api_layer_normalization_v2
2449/// @{
2450
2451/// Initializes a descriptor for layer normalization v2 forward propagation
2452/// primitive.
2453///
2454/// @note
2455/// In-place operation is supported: the dst can refer to the same memory
2456/// as the src.
2457///
2458/// @param lnrm_desc Output descriptor for layer normalization primitive.
2459/// @param prop_kind Propagation kind. Possible values are
2460/// #dnnl_forward_training and #dnnl_forward_inference.
2461/// @param src_desc Source memory descriptor.
2462/// @param dst_desc Destination memory descriptor.
2463/// @param stat_desc Memory descriptor for mean and variance. If this
2464/// parameter is NULL, a zero memory descriptor, or a memory descriptor
2465/// with format_kind set to #dnnl_format_kind_undef, then the memory
2466/// descriptor for stats is derived from @p data_desc by removing the last
2467/// dimension.
2468/// @param epsilon Layer normalization epsilon parameter.
2469/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
2470/// @returns #dnnl_success on success and a status describing the error
2471/// otherwise.
2472dnnl_status_t DNNL_API dnnl_layer_normalization_v2_forward_desc_init(
2473 dnnl_layer_normalization_v2_desc_t *lnrm_desc,
2474 dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc,
2475 const dnnl_memory_desc_t *dst_desc, const dnnl_memory_desc_t *stat_desc,
2476 float epsilon, unsigned flags);
2477
2478/// Initializes a descriptor for a layer normalization v2 backward propagation
2479/// primitive.
2480///
2481/// @note
2482/// In-place operation is supported: the diff_dst can refer to the same
2483/// memory as the diff_src.
2484///
2485/// @param lnrm_desc Output descriptor for layer normalization primitive.
2486/// @param prop_kind Propagation kind. Possible values are
2487/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
2488/// computed in this case).
2489/// @param diff_src_desc Diff source memory descriptor.
2490/// @param diff_dst_desc Diff destination memory descriptor.
2491/// @param src_desc Source memory descriptor.
2492/// @param stat_desc Memory descriptor for mean and variance. If this
2493/// parameter is NULL, a zero memory descriptor, or a memory descriptor
2494/// with format_kind set to #dnnl_format_kind_undef, then the memory
2495/// descriptor for stats is derived from @p data_desc by removing the last
2496/// dimension.
2497/// @param epsilon Layer normalization epsilon parameter.
2498/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
2499/// @returns #dnnl_success on success and a status describing the error
2500/// otherwise.
2501dnnl_status_t DNNL_API dnnl_layer_normalization_v2_backward_desc_init(
2502 dnnl_layer_normalization_v2_desc_t *lnrm_desc,
2503 dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_src_desc,
2504 const dnnl_memory_desc_t *diff_dst_desc,
2505 const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *stat_desc,
2506 float epsilon, unsigned flags);
2507
2508/// @} dnnl_api_layer_normalization_v2
2509
2510/// @addtogroup dnnl_api_inner_product
2511/// @{
2512
2513/// Initializes descriptor for inner product forward propagation.
2514///
2515/// @note
2516/// Memory descriptors can be initialized with
2517/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2518///
2519/// @param ip_desc Output descriptor for inner product primitive.
2520/// @param prop_kind Propagation kind. Possible values are
2521/// #dnnl_forward_training and #dnnl_forward_inference.
2522/// @param src_desc Source memory descriptor.
2523/// @param weights_desc Weights memory descriptor.
2524/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
2525/// descriptor, or a memory descriptor with format_kind set to
2526/// #dnnl_format_kind_undef disables the bias term.
2527/// @param dst_desc Destination memory descriptor.
2528/// @returns #dnnl_success on success and a status describing the error
2529/// otherwise.
2530dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init(
2531 dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind,
2532 const dnnl_memory_desc_t *src_desc,
2533 const dnnl_memory_desc_t *weights_desc,
2534 const dnnl_memory_desc_t *bias_desc,
2535 const dnnl_memory_desc_t *dst_desc);
2536
2537/// Initializes descriptor for inner product backward propagation.
2538///
2539/// @note
2540/// Memory descriptors can be initialized with
2541/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2542///
2543/// @param ip_desc Output descriptor for inner product primitive.
2544/// @param diff_src_desc Diff source memory descriptor.
2545/// @param weights_desc Weights memory descriptor.
2546/// @param diff_dst_desc Diff destination memory descriptor.
2547/// @returns #dnnl_success on success and a status describing the error
2548/// otherwise.
2549dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init(
2550 dnnl_inner_product_desc_t *ip_desc,
2551 const dnnl_memory_desc_t *diff_src_desc,
2552 const dnnl_memory_desc_t *weights_desc,
2553 const dnnl_memory_desc_t *diff_dst_desc);
2554
2555/// Initializes descriptor for inner product weights gradient primitive.
2556///
2557/// @note
2558/// Memory descriptors can be initialized with
2559/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2560///
2561/// @param ip_desc Output descriptor for inner product primitive.
2562/// @param src_desc Source memory descriptor.
2563/// @param diff_weights_desc Diff weights memory descriptor.
2564/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
2565/// memory descriptor, or a memory descriptor with format_kind set to
2566/// #dnnl_format_kind_undef disables the bias term.
2567/// @param diff_dst_desc Diff destination memory descriptor.
2568/// @returns #dnnl_success on success and a status describing the error
2569/// otherwise.
2570dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init(
2571 dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc,
2572 const dnnl_memory_desc_t *diff_weights_desc,
2573 const dnnl_memory_desc_t *diff_bias_desc,
2574 const dnnl_memory_desc_t *diff_dst_desc);
2575
2576/// @} dnnl_api_inner_product
2577
2578/// @addtogroup dnnl_api_attributes
2579/// @{
2580
2581/// Set quantization scale and shift parameters for RNN data tensors.
2582///
2583/// For performance reasons, the low-precision configuration of the RNN
2584/// primitives expects input activations to have the unsigned 8-bit integer
2585/// data type. The scale and shift parameters are used to quantize
2586/// floating-point data to unsigned integer and must be passed to the RNN
2587/// primitive using attributes.
2588///
2589/// The quantization formula is `scale * data + shift`.
2590///
2591/// @note
2592/// Quantization scale and shift are common for src_layer, src_iter,
2593/// dst_iter, and dst_layer.
2594///
2595/// Example usage:
2596/// @code
2597/// // RNN parameters
2598/// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
2599/// // Activations quantization parameters
2600/// float scale = 63.f, shift = 64.f;
2601///
2602/// dnnl_primitive_attr_t rnn_attr;
2603/// // Create default attributes
2604/// dnnl_primitive_attr_create(&rnn_attr);
2605///
2606/// // Set scale and shift for int8 quantization of activation
2607/// dnnl_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
2608///
2609/// // Create and configure rnn op_desc
2610/// dnnl_rnn_desc_t rnn_d;
2611/// dnnl_primitive_desc_t rnn_pd;
2612/// dnnl_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL);
2613/// @endcode
2614///
2615/// @param attr Primitive attributes.
2616/// @param scale The value to scale the data by.
2617/// @param shift The value to shift the data by.
2618/// @returns #dnnl_success on success and a status describing the error
2619/// otherwise.
2620dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(
2621 dnnl_primitive_attr_t attr, const float scale, const float shift);
2622
2623/// Returns the quantization scale and shift parameters for RNN data tensors.
2624///
2625/// @note
2626/// Quantization scale and shift are common for src_layer, src_iter,
2627/// dst_iter, and dst_layer.
2628///
2629/// @param attr Primitive attributes.
2630/// @param scale The value to scale the data by.
2631/// @param shift The value to shift the data by.
2632/// @returns #dnnl_success on success and a status describing the error
2633/// otherwise.
2634dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(
2635 const_dnnl_primitive_attr_t attr, float *scale, float *shift);
2636
2637/// Sets quantization scaling factors for RNN weights tensors. The
2638/// low-precision configuration of the RNN primitives expects input weights to
2639/// use the signed 8-bit integer data type. The scaling factors are used to
2640/// quantize floating-point data to signed integer and must be passed to RNN
2641/// primitives using attributes.
2642///
2643/// @note
2644/// The dimension order is always native and does not depend on the actual
2645/// layout used. For example, five-dimensional weights always have (l, d,
2646/// i, g, o) logical dimension ordering.
2647///
2648/// @note
2649/// Quantization scales are common for weights_layer and weights_iteration
2650///
2651/// @param attr Primitive attributes.
2652/// @param count Number of elements in the @p scales array.
2653/// @param mask Scaling factors correspondence mask that defines the
2654/// correspondence between the output tensor dimensions and the @p
2655/// scales vector. The set i-th bit indicates that a dedicated scaling
2656/// factor should be used for each index along that dimension. Set the
2657/// mask to 0 to use a common scaling factor for the whole output
2658/// tensor.
2659/// @param scales Array of output scaling factors that must contain @p count
2660/// values and the following equality must hold:
2661/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2662/// Violations can only be detected when the attributes are used to create
2663/// a primitive descriptor.
2664/// @returns #dnnl_success on success and a status describing the error
2665/// otherwise.
2666dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(
2667 dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
2668 const float *scales);
2669
2670/// Returns the quantization scaling factors for RNN weights tensors.
2671///
2672/// @param attr Primitive attributes.
2673/// @param count Number of elements in the @p scales array.
2674/// @param mask Scaling factors correspondence mask that defines the
2675/// correspondence between the output tensor dimensions and the @p
2676/// scales vector. The set i-th bit indicates that a dedicated scaling
2677/// factor should be used for each index along that dimension. Set the
2678/// mask to 0 to use a common scaling factor for the whole output
2679/// tensor.
2680/// @param scales Array of output scaling factors that contain @p count
2681/// values and the following equality must hold:
2682/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2683/// @returns #dnnl_success on success and a status describing the error
2684/// otherwise.
2685dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(
2686 const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
2687 const float **scales);
2688
2689/// Sets quantization scaling factors for RNN projection weights tensors. The
2690/// low-precision configuration of the RNN primitives expects input weights to
2691/// use the signed 8-bit integer data type. The scaling factors are used to
2692/// quantize floating-point data to signed integer and must be passed to RNN
2693/// primitives using attributes.
2694///
2695/// @note
2696/// The dimension order is always native and does not depend on the actual
2697/// layout used. For example, five-dimensional weights always have (l, d,
2698/// i, g, o) logical dimension ordering.
2699///
2700/// @param attr Primitive attributes.
2701/// @param count Number of elements in the @p scales array.
2702/// @param mask Scaling factors correspondence mask that defines the
2703/// correspondence between the output tensor dimensions and the @p
2704/// scales vector. The set i-th bit indicates that a dedicated scaling
2705/// factor should be used for each index along that dimension. Set the
2706/// mask to 0 to use a common scaling factor for the whole output
2707/// tensor.
2708/// @param scales Array of output scaling factors that must contain @p count
2709/// values and the following equality must hold:
2710/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2711/// Violations can only be detected when the attributes are used to create
2712/// a primitive descriptor.
2713/// @returns #dnnl_success on success and a status describing the error
2714/// otherwise.
2715dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(
2716 dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
2717 const float *scales);
2718
2719/// Returns the quantization scaling factors for RNN projection weights tensors.
2720///
2721/// @param attr Primitive attributes.
2722/// @param count Number of elements in the @p scales array.
2723/// @param mask Scaling factors correspondence mask that defines the
2724/// correspondence between the output tensor dimensions and the @p
2725/// scales vector. The set i-th bit indicates that a dedicated scaling
2726/// factor should be used for each index along that dimension. Set the
2727/// mask to 0 to use a common scaling factor for the whole output
2728/// tensor.
2729/// @param scales Array of output scaling factors that contain @p count
2730/// values and the following equality must hold:
2731/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2732/// @returns #dnnl_success on success and a status describing the error
2733/// otherwise.
2734dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
2735 const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
2736 const float **scales);
2737
2738/// @} dnnl_api_attributes
2739
2740/// @addtogroup dnnl_api_rnn
2741/// @{
2742
2743/// Initializes a descriptor for vanilla RNN forward propagation primitive.
2744///
2745/// The following arguments may either be @c NULL or point to a zero memory
2746/// descriptor:
2747/// - @p src_iter_desc,
2748/// - @p bias_desc,
2749/// - @p dst_iter_desc.
2750///
2751/// This would then indicate that the RNN forward propagation primitive should
2752/// not use them and should default to zero values instead.
2753///
2754/// @note
2755/// All memory descriptors can be initialized with
2756/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2757///
2758/// @param rnn_desc Output descriptor for vanilla RNN primitive.
2759/// @param prop_kind Propagation kind. Possible values are
2760/// #dnnl_forward_training and #dnnl_forward_inference.
2761/// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
2762/// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
2763/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2764/// info.
2765/// @param src_layer_desc Memory descriptor for the input vector.
2766/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2767/// state vector.
2768/// @param weights_layer_desc Memory descriptor for the weights applied to the
2769/// layer input.
2770/// @param weights_iter_desc Memory descriptor for the weights applied to the
2771/// recurrent input.
2772/// @param bias_desc Bias memory descriptor.
2773/// @param dst_layer_desc Memory descriptor for the output vector.
2774/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2775/// state vector.
2776/// @param flags Unused.
2777/// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
2778/// @param beta Unused.
2779/// @returns #dnnl_success on success and a status describing the error
2780/// otherwise.
2781dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init(
2782 dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
2783 const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction,
2784 const dnnl_memory_desc_t *src_layer_desc,
2785 const dnnl_memory_desc_t *src_iter_desc,
2786 const dnnl_memory_desc_t *weights_layer_desc,
2787 const dnnl_memory_desc_t *weights_iter_desc,
2788 const dnnl_memory_desc_t *bias_desc,
2789 const dnnl_memory_desc_t *dst_layer_desc,
2790 const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha,
2791 float beta);
2792
2793/// Initializes a descriptor for vanilla RNN backward propagation primitive.
2794///
2795/// The following arguments may either be @c NULL or point to a zero memory
2796/// descriptor:
2797/// - @p src_iter_desc together with @p diff_src_iter_desc,
2798/// - @p bias_desc together with @p diff_bias_desc,
2799/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
2800///
2801/// This would then indicate that the RNN backward propagation primitive should
2802/// not use the respective data and should use zero values instead.
2803///
2804/// @note
2805/// All memory descriptors can be initialized with
2806/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2807///
2808/// @param rnn_desc Output descriptor for vanilla RNN primitive.
2809/// @param prop_kind Propagation kind. Must be #dnnl_backward.
2810/// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
2811/// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
2812/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2813/// info.
2814/// @param src_layer_desc Memory descriptor for the input vector.
2815/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2816/// state vector.
2817/// @param weights_layer_desc Memory descriptor for the weights applied to the
2818/// layer input.
2819/// @param weights_iter_desc Memory descriptor for the weights applied to the
2820/// recurrent input.
2821/// @param bias_desc Bias memory descriptor.
2822/// @param dst_layer_desc Memory descriptor for the output vector.
2823/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2824/// state vector.
2825/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2826/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2827/// hidden state vector.
2828/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2829/// applied to the layer input.
2830/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2831/// applied to the recurrent input.
2832/// @param diff_bias_desc Diff bias memory descriptor.
2833/// @param diff_dst_layer_desc Memory descriptor for the diff of output
2834/// vector.
2835/// @param diff_dst_iter_desc Memory descriptor for the diff of output
2836/// recurrent hidden state vector.
2837/// @param flags Unused.
2838/// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
2839/// @param beta Unused.
2840/// @returns #dnnl_success on success and a status describing the error
2841/// otherwise.
2842dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init(
2843 dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
2844 const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction,
2845 const dnnl_memory_desc_t *src_layer_desc,
2846 const dnnl_memory_desc_t *src_iter_desc,
2847 const dnnl_memory_desc_t *weights_layer_desc,
2848 const dnnl_memory_desc_t *weights_iter_desc,
2849 const dnnl_memory_desc_t *bias_desc,
2850 const dnnl_memory_desc_t *dst_layer_desc,
2851 const dnnl_memory_desc_t *dst_iter_desc,
2852 const dnnl_memory_desc_t *diff_src_layer_desc,
2853 const dnnl_memory_desc_t *diff_src_iter_desc,
2854 const dnnl_memory_desc_t *diff_weights_layer_desc,
2855 const dnnl_memory_desc_t *diff_weights_iter_desc,
2856 const dnnl_memory_desc_t *diff_bias_desc,
2857 const dnnl_memory_desc_t *diff_dst_layer_desc,
2858 const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
2859 float alpha, float beta);
2860
2861/// Initializes a descriptor for LSTM forward propagation primitive.
2862///
2863/// The following arguments may either be @c NULL or point to a zero memory
2864/// descriptor:
2865/// - @p src_iter_desc together with @p src_iter_c_desc,
2866/// - @p bias_desc,
2867/// - @p dst_iter_desc together with @p dst_iter_c_desc.
2868///
2869/// This would then indicate that the LSTM forward propagation primitive should
2870/// not use them and should default to zero values instead.
2871///
2872/// @note
2873/// All memory descriptors can be initialized with
2874/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2875///
2876/// @sa dnnl_lstm_forward_desc_init_v2 to initialize forward LSTM with and
2877/// without peephole
2878/// @sa dnnl_lstm_forward_desc_init_v3 to initialize forward LSTM with and
2879/// without peephole / recurrent projection layer
2880///
2881/// @param rnn_desc Output descriptor for LSTM primitive.
2882/// @param prop_kind Propagation kind. Possible values are
2883/// #dnnl_forward_training and #dnnl_forward_inference.
2884/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2885/// info.
2886/// @param src_layer_desc Memory descriptor for the input vector.
2887/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2888/// state vector.
2889/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2890/// state vector.
2891/// @param weights_layer_desc Memory descriptor for the weights applied to the
2892/// layer input.
2893/// @param weights_iter_desc Memory descriptor for the weights applied to the
2894/// recurrent input.
2895/// @param bias_desc Bias memory descriptor.
2896/// @param dst_layer_desc Memory descriptor for the output vector.
2897/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2898/// state vector.
2899/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2900/// state vector.
2901/// @param flags Unused.
2902/// @returns #dnnl_success on success and a status describing the error
2903/// otherwise.
2904dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
2905 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2906 const dnnl_memory_desc_t *src_layer_desc,
2907 const dnnl_memory_desc_t *src_iter_desc,
2908 const dnnl_memory_desc_t *src_iter_c_desc,
2909 const dnnl_memory_desc_t *weights_layer_desc,
2910 const dnnl_memory_desc_t *weights_iter_desc,
2911 const dnnl_memory_desc_t *bias_desc,
2912 const dnnl_memory_desc_t *dst_layer_desc,
2913 const dnnl_memory_desc_t *dst_iter_desc,
2914 const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags);
2915
2916/// Initializes a descriptor for an LSTM (with or without peephole) forward
2917/// propagation primitive.
2918///
2919/// The following arguments may either be @c NULL or point to a zero memory
2920/// descriptor:
2921/// - @p src_iter_desc together with @p src_iter_c_desc,
2922/// - @p weights_peephole_desc,
2923/// - @p bias_desc,
2924/// - @p dst_iter_desc together with @p dst_iter_c_desc.
2925///
2926/// This would then indicate that the LSTM forward propagation primitive should
2927/// not use them and should default to zero values instead.
2928///
2929/// @note
2930/// All memory descriptors can be initialized with #dnnl_format_tag_any or
2931/// with format_kind set to #dnnl_format_kind_any.
2932///
2933/// @sa dnnl_lstm_forward_desc_init_v3 to initialize forward LSTM with and
2934/// without peephole / recurrent projection layer
2935///
2936/// @param rnn_desc Output descriptor for LSTM primitive.
2937/// @param prop_kind Propagation kind. Possible values are
2938/// #dnnl_forward_training and #dnnl_forward_inference.
2939/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2940/// info.
2941/// @param src_layer_desc Memory descriptor for the input vector.
2942/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2943/// state vector.
2944/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2945/// state vector.
2946/// @param weights_layer_desc Memory descriptor for the weights applied to the
2947/// layer input.
2948/// @param weights_iter_desc Memory descriptor for the weights applied to the
2949/// recurrent input.
2950/// @param weights_peephole_desc Memory descriptor for the weights applied to
2951/// the cell states (according to the Peephole LSTM formula).
2952/// @param bias_desc Bias memory descriptor.
2953/// @param dst_layer_desc Memory descriptor for the output vector.
2954/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2955/// state vector.
2956/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2957/// state vector.
2958/// @param flags Unused.
2959/// @returns #dnnl_success on success and a status describing the error
2960/// otherwise.
2961dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc,
2962 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2963 const dnnl_memory_desc_t *src_layer_desc,
2964 const dnnl_memory_desc_t *src_iter_desc,
2965 const dnnl_memory_desc_t *src_iter_c_desc,
2966 const dnnl_memory_desc_t *weights_layer_desc,
2967 const dnnl_memory_desc_t *weights_iter_desc,
2968 const dnnl_memory_desc_t *weights_peephole_desc,
2969 const dnnl_memory_desc_t *bias_desc,
2970 const dnnl_memory_desc_t *dst_layer_desc,
2971 const dnnl_memory_desc_t *dst_iter_desc,
2972 const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags);
2973
2974/// Initializes a descriptor for an LSTM (with or without peephole and with
2975/// or without recurrent projection layer) forward propagation primitive.
2976///
2977/// The following arguments may either be @c NULL or point to a zero memory
2978/// descriptor:
2979/// - @p src_iter_desc together with @p src_iter_c_desc,
2980/// - @p weights_peephole_desc,
2981/// - @p bias_desc,
2982/// - @p dst_iter_desc together with @p dst_iter_c_desc.
2983///
2984/// This would then indicate that the LSTM forward propagation primitive should
2985/// not use them and should default to zero values instead.
2986///
2987/// The @p weights_projection_desc could either be @c NULL or point to a zero
2988/// memory descriptor. This would then indicate that the LSTM doesn't have
2989/// recurrent projection layer.
2990///
2991/// @note
2992/// All memory descriptors can be initialized with #dnnl_format_tag_any or
2993/// with format_kind set to #dnnl_format_kind_any.
2994///
2995/// @param rnn_desc Output descriptor for LSTM primitive.
2996/// @param prop_kind Propagation kind. Possible values are
2997/// #dnnl_forward_training and #dnnl_forward_inference.
2998/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2999/// info.
3000/// @param src_layer_desc Memory descriptor for the input vector.
3001/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3002/// state vector.
3003/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
3004/// state vector.
3005/// @param weights_layer_desc Memory descriptor for the weights applied to the
3006/// layer input.
3007/// @param weights_iter_desc Memory descriptor for the weights applied to the
3008/// recurrent input.
3009/// @param weights_peephole_desc Memory descriptor for the weights applied to
3010/// the cell states (according to the Peephole LSTM formula).
3011/// @param weights_projection_desc Memory descriptor for the weights applied to
3012/// the hidden states to get the recurrent projection (according to the
3013/// Projection LSTM formula).
3014/// @param bias_desc Bias memory descriptor.
3015/// @param dst_layer_desc Memory descriptor for the output vector.
3016/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3017/// state vector.
3018/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
3019/// state vector.
3020/// @param flags Unused.
3021/// @returns #dnnl_success on success and a status describing the error
3022/// otherwise.
3023dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc,
3024 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3025 const dnnl_memory_desc_t *src_layer_desc,
3026 const dnnl_memory_desc_t *src_iter_desc,
3027 const dnnl_memory_desc_t *src_iter_c_desc,
3028 const dnnl_memory_desc_t *weights_layer_desc,
3029 const dnnl_memory_desc_t *weights_iter_desc,
3030 const dnnl_memory_desc_t *weights_peephole_desc,
3031 const dnnl_memory_desc_t *weights_projection_desc,
3032 const dnnl_memory_desc_t *bias_desc,
3033 const dnnl_memory_desc_t *dst_layer_desc,
3034 const dnnl_memory_desc_t *dst_iter_desc,
3035 const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags);
3036
3037/// Initializes a descriptor for an LSTM backward propagation primitive.
3038///
3039/// The following arguments may either be @c NULL or point to a zero memory
3040/// descriptor:
3041/// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
3042/// and @p diff_src_iter_c_desc,
3043/// - @p bias_desc together with @p diff_bias_desc,
3044/// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
3045/// and @p diff_dst_iter_c_desc.
3046///
3047/// This would then indicate that the LSTM backward propagation primitive
3048/// should not use them and should default to zero values instead.
3049///
3050/// @note
3051/// All memory descriptors can be initialized with
3052/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3053///
3054/// @sa dnnl_lstm_backward_desc_init_v2 to initialize backward LSTM with and
3055/// without peephole
3056/// @sa dnnl_lstm_backward_desc_init_v3 to initialize backward LSTM with and
3057/// without peephole / recurrent projection layer
3058///
3059/// @param rnn_desc Output descriptor for LSTM primitive.
3060/// @param prop_kind Propagation kind. Must be #dnnl_backward.
3061/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3062/// info.
3063/// @param src_layer_desc Memory descriptor for the input vector.
3064/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3065/// state vector.
3066/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
3067/// state vector.
3068/// @param weights_layer_desc Memory descriptor for the weights applied to the
3069/// layer input.
3070/// @param weights_iter_desc Memory descriptor for the weights applied to the
3071/// recurrent input.
3072/// @param bias_desc Bias memory descriptor.
3073/// @param dst_layer_desc Memory descriptor for the output vector.
3074/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3075/// state vector.
3076/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
3077/// state vector.
3078/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3079/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3080/// hidden state vector.
3081/// @param diff_src_iter_c_desc Memory descriptor for the diff of input
3082/// recurrent cell state vector.
3083/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3084/// applied to the layer input.
3085/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3086/// applied to the recurrent input.
3087/// @param diff_bias_desc Diff bias memory descriptor.
3088/// @param diff_dst_layer_desc Memory descriptor for the diff of output
3089/// vector.
3090/// @param diff_dst_iter_desc Memory descriptor for the diff of output
3091/// recurrent hidden state vector.
3092/// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
3093/// recurrent cell state vector.
3094/// @param flags Unused.
3095/// @returns #dnnl_success on success and a status describing the error
3096/// otherwise.
3097dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3098 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3099 const dnnl_memory_desc_t *src_layer_desc,
3100 const dnnl_memory_desc_t *src_iter_desc,
3101 const dnnl_memory_desc_t *src_iter_c_desc,
3102 const dnnl_memory_desc_t *weights_layer_desc,
3103 const dnnl_memory_desc_t *weights_iter_desc,
3104 const dnnl_memory_desc_t *bias_desc,
3105 const dnnl_memory_desc_t *dst_layer_desc,
3106 const dnnl_memory_desc_t *dst_iter_desc,
3107 const dnnl_memory_desc_t *dst_iter_c_desc,
3108 const dnnl_memory_desc_t *diff_src_layer_desc,
3109 const dnnl_memory_desc_t *diff_src_iter_desc,
3110 const dnnl_memory_desc_t *diff_src_iter_c_desc,
3111 const dnnl_memory_desc_t *diff_weights_layer_desc,
3112 const dnnl_memory_desc_t *diff_weights_iter_desc,
3113 const dnnl_memory_desc_t *diff_bias_desc,
3114 const dnnl_memory_desc_t *diff_dst_layer_desc,
3115 const dnnl_memory_desc_t *diff_dst_iter_desc,
3116 const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags);
3117
3118/// Initializes a descriptor for an LSTM (with or without peephole) backward
3119/// propagation primitive.
3120///
3121/// The following arguments may either be @c NULL or point to a zero memory
3122/// descriptor:
3123/// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
3124/// and @p diff_src_iter_c_desc,
3125/// - @p weights_peephole_desc together with @p diff_weights_peephole_desc,
3126/// - @p bias_desc together with @p diff_bias_desc,
3127/// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
3128/// and @p diff_dst_iter_c_desc.
3129///
3130/// This would then indicate that the LSTM backward propagation primitive
3131/// should not use them and should default to zero values instead.
3132///
3133/// @note
3134/// All memory descriptors can be initialized with #dnnl_format_tag_any or
3135/// with format_kind set to #dnnl_format_kind_any.
3136///
3137/// @sa dnnl_lstm_backward_desc_init_v3 to initialize backward LSTM with and
3138/// without peephole / recurrent projection layer
3139///
3140/// @param rnn_desc Output descriptor for LSTM primitive.
3141/// @param prop_kind Propagation kind. Must be #dnnl_backward.
3142/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3143/// info.
3144/// @param src_layer_desc Memory descriptor for the input vector.
3145/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3146/// state vector.
3147/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
3148/// state vector.
3149/// @param weights_layer_desc Memory descriptor for the weights applied to the
3150/// layer input.
3151/// @param weights_iter_desc Memory descriptor for the weights applied to the
3152/// recurrent input.
3153/// @param weights_peephole_desc Memory descriptor for the weights applied to
3154/// the cell states (according to the Peephole LSTM formula).
3155/// @param bias_desc Bias memory descriptor.
3156/// @param dst_layer_desc Memory descriptor for the output vector.
3157/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3158/// state vector.
3159/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
3160/// state vector.
3161/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3162/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3163/// hidden state vector.
3164/// @param diff_src_iter_c_desc Memory descriptor for the diff of input
3165/// recurrent cell state vector.
3166/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3167/// applied to the layer input.
3168/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3169/// applied to the recurrent input.
3170/// @param diff_weights_peephole_desc Memory descriptor for the diff of weights
3171/// applied to the cell states (according to the Peephole LSTM formula).
3172/// @param diff_bias_desc Diff bias memory descriptor.
3173/// @param diff_dst_layer_desc Memory descriptor for the diff of output
3174/// vector.
3175/// @param diff_dst_iter_desc Memory descriptor for the diff of output
3176/// recurrent hidden state vector.
3177/// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
3178/// recurrent cell state vector.
3179/// @param flags Unused.
3180/// @returns #dnnl_success on success and a status describing the error
3181/// otherwise.
3182dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v2(
3183 dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
3184 dnnl_rnn_direction_t direction,
3185 const dnnl_memory_desc_t *src_layer_desc,
3186 const dnnl_memory_desc_t *src_iter_desc,
3187 const dnnl_memory_desc_t *src_iter_c_desc,
3188 const dnnl_memory_desc_t *weights_layer_desc,
3189 const dnnl_memory_desc_t *weights_iter_desc,
3190 const dnnl_memory_desc_t *weights_peephole_desc,
3191 const dnnl_memory_desc_t *bias_desc,
3192 const dnnl_memory_desc_t *dst_layer_desc,
3193 const dnnl_memory_desc_t *dst_iter_desc,
3194 const dnnl_memory_desc_t *dst_iter_c_desc,
3195 const dnnl_memory_desc_t *diff_src_layer_desc,
3196 const dnnl_memory_desc_t *diff_src_iter_desc,
3197 const dnnl_memory_desc_t *diff_src_iter_c_desc,
3198 const dnnl_memory_desc_t *diff_weights_layer_desc,
3199 const dnnl_memory_desc_t *diff_weights_iter_desc,
3200 const dnnl_memory_desc_t *diff_weights_peephole_desc,
3201 const dnnl_memory_desc_t *diff_bias_desc,
3202 const dnnl_memory_desc_t *diff_dst_layer_desc,
3203 const dnnl_memory_desc_t *diff_dst_iter_desc,
3204 const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags);
3205
3206/// Initializes a descriptor for an LSTM (with or without peephole and with or
3207/// with out recurrent projection layer) backward propagation primitive.
3208///
3209/// The following arguments may either be @c NULL or point to a zero memory
3210/// descriptor:
3211/// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
3212/// and @p diff_src_iter_c_desc,
3213/// - @p weights_peephole_desc together with @p diff_weights_peephole_desc,
3214/// - @p bias_desc together with @p diff_bias_desc,
3215/// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
3216/// and @p diff_dst_iter_c_desc.
3217///
3218/// This would then indicate that the LSTM backward propagation primitive
3219/// should not use them and should default to zero values instead.
3220///
3221/// The @p weights_projection_desc together with @p
3222/// diff_weights_projection_desc could either be @c NULL or point to a zero
3223/// memory descriptor. This would then indicate that the LSTM doesn't have
3224/// recurrent projection layer.
3225///
3226/// @note
3227/// All memory descriptors can be initialized with #dnnl_format_tag_any or
3228/// with format_kind set to #dnnl_format_kind_any.
3229///
3230/// @param rnn_desc Output descriptor for LSTM primitive.
3231/// @param prop_kind Propagation kind. Must be #dnnl_backward.
3232/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3233/// info.
3234/// @param src_layer_desc Memory descriptor for the input vector.
3235/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3236/// state vector.
3237/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
3238/// state vector.
3239/// @param weights_layer_desc Memory descriptor for the weights applied to the
3240/// layer input.
3241/// @param weights_iter_desc Memory descriptor for the weights applied to the
3242/// recurrent input.
3243/// @param weights_peephole_desc Memory descriptor for the weights applied to
3244/// the cell states (according to the Peephole LSTM formula).
3245/// @param weights_projection_desc Memory descriptor for the weights applied to
3246/// the hidden states to get the recurrent projection (according to the
3247/// Projection LSTM formula).
3248/// @param bias_desc Bias memory descriptor.
3249/// @param dst_layer_desc Memory descriptor for the output vector.
3250/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3251/// state vector.
3252/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
3253/// state vector.
3254/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3255/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3256/// hidden state vector.
3257/// @param diff_src_iter_c_desc Memory descriptor for the diff of input
3258/// recurrent cell state vector.
3259/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3260/// applied to the layer input.
3261/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3262/// applied to the recurrent input.
3263/// @param diff_weights_peephole_desc Memory descriptor for the diff of weights
3264/// applied to the cell states (according to the Peephole LSTM formula).
3265/// @param diff_weights_projection_desc Memory descriptor for the diff of
3266/// weights applied to the hidden states to get the recurrent projection
3267/// (according to the Projection LSTM formula).
3268/// @param diff_bias_desc Diff bias memory descriptor.
3269/// @param diff_dst_layer_desc Memory descriptor for the diff of output
3270/// vector.
3271/// @param diff_dst_iter_desc Memory descriptor for the diff of output
3272/// recurrent hidden state vector.
3273/// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
3274/// recurrent cell state vector.
3275/// @param flags Unused.
3276/// @returns #dnnl_success on success and a status describing the error
3277/// otherwise.
3278dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v3(
3279 dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
3280 dnnl_rnn_direction_t direction,
3281 const dnnl_memory_desc_t *src_layer_desc,
3282 const dnnl_memory_desc_t *src_iter_desc,
3283 const dnnl_memory_desc_t *src_iter_c_desc,
3284 const dnnl_memory_desc_t *weights_layer_desc,
3285 const dnnl_memory_desc_t *weights_iter_desc,
3286 const dnnl_memory_desc_t *weights_peephole_desc,
3287 const dnnl_memory_desc_t *weights_projection_desc,
3288 const dnnl_memory_desc_t *bias_desc,
3289 const dnnl_memory_desc_t *dst_layer_desc,
3290 const dnnl_memory_desc_t *dst_iter_desc,
3291 const dnnl_memory_desc_t *dst_iter_c_desc,
3292 const dnnl_memory_desc_t *diff_src_layer_desc,
3293 const dnnl_memory_desc_t *diff_src_iter_desc,
3294 const dnnl_memory_desc_t *diff_src_iter_c_desc,
3295 const dnnl_memory_desc_t *diff_weights_layer_desc,
3296 const dnnl_memory_desc_t *diff_weights_iter_desc,
3297 const dnnl_memory_desc_t *diff_weights_peephole_desc,
3298 const dnnl_memory_desc_t *diff_weights_projection_desc,
3299 const dnnl_memory_desc_t *diff_bias_desc,
3300 const dnnl_memory_desc_t *diff_dst_layer_desc,
3301 const dnnl_memory_desc_t *diff_dst_iter_desc,
3302 const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags);
3303
3304/// Initializes a descriptor for GRU forward propagation primitive.
3305///
3306/// The following arguments may either be @c NULL or point to a zero memory
3307/// descriptor:
3308/// - @p src_iter_desc,
3309/// - @p bias_desc,
3310/// - @p dst_iter_desc.
3311///
3312/// This would then indicate that the GRU forward propagation primitive should
3313/// not use them and should default to zero values instead.
3314///
3315/// @note
3316/// All memory descriptors can be initialized with
3317/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3318///
3319/// @param rnn_desc Output descriptor for GRU primitive.
3320/// @param prop_kind Propagation kind. Possible values are
3321/// #dnnl_forward_training and #dnnl_forward_inference.
3322/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3323/// info.
3324/// @param src_layer_desc Memory descriptor for the input vector.
3325/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3326/// state vector.
3327/// @param weights_layer_desc Memory descriptor for the weights applied to the
3328/// layer input.
3329/// @param weights_iter_desc Memory descriptor for the weights applied to the
3330/// recurrent input.
3331/// @param bias_desc Bias memory descriptor.
3332/// @param dst_layer_desc Memory descriptor for the output vector.
3333/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3334/// state vector.
3335/// @param flags Unused.
3336/// @returns #dnnl_success on success and a status describing the error
3337/// otherwise.
3338dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3339 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3340 const dnnl_memory_desc_t *src_layer_desc,
3341 const dnnl_memory_desc_t *src_iter_desc,
3342 const dnnl_memory_desc_t *weights_layer_desc,
3343 const dnnl_memory_desc_t *weights_iter_desc,
3344 const dnnl_memory_desc_t *bias_desc,
3345 const dnnl_memory_desc_t *dst_layer_desc,
3346 const dnnl_memory_desc_t *dst_iter_desc, unsigned flags);
3347
3348/// Initializes a descriptor for GRU backward propagation primitive.
3349///
3350/// The following arguments may either be @c NULL or point to a zero memory
3351/// descriptor:
3352/// - @p src_iter_desc together with @p diff_src_iter_desc,
3353/// - @p bias_desc together with @p diff_bias_desc,
3354/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
3355///
3356/// This would then indicate that the GRU backward propagation primitive
3357/// should not use them and should default to zero values instead.
3358///
3359/// @note
3360/// All memory descriptors can be initialized with
3361/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3362///
3363/// @param rnn_desc Output descriptor for GRU primitive.
3364/// @param prop_kind Propagation kind. Must be #dnnl_backward.
3365/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3366/// info.
3367/// @param src_layer_desc Memory descriptor for the input vector.
3368/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3369/// state vector.
3370/// @param weights_layer_desc Memory descriptor for the weights applied to the
3371/// layer input.
3372/// @param weights_iter_desc Memory descriptor for the weights applied to the
3373/// recurrent input.
3374/// @param bias_desc Bias memory descriptor.
3375/// @param dst_layer_desc Memory descriptor for the output vector.
3376/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3377/// state vector.
3378/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3379/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3380/// hidden state vector.
3381/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3382/// applied to the layer input.
3383/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3384/// applied to the recurrent input.
3385/// @param diff_bias_desc Diff bias memory descriptor.
3386/// @param diff_dst_layer_desc Memory descriptor for the diff of output
3387/// vector.
3388/// @param diff_dst_iter_desc Memory descriptor for the diff of output
3389/// recurrent hidden state vector.
3390/// @param flags Unused.
3391/// @returns #dnnl_success on success and a status describing the error
3392/// otherwise.
3393dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3394 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3395 const dnnl_memory_desc_t *src_layer_desc,
3396 const dnnl_memory_desc_t *src_iter_desc,
3397 const dnnl_memory_desc_t *weights_layer_desc,
3398 const dnnl_memory_desc_t *weights_iter_desc,
3399 const dnnl_memory_desc_t *bias_desc,
3400 const dnnl_memory_desc_t *dst_layer_desc,
3401 const dnnl_memory_desc_t *dst_iter_desc,
3402 const dnnl_memory_desc_t *diff_src_layer_desc,
3403 const dnnl_memory_desc_t *diff_src_iter_desc,
3404 const dnnl_memory_desc_t *diff_weights_layer_desc,
3405 const dnnl_memory_desc_t *diff_weights_iter_desc,
3406 const dnnl_memory_desc_t *diff_bias_desc,
3407 const dnnl_memory_desc_t *diff_dst_layer_desc,
3408 const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags);
3409
3410/// Initializes a descriptor for LBR GRU forward propagation primitive.
3411///
3412/// The following arguments may either be @c NULL or point to a zero memory
3413/// descriptor:
3414/// - @p src_iter_desc,
3415/// - @p bias_desc,
3416/// - @p dst_iter_desc.
3417///
3418/// This would then indicate that the LBR GRU forward propagation primitive
3419/// should not use them and should default to zero values instead.
3420///
3421/// @param rnn_desc Output descriptor for LBR GRU primitive.
3422/// @param prop_kind Propagation kind. Possible values are
3423/// #dnnl_forward_training and #dnnl_forward_inference.
3424/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3425/// info.
3426/// @param src_layer_desc Memory descriptor for the input vector.
3427/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3428/// state vector.
3429/// @param weights_layer_desc Memory descriptor for the weights applied to the
3430/// layer input.
3431/// @param weights_iter_desc Memory descriptor for the weights applied to the
3432/// recurrent input.
3433/// @param bias_desc Bias memory descriptor.
3434/// @param dst_layer_desc Memory descriptor for the output vector.
3435/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3436/// state vector.
3437/// @param flags Unused.
3438/// @returns #dnnl_success on success and a status describing the error
3439/// otherwise.
3440dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3441 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3442 const dnnl_memory_desc_t *src_layer_desc,
3443 const dnnl_memory_desc_t *src_iter_desc,
3444 const dnnl_memory_desc_t *weights_layer_desc,
3445 const dnnl_memory_desc_t *weights_iter_desc,
3446 const dnnl_memory_desc_t *bias_desc,
3447 const dnnl_memory_desc_t *dst_layer_desc,
3448 const dnnl_memory_desc_t *dst_iter_desc, unsigned flags);
3449
3450/// Initializes a descriptor for LBR GRU backward propagation primitive.
3451///
3452/// The following arguments may either be @c NULL or point to a zero memory
3453/// descriptor:
3454/// - @p src_iter_desc together with @p diff_src_iter_desc,
3455/// - @p bias_desc together with @p diff_bias_desc,
3456/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
3457///
3458/// This would then indicate that the LBR GRU backward propagation primitive
3459/// should not use them and should default to zero values instead.
3460///
3461/// @note
3462/// All memory descriptors can be initialized with
3463/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3464///
3465/// @param rnn_desc Output descriptor for LBR GRU primitive.
3466/// @param prop_kind Propagation kind. Must be #dnnl_backward.
3467/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3468/// info.
3469/// @param src_layer_desc Memory descriptor for the input vector.
3470/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3471/// state vector.
3472/// @param weights_layer_desc Memory descriptor for the weights applied to the
3473/// layer input.
3474/// @param weights_iter_desc Memory descriptor for the weights applied to the
3475/// recurrent input.
3476/// @param bias_desc Bias memory descriptor.
3477/// @param dst_layer_desc Memory descriptor for the output vector.
3478/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3479/// state vector.
3480/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3481/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3482/// hidden state vector.
3483/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3484/// applied to the layer input.
3485/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3486/// applied to the recurrent input.
3487/// @param diff_bias_desc Diff bias memory descriptor.
3488/// @param diff_dst_layer_desc Memory descriptor for the diff of output
3489/// vector.
3490/// @param diff_dst_iter_desc Memory descriptor for the diff of output
3491/// recurrent hidden state vector.
3492/// @param flags Unused.
3493/// @returns #dnnl_success on success and a status describing the error
3494/// otherwise.
3495dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init(
3496 dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
3497 dnnl_rnn_direction_t direction,
3498 const dnnl_memory_desc_t *src_layer_desc,
3499 const dnnl_memory_desc_t *src_iter_desc,
3500 const dnnl_memory_desc_t *weights_layer_desc,
3501 const dnnl_memory_desc_t *weights_iter_desc,
3502 const dnnl_memory_desc_t *bias_desc,
3503 const dnnl_memory_desc_t *dst_layer_desc,
3504 const dnnl_memory_desc_t *dst_iter_desc,
3505 const dnnl_memory_desc_t *diff_src_layer_desc,
3506 const dnnl_memory_desc_t *diff_src_iter_desc,
3507 const dnnl_memory_desc_t *diff_weights_layer_desc,
3508 const dnnl_memory_desc_t *diff_weights_iter_desc,
3509 const dnnl_memory_desc_t *diff_bias_desc,
3510 const dnnl_memory_desc_t *diff_dst_layer_desc,
3511 const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags);
3512
3513/// Initializes a descriptor for AUGRU forward propagation primitive.
3514///
3515/// The following arguments may either be @c NULL or point to a zero memory
3516/// descriptor:
3517/// - @p src_iter_desc,
3518/// - @p bias_desc,
3519/// - @p dst_iter_desc.
3520///
3521/// This would then indicate that the AUGRU forward propagation primitive should
3522/// not use them and should default to zero values instead.
3523///
3524/// @note
3525/// All memory descriptors can be initialized with
3526/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3527///
3528/// @param rnn_desc Output descriptor for AUGRU primitive.
3529/// @param prop_kind Propagation kind. Possible values are
3530/// #dnnl_forward_training and #dnnl_forward_inference.
3531/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3532/// info.
3533/// @param src_layer_desc Memory descriptor for the input vector.
3534/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3535/// state vector.
3536/// @param attention_desc Memory descriptor for the attention vector.
3537/// @param weights_layer_desc Memory descriptor for the weights applied to the
3538/// layer input.
3539/// @param weights_iter_desc Memory descriptor for the weights applied to the
3540/// recurrent input.
3541/// @param bias_desc Bias memory descriptor.
3542/// @param dst_layer_desc Memory descriptor for the output vector.
3543/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3544/// state vector.
3545/// @param flags Unused.
3546/// @returns #dnnl_success on success and a status describing the error
3547/// otherwise.
3548dnnl_status_t DNNL_API dnnl_augru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3549 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3550 const dnnl_memory_desc_t *src_layer_desc,
3551 const dnnl_memory_desc_t *src_iter_desc,
3552 const dnnl_memory_desc_t *attention_desc,
3553 const dnnl_memory_desc_t *weights_layer_desc,
3554 const dnnl_memory_desc_t *weights_iter_desc,
3555 const dnnl_memory_desc_t *bias_desc,
3556 const dnnl_memory_desc_t *dst_layer_desc,
3557 const dnnl_memory_desc_t *dst_iter_desc, unsigned flags);
3558
3559/// Initializes a descriptor for AUGRU backward propagation primitive.
3560///
3561/// The following arguments may either be @c NULL or point to a zero memory
3562/// descriptor:
3563/// - @p src_iter_desc together with @p diff_src_iter_desc,
3564/// - @p bias_desc together with @p diff_bias_desc,
3565/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
3566///
3567/// This would then indicate that the AUGRU backward propagation primitive
3568/// should not use them and should default to zero values instead.
3569///
3570/// @note
3571/// All memory descriptors can be initialized with
3572/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3573///
3574/// @param rnn_desc Output descriptor for AUGRU primitive.
3575/// @param prop_kind Propagation kind. Must be #dnnl_backward.
3576/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3577/// info.
3578/// @param src_layer_desc Memory descriptor for the input vector.
3579/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3580/// state vector.
3581/// @param attention_desc Memory descriptor for the attention vector.
3582/// @param weights_layer_desc Memory descriptor for the weights applied to the
3583/// layer input.
3584/// @param weights_iter_desc Memory descriptor for the weights applied to the
3585/// recurrent input.
3586/// @param bias_desc Bias memory descriptor.
3587/// @param dst_layer_desc Memory descriptor for the output vector.
3588/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3589/// state vector.
3590/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3591/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3592/// hidden state vector.
3593/// @param diff_attention_desc Memory descriptor for the diff of attention vector.
3594/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3595/// applied to the layer input.
3596/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3597/// applied to the recurrent input.
3598/// @param diff_bias_desc Diff bias memory descriptor.
3599/// @param diff_dst_layer_desc Memory descriptor for the diff of output
3600/// vector.
3601/// @param diff_dst_iter_desc Memory descriptor for the diff of output
3602/// recurrent hidden state vector.
3603/// @param flags Unused.
3604/// @returns #dnnl_success on success and a status describing the error
3605/// otherwise.
3606dnnl_status_t DNNL_API dnnl_augru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3607 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3608 const dnnl_memory_desc_t *src_layer_desc,
3609 const dnnl_memory_desc_t *src_iter_desc,
3610 const dnnl_memory_desc_t *attention_desc,
3611 const dnnl_memory_desc_t *weights_layer_desc,
3612 const dnnl_memory_desc_t *weights_iter_desc,
3613 const dnnl_memory_desc_t *bias_desc,
3614 const dnnl_memory_desc_t *dst_layer_desc,
3615 const dnnl_memory_desc_t *dst_iter_desc,
3616 const dnnl_memory_desc_t *diff_src_layer_desc,
3617 const dnnl_memory_desc_t *diff_src_iter_desc,
3618 const dnnl_memory_desc_t *diff_attention_desc,
3619 const dnnl_memory_desc_t *diff_weights_layer_desc,
3620 const dnnl_memory_desc_t *diff_weights_iter_desc,
3621 const dnnl_memory_desc_t *diff_bias_desc,
3622 const dnnl_memory_desc_t *diff_dst_layer_desc,
3623 const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags);
3624
3625/// Initializes a descriptor for LBR AUGRU forward propagation primitive.
3626///
3627/// The following arguments may either be @c NULL or point to a zero memory
3628/// descriptor:
3629/// - @p src_iter_desc,
3630/// - @p bias_desc,
3631/// - @p dst_iter_desc.
3632///
3633/// This would then indicate that the LBR AUGRU forward propagation primitive
3634/// should not use them and should default to zero values instead.
3635///
3636/// @param rnn_desc Output descriptor for LBR AUGRU primitive.
3637/// @param prop_kind Propagation kind. Possible values are
3638/// #dnnl_forward_training and #dnnl_forward_inference.
3639/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3640/// info.
3641/// @param src_layer_desc Memory descriptor for the input vector.
3642/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3643/// state vector.
3644/// @param attention_desc Memory descriptor for the attention vector.
3645/// @param weights_layer_desc Memory descriptor for the weights applied to the
3646/// layer input.
3647/// @param weights_iter_desc Memory descriptor for the weights applied to the
3648/// recurrent input.
3649/// @param bias_desc Bias memory descriptor.
3650/// @param dst_layer_desc Memory descriptor for the output vector.
3651/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3652/// state vector.
3653/// @param flags Unused.
3654/// @returns #dnnl_success on success and a status describing the error
3655/// otherwise.
3656dnnl_status_t DNNL_API dnnl_lbr_augru_forward_desc_init(
3657 dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
3658 dnnl_rnn_direction_t direction,
3659 const dnnl_memory_desc_t *src_layer_desc,
3660 const dnnl_memory_desc_t *src_iter_desc,
3661 const dnnl_memory_desc_t *attention_desc,
3662 const dnnl_memory_desc_t *weights_layer_desc,
3663 const dnnl_memory_desc_t *weights_iter_desc,
3664 const dnnl_memory_desc_t *bias_desc,
3665 const dnnl_memory_desc_t *dst_layer_desc,
3666 const dnnl_memory_desc_t *dst_iter_desc, unsigned flags);
3667
3668/// Initializes a descriptor for LBR AUGRU backward propagation primitive.
3669///
3670/// The following arguments may either be @c NULL or point to a zero memory
3671/// descriptor:
3672/// - @p src_iter_desc together with @p diff_src_iter_desc,
3673/// - @p bias_desc together with @p diff_bias_desc,
3674/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
3675///
3676/// This would then indicate that the LBR AUGRU backward propagation primitive
3677/// should not use them and should default to zero values instead.
3678///
3679/// @note
3680/// All memory descriptors can be initialized with
3681/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3682///
3683/// @param rnn_desc Output descriptor for LBR AUGRU primitive.
3684/// @param prop_kind Propagation kind. Must be #dnnl_backward.
3685/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3686/// info.
3687/// @param src_layer_desc Memory descriptor for the input vector.
3688/// @param src_iter_desc Memory descriptor for the input recurrent hidden
3689/// state vector.
3690/// @param attention_desc Memory descriptor for the attention vector.
3691/// @param weights_layer_desc Memory descriptor for the weights applied to the
3692/// layer input.
3693/// @param weights_iter_desc Memory descriptor for the weights applied to the
3694/// recurrent input.
3695/// @param bias_desc Bias memory descriptor.
3696/// @param dst_layer_desc Memory descriptor for the output vector.
3697/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3698/// state vector.
3699/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3700/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3701/// hidden state vector.
3702/// @param diff_attention_desc Memory descriptor for the diff of attention vector.
3703/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3704/// applied to the layer input.
3705/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3706/// applied to the recurrent input.
3707/// @param diff_bias_desc Diff bias memory descriptor.
3708/// @param diff_dst_layer_desc Memory descriptor for the diff of output
3709/// vector.
3710/// @param diff_dst_iter_desc Memory descriptor for the diff of output
3711/// recurrent hidden state vector.
3712/// @param flags Unused.
3713/// @returns #dnnl_success on success and a status describing the error
3714/// otherwise.
3715dnnl_status_t DNNL_API dnnl_lbr_augru_backward_desc_init(
3716 dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
3717 dnnl_rnn_direction_t direction,
3718 const dnnl_memory_desc_t *src_layer_desc,
3719 const dnnl_memory_desc_t *src_iter_desc,
3720 const dnnl_memory_desc_t *attention_desc,
3721 const dnnl_memory_desc_t *weights_layer_desc,
3722 const dnnl_memory_desc_t *weights_iter_desc,
3723 const dnnl_memory_desc_t *bias_desc,
3724 const dnnl_memory_desc_t *dst_layer_desc,
3725 const dnnl_memory_desc_t *dst_iter_desc,
3726 const dnnl_memory_desc_t *diff_src_layer_desc,
3727 const dnnl_memory_desc_t *diff_src_iter_desc,
3728 const dnnl_memory_desc_t *diff_attention_desc,
3729 const dnnl_memory_desc_t *diff_weights_layer_desc,
3730 const dnnl_memory_desc_t *diff_weights_iter_desc,
3731 const dnnl_memory_desc_t *diff_bias_desc,
3732 const dnnl_memory_desc_t *diff_dst_layer_desc,
3733 const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags);
3734
3735/// @} dnnl_api_rnn
3736
3737/// @addtogroup dnnl_api_matmul
3738/// @{
3739
3740/// Initializes a matrix multiplication descriptor.
3741///
3742/// @param matmul_desc Output descriptor for matmul primitive.
3743/// @param src_desc Source memory descriptor (matrix A)
3744/// @param weights_desc Weights memory descriptor (matrix B)
3745/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
3746/// descriptor, or a memory descriptor with format_kind set to
3747/// #dnnl_format_kind_undef disables the bias term.
3748/// @param dst_desc Destination memory descriptor (matrix C).
3749/// @returns #dnnl_success on success and a status describing the error
3750/// otherwise.
3751dnnl_status_t DNNL_API dnnl_matmul_desc_init(dnnl_matmul_desc_t *matmul_desc,
3752 const dnnl_memory_desc_t *src_desc,
3753 const dnnl_memory_desc_t *weights_desc,
3754 const dnnl_memory_desc_t *bias_desc,
3755 const dnnl_memory_desc_t *dst_desc);
3756
3757/// @} dnnl_api_matmul
3758
3759/// @addtogroup dnnl_api_resampling Resampling
3760/// @{
3761
3762/// Initializes a descriptor for a resampling forward propagation primitive.
3763///
3764/// @note
3765/// Destination memory descriptor is allowed to be initialized with
3766/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3767///
3768///
3769/// @param resampling_desc Output descriptor for a resampling primitive.
3770/// @param prop_kind Propagation kind. Possible values are
3771/// #dnnl_forward_training and #dnnl_forward_inference.
3772/// @param alg_kind resampling algorithm kind: either #dnnl_resampling_nearest,
3773/// or #dnnl_resampling_linear.
3774/// @param factors Array of scaling factors for spatial dimension.
3775/// @param src_desc Source memory descriptor.
3776/// @param dst_desc Destination memory descriptor.
3777/// @returns #dnnl_success on success and a status describing the error
3778/// otherwise.
3779dnnl_status_t DNNL_API dnnl_resampling_forward_desc_init(
3780 dnnl_resampling_desc_t *resampling_desc, dnnl_prop_kind_t prop_kind,
3781 dnnl_alg_kind_t alg_kind, const float *factors,
3782 const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc);
3783
3784/// Initializes a descriptor for resampling backward propagation primitive.
3785///
3786/// @param resampling_desc Output descriptor for a resampling primitive.
3787/// @param alg_kind resamplinging algorithm kind: either
3788/// #dnnl_resampling_nearest, or #dnnl_resampling_linear.
3789/// @param diff_src_desc Diff source memory descriptor.
3790/// @param diff_dst_desc Diff destination memory descriptor.
3791/// @param factors Array of scaling factors for spatial dimension.
3792/// @returns #dnnl_success on success and a status describing the error
3793/// otherwise.
3794///
3795dnnl_status_t DNNL_API dnnl_resampling_backward_desc_init(
3796 dnnl_resampling_desc_t *resampling_desc, dnnl_alg_kind_t alg_kind,
3797 const float *factors, const dnnl_memory_desc_t *diff_src_desc,
3798 const dnnl_memory_desc_t *diff_dst_desc);
3799
3800/// @} dnnl_api_resampling
3801
3802/// @addtogroup dnnl_api_reduction Reduction
3803/// @{
3804
3805/// Initializes a descriptor for a reduction primitive.
3806///
3807/// @note
3808/// Destination memory descriptor is allowed to be initialized with
3809/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3810///
3811///
3812/// @param desc Output descriptor for a reduction primitive.
3813/// @param alg_kind reduction algorithm kind. Possible values:
3814/// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
3815/// #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max,
3816/// #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max,
3817/// #dnnl_reduction_norm_lp_power_p_sum.
3818/// @param p Algorithm specific parameter.
3819/// @param eps Algorithm specific parameter.
3820/// @param src_desc Source memory descriptor.
3821/// @param dst_desc Destination memory descriptor.
3822/// @returns #dnnl_success on success and a status describing the error
3823/// otherwise.
3824///
3825dnnl_status_t DNNL_API dnnl_reduction_desc_init(dnnl_reduction_desc_t *desc,
3826 dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
3827 const dnnl_memory_desc_t *dst_desc, float p, float eps);
3828
3829/// @} dnnl_api_reduction
3830
3831/// @} dnnl_api_primitives
3832
3833/// @addtogroup dnnl_api_engine
3834/// @{
3835
3836/// Returns the number of engines of a particular kind.
3837///
3838/// @param kind Kind of engines to count.
3839/// @returns Count of the engines.
3840size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind);
3841
3842/// Creates an engine.
3843///
3844/// @param engine Output engine.
3845/// @param kind Engine kind.
3846/// @param index Engine index that should be between 0 and the count of
3847/// engines of the requested kind.
3848/// @returns #dnnl_success on success and a status describing the error
3849/// otherwise.
3850dnnl_status_t DNNL_API dnnl_engine_create(
3851 dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index);
3852
3853/// Returns the kind of an engine.
3854///
3855/// @param engine Engine to query.
3856/// @param kind Output engine kind.
3857/// @returns #dnnl_success on success and a status describing the error
3858/// otherwise.
3859dnnl_status_t DNNL_API dnnl_engine_get_kind(
3860 dnnl_engine_t engine, dnnl_engine_kind_t *kind);
3861
3862/// Destroys an engine.
3863///
3864/// @param engine Engine to destroy.
3865/// @returns #dnnl_success on success and a status describing the error
3866/// otherwise.
3867dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine);
3868
3869/// @} dnnl_api_engine
3870
3871/// @addtogroup dnnl_api_stream
3872/// @{
3873
3874/// Creates an execution stream.
3875///
3876/// @param stream Output execution stream.
3877/// @param engine Engine to create the execution stream on.
3878/// @param flags Stream behavior flags (@sa dnnl_stream_flags_t).
3879/// @returns #dnnl_success on success and a status describing the error
3880/// otherwise.
3881dnnl_status_t DNNL_API dnnl_stream_create(
3882 dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags);
3883
3884/// Returns the engine of a stream object.
3885///
3886/// @param stream Stream object.
3887/// @param engine Output engine on which the stream is created.
3888/// @returns #dnnl_success on success and a status describing the error
3889/// otherwise.
3890dnnl_status_t DNNL_API dnnl_stream_get_engine(
3891 const_dnnl_stream_t stream, dnnl_engine_t *engine);
3892
3893/// Waits for all primitives in the execution stream to finish computations.
3894///
3895/// @param stream Execution stream.
3896/// @returns #dnnl_success on success and a status describing the error
3897/// otherwise.
3898dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream);
3899
3900/// Destroys an execution stream.
3901///
3902/// @param stream Execution stream to destroy.
3903/// @returns #dnnl_success on success and a status describing the error
3904/// otherwise.
3905dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream);
3906
3907/// @} dnnl_api_stream
3908
3909/// @addtogroup dnnl_api_primitive_cache
3910/// @{
3911
3912/// Returns the number of primitives that can be held in the primitive cache
3913/// at the same time.
3914///
3915/// @param capacity Primitive cache capacity to query. Concurrently
3916/// accessing @p capacity is safe.
3917/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3918/// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
3919/// success.
3920dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity);
3921
3922/// Sets a number of primitives that can be held in the primitive cache
3923/// at a time.
3924///
3925/// @param capacity Primitive cache capacity to set. If a new @p capacity is
3926/// less than a number of primitives that the primitive cache already has
3927/// then the excess entries will be evicted. Setting the @p capacity to 0
3928/// clears the primitive cache and disables it. Concurrently modifying
3929/// @p capacity is safe.
3930/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3931/// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
3932/// success.
3933dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity);
3934
3935/// @} dnnl_api_primitive_cache
3936
3937/// @addtogroup dnnl_api_mathmode Floating-point Math Mode
3938/// @{
3939
3940/// Returns the floating-point math mode that will be used by default
3941/// for all subsequently created primitives.
3942///
3943/// @param mode Output FP math mode.
3944/// @returns #dnnl_success on success and a status describing the error
3945/// otherwise.
3946dnnl_status_t DNNL_API dnnl_get_default_fpmath_mode(dnnl_fpmath_mode_t *mode);
3947
3948/// Sets the floating-point math mode that will be used by default
3949/// for all subsequently created primitives.
3950///
3951/// @param mode FP math mode. The possible values are:
3952/// #dnnl_fpmath_mode_strict,
3953/// #dnnl_fpmath_mode_bf16,
3954/// #dnnl_fpmath_mode_f16,
3955/// #dnnl_fpmath_mode_tf32,
3956/// #dnnl_fpmath_mode_any.
3957/// @returns #dnnl_success on success and a status describing the error
3958/// otherwise.
3959dnnl_status_t DNNL_API dnnl_set_default_fpmath_mode(dnnl_fpmath_mode_t mode);
3960
3961/// @} dnnl_api_mathmode
3962
3963/// @addtogroup dnnl_api_service
3964/// @{
3965
3966/// Configures verbose output to stdout.
3967///
3968/// @note
3969/// Enabling verbose output affects performance.
3970/// This setting overrides the ONEDNN_VERBOSE environment variable.
3971///
3972/// @param level Verbosity level:
3973/// - 0: no verbose output (default),
3974/// - 1: primitive information at execution,
3975/// - 2: primitive information at creation and execution.
3976/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3977/// @p level value is invalid, and #dnnl_success/#dnnl::status::success on
3978/// success.
3979dnnl_status_t DNNL_API dnnl_set_verbose(int level);
3980
3981/// Configures dumping of JIT-generated code.
3982///
3983/// @note
3984/// This setting overrides the DNNL_JIT_DUMP environment variable.
3985///
3986/// @param enable Flag value. Set to 0 to disable and set to 1 to enable.
3987/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3988/// @p flag value is invalid, and #dnnl_success/#dnnl::status::success on
3989/// success.
3990dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable);
3991
3992/// Returns library version information.
3993/// @returns Pointer to a constant structure containing
3994/// - major: major version number,
3995/// - minor: minor version number,
3996/// - patch: patch release number,
3997/// - hash: git commit hash.
3998const dnnl_version_t DNNL_API *dnnl_version(void);
3999
4000/// Sets library profiling flags. The flags define which profilers are
4001/// supported.
4002///
4003/// @note
4004/// This setting overrides DNNL_JIT_PROFILE environment variable.
4005///
4006/// @sa @ref dev_guide_profilers
4007///
4008/// @param flags Profiling flags that can contain the following bits:
4009/// - @ref DNNL_JIT_PROFILE_VTUNE -- integration with VTune Amplifier
4010/// (on by default)
4011/// - @ref DNNL_JIT_PROFILE_LINUX_JITDUMP -- produce Linux-specific
4012/// jit-pid.dump output (off by default). The location of the output
4013/// is controlled via JITDUMPDIR environment variable or via
4014/// dnnl_set_jit_profiling_jitdumpdir() function.
4015/// - @ref DNNL_JIT_PROFILE_LINUX_PERFMAP -- produce Linux-specific
4016/// perf-pid.map output (off by default). The output is always placed
4017/// into /tmp.
4018///
4019/// Passing @ref DNNL_JIT_PROFILE_NONE disables profiling completely.
4020///
4021/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
4022/// @p flags value is invalid, and #dnnl_success/#dnnl::status::success on
4023/// success.
4024dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags);
4025
4026/// Sets JIT dump output path. Only applicable to Linux and is only
4027/// used when profiling flags have DNNL_JIT_PROFILE_LINUX_PERF bit set.
4028///
4029/// After the first JIT kernel is generated, the jitdump output will be placed
4030/// into temporary directory created using the mkdtemp template
4031/// 'dir/.debug/jit/dnnl.XXXXXX'.
4032///
4033/// @sa @ref dev_guide_profilers
4034///
4035/// @note
4036/// This setting overrides JITDUMPDIR environment variable. If
4037/// JITDUMPDIR is not set, and this function is never called, the path
4038/// defaults to HOME. Passing NULL reverts the value to default.
4039///
4040/// @note
4041/// The directory is accessed only when the first JIT kernel is being
4042/// created. JIT profiling will be disabled in case of any errors
4043/// accessing or creating this directory.
4044///
4045/// @param dir JIT dump output path.
4046/// @returns #dnnl_success/#dnnl::status::success if the
4047/// output directory was set correctly and an error status otherwise.
4048/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented on Windows.
4049dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir);
4050
4051/// Sets the maximal ISA the library can dispatch to on the CPU. See
4052/// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values accepted by
4053/// the C and C++ API functions respectively.
4054///
4055/// This function has effect only once, and returns an error on subsequent
4056/// calls. It should also be invoked before any other oneDNN API call, otherwise
4057/// it may return an error.
4058///
4059/// This function overrides the DNNL_MAX_CPU_ISA environment variable. The
4060/// environment variable can be set to the desired maximal ISA name in upper
4061/// case and with dnnl_cpu_isa prefix removed. For example:
4062/// `DNNL_MAX_CPU_ISA=AVX2`.
4063///
4064/// @note
4065/// The ISAs are only partially ordered:
4066/// - SSE41 < AVX < AVX2,
4067/// - AVX2 < AVX512_CORE < AVX512_CORE_VNNI < AVX512_CORE_BF16
4068/// < AVX512_CORE_FP16 < AVX512_CORE_AMX,
4069/// - AVX2 < AVX2_VNNI.
4070///
4071/// @sa @ref dev_guide_cpu_dispatcher_control for more details
4072///
4073/// @param isa Maximal ISA the library should dispatch to. Pass
4074/// #dnnl_cpu_isa_all/#dnnl::cpu_isa::all to remove ISA restrictions
4075/// (except for ISAs with initial support in the library).
4076/// @returns #dnnl_success/#dnnl::status::success on success and a
4077/// #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the @p isa
4078/// parameter is invalid or the ISA cannot be changed at this time.
4079/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
4080/// was disabled at build time (see @ref dev_guide_build_options for more
4081/// details).
4082dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa);
4083
4084/// Gets the maximal ISA the library can dispatch to on the CPU. See
4085/// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values returned by
4086/// the C and C++ API functions respectively.
4087///
4088/// @sa @ref dev_guide_cpu_dispatcher_control for more details
4089///
4090/// @returns #dnnl_cpu_isa_t value reflecting the maximal ISA the library may
4091/// dispatch to.
4092dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void);
4093
4094/// Sets the hints flag for the CPU ISA. See #dnnl_cpu_isa_hints_t and
4095/// #dnnl::cpu_isa_hints for the list of the values accepted by the C and C++
4096/// API functions respectively.
4097///
4098/// This function has effect only once, and returns an error on subsequent
4099/// calls. It should also be invoked before any other oneDNN API call, otherwise
4100/// it may return an error.
4101///
4102/// This function overrides the DNNL_CPU_ISA_HINTS environment variable.
4103/// @sa @ref dev_guide_cpu_isa_hints for more details
4104///
4105/// @param isa_hints CPU ISA hints to be passed over to the implementation.
4106/// Pass #dnnl_cpu_isa_no_hints/#dnnl::cpu_isa_hints::no_hints to use
4107/// default features i.e. no hints.
4108/// @returns #dnnl_success/#dnnl::status::success on success and a
4109/// #dnnl_runtime_error/#dnnl::status::runtime_error if the ISA hints cannot
4110/// be specified at the current time.
4111/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
4112/// was disabled at build time (see @ref dev_guide_build_options for more
4113/// details).
4114dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints);
4115
4116/// Gets the ISA specific hints that library can follow. See
4117/// #dnnl_cpu_isa_hints_t and #dnnl::cpu_isa_hints for the list of the values
4118/// returned by the C and C++ API functions respectively.
4119///
4120/// @sa @ref dev_guide_cpu_isa_hints for more details
4121///
4122/// @returns #dnnl_cpu_isa_hints_t value reflecting the ISA specific hints the
4123/// library can follow.
4124dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void);
4125
4126/// @} dnnl_api_service
4127
4128/// @addtogroup dnnl_api_blas
4129/// @{
4130
4131/// Performs single-precision matrix-matrix multiply.
4132///
4133/// The operation is defined as:
4134///
4135/// `C := alpha * op( A ) * op( B ) + beta * C`
4136///
4137/// where
4138/// - `op( X ) = X` or `op( X ) = X**T`,
4139/// - `alpha` and `beta` are scalars, and
4140/// - `A`, `B`, and `C` are matrices:
4141/// - `op( A )` is an `MxK` matrix,
4142/// - `op( B )` is an `KxN` matrix,
4143/// - `C` is an `MxN` matrix.
4144///
4145/// The matrices are assumed to be stored in row-major order (the elements in
4146/// each of the matrix rows are contiguous in memory).
4147///
4148/// @note
4149/// This API does not support XERBLA. Instead, unlike the standard BLAS
4150/// functions, this one returns a dnnl_status_t value to allow error
4151/// handling.
4152///
4153/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
4154/// transposed, and 'T' or 't' means that A is transposed.
4155/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
4156/// transposed, and 'T' or 't' means that B is transposed.
4157/// @param M The M dimension.
4158/// @param N The N dimension.
4159/// @param K The K dimension.
4160/// @param alpha The alpha parameter that is used to scale the product of
4161/// matrices A and B.
4162/// @param A A pointer to the A matrix data.
4163/// @param lda The leading dimension for the matrix A.
4164/// @param B A pointer to the B matrix data.
4165/// @param ldb The leading dimension for the matrix B.
4166/// @param beta The beta parameter that is used to scale the matrix C.
4167/// @param C A pointer to the C matrix data.
4168/// @param ldc The leading dimension for the matrix C.
4169/// @returns #dnnl_success/#dnnl::status::success on success and a status
4170/// describing the error otherwise.
4171dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M,
4172 dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
4173 const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc);
4174
4175/// Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit
4176/// signed matrix B, and 32-bit signed resulting matrix C.
4177///
4178/// The operation is defined as:
4179///
4180/// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
4181///
4182/// where
4183/// - `op( X ) = X` or `op( X ) = X**T`,
4184/// - `alpha` and `beta` are scalars, and
4185/// - `A`, `B`, and `C` are matrices:
4186/// - `op( A )` is an `MxK` matrix,
4187/// - `op( B )` is an `KxN` matrix,
4188/// - `C` is an `MxN` matrix.
4189/// - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
4190/// - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
4191/// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
4192/// - if `offsetc = F`: the `len` must be at least `1`,
4193/// - if `offsetc = C`: the `len` must be at least `max(1, m)`,
4194/// - if `offsetc = R`: the `len` must be at least `max(1, n)`,
4195///
4196/// The matrices are assumed to be stored in row-major order (the elements in
4197/// each of the matrix rows are contiguous in memory).
4198///
4199/// @note
4200/// This API does not support XERBLA. Instead, unlike the standard BLAS
4201/// functions, this one returns a dnnl_status_t value to allow error
4202/// handling.
4203///
4204/// @warning
4205/// On some architectures saturation may happen during intermediate
4206/// computations, which would lead to unexpected results. For more
4207/// details, refer to @ref dev_guide_int8_computations.
4208///
4209/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
4210/// transposed, and 'T' or 't' means that A is transposed.
4211/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
4212/// transposed, and 'T' or 't' means that B is transposed.
4213/// @param offsetc Flag specifying how offsets should be applied to matrix C:
4214/// - 'F' means that the same offset will be applied to each element of
4215/// the matrix C,
4216/// - 'C' means that individual offset will be applied to each element
4217/// within each column,
4218/// - 'R' means that individual offset will be applied to each element
4219/// within each row.
4220/// @param M The M dimension.
4221/// @param N The N dimension.
4222/// @param K The K dimension.
4223/// @param alpha The alpha parameter that is used to scale the product of
4224/// matrices A and B.
4225/// @param A A pointer to the A matrix data.
4226/// @param lda The leading dimension for the matrix A.
4227/// @param ao The offset value for the matrix A.
4228/// @param B A pointer to the B matrix data.
4229/// @param ldb The leading dimension for the matrix B.
4230/// @param bo The offset value for the matrix B.
4231/// @param beta The beta parameter that is used to scale the matrix C.
4232/// @param C A pointer to the C matrix data.
4233/// @param ldc The leading dimension for the matrix C.
4234/// @param co An array of offset values for the matrix C. The number of
4235/// elements in the array depends on the value of @p offsetc.
4236/// @returns #dnnl_success/#dnnl::status::success on success and a status
4237/// describing the error otherwise.
4238dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc,
4239 dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
4240 dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
4241 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
4242
4243/// Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit
4244/// signed matrix B, and 32-bit signed resulting matrix C.
4245///
4246/// The operation is defined as:
4247///
4248/// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
4249///
4250/// where
4251/// - `op( X ) = X` or `op( X ) = X**T`,
4252/// - `alpha` and `beta` are scalars, and
4253/// - `A`, `B`, and `C` are matrices:
4254/// - `op( A )` is an `MxK` matrix,
4255/// - `op( B )` is an `KxN` matrix,
4256/// - `C` is an `MxN` matrix.
4257/// - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
4258/// - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
4259/// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
4260/// - if `offsetc = F`: the `len` must be at least `1`,
4261/// - if `offsetc = C`: the `len` must be at least `max(1, m)`,
4262/// - if `offsetc = R`: the `len` must be at least `max(1, n)`,
4263///
4264/// The matrices are assumed to be stored in row-major order (the elements in
4265/// each of the matrix rows are contiguous in memory).
4266///
4267/// @note
4268/// This API does not support XERBLA. Instead, unlike the standard BLAS
4269/// functions, this one returns a dnnl_status_t value to allow error
4270/// handling.
4271///
4272/// @warning
4273/// On some architectures saturation may happen during intermediate
4274/// computations, which would lead to unexpected results. For more
4275/// details, refer to @ref dev_guide_int8_computations.
4276///
4277/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
4278/// transposed, and 'T' or 't' means that A is transposed.
4279/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
4280/// transposed, and 'T' or 't' means that B is transposed.
4281/// @param offsetc Flag specifying how offsets should be applied to matrix C:
4282/// - 'F' means that the same offset will be applied to each element of
4283/// the matrix C,
4284/// - 'C' means that individual offset will be applied to each element
4285/// within each column,
4286/// - 'R' means that individual offset will be applied to each element
4287/// within each row.
4288/// @param M The M dimension.
4289/// @param N The N dimension.
4290/// @param K The K dimension.
4291/// @param alpha The alpha parameter that is used to scale the product of
4292/// matrices A and B.
4293/// @param A A pointer to the A matrix data.
4294/// @param lda The leading dimension for the matrix A.
4295/// @param ao The offset value for the matrix A.
4296/// @param B A pointer to the B matrix data.
4297/// @param ldb The leading dimension for the matrix B.
4298/// @param bo The offset value for the matrix B.
4299/// @param beta The beta parameter that is used to scale the matrix C.
4300/// @param C A pointer to the C matrix data.
4301/// @param ldc The leading dimension for the matrix C.
4302/// @param co An array of offset values for the matrix C. The number of
4303/// elements in the array depends on the value of @p offsetc.
4304/// @returns #dnnl_success/#dnnl::status::success on success and a status
4305/// describing the error otherwise.
4306dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc,
4307 dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
4308 dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
4309 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
4310
4311/// @} dnnl_api_blas
4312
4313/// @} dnnl_api
4314
4315#ifdef __cplusplus
4316}
4317#endif
4318
4319#endif /* ONEAPI_DNNL_DNNL_H */
4320