1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @file
18/// C API
19
20#ifndef ONEAPI_DNNL_DNNL_H
21#define ONEAPI_DNNL_DNNL_H
22
23#include "oneapi/dnnl/dnnl_common.h"
24#include "oneapi/dnnl/dnnl_config.h"
25#include "oneapi/dnnl/dnnl_types.h"
26#include "oneapi/dnnl/dnnl_version.h"
27
28#ifdef __cplusplus
29extern "C" {
30#endif
31
32/// @addtogroup dnnl_api
33/// @{
34
35/// @addtogroup dnnl_api_primitives
36/// @{
37
38/// @addtogroup dnnl_api_primitives_common
39/// @{
40
41/// Changes the primitive descriptor to point to the next available
42/// implementation.
43///
44/// @param primitive_desc A primitive descriptor to change.
45/// @returns #dnnl_success on success and a status describing the error
46/// otherwise.
47/// @returns #dnnl_last_impl_reached if no more implementations available,
48/// in which case the primitive descriptor itself is kept unchanged.
49dnnl_status_t DNNL_API dnnl_primitive_desc_next_impl(
50 dnnl_primitive_desc_t primitive_desc);
51
52/// Clones a primitive descriptor. The resulting primitive descriptor must be
53/// destroyed separately.
54///
55/// @param primitive_desc Output primitive descriptor.
56/// @param existing_primitive_desc Primitive descriptor to clone.
57/// @returns #dnnl_success on success and a status describing the error
58/// otherwise.
59dnnl_status_t DNNL_API dnnl_primitive_desc_clone(
60 dnnl_primitive_desc_t *primitive_desc,
61 const_dnnl_primitive_desc_t existing_primitive_desc);
62
63/// Returns a constant reference to the attributes of a primitive descriptor.
64///
65/// @warning
66/// It is an error to destroy the resulting @p attr.
67///
68/// @warning
69/// The lifetime of an @p attr is the same as that of a @p
70/// primitive_desc, so it is an error to use the @p attr once the @p
71/// primitive_desc has been destroyed.
72///
73/// @param primitive_desc Primitive descriptor.
74/// @param attr Output primitive attributes.
75/// @returns #dnnl_success on success and a status describing the error
76/// otherwise.
77dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(
78 const_dnnl_primitive_desc_t primitive_desc,
79 const_dnnl_primitive_attr_t *attr);
80
81/// Destroys a primitive descriptor.
82///
83/// @param primitive_desc Primitive descriptor to destroy.
84/// @returns #dnnl_success on success and a status describing the error
85/// otherwise.
86dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(
87 dnnl_primitive_desc_t primitive_desc);
88
89/// Queries a primitive descriptor for various pieces of information.
90///
91/// The most common use case is to query a primitive descriptor, created with
92/// source, weights, and destination memory descriptors with format tags set
93/// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this
94/// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and
95/// #dnnl_query_dst_md respectively) so that it is possible to create memory
96/// objects and reorder primitives if necessary.
97///
98/// Another typical use case is to query a primitive descriptor for workspace
99/// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this
100/// query returns #dnnl_not_required status, then workspace memory is not
101/// required.
102///
103/// @note
104/// When querying for a memory descriptor for a scratchpad, a workspace,
105/// or an optional parameter, the query will return a pointer to a zero
106/// memory descriptor if the parameter is not needed.
107///
108/// A few other use cases:
109/// - query a primitive descriptor for the implementation information string
110/// (#dnnl_query_impl_info_str)
111/// - query a primitive descriptor for the number of inputs and outputs
112/// (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32
113/// respectively)
114///
115/// @sa dnnl_query_t for more options
116///
117/// @param primitive_desc Primitive descriptor.
118/// @param what Parameter to query.
119/// @param index Index of the parameter to query for.
120/// @param result Output result. The type depends on the query. For example,
121/// it must be a @c dnnl_memory_desc_t* if querying for a memory
122/// descriptor.
123/// @returns #dnnl_success on success and a status describing the error
124/// otherwise.
125dnnl_status_t DNNL_API dnnl_primitive_desc_query(
126 const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
127 int index, void *result);
128
129/// Queries primitive descriptor for a memory descriptor.
130///
131/// @note
132/// This function is a convenience version of
133/// #dnnl_primitive_desc_query().
134///
135/// @param primitive_desc Primitive descriptor.
136/// @param what Kind of memory descriptor parameter to query for.
137/// @param index Index of the parameter to query.
138/// @returns A pointer to the requested memory descriptor.
139/// @returns A pointer to a zero memory descriptor if the parameter is not
140/// needed.
141/// @returns NULL in case of any error.
142///
143const_dnnl_memory_desc_t DNNL_API dnnl_primitive_desc_query_md(
144 const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
145 int index);
146
147/// Queries primitive descriptor for a signed 32bit int.
148///
149/// @note
150/// This function is a convenience version of
151/// #dnnl_primitive_desc_query().
152///
153/// @param primitive_desc Primitive descriptor.
154/// @param what Kind of the value to query for.
155/// @param index Index of the parameter to query.
156/// @returns The requested value.
157/// @returns 0 in case of any error (in particular if the queried entity is
158/// not of type int32_t). Note that 0 may also be the actual returned
159/// value.
160int DNNL_API dnnl_primitive_desc_query_s32(
161 const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
162 int index);
163
164/// Creates a primitive.
165///
166/// @param primitive Output primitive.
167/// @param primitive_desc Primitive descriptor used to create the primitive.
168/// @returns #dnnl_success on success and a status describing the error
169/// otherwise.
170dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive,
171 const_dnnl_primitive_desc_t primitive_desc);
172
173/// Creates a primitive from a cache blob.
174///
175/// @param primitive Output primitive.
176/// @param primitive_desc Primitive descriptor used to create the primitive.
177/// @param size Size of the cache blob in bytes.
178/// @param cache_blob Cache blob of size @p size.
179/// @returns #dnnl_success on success and a status describing the error
180/// otherwise.
181dnnl_status_t DNNL_API dnnl_primitive_create_from_cache_blob(
182 dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc,
183 size_t size, const uint8_t *cache_blob);
184
185/// Executes a primitive.
186///
187/// @param primitive Primitive to execute.
188/// @param stream Stream to use.
189/// @param nargs Number of arguments.
190/// @param args Array of arguments. Each argument is an
191/// <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
192/// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
193/// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
194/// descriptor as that returned by
195/// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
196/// @returns #dnnl_success on success and a status describing the error
197/// otherwise.
198
199/// @note If any argument in @p args is padded (padded_dims >
200/// dims), the primitive execution will assume properly zero-padded
201/// input arguments, and produce zero-padded output arguments.
202dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive,
203 dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args);
204
205/// Retrieves a constant reference to the primitive descriptor of a given
206/// primitive.
207///
208/// @warning
209/// It is an error to destroy the returned object. It is owned by the
210/// primitive. The @c const qualifier of the returned object prevents
211/// such attempts.
212///
213/// @param primitive Primitive to query for the primitive descriptor.
214/// @param primitive_desc Output primitive descriptor.
215/// @returns #dnnl_success on success and a status describing the error
216/// otherwise.
217dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(
218 const_dnnl_primitive_t primitive,
219 const_dnnl_primitive_desc_t *primitive_desc);
220
221/// Retrieves a cache blob associated with the given primitive.
222///
223/// @param primitive Primitive to query for the cache blob.
224/// @param size Size of the cache blob in bytes.
225/// @param cache_blob Cache blob of size @p size. If the @p cache_blob is
226/// nullptr then the size of the cache blob is returned in @p size.
227/// @returns #dnnl_success on success and a status describing the error
228/// otherwise.
229///
230/// @note The cache blob can be empty. It's the user's responsibility to check
231/// whether it's empty prior to passing it to
232/// #dnnl_primitive_create_from_cache_blob().
233dnnl_status_t DNNL_API dnnl_primitive_get_cache_blob(
234 const_dnnl_primitive_t primitive, size_t *size, uint8_t *cache_blob);
235
236/// Destroys a primitive.
237///
238/// @param primitive The primitive to destroy.
239/// @returns #dnnl_success on success and a status describing the error
240/// otherwise.
241dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive);
242
243/// @} dnnl_api_primitives_common
244
245/// @addtogroup dnnl_api_attributes
246/// @{
247
248/// Creates an empty (default) primitive attributes with all the parameters
249/// set to their default values.
250///
251/// Empty attributes are implied whenever the respective argument is NULL.
252///
253/// @param attr Output primitive attributes.
254/// @returns #dnnl_success on success and a status describing the error
255/// otherwise.
256dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr);
257
258/// Clones primitive attributes.
259///
260/// @param attr Output primitive attributes.
261/// @param existing_attr Primitive attributes to clone.
262/// @returns #dnnl_success on success and a status describing the error
263/// otherwise.
264dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
265 dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr);
266
267/// Destroys primitive attributes.
268///
269/// @param attr Primitive attributes to destroy.
270/// @returns #dnnl_success on success and a status describing the error
271/// otherwise.
272dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);
273
274/// Returns the floating-point math mode primitive attribute.
275///
276/// @param attr Primitive attributes.
277/// @param mode Output FP math mode.
278/// @returns #dnnl_success on success and a status describing the error
279/// otherwise.
280dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode(
281 const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode);
282
283/// Sets the floating-point math mode primitive attributes.
284///
285/// @param attr Primitive attributes.
286/// @param mode FP math mode. The possible values are:
287/// #dnnl_fpmath_mode_strict (default),
288/// #dnnl_fpmath_mode_bf16,
289/// #dnnl_fpmath_mode_f16,
290/// #dnnl_fpmath_mode_tf32,
291/// #dnnl_fpmath_mode_any.
292/// @returns #dnnl_success on success and a status describing the error
293/// otherwise.
294dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode(
295 dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode);
296
297/// Returns the primitive attributes scratchpad mode.
298///
299/// @param attr Primitive attributes.
300/// @param mode Output scratchpad mode.
301/// @returns #dnnl_success on success and a status describing the error
302/// otherwise.
303dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(
304 const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode);
305
306/// Sets primitive attributes scratchpad mode.
307///
308/// @param attr Primitive attributes.
309/// @param mode Scratchpad mode. The possible values are:
310/// #dnnl_scratchpad_mode_library (default) and
311/// #dnnl_scratchpad_mode_user.
312/// @returns #dnnl_success on success and a status describing the error
313/// otherwise.
314dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
315 dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode);
316
317/// Sets primitive attributes scaling factors for primitive operations for a
318/// given memory argument. The scaling factors must be passed at execution time
319/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
320///
321/// @sa dnnl_primitive_attr_set_scales_mask
322///
323///
324/// @param attr Primitive attributes.
325/// @param arg Parameter argument index as passed to the
326/// dnnl_primitive_execute() call.
327/// @param mask Scaling factors correspondence mask that defines the
328/// correspondence between the tensor dimensions and the @p scales array.
329/// The set i-th bit indicates that a dedicated scaling factor is used for
330/// each index along that dimension. Set the mask to 0 to use a common
331/// scaling factor for the whole output tensor.
332/// @returns #dnnl_success on success and a status describing the error
333/// otherwise.
334dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
335 dnnl_primitive_attr_t attr, int arg, int mask);
336
337/// Sets primitive attributes zero points for primitive operations for a given
338/// memory argument. The zero points must be passed at execution time
339/// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
340///
341/// @sa dnnl_primitive_attr_set_zero_points_mask
342///
343///
344/// @param attr Primitive attributes.
345/// @param arg Parameter argument index as passed to the
346/// dnnl_primitive_execute() call.
347/// @param mask Zero point correspondence mask that defines the
348/// correspondence between the tensor dimensions and the @p
349/// zero_points array. The set i-th bit indicates that a dedicated
350/// zero point is used for each index along that dimension. Set the
351/// mask to 0 to use a common zero point for the whole output tensor.
352/// @returns #dnnl_success on success and a status describing the error
353/// otherwise.
354dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
355 dnnl_primitive_attr_t attr, int arg, int mask);
356
357/// Returns primitive attributes post-ops.
358///
359/// @warning
360/// The output @p post_ops points to the internal @p attr field, so it is
361/// an error to modify or destroy them. The lifetime of @p post_ops is
362/// the same as that of the @p attr it belongs to, so it is an error to
363/// use @p post_ops after @p attr has been destroyed.
364///
365/// @param attr Primitive attributes.
366/// @param post_ops Output post-ops.
367/// @returns #dnnl_success on success and a status describing the error
368/// otherwise.
369dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(
370 const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops);
371
372/// Sets primitive attributes post-ops.
373///
374/// @note
375/// There is no way to check whether the post-ops would be supported by
376/// the target primitive. Any error will be reported by the
377/// dnnl_<primitive name>_[propagation kind]_primitive_desc_create() function call.
378///
379/// @param attr Primitive attributes.
380/// @param post_ops Post-ops to set.
381/// @returns #dnnl_success on success and a status describing the error
382/// otherwise.
383dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(
384 dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops);
385
386/// Creates empty post-ops sequence.
387///
388/// @param post_ops Output post-ops.
389/// @returns #dnnl_success on success and a status describing the error
390/// otherwise.
391dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops);
392
393/// Clones post-ops primitive attribute.
394///
395/// @param post_ops Output post-ops primitive attribute.
396/// @param existing_post_ops Post-ops primitive attribute to clone.
397/// @returns #dnnl_success on success and a status describing the error
398/// otherwise.
399dnnl_status_t DNNL_API dnnl_post_ops_clone(
400 dnnl_post_ops_t *post_ops, const_dnnl_post_ops_t existing_post_ops);
401
402/// Destroys post-ops.
403///
404/// @param post_ops Post-ops to destroy.
405/// @returns #dnnl_success on success and a status describing the error
406/// otherwise.
407dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops);
408
409/// Returns the length of post-ops.
410///
411/// @param post_ops Post-ops.
412/// @returns The number of post-ops entries.
413int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops);
414
415/// Returns the kind of a post-op entry.
416///
417/// @param post_ops Post-ops.
418/// @param index Post-op entry index.
419/// @returns The kind of the post-op with the specified index.
420/// @returns #dnnl_undefined_primitive if there is no post-op at the specified
421/// index.
422dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(
423 const_dnnl_post_ops_t post_ops, int index);
424
425/// Appends an accumulation v3 (sum) to post-ops. Prior to accumulating the
426/// result, a zero point is subtracted from the previous value and is
427/// multiplied by the scale.
428///
429/// The kind of this post-op is #dnnl_sum.
430///
431/// This feature may improve performance for cases like dequantize the
432/// asymmetrically quantized sum's src1 tensor to f32 domain before performing
433/// the sum operation by subtracting the @p zero_point before the scaling.
434///
435/// In the simplest case where accumulation is the only post-op, the
436/// computations will be:
437///
438/// dst[:] <- scale * (dst[:] - zero_point) + op(...)
439/// // instead of dst[:] <- op(...)
440///
441/// If @p data_type is specified, original dst tensor will be reinterpreted
442/// as a tensor with provided data type. Since it is reinterpretation,
443/// data_type and dst data type should have the same size.
444/// As a result, computations will be:
445///
446/// dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...)
447/// // instead of dst[:] <- op(...)
448/// @note
449/// This post-op executes in-place and does not change the
450/// destination layout.
451///
452/// @param post_ops Post-ops.
453/// @param scale Accumulation scaling factor.
454/// @param zero_point Single scalar int32_t value of zero point.
455/// @param data_type Accumulation data_type.
456/// @returns #dnnl_success on success and a status describing the error
457/// otherwise.
458dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops,
459 float scale, int32_t zero_point, dnnl_data_type_t data_type);
460
461/// Returns the parameters of an accumulation (sum) post-op with
462/// zero point and data type parameter.
463///
464/// @param post_ops Post-ops.
465/// @param index Index of the sum post-op.
466/// @param scale Output accumulation scaling factor.
467/// @param zero_point Zero point.
468/// @param data_type Data type for accumulation.
469/// @returns #dnnl_success on success and a status describing the error
470/// otherwise.
471dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(
472 const_dnnl_post_ops_t post_ops, int index, float *scale,
473 int32_t *zero_point, dnnl_data_type_t *data_type);
474
475/// Appends an elementwise post-op.
476///
477/// The kind of this post operation is #dnnl_eltwise.
478///
479/// In the simplest case when the elementwise is the only post operation, the
480/// computations would be:
481///
482/// dst[:] <- eltwise_op (op(...)) // instead of dst[:] <- op(...)
483///
484/// where eltwise_op is configured with the given parameters.
485///
486/// @param post_ops Post-ops.
487/// @param alg_kind Elementwise algorithm for the post-op.
488/// @param alpha Alpha parameter for the elementwise algorithm.
489/// @param beta Beta parameter for the elementwise algorithm.
490/// @returns #dnnl_success on success and a status describing the error
491/// otherwise.
492dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops,
493 dnnl_alg_kind_t alg_kind, float alpha, float beta);
494
495/// Returns the parameters of an elementwise post-op.
496///
497/// @param post_ops Post-ops.
498/// @param index Index of the elementwise post-op.
499/// @param alg_kind Output elementwise algorithm kind.
500/// @param alpha Output alpha parameter for the elementwise algorithm.
501/// @param beta Output beta parameter for the elementwise algorithm.
502/// @returns #dnnl_success on success and a status describing the error
503/// otherwise.
504/// @returns #dnnl_invalid_arguments if @p index does not refer to an
505/// elementwise post-op.
506dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
507 const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
508 float *alpha, float *beta);
509
510/// Appends a depthwise post-op convolution.
511///
512/// This post-op can only be fused with a 2D 1x1 convolution (convolution with
513/// weights spatial dimensions equal to 1 i.e., kh=kw=1).
514///
515/// The kind of this post-op is #dnnl_convolution.
516///
517/// The number of outputs for primitive with fusion is one. The output spatial
518/// size can be derived as below:
519///
520/// output_height = ceil(output_height_1x1_convolution, stride)
521/// output_width = ceil(output_width_1x1_convolution, stride)
522///
523/// See @ref dev_guide_attributes_post_ops_depthwise and
524/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
525///
526/// @param post_ops Post-ops.
527/// @param weights_data_type Weights data type of depthwise post-op
528/// @param bias_data_type Bias data type of depthwise post-op
529/// @param dst_data_type Output data type of depthwise post-op
530/// @param kernel_size Size of kernel of depthwise post-op
531/// @param stride_size Size of stride of depthwise post-op
532/// @param padding_l_size Size of left and top paddings of depthwise post-op
533/// @returns #dnnl_success on success and a status describing the error
534/// otherwise
535dnnl_status_t DNNL_API dnnl_post_ops_append_dw(dnnl_post_ops_t post_ops,
536 dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
537 dnnl_data_type_t dst_data_type, dnnl_dim_t kernel_size,
538 dnnl_dim_t stride_size, dnnl_dim_t padding_l_size);
539
540/// Returns the parameters of an depthwise post-op.
541///
542/// @param post_ops Post-ops.
543/// @param index Index of the elementwise post-op.
544/// @param weights_data_type Weights data type of depthwise post-op
545/// @param bias_data_type Bias data type of depthwise post-op
546/// @param dst_data_type Output data type of depthwise post-op
547/// @param kernel_size Size of kernel of depthwise post-op
548/// @param stride_size Size of stride of depthwise post-op
549/// @param padding_l_size Size of left and top paddings of depthwise post-op
550/// @returns #dnnl_success on success and a status describing the error
551/// otherwise
552dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
553 const_dnnl_post_ops_t post_ops, int index,
554 dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
555 dnnl_data_type_t *dst_data_type, dnnl_dim_t *kernel_size,
556 dnnl_dim_t *stride_size, dnnl_dim_t *padding_l_size);
557
558/// Appends a binary post-op.
559///
560/// The kind of this post operation is #dnnl_binary.
561///
562/// In the simplest case when the binary is the only post operation, the
563/// computations would be:
564///
565/// dst[:] <- binary_op (dst[:], another_input[:])
566///
567/// where binary_op is configured with the given parameters. binary_op supports
568/// broadcast semantics for a second operand.
569///
570/// @param post_ops Post-ops.
571/// @param alg_kind Binary algorithm for the post-op.
572/// @param src1_desc Memory descriptor of a second operand.
573/// @returns #dnnl_success on success and a status describing the error
574/// otherwise.
575dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
576 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc);
577
578/// Returns the parameters of a binary post-op.
579///
580/// @param post_ops Post-ops.
581/// @param index Index of the binary post-op.
582/// @param alg_kind Output binary algorithm kind.
583/// @param src1_desc Output memory descriptor of a second operand.
584/// @returns #dnnl_success on success and a status describing the error
585/// otherwise.
586/// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
587/// post-op.
588dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
589 const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
590 const_dnnl_memory_desc_t *src1_desc);
591
592/// Appends a prelu forward post-op.
593///
594/// The kind of this post-op is #dnnl::primitive::kind::prelu.
595///
596/// The post-op can be defined as:
597///
598/// dst[:] <- prelu(dst[:], weights[:])
599/// prelu:
600/// dst[:] <- dst[:] if dst[:] > 0
601/// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
602///
603///
604/// @note
605/// The order of dimensions does not depend on how elements are laid
606/// out in memory. For example:
607/// - for a 2D CNN activations tensor the order is always (n, c)
608/// - for a 4D CNN activations tensor the order is always (n, c, h, w)
609/// - for a 5D CNN weights tensor the order is always
610/// (g, oc, ic, kh, kw)
611///
612/// Prelu weights tensor is passed in runtime execution phase. Prelu
613/// weights tensor data type is implicitly assumed as f32 using plain
614/// layout (a, ab, acb, acdb, acdeb)
615///
616/// @param post_ops Post-ops.
617/// @param mask Defines the correspondence between the output tensor
618/// dimensions and the prelu weights tensor. The set i-th bit indicates
619/// that a dedicated weights value is used for each index along that
620/// dimension. Set the mask to 0 to use a common weights value
621/// for the whole output tensor.
622/// @returns #dnnl_success on success and a status describing the error
623/// otherwise.
624dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
625 dnnl_post_ops_t post_ops, int mask);
626
627/// Returns the parameters of a prelu post-op.
628///
629/// @param post_ops Post-ops.
630/// @param index Index of the prelu post-op.
631/// @param mask Mask of the prelu post-op.
632/// @returns #dnnl_success on success and a status describing the error
633/// otherwise.
634dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
635 const_dnnl_post_ops_t post_ops, int index, int *mask);
636
637/// @} dnnl_api_attributes
638
639/// @} dnnl_api_primitives
640
641/// @addtogroup dnnl_api_memory
642/// @{
643
644/// Destroys a memory descriptor.
645///
646/// @param memory_desc Memory descriptor to destroy.
647/// @returns #dnnl_success on success and a status describing the error
648/// otherwise.
649dnnl_status_t DNNL_API dnnl_memory_desc_destroy(dnnl_memory_desc_t memory_desc);
650
651/// Clones a memory descriptor. The resulting memory descriptor must be
652/// destroyed separately.
653///
654/// @param memory_desc Output memory descriptor.
655/// @param existing_memory_desc Memory descriptor to clone.
656/// @returns #dnnl_success on success and a status describing the error
657/// otherwise.
658dnnl_status_t DNNL_API dnnl_memory_desc_clone(dnnl_memory_desc_t *memory_desc,
659 const_dnnl_memory_desc_t existing_memory_desc);
660
661/// Creates a memory descriptor using dimensions and strides.
662///
663/// @note
664/// As always, the logical order of dimensions corresponds to the `abc...`
665/// format tag, and the physical meaning of the dimensions depends on both
666/// the primitive that consumes the memory and the context of that
667/// consumption.
668///
669/// @param memory_desc Output memory descriptor.
670/// @param ndims Number of dimensions
671/// @param dims Array of dimensions.
672/// @param data_type Elements data type.
673/// @param strides Strides in each dimension.
674/// @returns #dnnl_success on success and a status describing the error
675/// otherwise.
676dnnl_status_t DNNL_API dnnl_memory_desc_create_with_strides(
677 dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
678 dnnl_data_type_t data_type, const dnnl_dims_t strides);
679
680/// Creates a memory descriptor using dimensions and memory format tag.
681///
682/// @note
683/// As always, the logical order of dimensions corresponds to the `abc...`
684/// format tag, and the physical meaning of the dimensions depends on both
685/// the primitive that consumes the memory and the context of that
686/// consumption.
687///
688/// @param memory_desc Output memory descriptor.
689/// @param ndims Number of dimensions
690/// @param dims Array of dimensions.
691/// @param data_type Elements data type.
692/// @param tag Memory format tag. Can be #dnnl_format_tag_any which would
693/// allow a primitive to chose the final memory format. In this case the
694/// format_kind field of the memory descriptor would be set to
695/// #dnnl_format_kind_any.
696/// @returns #dnnl_success on success and a status describing the error
697/// otherwise.
698dnnl_status_t DNNL_API dnnl_memory_desc_create_with_tag(
699 dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
700 dnnl_data_type_t data_type, dnnl_format_tag_t tag);
701
702/// Creates a memory descriptor for a region inside an area
703/// described by an existing memory descriptor.
704///
705/// @warning
706/// Some combinations of physical memory layout and/or offsets or dims may
707/// result in a failure to create a submemory.
708//
709/// @param memory_desc Output memory descriptor.
710/// @param parent_memory_desc An existing memory descriptor.
711/// @param dims Sizes of the region.
712/// @param offsets Offsets to the region from the encompassing
713/// memory object in each dimension
714/// @returns #dnnl_success on success and a status describing the error
715/// otherwise.
716dnnl_status_t DNNL_API dnnl_memory_desc_create_submemory(
717 dnnl_memory_desc_t *memory_desc,
718 const_dnnl_memory_desc_t parent_memory_desc, const dnnl_dims_t dims,
719 const dnnl_dims_t offsets);
720
721/// Creates a memory descriptor by reshaping an existing one. The new
722/// memory descriptor inherits the data type. This operation is valid only for
723/// memory descriptors that have format_kind #dnnl_blocked or
724/// #dnnl_format_kind_any.
725///
726/// The resulting memory descriptor must be destroyed separately.
727///
728/// The operation ensures the transformation of the physical memory format
729/// corresponds to the transformation of the logical dimensions. If such
730/// transformation is impossible, the function returns #dnnl_invalid_arguments.
731///
732/// The reshape operation can be described as a combination of the following
733/// basic operations:
734/// 1. Add a dimension of size `1`. This is always possible.
735/// 2. Remove a dimension of size `1`. This is possible only if the dimension
736/// has no padding (i.e. `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
737/// 3. Split a dimension into multiple ones. This is possible only if the size
738/// of the dimension is exactly equal to the product of the split ones and
739/// the dimension does not have padding (i.e.
740/// `padded_dims[dim] = dims[dim]`).
741/// 4. Joining multiple consecutive dimensions into a single one. As in the
742/// cases above, this requires that the dimensions do not have padding and
743/// that the memory format is such that in physical memory these dimensions
744/// are dense and have the same order as their logical counterparts. This
745/// also assumes that these dimensions are not blocked.
746/// - Here, dense means:
747/// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
748/// - And same order means:
749/// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
750///
751/// @warning
752/// Some combinations of physical memory layout and/or offsets or
753/// dimensions may result in a failure to make a reshape.
754///
755/// @param out_memory_desc Output memory descriptor.
756/// @param in_memory_desc An existing memory descriptor. Must have format_kind
757/// set to #dnnl_blocked or #dnnl_format_kind_any.
758/// @param ndims Number of dimensions for the output memory descriptor.
759/// @param dims Dimensions for the output memory descriptor.
760/// @returns #dnnl_success on success and a status describing the error
761/// otherwise.
762dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
763 dnnl_memory_desc_t *out_memory_desc,
764 const_dnnl_memory_desc_t in_memory_desc, int ndims,
765 const dnnl_dims_t dims);
766
767/// Creates a memory descriptor by permuting axes in an existing one.
768///
769/// The physical memory layout representation is adjusted accordingly to
770/// maintain the consistency between the logical and physical parts of the
771/// memory descriptor.
772///
773/// The resulting memory descriptor must be destroyed separately.
774///
775/// The new memory descriptor inherits the data type. This operation is valid
776/// only for memory descriptors that have format_kind set to #dnnl_blocked or
777/// #dnnl_format_kind_any.
778///
779/// The logical axes will be permuted in the following manner:
780/// ```
781/// for (i: 0 .. in_memory_desc->ndims)
782/// out_memory_desc->dims[permutation[i]] = in_memory_desc->dims[i];
783/// ```
784///
785/// Example:
786/// @code
787/// dnnl_memory_desc_t in_md, out_md, expect_out_md;
788///
789/// const int permutation[] = {1, 0}; // swap the first and the second axes
790///
791/// dnnl_dims_t in_dims = {2, 3}, out_dims = {3, 2};
792/// dnnl_format_tag_t in_tag = dnnl_ab, out_tag = dnnl_ba;
793///
794/// dnnl_memory_desc_create_with_tag(
795/// &in_md, 2, in_dims, data_type, in_tag);
796/// dnnl_memory_desc_create_with_tag(
797/// &expect_out_md, 2, out_dims, data_type, out_tag);
798///
799/// dnnl_memory_desc_permute_axes(&out_md, in_md, permutation);
800/// assert(dnnl_memory_desc_equal(out_md, expect_out_md));
801///
802/// dnnl_memory_desc_destroy(in_md);
803/// dnnl_memory_desc_destroy(out_md);
804/// dnnl_memory_desc_destroy(expect_out_md);
805/// @endcode
806///
807/// @param out_memory_desc Output memory descriptor.
808/// @param in_memory_desc An existing memory descriptor. Must have format_kind
809/// set to #dnnl_blocked or #dnnl_format_kind_any.
810/// @param permutation Axes permutation (of size `in_memory_desc->ndims`).
811/// @returns #dnnl_success on success and a status describing the error
812/// otherwise.
813dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
814 dnnl_memory_desc_t *out_memory_desc,
815 const_dnnl_memory_desc_t in_memory_desc, const int *permutation);
816
817/// Queries a memory descriptor for various pieces of information.
818///
819/// The following information can be queried:
820/// - Number of dimensions (#dnnl_query_ndims_s32)
821/// - Dimensions (#dnnl_query_dims) in the following order:
822/// - CNN data tensors: mini-batch, channel, spatial
823/// (<code>{N, C, [[D,] H,] W}</code>)
824/// - CNN weight tensors: group (optional), output channel, input channel,
825/// spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
826/// - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
827/// or layers, directions, states, mini-batch, channels
828/// (<code>{L, D, S, N, C}</code>)
829/// - RNN weight tensor: layers, directions, input channel, gates, output
830/// channels (<code>{L, D, I, G, O}</code>)
831/// - Data type of the tensor elements (#dnnl_query_data_type)
832/// - Padded dimensions (#dnnl_query_padded_dims) - size of the data including
833/// padding in each dimension
834/// - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from
835/// the padding to actual data, the top-level tensor with offsets applied
836/// must lie within the padding area.
837/// - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory
838/// origin to the current block, non-zero only in a description of a memory
839/// sub-block.
840/// - Format kind (#dnnl_query_format_kind) - memory format kind
841///
842/// @note
843/// The order of dimensions does not depend on the memory format, so
844/// whether the data is laid out in #dnnl_nchw or #dnnl_nhwc
845/// the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
846///
847/// The following queries are applicable only to format kind #dnnl_blocked.
848/// - Strides (#dnnl_query_strides) between the outermost blocks or in case
849/// of plain (non-blocked) formats the strides between dimensions
850/// - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g.
851/// `{4, 16, 4}` in case of `OIhw_4i16o4i`
852/// - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case
853/// of `OIhw_4i16o4i_`
854/// - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}`
855/// in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim
856///
857/// @param memory_desc Memory descriptor.
858/// @param what Parameter to query.
859/// @param result Output result. The type depends on the query. For example,
860/// it must be a @c dnnl_dims_t** if querying for a strides.
861/// @returns #dnnl_success on success and a status describing the error
862/// otherwise.
863dnnl_status_t DNNL_API dnnl_memory_desc_query(
864 const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, void *result);
865
866/// Compares two memory descriptors.
867///
868/// Use this function to identify whether a reorder is required between the
869/// two memories
870///
871/// @param lhs Left-hand side of the comparison.
872/// @param rhs Right-hand side of the comparison.
873/// @returns 1 if the descriptors are the same.
874/// @returns 0 if the descriptors are different.
875int DNNL_API dnnl_memory_desc_equal(
876 const_dnnl_memory_desc_t lhs, const_dnnl_memory_desc_t rhs);
877
878/// Returns the size of a memory descriptor.
879///
880/// @param memory_desc Memory descriptor.
881/// @returns The number of bytes required for memory described by a memory
882/// descriptor.
883size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc);
884
885/// Returns the size of data type.
886///
887/// @param data_type Data type.
888/// @returns The number of bytes occupied by data type.
889size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type);
890
891/// Creates a memory object.
892///
893/// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory
894/// object will have the underlying buffer set. In this case, the buffer will
895/// be initialized as if dnnl_memory_set_data_handle() had been called.
896///
897/// @sa dnnl_memory_set_data_handle()
898///
899/// @param memory Output memory object.
900/// @param memory_desc Memory descriptor.
901/// @param engine Engine to use.
902/// @param handle Handle of the memory buffer to use as an underlying storage.
903/// - A pointer to the user-allocated buffer. In this case the library
904/// doesn't own the buffer.
905/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
906/// allocate the buffer for the memory object. In this case the library
907/// owns the buffer.
908/// - DNNL_MEMORY_NONE to create dnnl_memory without an underlying buffer.
909/// @returns #dnnl_success on success and a status describing the error
910/// otherwise.
911dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory,
912 const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
913 void *handle);
914
915/// Returns the memory descriptor for a memory object.
916///
917/// @param memory Memory object.
918/// @param memory_desc Output memory descriptor (a copy).
919/// @returns #dnnl_success on success and a status describing the error
920/// otherwise.
921dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(
922 const_dnnl_memory_t memory, const_dnnl_memory_desc_t *memory_desc);
923
924/// Returns the engine of a memory object.
925///
926/// @param memory Memory object.
927/// @param engine Output engine on which the memory is located.
928/// @returns #dnnl_success on success and a status describing the error
929/// otherwise.
930dnnl_status_t DNNL_API dnnl_memory_get_engine(
931 const_dnnl_memory_t memory, dnnl_engine_t *engine);
932
933/// Maps a memory object and returns a host-side pointer to a memory buffer
934/// with a copy of its contents.
935///
936/// Mapping enables explicit direct access to memory contents for the engines
937/// that do not support it implicitly.
938///
939/// Mapping is an exclusive operation - a memory object cannot be used in
940/// other operations until this memory object is unmapped.
941///
942/// @note
943/// Any primitives working with @p memory should be completed before
944/// the memory is mapped. Use dnnl_stream_wait to synchronize the
945/// corresponding execution stream.
946///
947/// @note
948/// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
949/// mainly provided for debug and testing purposes, and their performance
950/// may be suboptimal.
951///
952/// @param memory Memory object.
953/// @param mapped_ptr Output pointer to the mapped buffer.
954/// @returns #dnnl_success on success and a status describing the error
955/// otherwise.
956dnnl_status_t DNNL_API dnnl_memory_map_data(
957 const_dnnl_memory_t memory, void **mapped_ptr);
958
959/// Unmaps a memory object and writes back any changes made to the previously
960/// mapped memory buffer. The pointer to the mapped buffer must be obtained
961/// via the dnnl_memory_map_data() call.
962///
963/// @note
964/// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
965/// mainly provided for debug and testing purposes, and their performance
966/// may be suboptimal.
967///
968/// @param memory Memory object.
969/// @param mapped_ptr Pointer to the mapped buffer that must have been
970/// obtained using the dnnl_memory_map_data() function.
971/// @returns #dnnl_success on success and a status describing the error
972/// otherwise.
973dnnl_status_t DNNL_API dnnl_memory_unmap_data(
974 const_dnnl_memory_t memory, void *mapped_ptr);
975
976/// Returns memory object's data handle.
977///
978/// @param memory Memory object.
979/// @param handle Output data handle. For the CPU engine, the data handle is a
980/// pointer to the actual data. For OpenCL it is a cl_mem.
981/// @returns #dnnl_success on success and a status describing the error
982/// otherwise.
983dnnl_status_t DNNL_API dnnl_memory_get_data_handle(
984 const_dnnl_memory_t memory, void **handle);
985
986/// Sets the underlying memory buffer.
987///
988/// @param memory Memory object.
989/// @param handle Data handle. For the CPU engine or when USM is used, the
990/// memory buffer is a pointer to the actual data. For OpenCL it is a
991/// `cl_mem`.
992/// @returns #dnnl_success on success and a status describing the error
993/// otherwise.
994dnnl_status_t DNNL_API dnnl_memory_set_data_handle(
995 dnnl_memory_t memory, void *handle);
996
997/// Destroys a memory object.
998///
999/// @param memory Memory object to destroy.
1000/// @returns #dnnl_success on success and a status describing the error
1001/// otherwise.
1002dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);
1003
1004/// @} dnnl_api_memory
1005
1006/// @addtogroup dnnl_api_primitives
1007/// @{
1008
1009/// @addtogroup dnnl_api_reorder
1010/// @{
1011
1012/// Creates a primitive descriptor for a reorder primitive.
1013///
1014/// @param reorder_primitive_desc Output primitive descriptor.
1015/// @param src_desc Source memory descriptor.
1016/// @param src_engine Engine on which the source memory object will be
1017/// located.
1018/// @param dst_desc Destination memory descriptor.
1019/// @param dst_engine Engine on which the destination memory object
1020/// will be located.
1021/// @param attr Primitive attributes to use (can be NULL).
1022/// @returns #dnnl_success on success and a status describing the error
1023/// otherwise.
1024dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
1025 dnnl_primitive_desc_t *reorder_primitive_desc,
1026 const_dnnl_memory_desc_t src_desc, dnnl_engine_t src_engine,
1027 const_dnnl_memory_desc_t dst_desc, dnnl_engine_t dst_engine,
1028 const_dnnl_primitive_attr_t attr);
1029
1030/// @} dnnl_api_reorder
1031
1032/// @addtogroup dnnl_api_concat
1033/// @{
1034
1035/// Creates a primitive descriptor for an out-of-place concatenation
1036/// primitive.
1037///
1038/// @param concat_primitive_desc Output primitive descriptor.
1039/// @param dst_desc Destination memory descriptor.
1040/// @param n Number of source parameters.
1041/// @param concat_dimension Source tensors will be concatenated over
1042/// dimension with this index. Note that order of dimensions does
1043/// not depend on memory format.
1044/// @param src_descs Array of source memory descriptors with @p n elements.
1045/// @param attr Primitive attributes to use (can be NULL).
1046/// @param engine Engine to use.
1047/// @returns #dnnl_success on success and a status describing the error
1048/// otherwise.
1049dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
1050 dnnl_primitive_desc_t *concat_primitive_desc, dnnl_engine_t engine,
1051 const_dnnl_memory_desc_t dst_desc, int n, int concat_dimension,
1052 const_dnnl_memory_desc_t const *src_descs,
1053 const_dnnl_primitive_attr_t attr);
1054
1055/// @} dnnl_api_concat
1056
1057/// @addtogroup dnnl_api_sum
1058/// @{
1059
1060/// Creates a primitive descriptor for an (out-of-place) sum primitive.
1061///
1062/// @param sum_primitive_desc Output primitive descriptor.
1063/// @param dst_desc Destination memory descriptor.
1064/// @param n Number of source parameters.
1065/// @param scales Vector of scales to multiply data in each source
1066/// memory by.
1067/// @param src_descs Array of source memory descriptors having @p n elements.
1068/// @param attr Primitive attributes to use (can be NULL).
1069/// @param engine Engine to use.
1070/// @returns #dnnl_success on success and a status describing the error
1071/// otherwise.
1072dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
1073 dnnl_primitive_desc_t *sum_primitive_desc, dnnl_engine_t engine,
1074 const_dnnl_memory_desc_t dst_desc, int n, const float *scales,
1075 const_dnnl_memory_desc_t const *src_descs,
1076 const_dnnl_primitive_attr_t attr);
1077
1078/// @} dnnl_api_sum
1079
1080/// @addtogroup dnnl_api_binary
1081/// @{
1082
1083/// Creates a primitive descriptor for a binary primitive.
1084///
1085/// @note
1086/// Memory descriptors @p src1_desc and @p dst_desc are alloweded to be
1087/// initialized with #dnnl_format_tag_any or with format_kind set to
1088/// #dnnl_format_kind_any.
1089///
1090/// @note
1091/// Both memory descriptors must have the same number of dimensions.
1092/// Element broadcasting is supported for memory descriptor @p src1_desc
1093/// and are applied to @p src1_desc dimensions that have size equal to 1.
1094///
1095/// @param primitive_desc Output primitive descriptor.
1096/// @param engine Engine to use.
1097/// @param alg_kind Algorithm kind. Valid values are #dnnl_binary_add,
1098/// #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, #dnnl_binary_div,
1099/// #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, #dnnl_binary_le,
1100/// #dnnl_binary_lt, #dnnl_binary_eq and #dnnl_binary_ne.
1101/// @param src0_desc Source 0 memory descriptor.
1102/// @param src1_desc Source 1 memory descriptor.
1103/// @param dst_desc Destination memory descriptor.
1104/// @param attr Primitive attributes (can be NULL).
1105/// @returns #dnnl_success on success and a status describing the error
1106/// otherwise.
1107dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create(
1108 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1109 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc,
1110 const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t dst_desc,
1111 const_dnnl_primitive_attr_t attr);
1112
1113/// @} dnnl_api_binary
1114
1115/// @addtogroup dnnl_api_convolution
1116/// @{
1117
1118/// Creates a primitive descriptor for a convolution forward propagation
1119/// primitive.
1120///
1121/// @note
1122/// Memory descriptors can be initialized with
1123/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1124///
1125/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1126/// values for spatial dimensions only and hence must have the same number of
1127/// elements as there are spatial dimensions. The order of values is the same
1128/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1129/// and width.
1130///
1131/// @param primitive_desc Output primitive descriptor.
1132/// @param engine Engine to use.
1133/// @param prop_kind Propagation kind. Possible values are
1134/// #dnnl_forward_training and #dnnl_forward_inference.
1135/// @param alg_kind Convolution algorithm. Possible values are
1136/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1137/// #dnnl_convolution_auto.
1138/// @param src_desc Source memory descriptor.
1139/// @param weights_desc Weights memory descriptor.
1140/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1141/// descriptor, or a memory descriptor with format_kind set to
1142/// #dnnl_format_kind_undef disables the bias term.
1143/// @param dst_desc Destination memory descriptor.
1144/// @param strides Array of strides for spatial dimension.
1145/// @param dilates Array of dilations for spatial dimension. A zero value
1146/// means no dilation in the corresponding dimension.
1147/// @param padding_l Array of padding values for low indices for each spatial
1148/// dimension `([[front,] top,] left)`.
1149/// @param padding_r Array of padding values for high indices for each spatial
1150/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1151/// padding is considered to be symmetrical.
1152/// @param attr Primitive attributes (can be NULL).
1153/// @returns #dnnl_success on success and a status describing the error
1154/// otherwise.
1155dnnl_status_t DNNL_API dnnl_convolution_forward_primitive_desc_create(
1156 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1157 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
1158 const_dnnl_memory_desc_t src_desc,
1159 const_dnnl_memory_desc_t weights_desc,
1160 const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
1161 const dnnl_dims_t strides, const dnnl_dims_t dilates,
1162 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
1163 const_dnnl_primitive_attr_t attr);
1164
1165/// Creates a primitive descriptor for a convolution backward propagation
1166/// primitive.
1167///
1168/// @note
1169/// Memory descriptors can be initialized with
1170/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1171///
1172/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1173/// values for spatial dimensions only and hence must have the same number of
1174/// elements as there are spatial dimensions. The order of values is the same
1175/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1176/// and width.
1177///
1178/// @param primitive_desc Output primitive descriptor.
1179/// @param engine Engine to use.
1180/// @param alg_kind Convolution algorithm. Possible values are
1181/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1182/// #dnnl_convolution_auto.
1183/// @param diff_src_desc Diff source memory descriptor.
1184/// @param weights_desc Weights memory descriptor.
1185/// @param diff_dst_desc Diff destination memory descriptor.
1186/// @param strides Array of strides for spatial dimension.
1187/// @param dilates Array of dilations for spatial dimension. A zero value
1188/// means no dilation in the corresponding dimension.
1189/// @param padding_l Array of padding values for low indices for each spatial
1190/// dimension `([[front,] top,] left)`.
1191/// @param padding_r Array of padding values for high indices for each spatial
1192/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1193/// padding is considered to be symmetrical.
1194/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1195/// primitive.
1196/// @param attr Primitive attributes (can be NULL).
1197/// @returns #dnnl_success on success and a status describing the error
1198/// otherwise.
1199dnnl_status_t DNNL_API dnnl_convolution_backward_data_primitive_desc_create(
1200 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1201 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
1202 const_dnnl_memory_desc_t weights_desc,
1203 const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
1204 const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1205 const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
1206 const_dnnl_primitive_attr_t attr);
1207
1208/// Creates a primitive descriptor for a convolution weights gradient primitive.
1209///
1210/// @note
1211/// Memory descriptors can be initialized with
1212/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1213///
1214/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1215/// values for spatial dimensions only and hence must have the same number of
1216/// elements as there are spatial dimensions. The order of values is the same
1217/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1218/// and width.
1219///
1220/// @param primitive_desc Output primitive descriptor.
1221/// @param engine Engine to use.
1222/// @param alg_kind Convolution algorithm. Possible values are
1223/// #dnnl_convolution_direct, #dnnl_convolution_winograd,
1224/// #dnnl_convolution_auto.
1225/// @param src_desc Source memory descriptor.
1226/// @param diff_weights_desc Diff weights memory descriptor.
1227/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1228/// memory descriptor, or a memory descriptor with format_kind set to
1229/// #dnnl_format_kind_undef disables the bias term.
1230/// @param diff_dst_desc Diff destination memory descriptor.
1231/// @param strides Array of strides for spatial dimension.
1232/// @param dilates Array of dilations for spatial dimension. A zero value
1233/// means no dilation in the corresponding dimension.
1234/// @param padding_l Array of padding values for low indices for each spatial
1235/// dimension `([[front,] top,] left)`.
1236/// @param padding_r Array of padding values for high indices for each spatial
1237/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1238/// padding is considered to be symmetrical.
1239/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1240/// primitive.
1241/// @param attr Primitive attributes (can be NULL).
1242/// @returns #dnnl_success on success and a status describing the error
1243/// otherwise.
1244dnnl_status_t DNNL_API dnnl_convolution_backward_weights_primitive_desc_create(
1245 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1246 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
1247 const_dnnl_memory_desc_t diff_weights_desc,
1248 const_dnnl_memory_desc_t diff_bias_desc,
1249 const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
1250 const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1251 const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
1252 const_dnnl_primitive_attr_t attr);
1253
1254/// @} dnnl_api_convolution
1255
1256/// @addtogroup dnnl_api_deconvolution
1257/// @{
1258
1259/// Creates a primitive descriptor for a deconvolution forward propagation
1260/// primitive.
1261///
1262/// @note
1263/// Memory descriptors can be initialized with
1264/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1265///
1266/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1267/// values for spatial dimensions only and hence must have the same number of
1268/// elements as there are spatial dimensions. The order of values is the same
1269/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1270/// and width.
1271///
1272/// @param primitive_desc Output primitive descriptor.
1273/// @param engine Engine to use.
1274/// @param prop_kind Propagation kind. Possible values are
1275/// #dnnl_forward_training and #dnnl_forward_inference.
1276/// @param alg_kind Deconvolution algorithm. Possible values are
1277/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1278/// @param src_desc Source memory descriptor.
1279/// @param weights_desc Weights memory descriptor.
1280/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1281/// descriptor, or a memory descriptor with format_kind set to
1282/// #dnnl_format_kind_undef disables the bias term.
1283/// @param dst_desc Destination memory descriptor.
1284/// @param strides Array of strides for spatial dimension.
1285/// @param dilates Array of dilations for spatial dimension. A zero value
1286/// means no dilation in the corresponding dimension.
1287/// @param padding_l Array of padding values for low indices for each spatial
1288/// dimension `([[front,] top,] left)`.
1289/// @param padding_r Array of padding values for high indices for each spatial
1290/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1291/// padding is considered to be symmetrical.
1292/// @param attr Primitive attributes (can be NULL).
1293/// @returns #dnnl_success on success and a status describing the error
1294/// otherwise.
1295dnnl_status_t DNNL_API dnnl_deconvolution_forward_primitive_desc_create(
1296 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1297 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
1298 const_dnnl_memory_desc_t src_desc,
1299 const_dnnl_memory_desc_t weights_desc,
1300 const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
1301 const dnnl_dims_t strides, const dnnl_dims_t dilates,
1302 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
1303 const_dnnl_primitive_attr_t attr);
1304
1305/// Creates a primitive descriptor for a deconvolution backward propagation
1306/// primitive.
1307///
1308/// @note
1309/// Memory descriptors can be initialized with
1310/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1311///
1312/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1313/// values for spatial dimensions only and hence must have the same number of
1314/// elements as there are spatial dimensions. The order of values is the same
1315/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1316/// and width.
1317///
1318/// @param primitive_desc Output primitive descriptor.
1319/// @param engine Engine to use.
1320/// @param alg_kind Deconvolution algorithm. Possible values are
1321/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1322/// @param diff_src_desc Diff source memory descriptor.
1323/// @param weights_desc Weights memory descriptor.
1324/// @param diff_dst_desc Diff destination memory descriptor.
1325/// @param strides Array of strides for spatial dimension.
1326/// @param dilates Array of dilations for spatial dimension. A zero value
1327/// means no dilation in the corresponding dimension.
1328/// @param padding_l Array of padding values for low indices for each spatial
1329/// dimension `([[front,] top,] left)`.
1330/// @param padding_r Array of padding values for high indices for each spatial
1331/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1332/// padding is considered to be symmetrical.
1333/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1334/// primitive.
1335/// @param attr Primitive attributes (can be NULL).
1336/// @returns #dnnl_success on success and a status describing the error
1337/// otherwise.
1338dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_primitive_desc_create(
1339 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1340 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
1341 const_dnnl_memory_desc_t weights_desc,
1342 const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
1343 const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1344 const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
1345 const_dnnl_primitive_attr_t attr);
1346
1347/// Creates a primitive descriptor for a deconvolution weights gradient
1348/// primitive.
1349///
1350/// @note
1351/// Memory descriptors can be initialized with
1352/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1353///
1354/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1355/// values for spatial dimensions only and hence must have the same number of
1356/// elements as there are spatial dimensions. The order of values is the same
1357/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1358/// and width.
1359///
1360/// @param primitive_desc Output primitive descriptor.
1361/// @param engine Engine to use.
1362/// @param alg_kind Deconvolution algorithm. Possible values are
1363/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1364/// @param src_desc Source memory descriptor.
1365/// @param diff_weights_desc Diff weights memory descriptor.
1366/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1367/// memory descriptor, or a memory descriptor with format_kind set to
1368/// #dnnl_format_kind_undef disables the bias term.
1369/// @param diff_dst_desc Diff destination memory descriptor.
1370/// @param strides Array of strides for spatial dimension.
1371/// @param dilates Array of dilations for spatial dimension. A zero value
1372/// means no dilation in the corresponding dimension.
1373/// @param padding_l Array of padding values for low indices for each spatial
1374/// dimension `([[front,] top,] left)`.
1375/// @param padding_r Array of padding values for high indices for each spatial
1376/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1377/// padding is considered to be symmetrical.
1378/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1379/// primitive.
1380/// @param attr Primitive attributes (can be NULL).
1381/// @returns #dnnl_success on success and a status describing the error
1382/// otherwise.
1383dnnl_status_t DNNL_API
1384dnnl_deconvolution_backward_weights_primitive_desc_create(
1385 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1386 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
1387 const_dnnl_memory_desc_t diff_weights_desc,
1388 const_dnnl_memory_desc_t diff_bias_desc,
1389 const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
1390 const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1391 const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
1392 const_dnnl_primitive_attr_t attr);
1393
1394/// @} dnnl_api_deconvolution
1395
1396/// @addtogroup dnnl_api_shuffle
1397/// @{
1398
1399/// Creates a primitive descriptor for a shuffle forward propagation primitive
1400///
1401/// @param primitive_desc Output primitive descriptor.
1402/// @param engine Engine to use.
1403/// @param prop_kind Propagation kind. Possible values are
1404/// #dnnl_forward_training and #dnnl_forward_inference.
1405/// @param src_desc Source memory descriptor.
1406/// @param dst_desc Destination memory descriptor.
1407/// @param axis The axis along which the data is shuffled.
1408/// @param group_size Shuffle group size.
1409/// @param attr Primitive attributes (can be NULL).
1410/// @returns #dnnl_success on success and a status describing the error
1411/// otherwise.
1412dnnl_status_t DNNL_API dnnl_shuffle_forward_primitive_desc_create(
1413 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1414 dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
1415 const_dnnl_memory_desc_t dst_desc, int axis, dnnl_dim_t group_size,
1416 const_dnnl_primitive_attr_t attr);
1417
1418/// Creates a primitive descriptor for a shuffle backward propagation primitive
1419///
1420/// @param primitive_desc Output primitive descriptor.
1421/// @param engine Engine to use.
1422/// @param diff_src_desc Diff source memory descriptor.
1423/// @param diff_dst_desc Diff destination memory descriptor.
1424/// @param axis The axis along which the data is shuffled.
1425/// @param group_size Shuffle group size.
1426/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1427/// primitive.
1428/// @param attr Primitive attributes (can be NULL).
1429/// @returns #dnnl_success on success and a status describing the error
1430/// otherwise.
1431dnnl_status_t DNNL_API dnnl_shuffle_backward_primitive_desc_create(
1432 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1433 const_dnnl_memory_desc_t diff_src_desc,
1434 const_dnnl_memory_desc_t diff_dst_desc, int axis, dnnl_dim_t group_size,
1435 const_dnnl_primitive_desc_t hint_fwd_pd,
1436 const_dnnl_primitive_attr_t attr);
1437
1438/// @} dnnl_api_shuffle
1439
1440/// @addtogroup dnnl_api_eltwise
1441/// @{
1442
1443/// Creates a primitive descriptor for an eltwise forward propagation primitive.
1444///
1445/// @param primitive_desc Output primitive descriptor.
1446/// @param engine Engine to use.
1447/// @param prop_kind Propagation kind. Possible values are
1448/// #dnnl_forward_training and #dnnl_forward_inference.
1449/// @param alg_kind Elementwise algorithm kind.
1450/// @param src_desc Source memory descriptor.
1451/// @param dst_desc Destination memory descriptor.
1452/// @param alpha The alpha parameter for the elementwise operation. Specific
1453/// meaning depends on the algorithm.
1454/// @param beta The beta parameter for the elementwise operation. Specific
1455/// meaning depends on the algorithm.
1456/// @param attr Primitive attributes (can be NULL).
1457/// @returns #dnnl_success on success and a status describing the error
1458/// otherwise.
1459dnnl_status_t DNNL_API dnnl_eltwise_forward_primitive_desc_create(
1460 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1461 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
1462 const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
1463 float alpha, float beta, const_dnnl_primitive_attr_t attr);
1464
1465/// Creates a primitive descriptor for an eltwise backward propagation
1466/// primitive.
1467///
1468/// @param primitive_desc Output primitive descriptor.
1469/// @param engine Engine to use.
1470/// @param alg_kind Elementwise algorithm kind.
1471/// @param diff_src_desc Diff source memory descriptor.
1472/// @param diff_dst_desc Diff destination memory descriptor.
1473/// @param data_desc Destination memory descriptor if one of the
1474/// "use_dst_for_bwd" algorithms are used (such as
1475/// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor otherwise.
1476/// @param alpha The alpha parameter for the elementwise operation. Specific
1477/// meaning depends on the algorithm.
1478/// @param beta The beta parameter for the elementwise operation. Specific
1479/// meaning depends on the algorithm.
1480/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1481/// primitive.
1482/// @param attr Primitive attributes (can be NULL).
1483/// @returns #dnnl_success on success and a status describing the error
1484/// otherwise.
1485dnnl_status_t DNNL_API dnnl_eltwise_backward_primitive_desc_create(
1486 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1487 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
1488 const_dnnl_memory_desc_t diff_dst_desc,
1489 const_dnnl_memory_desc_t data_desc, float alpha, float beta,
1490 const_dnnl_primitive_desc_t hint_fwd_pd,
1491 const_dnnl_primitive_attr_t attr);
1492
1493/// @} dnnl_api_eltwise
1494
1495/// @addtogroup dnnl_api_softmax
1496/// @{
1497
1498/// Creates a primitive descriptor for a softmax forward propagation primitive.
1499///
1500/// @param primitive_desc Output primitive descriptor.
1501/// @param engine Engine to use.
1502/// @param prop_kind Propagation kind. Possible values are
1503/// #dnnl_forward_training and #dnnl_forward_inference.
1504/// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
1505/// #dnnl_softmax_log.
1506/// @param src_desc Source memory descriptor.
1507/// @param dst_desc Destination memory descriptor.
1508/// @param softmax_axis Axis over which softmax is computed.
1509/// @param attr Primitive attributes (can be NULL).
1510/// @returns #dnnl_success on success and a status describing the error
1511/// otherwise.
1512dnnl_status_t DNNL_API dnnl_softmax_forward_primitive_desc_create(
1513 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1514 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
1515 const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
1516 int softmax_axis, const_dnnl_primitive_attr_t attr);
1517
1518/// Creates a primitive descriptor for a softmax backward propagation primitive.
1519///
1520/// @param primitive_desc Output primitive descriptor.
1521/// @param engine Engine to use.
1522/// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
1523/// #dnnl_softmax_log.
1524/// @param diff_src_desc Diff source memory descriptor.
1525/// @param diff_dst_desc Diff destination memory descriptor.
1526/// @param dst_desc Destination memory descriptor.
1527/// @param softmax_axis Axis over which softmax is computed.
1528/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1529/// primitive.
1530/// @param attr Primitive attributes (can be NULL).
1531/// @returns #dnnl_success on success and a status describing the error
1532/// otherwise.
1533dnnl_status_t DNNL_API dnnl_softmax_backward_primitive_desc_create(
1534 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1535 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
1536 const_dnnl_memory_desc_t diff_dst_desc,
1537 const_dnnl_memory_desc_t dst_desc, int softmax_axis,
1538 const_dnnl_primitive_desc_t hint_fwd_pd,
1539 const_dnnl_primitive_attr_t attr);
1540
1541/// @} dnnl_api_softmax
1542
1543/// @addtogroup dnnl_api_pooling
1544/// @{
1545
1546/// Creates a primitive descriptor for a pooling forward propagation
1547/// primitive.
1548///
1549/// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
1550/// contain values for spatial dimensions only and hence must have the same
1551/// number of elements as there are spatial dimensions. The order of values
1552/// is the same as in the tensor: depth (for 3D tensors),
1553/// height (for 3D and 2D tensors), and width.
1554///
1555/// @param primitive_desc Output primitive descriptor.
1556/// @param engine Engine to use.
1557/// @param prop_kind Propagation kind. Possible values are
1558/// #dnnl_forward_training and #dnnl_forward_inference.
1559/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
1560/// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding.
1561/// @param src_desc Source memory descriptor.
1562/// @param dst_desc Destination memory descriptor.
1563/// @param strides Array of strides for spatial dimension.
1564/// @param kernel Array of kernel spatial dimensions.
1565/// @param dilation Array of dilations for spatial dimension.
1566/// @param padding_l Array of padding values for low indices for each spatial
1567/// dimension `([[front,] top,] left)`.
1568/// @param padding_r Array of padding values for high indices for each spatial
1569/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1570/// padding is considered to be symmetrical.
1571/// @param attr Primitive attributes (can be NULL).
1572/// @returns #dnnl_success on success and a status describing the error
1573/// otherwise.
1574dnnl_status_t DNNL_API dnnl_pooling_forward_primitive_desc_create(
1575 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1576 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
1577 const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
1578 const dnnl_dims_t strides, const dnnl_dims_t kernel,
1579 const dnnl_dims_t dilation, const dnnl_dims_t padding_l,
1580 const dnnl_dims_t padding_r, const_dnnl_primitive_attr_t attr);
1581
1582/// Creates a primitive descriptor for a pooling backward propagation
1583/// primitive.
1584///
1585/// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
1586/// contain values for spatial dimensions only and hence must have the same
1587/// number of elements as there are spatial dimensions. The order of values
1588/// is the same as in the tensor: depth (for 3D tensors),
1589/// height (for 3D and 2D tensors), and width.
1590///
1591/// @param primitive_desc Output primitive descriptor.
1592/// @param engine Engine to use.
1593/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
1594/// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding.
1595/// @param diff_src_desc Diff source memory descriptor.
1596/// @param diff_dst_desc Diff destination memory descriptor.
1597/// @param strides Array of strides for spatial dimension.
1598/// @param kernel Array of kernel spatial dimensions.
1599/// @param dilation Array of dilations for spatial dimension.
1600/// @param padding_l Array of padding values for low indices for each spatial
1601/// dimension `([[front,] top,] left)`.
1602/// @param padding_r Array of padding values for high indices for each spatial
1603/// dimension `([[back,] bottom,] right)`. Can be NULL in which case
1604/// padding is considered to be symmetrical.
1605/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1606/// primitive.
1607/// @param attr Primitive attributes (can be NULL).
1608/// @returns #dnnl_success on success and a status describing the error
1609/// otherwise.
1610dnnl_status_t DNNL_API dnnl_pooling_backward_primitive_desc_create(
1611 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1612 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
1613 const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
1614 const dnnl_dims_t kernel, const dnnl_dims_t dilation,
1615 const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
1616 const_dnnl_primitive_desc_t hint_fwd_pd,
1617 const_dnnl_primitive_attr_t attr);
1618
1619/// @} dnnl_api_pooling
1620
1621/// @addtogroup dnnl_api_prelu
1622/// @{
1623
1624/// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
1625/// alpha parameter) forward propagation primitive.
1626///
1627/// @note
1628/// weights descriptor is allowed to be initialized with
1629/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1630///
1631/// @param primitive_desc Output primitive descriptor.
1632/// @param engine Engine to use.
1633/// @param prop_kind Propagation kind. Possible values are
1634/// #dnnl_forward_training and #dnnl_forward_inference.
1635/// @param src_desc Source memory descriptor.
1636/// @param weights_desc Alpha parameters memory descriptor.
1637/// @param dst_desc Destination memory descriptor.
1638/// @param attr Primitive attributes (can be NULL).
1639/// @returns #dnnl_success on success and a status describing the error
1640/// otherwise.
1641dnnl_status_t DNNL_API dnnl_prelu_forward_primitive_desc_create(
1642 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1643 dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
1644 const_dnnl_memory_desc_t weights_desc,
1645 const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
1646
1647/// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
1648/// alpha parameter) backward propagation primitive.
1649///
1650/// @note
1651/// weights descriptor and diff_weights descriptor are allowed
1652/// to be initialized with #dnnl_format_tag_any or with format_kind
1653/// set to #dnnl_format_kind_any.
1654///
1655/// @param primitive_desc Output primitive descriptor.
1656/// @param engine Engine to use.
1657/// @param src_desc Source memory descriptor.
1658/// @param weights_desc Alpha parameters memory descriptor.
1659/// @param diff_src_desc Diff source memory descriptor.
1660/// @param diff_weights_desc Diff alpha parameters memory descriptor.
1661/// @param diff_dst_desc Diff destination memory descriptor.
1662/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1663/// primitive.
1664/// @param attr Primitive attributes (can be NULL).
1665/// @returns #dnnl_success on success and a status describing the error
1666/// otherwise.
1667dnnl_status_t DNNL_API dnnl_prelu_backward_primitive_desc_create(
1668 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1669 const_dnnl_memory_desc_t src_desc,
1670 const_dnnl_memory_desc_t weights_desc,
1671 const_dnnl_memory_desc_t diff_src_desc,
1672 const_dnnl_memory_desc_t diff_weights_desc,
1673 const_dnnl_memory_desc_t diff_dst_desc,
1674 const_dnnl_primitive_desc_t hint_fwd_pd,
1675 const_dnnl_primitive_attr_t attr);
1676
1677/// @} dnnl_api_prelu
1678
1679/// @addtogroup dnnl_api_lrn
1680/// @{
1681
1682/// Creates a primitive descriptor for an LRN forward propagation primitive.
1683///
1684/// @param primitive_desc Output primitive_descriptor.
1685/// @param engine Engine to use.
1686/// @param prop_kind Propagation kind. Possible values are
1687/// #dnnl_forward_training and #dnnl_forward_inference.
1688/// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
1689/// #dnnl_lrn_within_channel.
1690/// @param src_desc Source memory descriptor.
1691/// @param dst_desc Destination memory descriptor.
1692/// @param local_size Regularization local size.
1693/// @param alpha The alpha regularization parameter.
1694/// @param beta The beta regularization parameter.
1695/// @param k The k regularization parameter.
1696/// @param attr Primitive attributes (can be NULL).
1697/// @returns #dnnl_success on success and a status describing the error
1698/// otherwise.
1699dnnl_status_t DNNL_API dnnl_lrn_forward_primitive_desc_create(
1700 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1701 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
1702 const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
1703 dnnl_dim_t local_size, float alpha, float beta, float k,
1704 const_dnnl_primitive_attr_t attr);
1705
1706/// Creates a primitive descriptor for an LRN backward propagation primitive.
1707///
1708/// @param primitive_desc Output primitive_descriptor.
1709/// @param engine Engine to use.
1710/// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
1711/// #dnnl_lrn_within_channel.
1712/// @param diff_src_desc Diff source memory descriptor.
1713/// @param diff_dst_desc Diff destination memory descriptor.
1714/// @param src_desc Source memory descriptor.
1715/// @param local_size Regularization local size.
1716/// @param alpha The alpha regularization parameter.
1717/// @param beta The beta regularization parameter.
1718/// @param k The k regularization parameter.
1719/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1720/// primitive.
1721/// @param attr Primitive attributes (can be NULL).
1722/// @returns #dnnl_success on success and a status describing the error
1723/// otherwise.
1724dnnl_status_t DNNL_API dnnl_lrn_backward_primitive_desc_create(
1725 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1726 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
1727 const_dnnl_memory_desc_t diff_dst_desc,
1728 const_dnnl_memory_desc_t src_desc, dnnl_dim_t local_size, float alpha,
1729 float beta, float k, const_dnnl_primitive_desc_t hint_fwd_pd,
1730 const_dnnl_primitive_attr_t attr);
1731
1732/// @} dnnl_api_lrn
1733
1734/// @addtogroup dnnl_api_batch_normalization
1735/// @{
1736
1737/// Creates a primitive descriptor for a batch normalization forward propagation
1738/// primitive.
1739///
1740/// @note
1741/// In-place operation is supported: the dst can refer to the same memory
1742/// as the src.
1743///
1744/// @param primitive_desc Output primitive_descriptor.
1745/// @param engine Engine to use.
1746/// @param prop_kind Propagation kind. Possible values are
1747/// #dnnl_forward_training and #dnnl_forward_inference.
1748/// @param src_desc Source memory descriptor.
1749/// @param dst_desc Destination memory descriptor.
1750/// @param epsilon Batch normalization epsilon parameter.
1751/// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
1752/// @param attr Primitive attributes (can be NULL).
1753/// @returns #dnnl_success on success and a status describing the error
1754/// otherwise.
1755dnnl_status_t DNNL_API dnnl_batch_normalization_forward_primitive_desc_create(
1756 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1757 dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
1758 const_dnnl_memory_desc_t dst_desc, float epsilon, unsigned flags,
1759 const_dnnl_primitive_attr_t attr);
1760
1761/// Creates a primitive descriptor for a batch normalization backward
1762/// propagation primitive.
1763///
1764/// @note
1765/// In-place operation is supported: the diff_dst can refer to the same
1766/// memory as the diff_src.
1767///
1768/// @param primitive_desc Output primitive_descriptor.
1769/// @param engine Engine to use.
1770/// @param prop_kind Propagation kind. Possible values are
1771/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
1772/// computed in this case).
1773/// @param diff_src_desc Diff source memory descriptor.
1774/// @param diff_dst_desc Diff destination memory descriptor.
1775/// @param src_desc Source memory descriptor.
1776/// @param epsilon Batch normalization epsilon parameter.
1777/// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
1778/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1779/// primitive.
1780/// @param attr Primitive attributes (can be NULL).
1781/// @returns #dnnl_success on success and a status describing the error
1782/// otherwise.
1783dnnl_status_t DNNL_API dnnl_batch_normalization_backward_primitive_desc_create(
1784 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1785 dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
1786 const_dnnl_memory_desc_t diff_dst_desc,
1787 const_dnnl_memory_desc_t src_desc, float epsilon, unsigned flags,
1788 const_dnnl_primitive_desc_t hint_fwd_pd,
1789 const_dnnl_primitive_attr_t attr);
1790
1791/// @} dnnl_api_batch_normalization
1792
1793/// @addtogroup dnnl_api_layer_normalization
1794/// @{
1795
1796/// Creates a primitive descriptor for a layer normalization forward propagation
1797/// primitive.
1798///
1799/// @note
1800/// In-place operation is supported: the dst can refer to the same memory
1801/// as the src.
1802///
1803/// @param primitive_desc Output primitive_descriptor.
1804/// @param engine Engine to use.
1805/// @param prop_kind Propagation kind. Possible values are
1806/// #dnnl_forward_training and #dnnl_forward_inference.
1807/// @param src_desc Source memory descriptor.
1808/// @param dst_desc Destination memory descriptor.
1809/// @param stat_desc Memory descriptor for mean and variance. If this
1810/// parameter is NULL, a zero memory descriptor, or a memory descriptor
1811/// with format_kind set to #dnnl_format_kind_undef, then the memory
1812/// descriptor for stats is derived from @p src_desc by removing the last
1813/// dimension.
1814/// @param epsilon Layer normalization epsilon parameter.
1815/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
1816/// @param attr Primitive attributes (can be NULL).
1817/// @returns #dnnl_success on success and a status describing the error
1818/// otherwise.
1819dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create(
1820 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1821 dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
1822 const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc,
1823 float epsilon, unsigned flags, const_dnnl_primitive_attr_t attr);
1824
1825/// Creates a primitive descriptor for a layer normalization backward
1826/// propagation primitive.
1827///
1828/// @note
1829/// In-place operation is supported: the diff_dst can refer to the same
1830/// memory as the diff_src.
1831///
1832/// @param primitive_desc Output primitive_descriptor.
1833/// @param engine Engine to use.
1834/// @param prop_kind Propagation kind. Possible values are
1835/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
1836/// computed in this case).
1837/// @param diff_src_desc Diff source memory descriptor.
1838/// @param diff_dst_desc Diff destination memory descriptor.
1839/// @param src_desc Source memory descriptor.
1840/// @param stat_desc Memory descriptor for mean and variance. If this
1841/// parameter is NULL, a zero memory descriptor, or a memory descriptor
1842/// with format_kind set to #dnnl_format_kind_undef, then the memory
1843/// descriptor for stats is derived from @p src_desc by removing the last
1844/// dimension.
1845/// @param epsilon Layer normalization epsilon parameter.
1846/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
1847/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1848/// primitive.
1849/// @param attr Primitive attributes (can be NULL).
1850/// @returns #dnnl_success on success and a status describing the error
1851/// otherwise.
1852dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create(
1853 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1854 dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
1855 const_dnnl_memory_desc_t diff_dst_desc,
1856 const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc,
1857 float epsilon, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
1858 const_dnnl_primitive_attr_t attr);
1859
1860/// @} dnnl_api_layer_normalization
1861
1862/// @addtogroup dnnl_api_inner_product
1863/// @{
1864
1865/// Creates a primitive descriptor for an inner product forward propagation
1866/// primitive.
1867///
1868/// @note
1869/// Memory descriptors can be initialized with
1870/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1871///
1872/// @param primitive_desc Output primitive_descriptor.
1873/// @param engine Engine to use.
1874/// @param prop_kind Propagation kind. Possible values are
1875/// #dnnl_forward_training and #dnnl_forward_inference.
1876/// @param src_desc Source memory descriptor.
1877/// @param weights_desc Weights memory descriptor.
1878/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1879/// descriptor, or a memory descriptor with format_kind set to
1880/// #dnnl_format_kind_undef disables the bias term.
1881/// @param dst_desc Destination memory descriptor.
1882/// @param attr Primitive attributes (can be NULL).
1883/// @returns #dnnl_success on success and a status describing the error
1884/// otherwise.
1885dnnl_status_t DNNL_API dnnl_inner_product_forward_primitive_desc_create(
1886 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1887 dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
1888 const_dnnl_memory_desc_t weights_desc,
1889 const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
1890 const_dnnl_primitive_attr_t attr);
1891
1892/// Creates a primitive descriptor for an inner product backward propagation
1893/// primitive.
1894///
1895/// @note
1896/// Memory descriptors can be initialized with
1897/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1898///
1899/// @param primitive_desc Output primitive_descriptor.
1900/// @param engine Engine to use.
1901/// @param diff_src_desc Diff source memory descriptor.
1902/// @param weights_desc Weights memory descriptor.
1903/// @param diff_dst_desc Diff destination memory descriptor.
1904/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1905/// primitive.
1906/// @param attr Primitive attributes (can be NULL).
1907/// @returns #dnnl_success on success and a status describing the error
1908/// otherwise.
1909dnnl_status_t DNNL_API dnnl_inner_product_backward_data_primitive_desc_create(
1910 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1911 const_dnnl_memory_desc_t diff_src_desc,
1912 const_dnnl_memory_desc_t weights_desc,
1913 const_dnnl_memory_desc_t diff_dst_desc,
1914 const_dnnl_primitive_desc_t hint_fwd_pd,
1915 const_dnnl_primitive_attr_t attr);
1916
1917/// Creates a primitive descriptor for an inner product weights gradient
1918/// primitive.
1919///
1920/// @note
1921/// Memory descriptors can be initialized with
1922/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1923///
1924/// @param primitive_desc Output primitive_descriptor.
1925/// @param engine Engine to use.
1926/// @param src_desc Source memory descriptor.
1927/// @param diff_weights_desc Diff weights memory descriptor.
1928/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1929/// memory descriptor, or a memory descriptor with format_kind set to
1930/// #dnnl_format_kind_undef disables the bias term.
1931/// @param diff_dst_desc Diff destination memory descriptor.
1932/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
1933/// primitive.
1934/// @param attr Primitive attributes (can be NULL).
1935/// @returns #dnnl_success on success and a status describing the error
1936/// otherwise.
1937dnnl_status_t DNNL_API
1938dnnl_inner_product_backward_weights_primitive_desc_create(
1939 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
1940 const_dnnl_memory_desc_t src_desc,
1941 const_dnnl_memory_desc_t diff_weights_desc,
1942 const_dnnl_memory_desc_t diff_bias_desc,
1943 const_dnnl_memory_desc_t diff_dst_desc,
1944 const_dnnl_primitive_desc_t hint_fwd_pd,
1945 const_dnnl_primitive_attr_t attr);
1946
1947/// @} dnnl_api_inner_product
1948
1949/// @addtogroup dnnl_api_attributes
1950/// @{
1951
1952/// Set quantization scale and shift parameters for RNN data tensors.
1953///
1954/// For performance reasons, the low-precision configuration of the RNN
1955/// primitives expects input activations to have the unsigned 8-bit integer
1956/// data type. The scale and shift parameters are used to quantize
1957/// floating-point data to unsigned integer and must be passed to the RNN
1958/// primitive using attributes.
1959///
1960/// The quantization formula is `scale * data + shift`.
1961///
1962/// @note
1963/// Quantization scale and shift are common for src_layer, src_iter,
1964/// dst_iter, and dst_layer.
1965///
1966/// Example usage:
1967/// @code
1968/// // RNN parameters
1969/// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
1970/// // Activations quantization parameters
1971/// float scale = 63.f, shift = 64.f;
1972///
1973/// dnnl_primitive_attr_t rnn_attr;
1974/// // Create default attributes
1975/// dnnl_primitive_attr_create(&rnn_attr);
1976///
1977/// // Set scale and shift for int8 quantization of activation
1978/// dnnl_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
1979///
1980/// // Create an RNN primitive descriptor.
1981/// dnnl_primitive_desc_t rnn_pd;
1982/// dnnl_vanilla_rnn_forward_primitive_desc_create(&rnn_pd,
1983/// engine, /* arguments */, attr);
1984/// @endcode
1985///
1986/// @param attr Primitive attributes.
1987/// @param scale The value to scale the data by.
1988/// @param shift The value to shift the data by.
1989/// @returns #dnnl_success on success and a status describing the error
1990/// otherwise.
1991dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(
1992 dnnl_primitive_attr_t attr, const float scale, const float shift);
1993
1994/// Returns the quantization scale and shift parameters for RNN data tensors.
1995///
1996/// @note
1997/// Quantization scale and shift are common for src_layer, src_iter,
1998/// dst_iter, and dst_layer.
1999///
2000/// @param attr Primitive attributes.
2001/// @param scale The value to scale the data by.
2002/// @param shift The value to shift the data by.
2003/// @returns #dnnl_success on success and a status describing the error
2004/// otherwise.
2005dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(
2006 const_dnnl_primitive_attr_t attr, float *scale, float *shift);
2007
2008/// Sets quantization scaling factors for RNN weights tensors. The
2009/// low-precision configuration of the RNN primitives expects input weights to
2010/// use the signed 8-bit integer data type. The scaling factors are used to
2011/// quantize floating-point data to signed integer and must be passed to RNN
2012/// primitives using attributes.
2013///
2014/// @note
2015/// The dimension order is always native and does not depend on the actual
2016/// layout used. For example, five-dimensional weights always have (l, d,
2017/// i, g, o) logical dimension ordering.
2018///
2019/// @note
2020/// Quantization scales are common for weights_layer and weights_iteration
2021///
2022/// @param attr Primitive attributes.
2023/// @param count Number of elements in the @p scales array.
2024/// @param mask Scaling factors correspondence mask that defines the
2025/// correspondence between the output tensor dimensions and the @p
2026/// scales vector. The set i-th bit indicates that a dedicated scaling
2027/// factor should be used for each index along that dimension. Set the
2028/// mask to 0 to use a common scaling factor for the whole output
2029/// tensor.
2030/// @param scales Array of output scaling factors that must contain @p count
2031/// values and the following equality must hold:
2032/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2033/// Violations can only be detected when the attributes are used to create
2034/// a primitive descriptor.
2035/// @returns #dnnl_success on success and a status describing the error
2036/// otherwise.
2037dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(
2038 dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
2039 const float *scales);
2040
2041/// Returns the quantization scaling factors for RNN weights tensors.
2042///
2043/// @param attr Primitive attributes.
2044/// @param count Number of elements in the @p scales array.
2045/// @param mask Scaling factors correspondence mask that defines the
2046/// correspondence between the output tensor dimensions and the @p
2047/// scales vector. The set i-th bit indicates that a dedicated scaling
2048/// factor should be used for each index along that dimension. Set the
2049/// mask to 0 to use a common scaling factor for the whole output
2050/// tensor.
2051/// @param scales Array of output scaling factors that contain @p count
2052/// values and the following equality must hold:
2053/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2054/// @returns #dnnl_success on success and a status describing the error
2055/// otherwise.
2056dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(
2057 const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
2058 const float **scales);
2059
2060/// Sets quantization scaling factors for RNN projection weights tensors. The
2061/// low-precision configuration of the RNN primitives expects input weights to
2062/// use the signed 8-bit integer data type. The scaling factors are used to
2063/// quantize floating-point data to signed integer and must be passed to RNN
2064/// primitives using attributes.
2065///
2066/// @note
2067/// The dimension order is always native and does not depend on the actual
2068/// layout used. For example, five-dimensional weights always have (l, d,
2069/// i, g, o) logical dimension ordering.
2070///
2071/// @param attr Primitive attributes.
2072/// @param count Number of elements in the @p scales array.
2073/// @param mask Scaling factors correspondence mask that defines the
2074/// correspondence between the output tensor dimensions and the @p
2075/// scales vector. The set i-th bit indicates that a dedicated scaling
2076/// factor should be used for each index along that dimension. Set the
2077/// mask to 0 to use a common scaling factor for the whole output
2078/// tensor.
2079/// @param scales Array of output scaling factors that must contain @p count
2080/// values and the following equality must hold:
2081/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2082/// Violations can only be detected when the attributes are used to create
2083/// a primitive descriptor.
2084/// @returns #dnnl_success on success and a status describing the error
2085/// otherwise.
2086dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(
2087 dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
2088 const float *scales);
2089
2090/// Returns the quantization scaling factors for RNN projection weights tensors.
2091///
2092/// @param attr Primitive attributes.
2093/// @param count Number of elements in the @p scales array.
2094/// @param mask Scaling factors correspondence mask that defines the
2095/// correspondence between the output tensor dimensions and the @p
2096/// scales vector. The set i-th bit indicates that a dedicated scaling
2097/// factor should be used for each index along that dimension. Set the
2098/// mask to 0 to use a common scaling factor for the whole output
2099/// tensor.
2100/// @param scales Array of output scaling factors that contain @p count
2101/// values and the following equality must hold:
2102/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2103/// @returns #dnnl_success on success and a status describing the error
2104/// otherwise.
2105dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
2106 const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
2107 const float **scales);
2108
2109/// @} dnnl_api_attributes
2110
2111/// @addtogroup dnnl_api_rnn
2112/// @{
2113
2114/// Creates a primitive descriptor for vanilla RNN forward propagation
2115/// primitive.
2116///
2117/// The following arguments may either be @c NULL or point to a zero memory
2118/// descriptor:
2119/// - @p src_iter_desc,
2120/// - @p bias_desc,
2121/// - @p dst_iter_desc.
2122///
2123/// This would then indicate that the RNN forward propagation primitive should
2124/// not use them and should default to zero values instead.
2125///
2126/// @note
2127/// All memory descriptors can be initialized with
2128/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2129///
2130/// @param primitive_desc Output primitive descriptor.
2131/// @param engine Engine to use.
2132/// @param prop_kind Propagation kind. Possible values are
2133/// #dnnl_forward_training and #dnnl_forward_inference.
2134/// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
2135/// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
2136/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2137/// info.
2138/// @param src_layer_desc Memory descriptor for the input vector.
2139/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2140/// state vector.
2141/// @param weights_layer_desc Memory descriptor for the weights applied to the
2142/// layer input.
2143/// @param weights_iter_desc Memory descriptor for the weights applied to the
2144/// recurrent input.
2145/// @param bias_desc Bias memory descriptor.
2146/// @param dst_layer_desc Memory descriptor for the output vector.
2147/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2148/// state vector.
2149/// @param flags Unused.
2150/// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
2151/// @param beta Unused.
2152/// @param attr Primitive attributes (can be NULL).
2153/// @returns #dnnl_success on success and a status describing the error
2154/// otherwise.
2155dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create(
2156 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2157 dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
2158 const dnnl_rnn_direction_t direction,
2159 const_dnnl_memory_desc_t src_layer_desc,
2160 const_dnnl_memory_desc_t src_iter_desc,
2161 const_dnnl_memory_desc_t weights_layer_desc,
2162 const_dnnl_memory_desc_t weights_iter_desc,
2163 const_dnnl_memory_desc_t bias_desc,
2164 const_dnnl_memory_desc_t dst_layer_desc,
2165 const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, float alpha,
2166 float beta, const_dnnl_primitive_attr_t attr);
2167
2168/// Creates a primitive descriptor for vanilla RNN backward propagation
2169/// primitive.
2170///
2171/// The following arguments may either be @c NULL or point to a zero memory
2172/// descriptor:
2173/// - @p src_iter_desc together with @p diff_src_iter_desc,
2174/// - @p bias_desc together with @p diff_bias_desc,
2175/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
2176///
2177/// This would then indicate that the RNN backward propagation primitive should
2178/// not use the respective data and should use zero values instead.
2179///
2180/// @note
2181/// All memory descriptors can be initialized with
2182/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2183///
2184/// @param primitive_desc Output primitive descriptor.
2185/// @param engine Engine to use.
2186/// @param prop_kind Propagation kind. Must be #dnnl_backward.
2187/// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
2188/// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
2189/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2190/// info.
2191/// @param src_layer_desc Memory descriptor for the input vector.
2192/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2193/// state vector.
2194/// @param weights_layer_desc Memory descriptor for the weights applied to the
2195/// layer input.
2196/// @param weights_iter_desc Memory descriptor for the weights applied to the
2197/// recurrent input.
2198/// @param bias_desc Bias memory descriptor.
2199/// @param dst_layer_desc Memory descriptor for the output vector.
2200/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2201/// state vector.
2202/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2203/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2204/// hidden state vector.
2205/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2206/// applied to the layer input.
2207/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2208/// applied to the recurrent input.
2209/// @param diff_bias_desc Diff bias memory descriptor.
2210/// @param diff_dst_layer_desc Memory descriptor for the diff of output
2211/// vector.
2212/// @param diff_dst_iter_desc Memory descriptor for the diff of output
2213/// recurrent hidden state vector.
2214/// @param flags Unused.
2215/// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
2216/// @param beta Unused.
2217/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
2218/// primitive.
2219/// @param attr Primitive attributes (can be NULL).
2220/// @returns #dnnl_success on success and a status describing the error
2221/// otherwise.
2222dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
2223 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2224 dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
2225 const dnnl_rnn_direction_t direction,
2226 const_dnnl_memory_desc_t src_layer_desc,
2227 const_dnnl_memory_desc_t src_iter_desc,
2228 const_dnnl_memory_desc_t weights_layer_desc,
2229 const_dnnl_memory_desc_t weights_iter_desc,
2230 const_dnnl_memory_desc_t bias_desc,
2231 const_dnnl_memory_desc_t dst_layer_desc,
2232 const_dnnl_memory_desc_t dst_iter_desc,
2233 const_dnnl_memory_desc_t diff_src_layer_desc,
2234 const_dnnl_memory_desc_t diff_src_iter_desc,
2235 const_dnnl_memory_desc_t diff_weights_layer_desc,
2236 const_dnnl_memory_desc_t diff_weights_iter_desc,
2237 const_dnnl_memory_desc_t diff_bias_desc,
2238 const_dnnl_memory_desc_t diff_dst_layer_desc,
2239 const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
2240 float alpha, float beta, const_dnnl_primitive_desc_t hint_fwd_pd,
2241 const_dnnl_primitive_attr_t attr);
2242
2243/// Creates a primitive descriptor for an LSTM forward propagation primitive.
2244///
2245/// The following arguments may either be @c NULL or point to a zero memory
2246/// descriptor:
2247/// - @p src_iter_desc together with @p src_iter_c_desc,
2248/// - @p weights_peephole_desc,
2249/// - @p bias_desc,
2250/// - @p dst_iter_desc together with @p dst_iter_c_desc.
2251///
2252/// This would then indicate that the LSTM forward propagation primitive should
2253/// not use them and should default to zero values instead.
2254///
2255/// The @p weights_projection_desc could either be @c NULL or point to a zero
2256/// memory descriptor. This would then indicate that the LSTM doesn't have
2257/// recurrent projection layer.
2258///
2259/// @note
2260/// All memory descriptors can be initialized with #dnnl_format_tag_any or
2261/// with format_kind set to #dnnl_format_kind_any.
2262///
2263/// @param primitive_desc Output primitive descriptor.
2264/// @param engine Engine to use.
2265/// @param prop_kind Propagation kind. Possible values are
2266/// #dnnl_forward_training and #dnnl_forward_inference.
2267/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2268/// info.
2269/// @param src_layer_desc Memory descriptor for the input vector.
2270/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2271/// state vector.
2272/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2273/// state vector.
2274/// @param weights_layer_desc Memory descriptor for the weights applied to the
2275/// layer input.
2276/// @param weights_iter_desc Memory descriptor for the weights applied to the
2277/// recurrent input.
2278/// @param weights_peephole_desc Memory descriptor for the weights applied to
2279/// the cell states (according to the Peephole LSTM formula).
2280/// @param weights_projection_desc Memory descriptor for the weights applied to
2281/// the hidden states to get the recurrent projection (according to the
2282/// Projection LSTM formula).
2283/// @param bias_desc Bias memory descriptor.
2284/// @param dst_layer_desc Memory descriptor for the output vector.
2285/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2286/// state vector.
2287/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2288/// state vector.
2289/// @param flags Unused.
2290/// @param attr Primitive attributes (can be NULL).
2291/// @returns #dnnl_success on success and a status describing the error
2292/// otherwise.
2293dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
2294 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2295 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2296 const_dnnl_memory_desc_t src_layer_desc,
2297 const_dnnl_memory_desc_t src_iter_desc,
2298 const_dnnl_memory_desc_t src_iter_c_desc,
2299 const_dnnl_memory_desc_t weights_layer_desc,
2300 const_dnnl_memory_desc_t weights_iter_desc,
2301 const_dnnl_memory_desc_t weights_peephole_desc,
2302 const_dnnl_memory_desc_t weights_projection_desc,
2303 const_dnnl_memory_desc_t bias_desc,
2304 const_dnnl_memory_desc_t dst_layer_desc,
2305 const_dnnl_memory_desc_t dst_iter_desc,
2306 const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags,
2307 const_dnnl_primitive_attr_t attr);
2308
2309/// Creates a primitive descriptor for an LSTM backward propagation primitive.
2310///
2311/// The following arguments may either be @c NULL or point to a zero memory
2312/// descriptor:
2313/// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
2314/// and @p diff_src_iter_c_desc,
2315/// - @p weights_peephole_desc together with @p diff_weights_peephole_desc,
2316/// - @p bias_desc together with @p diff_bias_desc,
2317/// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
2318/// and @p diff_dst_iter_c_desc.
2319///
2320/// This would then indicate that the LSTM backward propagation primitive
2321/// should not use them and should default to zero values instead.
2322///
2323/// The @p weights_projection_desc together with @p
2324/// diff_weights_projection_desc could either be @c NULL or point to a zero
2325/// memory descriptor. This would then indicate that the LSTM doesn't have
2326/// recurrent projection layer.
2327///
2328/// @note
2329/// All memory descriptors can be initialized with #dnnl_format_tag_any or
2330/// with format_kind set to #dnnl_format_kind_any.
2331///
2332/// @param primitive_desc Output primitive descriptor.
2333/// @param engine Engine to use.
2334/// @param prop_kind Propagation kind. Must be #dnnl_backward.
2335/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2336/// info.
2337/// @param src_layer_desc Memory descriptor for the input vector.
2338/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2339/// state vector.
2340/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2341/// state vector.
2342/// @param weights_layer_desc Memory descriptor for the weights applied to the
2343/// layer input.
2344/// @param weights_iter_desc Memory descriptor for the weights applied to the
2345/// recurrent input.
2346/// @param weights_peephole_desc Memory descriptor for the weights applied to
2347/// the cell states (according to the Peephole LSTM formula).
2348/// @param weights_projection_desc Memory descriptor for the weights applied to
2349/// the hidden states to get the recurrent projection (according to the
2350/// Projection LSTM formula).
2351/// @param bias_desc Bias memory descriptor.
2352/// @param dst_layer_desc Memory descriptor for the output vector.
2353/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2354/// state vector.
2355/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2356/// state vector.
2357/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2358/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2359/// hidden state vector.
2360/// @param diff_src_iter_c_desc Memory descriptor for the diff of input
2361/// recurrent cell state vector.
2362/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2363/// applied to the layer input.
2364/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2365/// applied to the recurrent input.
2366/// @param diff_weights_peephole_desc Memory descriptor for the diff of weights
2367/// applied to the cell states (according to the Peephole LSTM formula).
2368/// @param diff_weights_projection_desc Memory descriptor for the diff of
2369/// weights applied to the hidden states to get the recurrent projection
2370/// (according to the Projection LSTM formula).
2371/// @param diff_bias_desc Diff bias memory descriptor.
2372/// @param diff_dst_layer_desc Memory descriptor for the diff of output
2373/// vector.
2374/// @param diff_dst_iter_desc Memory descriptor for the diff of output
2375/// recurrent hidden state vector.
2376/// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
2377/// recurrent cell state vector.
2378/// @param flags Unused.
2379/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
2380/// primitive.
2381/// @param attr Primitive attributes (can be NULL).
2382/// @returns #dnnl_success on success and a status describing the error
2383/// otherwise.
2384dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
2385 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2386 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2387 const_dnnl_memory_desc_t src_layer_desc,
2388 const_dnnl_memory_desc_t src_iter_desc,
2389 const_dnnl_memory_desc_t src_iter_c_desc,
2390 const_dnnl_memory_desc_t weights_layer_desc,
2391 const_dnnl_memory_desc_t weights_iter_desc,
2392 const_dnnl_memory_desc_t weights_peephole_desc,
2393 const_dnnl_memory_desc_t weights_projection_desc,
2394 const_dnnl_memory_desc_t bias_desc,
2395 const_dnnl_memory_desc_t dst_layer_desc,
2396 const_dnnl_memory_desc_t dst_iter_desc,
2397 const_dnnl_memory_desc_t dst_iter_c_desc,
2398 const_dnnl_memory_desc_t diff_src_layer_desc,
2399 const_dnnl_memory_desc_t diff_src_iter_desc,
2400 const_dnnl_memory_desc_t diff_src_iter_c_desc,
2401 const_dnnl_memory_desc_t diff_weights_layer_desc,
2402 const_dnnl_memory_desc_t diff_weights_iter_desc,
2403 const_dnnl_memory_desc_t diff_weights_peephole_desc,
2404 const_dnnl_memory_desc_t diff_weights_projection_desc,
2405 const_dnnl_memory_desc_t diff_bias_desc,
2406 const_dnnl_memory_desc_t diff_dst_layer_desc,
2407 const_dnnl_memory_desc_t diff_dst_iter_desc,
2408 const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags,
2409 const_dnnl_primitive_desc_t hint_fwd_pd,
2410 const_dnnl_primitive_attr_t attr);
2411
2412/// Creates a primitive descriptor for GRU forward propagation primitive.
2413///
2414/// The following arguments may either be @c NULL or point to a zero memory
2415/// descriptor:
2416/// - @p src_iter_desc,
2417/// - @p bias_desc,
2418/// - @p dst_iter_desc.
2419///
2420/// This would then indicate that the GRU forward propagation primitive should
2421/// not use them and should default to zero values instead.
2422///
2423/// @note
2424/// All memory descriptors can be initialized with
2425/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2426///
2427/// @param primitive_desc Output primitive descriptor.
2428/// @param engine Engine to use.
2429/// @param prop_kind Propagation kind. Possible values are
2430/// #dnnl_forward_training and #dnnl_forward_inference.
2431/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2432/// info.
2433/// @param src_layer_desc Memory descriptor for the input vector.
2434/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2435/// state vector.
2436/// @param weights_layer_desc Memory descriptor for the weights applied to the
2437/// layer input.
2438/// @param weights_iter_desc Memory descriptor for the weights applied to the
2439/// recurrent input.
2440/// @param bias_desc Bias memory descriptor.
2441/// @param dst_layer_desc Memory descriptor for the output vector.
2442/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2443/// state vector.
2444/// @param flags Unused.
2445/// @param attr Primitive attributes (can be NULL).
2446/// @returns #dnnl_success on success and a status describing the error
2447/// otherwise.
2448dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
2449 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2450 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2451 const_dnnl_memory_desc_t src_layer_desc,
2452 const_dnnl_memory_desc_t src_iter_desc,
2453 const_dnnl_memory_desc_t weights_layer_desc,
2454 const_dnnl_memory_desc_t weights_iter_desc,
2455 const_dnnl_memory_desc_t bias_desc,
2456 const_dnnl_memory_desc_t dst_layer_desc,
2457 const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
2458 const_dnnl_primitive_attr_t attr);
2459
2460/// Creates a primitive descriptor for GRU backward propagation primitive.
2461///
2462/// The following arguments may either be @c NULL or point to a zero memory
2463/// descriptor:
2464/// - @p src_iter_desc together with @p diff_src_iter_desc,
2465/// - @p bias_desc together with @p diff_bias_desc,
2466/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
2467///
2468/// This would then indicate that the GRU backward propagation primitive
2469/// should not use them and should default to zero values instead.
2470///
2471/// @note
2472/// All memory descriptors can be initialized with
2473/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2474///
2475/// @param primitive_desc Output primitive descriptor.
2476/// @param engine Engine to use.
2477/// @param prop_kind Propagation kind. Must be #dnnl_backward.
2478/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2479/// info.
2480/// @param src_layer_desc Memory descriptor for the input vector.
2481/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2482/// state vector.
2483/// @param weights_layer_desc Memory descriptor for the weights applied to the
2484/// layer input.
2485/// @param weights_iter_desc Memory descriptor for the weights applied to the
2486/// recurrent input.
2487/// @param bias_desc Bias memory descriptor.
2488/// @param dst_layer_desc Memory descriptor for the output vector.
2489/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2490/// state vector.
2491/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2492/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2493/// hidden state vector.
2494/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2495/// applied to the layer input.
2496/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2497/// applied to the recurrent input.
2498/// @param diff_bias_desc Diff bias memory descriptor.
2499/// @param diff_dst_layer_desc Memory descriptor for the diff of output
2500/// vector.
2501/// @param diff_dst_iter_desc Memory descriptor for the diff of output
2502/// recurrent hidden state vector.
2503/// @param flags Unused.
2504/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
2505/// primitive.
2506/// @param attr Primitive attributes (can be NULL).
2507/// @returns #dnnl_success on success and a status describing the error
2508/// otherwise.
2509dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
2510 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2511 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2512 const_dnnl_memory_desc_t src_layer_desc,
2513 const_dnnl_memory_desc_t src_iter_desc,
2514 const_dnnl_memory_desc_t weights_layer_desc,
2515 const_dnnl_memory_desc_t weights_iter_desc,
2516 const_dnnl_memory_desc_t bias_desc,
2517 const_dnnl_memory_desc_t dst_layer_desc,
2518 const_dnnl_memory_desc_t dst_iter_desc,
2519 const_dnnl_memory_desc_t diff_src_layer_desc,
2520 const_dnnl_memory_desc_t diff_src_iter_desc,
2521 const_dnnl_memory_desc_t diff_weights_layer_desc,
2522 const_dnnl_memory_desc_t diff_weights_iter_desc,
2523 const_dnnl_memory_desc_t diff_bias_desc,
2524 const_dnnl_memory_desc_t diff_dst_layer_desc,
2525 const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
2526 const_dnnl_primitive_desc_t hint_fwd_pd,
2527 const_dnnl_primitive_attr_t attr);
2528
2529/// Creates a descriptor for LBR GRU forward propagation primitive.
2530///
2531/// The following arguments may either be @c NULL or point to a zero memory
2532/// descriptor:
2533/// - @p src_iter_desc,
2534/// - @p bias_desc,
2535/// - @p dst_iter_desc.
2536///
2537/// This would then indicate that the LBR GRU forward propagation primitive
2538/// should not use them and should default to zero values instead.
2539///
2540/// @param primitive_desc Output primitive descriptor.
2541/// @param engine Engine to use.
2542/// @param prop_kind Propagation kind. Possible values are
2543/// #dnnl_forward_training and #dnnl_forward_inference.
2544/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2545/// info.
2546/// @param src_layer_desc Memory descriptor for the input vector.
2547/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2548/// state vector.
2549/// @param weights_layer_desc Memory descriptor for the weights applied to the
2550/// layer input.
2551/// @param weights_iter_desc Memory descriptor for the weights applied to the
2552/// recurrent input.
2553/// @param bias_desc Bias memory descriptor.
2554/// @param dst_layer_desc Memory descriptor for the output vector.
2555/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2556/// state vector.
2557/// @param flags Unused.
2558/// @param attr Primitive attributes (can be NULL).
2559/// @returns #dnnl_success on success and a status describing the error
2560/// otherwise.
2561dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create(
2562 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2563 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2564 const_dnnl_memory_desc_t src_layer_desc,
2565 const_dnnl_memory_desc_t src_iter_desc,
2566 const_dnnl_memory_desc_t weights_layer_desc,
2567 const_dnnl_memory_desc_t weights_iter_desc,
2568 const_dnnl_memory_desc_t bias_desc,
2569 const_dnnl_memory_desc_t dst_layer_desc,
2570 const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
2571 const_dnnl_primitive_attr_t attr);
2572
2573/// Creates a primitive descriptor for LBR GRU backward propagation primitive.
2574///
2575/// The following arguments may either be @c NULL or point to a zero memory
2576/// descriptor:
2577/// - @p src_iter_desc together with @p diff_src_iter_desc,
2578/// - @p bias_desc together with @p diff_bias_desc,
2579/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
2580///
2581/// This would then indicate that the LBR GRU backward propagation primitive
2582/// should not use them and should default to zero values instead.
2583///
2584/// @note
2585/// All memory descriptors can be initialized with
2586/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2587///
2588/// @param primitive_desc Output primitive descriptor.
2589/// @param engine Engine to use.
2590/// @param prop_kind Propagation kind. Must be #dnnl_backward.
2591/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2592/// info.
2593/// @param src_layer_desc Memory descriptor for the input vector.
2594/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2595/// state vector.
2596/// @param weights_layer_desc Memory descriptor for the weights applied to the
2597/// layer input.
2598/// @param weights_iter_desc Memory descriptor for the weights applied to the
2599/// recurrent input.
2600/// @param bias_desc Bias memory descriptor.
2601/// @param dst_layer_desc Memory descriptor for the output vector.
2602/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2603/// state vector.
2604/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2605/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2606/// hidden state vector.
2607/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2608/// applied to the layer input.
2609/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2610/// applied to the recurrent input.
2611/// @param diff_bias_desc Diff bias memory descriptor.
2612/// @param diff_dst_layer_desc Memory descriptor for the diff of output
2613/// vector.
2614/// @param diff_dst_iter_desc Memory descriptor for the diff of output
2615/// recurrent hidden state vector.
2616/// @param flags Unused.
2617/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
2618/// primitive.
2619/// @param attr Primitive attributes (can be NULL).
2620/// @returns #dnnl_success on success and a status describing the error
2621/// otherwise.
2622dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create(
2623 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2624 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2625 const_dnnl_memory_desc_t src_layer_desc,
2626 const_dnnl_memory_desc_t src_iter_desc,
2627 const_dnnl_memory_desc_t weights_layer_desc,
2628 const_dnnl_memory_desc_t weights_iter_desc,
2629 const_dnnl_memory_desc_t bias_desc,
2630 const_dnnl_memory_desc_t dst_layer_desc,
2631 const_dnnl_memory_desc_t dst_iter_desc,
2632 const_dnnl_memory_desc_t diff_src_layer_desc,
2633 const_dnnl_memory_desc_t diff_src_iter_desc,
2634 const_dnnl_memory_desc_t diff_weights_layer_desc,
2635 const_dnnl_memory_desc_t diff_weights_iter_desc,
2636 const_dnnl_memory_desc_t diff_bias_desc,
2637 const_dnnl_memory_desc_t diff_dst_layer_desc,
2638 const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
2639 const_dnnl_primitive_desc_t hint_fwd_pd,
2640 const_dnnl_primitive_attr_t attr);
2641
2642/// Creates a primitive descriptor for AUGRU forward propagation primitive.
2643///
2644/// The following arguments may either be @c NULL or point to a zero memory
2645/// descriptor:
2646/// - @p src_iter_desc,
2647/// - @p bias_desc,
2648/// - @p dst_iter_desc.
2649///
2650/// This would then indicate that the AUGRU forward propagation primitive should
2651/// not use them and should default to zero values instead.
2652///
2653/// @note
2654/// All memory descriptors can be initialized with
2655/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2656///
2657/// @param primitive_desc Output primitive descriptor.
2658/// @param engine Engine to use.
2659/// @param prop_kind Propagation kind. Possible values are
2660/// #dnnl_forward_training and #dnnl_forward_inference.
2661/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2662/// info.
2663/// @param src_layer_desc Memory descriptor for the input vector.
2664/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2665/// state vector.
2666/// @param attention_desc Memory descriptor for the attention vector.
2667/// @param weights_layer_desc Memory descriptor for the weights applied to the
2668/// layer input.
2669/// @param weights_iter_desc Memory descriptor for the weights applied to the
2670/// recurrent input.
2671/// @param bias_desc Bias memory descriptor.
2672/// @param dst_layer_desc Memory descriptor for the output vector.
2673/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2674/// state vector.
2675/// @param flags Unused.
2676/// @param attr Primitive attributes (can be NULL).
2677/// @returns #dnnl_success on success and a status describing the error
2678/// otherwise.
2679dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create(
2680 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2681 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2682 const_dnnl_memory_desc_t src_layer_desc,
2683 const_dnnl_memory_desc_t src_iter_desc,
2684 const_dnnl_memory_desc_t attention_desc,
2685 const_dnnl_memory_desc_t weights_layer_desc,
2686 const_dnnl_memory_desc_t weights_iter_desc,
2687 const_dnnl_memory_desc_t bias_desc,
2688 const_dnnl_memory_desc_t dst_layer_desc,
2689 const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
2690 const_dnnl_primitive_attr_t attr);
2691
2692/// Creates a primitive descriptor for AUGRU backward propagation primitive.
2693///
2694/// The following arguments may either be @c NULL or point to a zero memory
2695/// descriptor:
2696/// - @p src_iter_desc together with @p diff_src_iter_desc,
2697/// - @p bias_desc together with @p diff_bias_desc,
2698/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
2699///
2700/// This would then indicate that the AUGRU backward propagation primitive
2701/// should not use them and should default to zero values instead.
2702///
2703/// @note
2704/// All memory descriptors can be initialized with
2705/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2706///
2707/// @param primitive_desc Output primitive descriptor.
2708/// @param engine Engine to use.
2709/// @param prop_kind Propagation kind. Must be #dnnl_backward.
2710/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2711/// info.
2712/// @param src_layer_desc Memory descriptor for the input vector.
2713/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2714/// state vector.
2715/// @param attention_desc Memory descriptor for the attention vector.
2716/// @param weights_layer_desc Memory descriptor for the weights applied to the
2717/// layer input.
2718/// @param weights_iter_desc Memory descriptor for the weights applied to the
2719/// recurrent input.
2720/// @param bias_desc Bias memory descriptor.
2721/// @param dst_layer_desc Memory descriptor for the output vector.
2722/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2723/// state vector.
2724/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2725/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2726/// hidden state vector.
2727/// @param diff_attention_desc Memory descriptor for the diff of attention vector.
2728/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2729/// applied to the layer input.
2730/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2731/// applied to the recurrent input.
2732/// @param diff_bias_desc Diff bias memory descriptor.
2733/// @param diff_dst_layer_desc Memory descriptor for the diff of output
2734/// vector.
2735/// @param diff_dst_iter_desc Memory descriptor for the diff of output
2736/// recurrent hidden state vector.
2737/// @param flags Unused.
2738/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
2739/// primitive.
2740/// @param attr Primitive attributes (can be NULL).
2741/// @returns #dnnl_success on success and a status describing the error
2742/// otherwise.
2743dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create(
2744 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2745 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2746 const_dnnl_memory_desc_t src_layer_desc,
2747 const_dnnl_memory_desc_t src_iter_desc,
2748 const_dnnl_memory_desc_t attention_desc,
2749 const_dnnl_memory_desc_t weights_layer_desc,
2750 const_dnnl_memory_desc_t weights_iter_desc,
2751 const_dnnl_memory_desc_t bias_desc,
2752 const_dnnl_memory_desc_t dst_layer_desc,
2753 const_dnnl_memory_desc_t dst_iter_desc,
2754 const_dnnl_memory_desc_t diff_src_layer_desc,
2755 const_dnnl_memory_desc_t diff_src_iter_desc,
2756 const_dnnl_memory_desc_t diff_attention_desc,
2757 const_dnnl_memory_desc_t diff_weights_layer_desc,
2758 const_dnnl_memory_desc_t diff_weights_iter_desc,
2759 const_dnnl_memory_desc_t diff_bias_desc,
2760 const_dnnl_memory_desc_t diff_dst_layer_desc,
2761 const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
2762 const_dnnl_primitive_desc_t hint_fwd_pd,
2763 const_dnnl_primitive_attr_t attr);
2764
2765/// Creates a primitive descriptor for LBR AUGRU forward propagation primitive.
2766///
2767/// The following arguments may either be @c NULL or point to a zero memory
2768/// descriptor:
2769/// - @p src_iter_desc,
2770/// - @p bias_desc,
2771/// - @p dst_iter_desc.
2772///
2773/// This would then indicate that the LBR AUGRU forward propagation primitive
2774/// should not use them and should default to zero values instead.
2775///
2776/// @param primitive_desc Output primitive descriptor.
2777/// @param engine Engine to use.
2778/// @param prop_kind Propagation kind. Possible values are
2779/// #dnnl_forward_training and #dnnl_forward_inference.
2780/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2781/// info.
2782/// @param src_layer_desc Memory descriptor for the input vector.
2783/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2784/// state vector.
2785/// @param attention_desc Memory descriptor for the attention vector.
2786/// @param weights_layer_desc Memory descriptor for the weights applied to the
2787/// layer input.
2788/// @param weights_iter_desc Memory descriptor for the weights applied to the
2789/// recurrent input.
2790/// @param bias_desc Bias memory descriptor.
2791/// @param dst_layer_desc Memory descriptor for the output vector.
2792/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2793/// state vector.
2794/// @param flags Unused.
2795/// @param attr Primitive attributes (can be NULL).
2796/// @returns #dnnl_success on success and a status describing the error
2797/// otherwise.
2798dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create(
2799 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2800 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2801 const_dnnl_memory_desc_t src_layer_desc,
2802 const_dnnl_memory_desc_t src_iter_desc,
2803 const_dnnl_memory_desc_t attention_desc,
2804 const_dnnl_memory_desc_t weights_layer_desc,
2805 const_dnnl_memory_desc_t weights_iter_desc,
2806 const_dnnl_memory_desc_t bias_desc,
2807 const_dnnl_memory_desc_t dst_layer_desc,
2808 const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
2809 const_dnnl_primitive_attr_t attr);
2810
2811/// Creates a primitive descriptor for LBR AUGRU backward propagation primitive.
2812///
2813/// The following arguments may either be @c NULL or point to a zero memory
2814/// descriptor:
2815/// - @p src_iter_desc together with @p diff_src_iter_desc,
2816/// - @p bias_desc together with @p diff_bias_desc,
2817/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
2818///
2819/// This would then indicate that the LBR AUGRU backward propagation primitive
2820/// should not use them and should default to zero values instead.
2821///
2822/// @note
2823/// All memory descriptors can be initialized with
2824/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2825///
2826/// @param primitive_desc Output primitive descriptor.
2827/// @param engine Engine to use.
2828/// @param prop_kind Propagation kind. Must be #dnnl_backward.
2829/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2830/// info.
2831/// @param src_layer_desc Memory descriptor for the input vector.
2832/// @param src_iter_desc Memory descriptor for the input recurrent hidden
2833/// state vector.
2834/// @param attention_desc Memory descriptor for the attention vector.
2835/// @param weights_layer_desc Memory descriptor for the weights applied to the
2836/// layer input.
2837/// @param weights_iter_desc Memory descriptor for the weights applied to the
2838/// recurrent input.
2839/// @param bias_desc Bias memory descriptor.
2840/// @param dst_layer_desc Memory descriptor for the output vector.
2841/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2842/// state vector.
2843/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2844/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2845/// hidden state vector.
2846/// @param diff_attention_desc Memory descriptor for the diff of attention vector.
2847/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2848/// applied to the layer input.
2849/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2850/// applied to the recurrent input.
2851/// @param diff_bias_desc Diff bias memory descriptor.
2852/// @param diff_dst_layer_desc Memory descriptor for the diff of output
2853/// vector.
2854/// @param diff_dst_iter_desc Memory descriptor for the diff of output
2855/// recurrent hidden state vector.
2856/// @param flags Unused.
2857/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
2858/// primitive.
2859/// @param attr Primitive attributes (can be NULL).
2860/// @returns #dnnl_success on success and a status describing the error
2861/// otherwise.
2862dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create(
2863 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2864 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2865 const_dnnl_memory_desc_t src_layer_desc,
2866 const_dnnl_memory_desc_t src_iter_desc,
2867 const_dnnl_memory_desc_t attention_desc,
2868 const_dnnl_memory_desc_t weights_layer_desc,
2869 const_dnnl_memory_desc_t weights_iter_desc,
2870 const_dnnl_memory_desc_t bias_desc,
2871 const_dnnl_memory_desc_t dst_layer_desc,
2872 const_dnnl_memory_desc_t dst_iter_desc,
2873 const_dnnl_memory_desc_t diff_src_layer_desc,
2874 const_dnnl_memory_desc_t diff_src_iter_desc,
2875 const_dnnl_memory_desc_t diff_attention_desc,
2876 const_dnnl_memory_desc_t diff_weights_layer_desc,
2877 const_dnnl_memory_desc_t diff_weights_iter_desc,
2878 const_dnnl_memory_desc_t diff_bias_desc,
2879 const_dnnl_memory_desc_t diff_dst_layer_desc,
2880 const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
2881 const_dnnl_primitive_desc_t hint_fwd_pd,
2882 const_dnnl_primitive_attr_t attr);
2883
2884/// @} dnnl_api_rnn
2885
2886/// @addtogroup dnnl_api_matmul
2887/// @{
2888
2889/// Creates a primitive descriptor for a matrix multiplication primitive.
2890///
2891/// @param primitive_desc Output primitive descriptor.
2892/// @param engine Engine to use.
2893/// @param src_desc Source memory descriptor (matrix A)
2894/// @param weights_desc Weights memory descriptor (matrix B)
2895/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
2896/// descriptor, or a memory descriptor with format_kind set to
2897/// #dnnl_format_kind_undef disables the bias term.
2898/// @param dst_desc Destination memory descriptor (matrix C).
2899/// @param attr Primitive attributes (can be NULL).
2900/// @returns #dnnl_success on success and a status describing the error
2901/// otherwise.
2902dnnl_status_t DNNL_API dnnl_matmul_primitive_desc_create(
2903 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2904 const_dnnl_memory_desc_t src_desc,
2905 const_dnnl_memory_desc_t weights_desc,
2906 const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
2907 const_dnnl_primitive_attr_t attr);
2908
2909/// @} dnnl_api_matmul
2910
2911/// @addtogroup dnnl_api_resampling Resampling
2912/// @{
2913
2914/// Creates a primitive descriptor for a resampling forward propagation
2915/// primitive.
2916///
2917/// @note
2918/// Destination memory descriptor is allowed to be initialized with
2919/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2920///
2921/// @param primitive_desc Output primitive descriptor.
2922/// @param engine Engine to use.
2923/// @param prop_kind Propagation kind. Possible values are
2924/// #dnnl_forward_training and #dnnl_forward_inference.
2925/// @param alg_kind resampling algorithm kind: either #dnnl_resampling_nearest,
2926/// or #dnnl_resampling_linear.
2927/// @param factors Array of scaling factors for spatial dimension.
2928/// @param src_desc Source memory descriptor.
2929/// @param dst_desc Destination memory descriptor.
2930/// @param attr Primitive attributes (can be NULL).
2931/// @returns #dnnl_success on success and a status describing the error
2932/// otherwise.
2933dnnl_status_t DNNL_API dnnl_resampling_forward_primitive_desc_create(
2934 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2935 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
2936 const float *factors, const_dnnl_memory_desc_t src_desc,
2937 const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
2938
2939/// Creates a primitive descriptor for a resampling backward propagation
2940/// primitive.
2941///
2942/// @param primitive_desc Output primitive descriptor.
2943/// @param engine Engine to use.
2944/// @param alg_kind resamplinging algorithm kind: either
2945/// #dnnl_resampling_nearest, or #dnnl_resampling_linear.
2946/// @param diff_src_desc Diff source memory descriptor.
2947/// @param diff_dst_desc Diff destination memory descriptor.
2948/// @param factors Array of scaling factors for spatial dimension.
2949/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
2950/// primitive.
2951/// @param attr Primitive attributes (can be NULL).
2952/// @returns #dnnl_success on success and a status describing the error
2953/// otherwise.
2954///
2955dnnl_status_t DNNL_API dnnl_resampling_backward_primitive_desc_create(
2956 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2957 dnnl_alg_kind_t alg_kind, const float *factors,
2958 const_dnnl_memory_desc_t diff_src_desc,
2959 const_dnnl_memory_desc_t diff_dst_desc,
2960 const_dnnl_primitive_desc_t hint_fwd_pd,
2961 const_dnnl_primitive_attr_t attr);
2962
2963/// @} dnnl_api_resampling
2964
2965/// @addtogroup dnnl_api_reduction Reduction
2966/// @{
2967
2968/// Creates a primitive descriptor for a reduction primitive.
2969///
2970/// @note
2971/// Destination memory descriptor is allowed to be initialized with
2972/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2973///
2974/// @param primitive_desc Output primitive descriptor.
2975/// @param engine Engine to use.
2976/// @param alg_kind reduction algorithm kind. Possible values:
2977/// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
2978/// #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max,
2979/// #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max,
2980/// #dnnl_reduction_norm_lp_power_p_sum.
2981/// @param p Algorithm specific parameter.
2982/// @param eps Algorithm specific parameter.
2983/// @param src_desc Source memory descriptor.
2984/// @param dst_desc Destination memory descriptor.
2985/// @param attr Primitive attributes (can be NULL).
2986/// @returns #dnnl_success on success and a status describing the error
2987/// otherwise.
2988dnnl_status_t DNNL_API dnnl_reduction_primitive_desc_create(
2989 dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
2990 dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
2991 const_dnnl_memory_desc_t dst_desc, float p, float eps,
2992 const_dnnl_primitive_attr_t attr);
2993
2994/// @} dnnl_api_reduction
2995
2996/// @} dnnl_api_primitives
2997
2998/// @addtogroup dnnl_api_primitive_cache
2999/// @{
3000
3001/// Returns the number of primitives that can be held in the primitive cache
3002/// at the same time.
3003///
3004/// @param capacity Primitive cache capacity to query. Concurrently
3005/// accessing @p capacity is safe.
3006/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3007/// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
3008/// success.
3009dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity);
3010
3011/// Sets a number of primitives that can be held in the primitive cache
3012/// at a time.
3013///
3014/// @param capacity Primitive cache capacity to set. If a new @p capacity is
3015/// less than a number of primitives that the primitive cache already has
3016/// then the excess entries will be evicted. Setting the @p capacity to 0
3017/// clears the primitive cache and disables it. Concurrently modifying
3018/// @p capacity is safe.
3019/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3020/// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
3021/// success.
3022dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity);
3023
3024/// @} dnnl_api_primitive_cache
3025
3026/// @addtogroup dnnl_api_service
3027/// @{
3028
3029/// Configures dumping of JIT-generated code.
3030///
3031/// @note
3032/// This setting overrides the DNNL_JIT_DUMP environment variable.
3033///
3034/// @param enable Flag value. Set to 0 to disable and set to 1 to enable.
3035/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3036/// @p flag value is invalid, and #dnnl_success/#dnnl::status::success on
3037/// success.
3038dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable);
3039
3040/// Sets library profiling flags. The flags define which profilers are
3041/// supported.
3042///
3043/// @note
3044/// This setting overrides DNNL_JIT_PROFILE environment variable.
3045///
3046/// @sa @ref dev_guide_profilers
3047///
3048/// @param flags Profiling flags that can contain the following bits:
3049/// - @ref DNNL_JIT_PROFILE_VTUNE -- integration with VTune Amplifier
3050/// (on by default)
3051/// - @ref DNNL_JIT_PROFILE_LINUX_JITDUMP -- produce Linux-specific
3052/// jit-pid.dump output (off by default). The location of the output
3053/// is controlled via JITDUMPDIR environment variable or via
3054/// dnnl_set_jit_profiling_jitdumpdir() function.
3055/// - @ref DNNL_JIT_PROFILE_LINUX_PERFMAP -- produce Linux-specific
3056/// perf-pid.map output (off by default). The output is always placed
3057/// into /tmp.
3058///
3059/// Passing @ref DNNL_JIT_PROFILE_NONE disables profiling completely.
3060///
3061/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3062/// @p flags value is invalid, and #dnnl_success/#dnnl::status::success on
3063/// success.
3064dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags);
3065
3066/// Sets JIT dump output path. Only applicable to Linux and is only
3067/// used when profiling flags have DNNL_JIT_PROFILE_LINUX_PERF bit set.
3068///
3069/// After the first JIT kernel is generated, the jitdump output will be placed
3070/// into temporary directory created using the mkdtemp template
3071/// 'dir/.debug/jit/dnnl.XXXXXX'.
3072///
3073/// @sa @ref dev_guide_profilers
3074///
3075/// @note
3076/// This setting overrides JITDUMPDIR environment variable. If
3077/// JITDUMPDIR is not set, and this function is never called, the path
3078/// defaults to HOME. Passing NULL reverts the value to default.
3079///
3080/// @note
3081/// The directory is accessed only when the first JIT kernel is being
3082/// created. JIT profiling will be disabled in case of any errors
3083/// accessing or creating this directory.
3084///
3085/// @param dir JIT dump output path.
3086/// @returns #dnnl_success/#dnnl::status::success if the
3087/// output directory was set correctly and an error status otherwise.
3088/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented on Windows.
3089dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir);
3090
3091/// Sets the maximal ISA the library can dispatch to on the CPU. See
3092/// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values accepted by
3093/// the C and C++ API functions respectively.
3094///
3095/// This function has effect only once, and returns an error on subsequent
3096/// calls. It should also be invoked before any other oneDNN API call, otherwise
3097/// it may return an error.
3098///
3099/// This function overrides the DNNL_MAX_CPU_ISA environment variable. The
3100/// environment variable can be set to the desired maximal ISA name in upper
3101/// case and with dnnl_cpu_isa prefix removed. For example:
3102/// `DNNL_MAX_CPU_ISA=AVX2`.
3103///
3104/// @note
3105/// The ISAs are only partially ordered:
3106/// - SSE41 < AVX < AVX2 < AVX2_VNNI < AVX2_VNNI_2,
3107/// - AVX2 < AVX512_CORE < AVX512_CORE_VNNI < AVX512_CORE_BF16
3108/// < AVX512_CORE_FP16 < AVX512_CORE_AMX < AVX512_CORE_AMX_FP16,
3109/// - AVX2_VNNI < AVX512_CORE_FP16.
3110///
3111/// @sa @ref dev_guide_cpu_dispatcher_control for more details
3112///
3113/// @param isa Maximal ISA the library should dispatch to. Pass
3114/// #dnnl_cpu_isa_default/#dnnl::cpu_isa::isa_default to remove ISA restrictions
3115/// (except for ISAs with initial support in the library).
3116/// @returns #dnnl_success/#dnnl::status::success on success and a
3117/// #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the @p isa
3118/// parameter is invalid or the ISA cannot be changed at this time.
3119/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
3120/// was disabled at build time (see @ref dev_guide_build_options for more
3121/// details).
3122dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa);
3123
3124/// Gets the maximal ISA the library can dispatch to on the CPU. See
3125/// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values returned by
3126/// the C and C++ API functions respectively.
3127///
3128/// @sa @ref dev_guide_cpu_dispatcher_control for more details
3129///
3130/// @returns #dnnl_cpu_isa_t value reflecting the maximal ISA the library may
3131/// dispatch to.
3132dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void);
3133
3134/// Sets the hints flag for the CPU ISA. See #dnnl_cpu_isa_hints_t and
3135/// #dnnl::cpu_isa_hints for the list of the values accepted by the C and C++
3136/// API functions respectively.
3137///
3138/// This function has effect only once, and returns an error on subsequent
3139/// calls. It should also be invoked before any other oneDNN API call, otherwise
3140/// it may return an error.
3141///
3142/// This function overrides the DNNL_CPU_ISA_HINTS environment variable.
3143/// @sa @ref dev_guide_cpu_isa_hints for more details
3144///
3145/// @param isa_hints CPU ISA hints to be passed over to the implementation.
3146/// Pass #dnnl_cpu_isa_no_hints/#dnnl::cpu_isa_hints::no_hints to use
3147/// default features i.e. no hints.
3148/// @returns #dnnl_success/#dnnl::status::success on success and a
3149/// #dnnl_runtime_error/#dnnl::status::runtime_error if the ISA hints cannot
3150/// be specified at the current time.
3151/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
3152/// was disabled at build time (see @ref dev_guide_build_options for more
3153/// details).
3154dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints);
3155
3156/// Gets the ISA specific hints that library can follow. See
3157/// #dnnl_cpu_isa_hints_t and #dnnl::cpu_isa_hints for the list of the values
3158/// returned by the C and C++ API functions respectively.
3159///
3160/// @sa @ref dev_guide_cpu_isa_hints for more details
3161///
3162/// @returns #dnnl_cpu_isa_hints_t value reflecting the ISA specific hints the
3163/// library can follow.
3164dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void);
3165
3166/// @} dnnl_api_service
3167
3168/// @addtogroup dnnl_api_blas
3169/// @{
3170
3171/// Performs single-precision matrix-matrix multiply.
3172///
3173/// The operation is defined as:
3174///
3175/// `C := alpha * op( A ) * op( B ) + beta * C`
3176///
3177/// where
3178/// - `op( X ) = X` or `op( X ) = X**T`,
3179/// - `alpha` and `beta` are scalars, and
3180/// - `A`, `B`, and `C` are matrices:
3181/// - `op( A )` is an `MxK` matrix,
3182/// - `op( B )` is an `KxN` matrix,
3183/// - `C` is an `MxN` matrix.
3184///
3185/// The matrices are assumed to be stored in row-major order (the elements in
3186/// each of the matrix rows are contiguous in memory).
3187///
3188/// @note
3189/// This API does not support XERBLA. Instead, unlike the standard BLAS
3190/// functions, this one returns a dnnl_status_t value to allow error
3191/// handling.
3192///
3193/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
3194/// transposed, and 'T' or 't' means that A is transposed.
3195/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
3196/// transposed, and 'T' or 't' means that B is transposed.
3197/// @param M The M dimension.
3198/// @param N The N dimension.
3199/// @param K The K dimension.
3200/// @param alpha The alpha parameter that is used to scale the product of
3201/// matrices A and B.
3202/// @param A A pointer to the A matrix data.
3203/// @param lda The leading dimension for the matrix A.
3204/// @param B A pointer to the B matrix data.
3205/// @param ldb The leading dimension for the matrix B.
3206/// @param beta The beta parameter that is used to scale the matrix C.
3207/// @param C A pointer to the C matrix data.
3208/// @param ldc The leading dimension for the matrix C.
3209/// @returns #dnnl_success/#dnnl::status::success on success and a status
3210/// describing the error otherwise.
3211dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M,
3212 dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
3213 const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc);
3214
3215/// Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit
3216/// signed matrix B, and 32-bit signed resulting matrix C.
3217///
3218/// The operation is defined as:
3219///
3220/// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
3221///
3222/// where
3223/// - `op( X ) = X` or `op( X ) = X**T`,
3224/// - `alpha` and `beta` are scalars, and
3225/// - `A`, `B`, and `C` are matrices:
3226/// - `op( A )` is an `MxK` matrix,
3227/// - `op( B )` is an `KxN` matrix,
3228/// - `C` is an `MxN` matrix.
3229/// - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
3230/// - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
3231/// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
3232/// - if `offsetc = F`: the `len` must be at least `1`,
3233/// - if `offsetc = C`: the `len` must be at least `max(1, m)`,
3234/// - if `offsetc = R`: the `len` must be at least `max(1, n)`,
3235///
3236/// The matrices are assumed to be stored in row-major order (the elements in
3237/// each of the matrix rows are contiguous in memory).
3238///
3239/// @note
3240/// This API does not support XERBLA. Instead, unlike the standard BLAS
3241/// functions, this one returns a dnnl_status_t value to allow error
3242/// handling.
3243///
3244/// @warning
3245/// On some architectures saturation may happen during intermediate
3246/// computations, which would lead to unexpected results. For more
3247/// details, refer to @ref dev_guide_int8_computations.
3248///
3249/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
3250/// transposed, and 'T' or 't' means that A is transposed.
3251/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
3252/// transposed, and 'T' or 't' means that B is transposed.
3253/// @param offsetc Flag specifying how offsets should be applied to matrix C:
3254/// - 'F' means that the same offset will be applied to each element of
3255/// the matrix C,
3256/// - 'C' means that individual offset will be applied to each element
3257/// within each column,
3258/// - 'R' means that individual offset will be applied to each element
3259/// within each row.
3260/// @param M The M dimension.
3261/// @param N The N dimension.
3262/// @param K The K dimension.
3263/// @param alpha The alpha parameter that is used to scale the product of
3264/// matrices A and B.
3265/// @param A A pointer to the A matrix data.
3266/// @param lda The leading dimension for the matrix A.
3267/// @param ao The offset value for the matrix A.
3268/// @param B A pointer to the B matrix data.
3269/// @param ldb The leading dimension for the matrix B.
3270/// @param bo The offset value for the matrix B.
3271/// @param beta The beta parameter that is used to scale the matrix C.
3272/// @param C A pointer to the C matrix data.
3273/// @param ldc The leading dimension for the matrix C.
3274/// @param co An array of offset values for the matrix C. The number of
3275/// elements in the array depends on the value of @p offsetc.
3276/// @returns #dnnl_success/#dnnl::status::success on success and a status
3277/// describing the error otherwise.
3278dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc,
3279 dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
3280 dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
3281 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
3282
3283/// Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit
3284/// signed matrix B, and 32-bit signed resulting matrix C.
3285///
3286/// The operation is defined as:
3287///
3288/// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
3289///
3290/// where
3291/// - `op( X ) = X` or `op( X ) = X**T`,
3292/// - `alpha` and `beta` are scalars, and
3293/// - `A`, `B`, and `C` are matrices:
3294/// - `op( A )` is an `MxK` matrix,
3295/// - `op( B )` is an `KxN` matrix,
3296/// - `C` is an `MxN` matrix.
3297/// - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
3298/// - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
3299/// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
3300/// - if `offsetc = F`: the `len` must be at least `1`,
3301/// - if `offsetc = C`: the `len` must be at least `max(1, m)`,
3302/// - if `offsetc = R`: the `len` must be at least `max(1, n)`,
3303///
3304/// The matrices are assumed to be stored in row-major order (the elements in
3305/// each of the matrix rows are contiguous in memory).
3306///
3307/// @note
3308/// This API does not support XERBLA. Instead, unlike the standard BLAS
3309/// functions, this one returns a dnnl_status_t value to allow error
3310/// handling.
3311///
3312/// @warning
3313/// On some architectures saturation may happen during intermediate
3314/// computations, which would lead to unexpected results. For more
3315/// details, refer to @ref dev_guide_int8_computations.
3316///
3317/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
3318/// transposed, and 'T' or 't' means that A is transposed.
3319/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
3320/// transposed, and 'T' or 't' means that B is transposed.
3321/// @param offsetc Flag specifying how offsets should be applied to matrix C:
3322/// - 'F' means that the same offset will be applied to each element of
3323/// the matrix C,
3324/// - 'C' means that individual offset will be applied to each element
3325/// within each column,
3326/// - 'R' means that individual offset will be applied to each element
3327/// within each row.
3328/// @param M The M dimension.
3329/// @param N The N dimension.
3330/// @param K The K dimension.
3331/// @param alpha The alpha parameter that is used to scale the product of
3332/// matrices A and B.
3333/// @param A A pointer to the A matrix data.
3334/// @param lda The leading dimension for the matrix A.
3335/// @param ao The offset value for the matrix A.
3336/// @param B A pointer to the B matrix data.
3337/// @param ldb The leading dimension for the matrix B.
3338/// @param bo The offset value for the matrix B.
3339/// @param beta The beta parameter that is used to scale the matrix C.
3340/// @param C A pointer to the C matrix data.
3341/// @param ldc The leading dimension for the matrix C.
3342/// @param co An array of offset values for the matrix C. The number of
3343/// elements in the array depends on the value of @p offsetc.
3344/// @returns #dnnl_success/#dnnl::status::success on success and a status
3345/// describing the error otherwise.
3346dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc,
3347 dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
3348 dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
3349 float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
3350
3351/// @} dnnl_api_blas
3352
3353/// @} dnnl_api
3354
3355#ifdef __cplusplus
3356}
3357#endif
3358
3359#endif /* ONEAPI_DNNL_DNNL_H */
3360