1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /// @file |
18 | /// C API |
19 | |
20 | #ifndef ONEAPI_DNNL_DNNL_H |
21 | #define ONEAPI_DNNL_DNNL_H |
22 | |
23 | #include "oneapi/dnnl/dnnl_config.h" |
24 | #include "oneapi/dnnl/dnnl_types.h" |
25 | #include "oneapi/dnnl/dnnl_version.h" |
26 | |
27 | #ifdef __cplusplus |
28 | extern "C" { |
29 | #endif |
30 | |
31 | /// @addtogroup dnnl_api |
32 | /// @{ |
33 | |
34 | /// @addtogroup dnnl_api_primitives |
35 | /// @{ |
36 | |
37 | /// @addtogroup dnnl_api_primitives_common |
38 | /// @{ |
39 | |
40 | /// Creates a primitive descriptor iterator. |
41 | /// |
42 | /// @param iterator Output primitive descriptor iterator. |
43 | /// @param op_desc Operation descriptor. |
44 | /// @param attr Primitive attributes (can be NULL). |
45 | /// @param engine Engine to use. |
46 | /// @param hint_forward_primitive_desc For backward propagation: primitive |
47 | /// descriptor for a respective forward propagation primitive. Pass NULL |
48 | /// for forward propagation. |
49 | /// @returns #dnnl_success on success and a status describing the error |
50 | /// otherwise. |
51 | dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create( |
52 | dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc, |
53 | const_dnnl_primitive_attr_t attr, dnnl_engine_t engine, |
54 | const_dnnl_primitive_desc_t hint_forward_primitive_desc); |
55 | |
56 | /// Advances the primitive descriptor iterator to point to the next available |
57 | /// implementation. |
58 | /// |
59 | /// @param iterator A primitive descriptor iterator to advance. |
60 | /// @returns #dnnl_success on success and a status describing the error |
61 | /// otherwise. |
62 | /// @returns #dnnl_iterator_ends if no more implementations available. |
63 | dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next( |
64 | dnnl_primitive_desc_iterator_t iterator); |
65 | |
66 | /// Fetches the current primitive descriptor from a primitive descriptor |
67 | /// iterator. |
68 | /// |
69 | /// @note |
70 | /// The user is responsible for deleting the resulting primitive |
71 | /// descriptor using dnnl_primitive_desc_destroy(). |
72 | /// |
73 | /// @param iterator A primitive descriptor iterator. |
74 | /// @returns A primitive descriptor. |
75 | dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch( |
76 | const_dnnl_primitive_desc_iterator_t iterator); |
77 | |
78 | /// Destroys a primitive descriptor iterator. |
79 | /// |
80 | /// @param iterator Primitive descriptor iterator to destroy. |
81 | /// @returns #dnnl_success on success and a status describing the error |
82 | /// otherwise. |
83 | dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy( |
84 | dnnl_primitive_desc_iterator_t iterator); |
85 | |
86 | /// Creates a primitive descriptor. This function is equivalent to a sequence |
87 | /// of #dnnl_primitive_desc_iterator_create() and |
88 | /// #dnnl_primitive_desc_iterator_fetch(). In other words, the library will |
89 | /// pick the first suitable implementation. |
90 | /// |
91 | /// @param primitive_desc Output primitive descriptor. |
92 | /// @param op_desc Operation descriptor. |
93 | /// @param attr Primitive attributes (can be NULL). |
94 | /// @param engine Engine to use. |
95 | /// @param hint_forward_primitive_desc For backward propagation: primitive |
96 | /// descriptor for a respective forward propagation primitive. Pass NULL |
97 | /// for forward propagation. |
98 | /// @returns #dnnl_success on success and a status describing the error |
99 | /// otherwise. |
100 | dnnl_status_t DNNL_API dnnl_primitive_desc_create( |
101 | dnnl_primitive_desc_t *primitive_desc, const_dnnl_op_desc_t op_desc, |
102 | const_dnnl_primitive_attr_t attr, dnnl_engine_t engine, |
103 | const_dnnl_primitive_desc_t hint_forward_primitive_desc); |
104 | |
105 | /// Clones a primitive descriptor. The resulting primitive descriptor must be |
106 | /// destroyed separately. |
107 | /// |
108 | /// @param primitive_desc Output primitive descriptor. |
109 | /// @param existing_primitive_desc Primitive descriptor to clone. |
110 | /// @returns #dnnl_success on success and a status describing the error |
111 | /// otherwise. |
112 | dnnl_status_t DNNL_API dnnl_primitive_desc_clone( |
113 | dnnl_primitive_desc_t *primitive_desc, |
114 | const_dnnl_primitive_desc_t existing_primitive_desc); |
115 | |
116 | /// Returns a constant reference to the attributes of a primitive descriptor. |
117 | /// |
118 | /// @warning |
119 | /// It is an error to destroy the resulting @p attr. |
120 | /// |
121 | /// @warning |
122 | /// The lifetime of an @p attr is the same as that of a @p |
123 | /// primitive_desc, so it is an error to use the @p attr once the @p |
124 | /// primitive_desc has been destroyed. |
125 | /// |
126 | /// @param primitive_desc Primitive descriptor. |
127 | /// @param attr Output primitive attributes. |
128 | /// @returns #dnnl_success on success and a status describing the error |
129 | /// otherwise. |
130 | dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr( |
131 | const_dnnl_primitive_desc_t primitive_desc, |
132 | const_dnnl_primitive_attr_t *attr); |
133 | |
134 | /// Destroys a primitive descriptor. |
135 | /// |
136 | /// @param primitive_desc Primitive descriptor to destroy. |
137 | /// @returns #dnnl_success on success and a status describing the error |
138 | /// otherwise. |
139 | dnnl_status_t DNNL_API dnnl_primitive_desc_destroy( |
140 | dnnl_primitive_desc_t primitive_desc); |
141 | |
142 | /// Queries a primitive descriptor for various pieces of information. |
143 | /// |
144 | /// The most common use case is to query a primitive descriptor, created with |
145 | /// source, weights, and destination memory descriptors with format tags set |
146 | /// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this |
147 | /// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and |
148 | /// #dnnl_query_dst_md respectively) so that it is possible to create memory |
149 | /// objects and reorder primitives if necessary. |
150 | /// |
151 | /// Another typical use case is to query a primitive descriptor for workspace |
152 | /// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this |
153 | /// query returns #dnnl_not_required status, then workspace memory is not |
154 | /// required. |
155 | /// |
156 | /// @note |
157 | /// When querying for a memory descriptor for a scratchpad, a workspace, |
158 | /// or an optional parameter, the query will return a pointer to a zero |
159 | /// memory descriptor if the parameter is not needed. |
160 | /// |
161 | /// A few other use cases: |
162 | /// - query a primitive descriptor for the underlying operation descriptor |
163 | /// (#dnnl_query_convolution_d, #dnnl_query_eltwise_d, #dnnl_query_rnn_d, |
164 | /// etc.) |
165 | /// - query a primitive descriptor for the implementation information string |
166 | /// (#dnnl_query_impl_info_str) |
167 | /// - query a primitive descriptor for the number of inputs and outputs |
168 | /// (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32 |
169 | /// respectively) |
170 | /// |
171 | /// @sa dnnl_query_t for more options |
172 | /// |
173 | /// @param primitive_desc Primitive descriptor. |
174 | /// @param what Parameter to query. |
175 | /// @param index Index of the parameter to query for. |
176 | /// @param result Output result. The type depends on the query. For example, |
177 | /// it must be a @c dnnl_memory_desc_t* if querying for a memory |
178 | /// descriptor. |
179 | /// @returns #dnnl_success on success and a status describing the error |
180 | /// otherwise. |
181 | dnnl_status_t DNNL_API dnnl_primitive_desc_query( |
182 | const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, |
183 | int index, void *result); |
184 | |
185 | /// Queries primitive descriptor for a memory descriptor. |
186 | /// |
187 | /// @note |
188 | /// This function is a convenience version of |
189 | /// #dnnl_primitive_desc_query(). |
190 | /// |
191 | /// @param primitive_desc Primitive descriptor. |
192 | /// @param what Kind of memory descriptor parameter to query for. |
193 | /// @param index Index of the parameter to query. |
194 | /// @returns A pointer to the requested memory descriptor. |
195 | /// @returns A pointer to a zero memory descriptor if the parameter is not |
196 | /// needed. |
197 | /// @returns NULL in case of any error. |
198 | /// |
199 | const dnnl_memory_desc_t DNNL_API *dnnl_primitive_desc_query_md( |
200 | const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, |
201 | int index); |
202 | |
203 | /// Queries primitive descriptor for a signed 32bit int. |
204 | /// |
205 | /// @note |
206 | /// This function is a convenience version of |
207 | /// #dnnl_primitive_desc_query(). |
208 | /// |
209 | /// @param primitive_desc Primitive descriptor. |
210 | /// @param what Kind of the value to query for. |
211 | /// @param index Index of the parameter to query. |
212 | /// @returns The requested value. |
213 | /// @returns 0 in case of any error (in particular if the queried entity is |
214 | /// not of type int32_t). Note that 0 may also be the actual returned |
215 | /// value. |
216 | int DNNL_API dnnl_primitive_desc_query_s32( |
217 | const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, |
218 | int index); |
219 | |
220 | /// Creates a primitive. |
221 | /// |
222 | /// @param primitive Output primitive. |
223 | /// @param primitive_desc Primitive descriptor used to create the primitive. |
224 | /// @returns #dnnl_success on success and a status describing the error |
225 | /// otherwise. |
226 | dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive, |
227 | const_dnnl_primitive_desc_t primitive_desc); |
228 | |
229 | /// Creates a primitive from a cache blob. |
230 | /// |
231 | /// @param primitive Output primitive. |
232 | /// @param primitive_desc Primitive descriptor used to create the primitive. |
233 | /// @param size Size of the cache blob in bytes. |
234 | /// @param cache_blob Cache blob of size @p size. |
235 | /// @returns #dnnl_success on success and a status describing the error |
236 | /// otherwise. |
237 | dnnl_status_t DNNL_API dnnl_primitive_create_from_cache_blob( |
238 | dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc, |
239 | size_t size, const uint8_t *cache_blob); |
240 | |
241 | /// Executes a primitive. |
242 | /// |
243 | /// @param primitive Primitive to execute. |
244 | /// @param stream Stream to use. |
245 | /// @param nargs Number of arguments. |
246 | /// @param args Array of arguments. Each argument is an |
247 | /// <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*` |
248 | /// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see |
249 | /// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory |
250 | /// descriptor as that returned by |
251 | /// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index). |
252 | /// @returns #dnnl_success on success and a status describing the error |
253 | /// otherwise. |
254 | |
255 | /// @note If any argument in @param args is padded (padded_dims > |
256 | /// dims), the primitive execution will assume properly zero-padded |
257 | /// input arguments, and produce zero-padded output arguments. |
258 | dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive, |
259 | dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args); |
260 | |
261 | /// Retrieves a constant reference to the primitive descriptor of a given |
262 | /// primitive. |
263 | /// |
264 | /// @warning |
265 | /// It is an error to destroy the returned object. It is owned by the |
266 | /// primitive. The @c const qualifier of the returned object prevents |
267 | /// such attempts. |
268 | /// |
269 | /// @param primitive Primitive to query for the primitive descriptor. |
270 | /// @param primitive_desc Output primitive descriptor. |
271 | /// @returns #dnnl_success on success and a status describing the error |
272 | /// otherwise. |
273 | dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc( |
274 | const_dnnl_primitive_t primitive, |
275 | const_dnnl_primitive_desc_t *primitive_desc); |
276 | |
277 | /// Retrieves a cache blob associated with the given primitive. |
278 | /// |
279 | /// @param primitive Primitive to query for the cache blob. |
280 | /// @param size Size of the cache blob in bytes. |
281 | /// @param cache_blob Cache blob of size @p size. If the @p cache_blob is |
282 | /// nullptr then the size of the cache blob is returned in @p size. |
283 | /// @returns #dnnl_success on success and a status describing the error |
284 | /// otherwise. |
285 | /// |
286 | /// @note The cache blob can be empty. It's the user's responsibility to check |
287 | /// whether it's empty prior to passing it to |
288 | /// #dnnl_primitive_create_from_cache_blob(). |
289 | dnnl_status_t DNNL_API dnnl_primitive_get_cache_blob( |
290 | const_dnnl_primitive_t primitive, size_t *size, uint8_t *cache_blob); |
291 | |
292 | /// Destroys a primitive. |
293 | /// |
294 | /// @param primitive The primitive to destroy. |
295 | /// @returns #dnnl_success on success and a status describing the error |
296 | /// otherwise. |
297 | dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive); |
298 | |
299 | /// @} dnnl_api_primitives_common |
300 | |
301 | /// @addtogroup dnnl_api_attributes |
302 | /// @{ |
303 | |
304 | /// Creates an empty (default) primitive attributes with all the parameters |
305 | /// set to their default values. |
306 | /// |
307 | /// Empty attributes are implied whenever the respective argument is NULL. |
308 | /// |
309 | /// @param attr Output primitive attributes. |
310 | /// @returns #dnnl_success on success and a status describing the error |
311 | /// otherwise. |
312 | dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr); |
313 | |
314 | /// Clones primitive attributes. |
315 | /// |
316 | /// @param attr Output primitive attributes. |
317 | /// @param existing_attr Primitive attributes to clone. |
318 | /// @returns #dnnl_success on success and a status describing the error |
319 | /// otherwise. |
320 | dnnl_status_t DNNL_API dnnl_primitive_attr_clone( |
321 | dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr); |
322 | |
323 | /// Destroys primitive attributes. |
324 | /// |
325 | /// @param attr Primitive attributes to destroy. |
326 | /// @returns #dnnl_success on success and a status describing the error |
327 | /// otherwise. |
328 | dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr); |
329 | |
330 | /// Returns the floating-point math mode primitive attribute. |
331 | /// |
332 | /// @param attr Primitive attributes. |
333 | /// @param mode Output FP math mode. |
334 | /// @returns #dnnl_success on success and a status describing the error |
335 | /// otherwise. |
336 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode( |
337 | const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode); |
338 | |
339 | /// Sets the floating-point math mode primitive attributes. |
340 | /// |
341 | /// @param attr Primitive attributes. |
342 | /// @param mode FP math mode. The possible values are: |
343 | /// #dnnl_fpmath_mode_strict (default), |
344 | /// #dnnl_fpmath_mode_bf16, |
345 | /// #dnnl_fpmath_mode_f16, |
346 | /// #dnnl_fpmath_mode_tf32, |
347 | /// #dnnl_fpmath_mode_any. |
348 | /// @returns #dnnl_success on success and a status describing the error |
349 | /// otherwise. |
350 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode( |
351 | dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode); |
352 | |
353 | /// Returns the primitive attributes scratchpad mode. |
354 | /// |
355 | /// @param attr Primitive attributes. |
356 | /// @param mode Output scratchpad mode. |
357 | /// @returns #dnnl_success on success and a status describing the error |
358 | /// otherwise. |
359 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode( |
360 | const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode); |
361 | |
362 | /// Sets primitive attributes scratchpad mode. |
363 | /// |
364 | /// @param attr Primitive attributes. |
365 | /// @param mode Scratchpad mode. The possible values are: |
366 | /// #dnnl_scratchpad_mode_library (default) and |
367 | /// #dnnl_scratchpad_mode_user. |
368 | /// @returns #dnnl_success on success and a status describing the error |
369 | /// otherwise. |
370 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode( |
371 | dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode); |
372 | |
373 | /// Returns primitive attributes output scaling factors correspondence mask |
374 | /// and values. |
375 | /// |
376 | /// @warning |
377 | /// The @p scales array is an internal part of the primitive attributes |
378 | /// @p attr, so it is an error to modify or destroy the @p scales array. |
379 | /// |
380 | /// @warning |
381 | /// The lifetime of @p scales array is the same as that of the primitive |
382 | /// attributes @p attr to which it belongs, so it is an error to use |
383 | /// @p scales after @p attr is destroyed. |
384 | /// |
385 | /// @param attr Primitive attributes. |
386 | /// @param count Output length of the array of scaling factors @p scales. |
387 | /// @param mask Output scaling factors correspondence mask that defines the |
388 | /// correspondence between the output tensor dimensions and the @p scales |
389 | /// vector. The set i-th bit indicates that a dedicated output scaling |
390 | /// factor is used for each index along that dimension. The mask value of |
391 | /// 0 implies a common output scaling factor for the whole output tensor. |
392 | /// @param scales Output pointer to a constant array of scaling factors. |
393 | /// @returns #dnnl_success on success and a status describing the error |
394 | /// otherwise. |
395 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales( |
396 | const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, |
397 | const float **scales); |
398 | |
399 | /// Sets output scaling factors correspondence mask and values. |
400 | /// |
401 | /// @note |
402 | /// The order of dimensions does not depend on how elements are laid |
403 | /// out in memory. For example: |
404 | /// - for a 2D CNN activations tensor the order is always (n, c) |
405 | /// - for a 4D CNN activations tensor the order is always (n, c, h, w) |
406 | /// - for a 5D CNN weights tensor the order is always |
407 | /// (g, oc, ic, kh, kw) |
408 | /// |
409 | /// Example usage: |
410 | /// @code |
411 | /// int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params |
412 | /// float scales[oc] = { ... }; // unique output scales per output channel |
413 | /// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... |
414 | /// |
415 | /// dnnl_convolution_desc_t conv_d; // create a convolution descriptor |
416 | /// |
417 | /// dnnl_primitive_attr_t attr; |
418 | /// dnnl_primitive_attr_create(&attr); // create primitive attributes |
419 | /// dnnl_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales); |
420 | /// |
421 | /// dnnl_primitive_desc_t conv_pd; |
422 | /// dnnl_primitive_desc_create(&conv_pd, &conv_d, attr, engine, NULL); |
423 | /// @endcode |
424 | /// |
425 | /// @param attr Primitive attributes. |
426 | /// @param count Length of the array of scaling factors @p scales. |
427 | /// @param mask Scaling factors correspondence mask that defines the |
428 | /// correspondence between the output tensor dimensions and the @p scales |
429 | /// array. The set i-th bit indicates that a dedicated output scaling |
430 | /// factor is used for each index along that dimension. The mask value of |
431 | /// 0 implies a common output scaling factor for the whole output tensor. |
432 | /// @param scales Array of output scaling factors. If the output scaling |
433 | /// factors are known at the time of this call, this array must contain @p |
434 | /// count values and the following equality must hold: |
435 | /// \f[count = \prod\limits_{d \in mask} output.dims[d].\f] |
436 | /// Violations can only be detected when the attributes are used to create |
437 | /// a primitive descriptor. |
438 | /// If the output scaling factors are not known at the time of the call, |
439 | /// this array must contain a single #DNNL_RUNTIME_F32_VAL value and the |
440 | /// output scaling factors must be passed at execution time as an argument |
441 | /// with index #DNNL_ARG_ATTR_OUTPUT_SCALES. |
442 | /// @returns #dnnl_success on success and a status describing the error |
443 | /// otherwise. |
444 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales( |
445 | dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, |
446 | const float *scales); |
447 | |
448 | /// Returns primitive attributes scaling factors correspondence mask and values |
449 | /// for a given memory argument. |
450 | /// |
451 | /// @warning |
452 | /// The output @p scales array is an internal part of the primitive |
453 | /// attributes @p attr, so it is an error to modify or destroy the @p |
454 | /// scales array. |
455 | /// |
456 | /// @warning |
457 | /// The lifetime of the @p scales array is the same as that of the primitive |
458 | /// attributes @p attr to which it belongs, so it is an error to use @p |
459 | /// scales after @p attr is destroyed. |
460 | /// |
461 | /// |
462 | /// @param attr Primitive attributes. |
463 | /// @param arg Parameter argument index as passed to the |
464 | /// dnnl_primitive_execute() call. |
465 | /// @param count Output length of the array of scaling factors @p scales. |
466 | /// @param mask Output scaling factors correspondence mask that defines the |
467 | /// correspondence between the output tensor dimensions and the @p |
468 | /// scales array. The set i-th bit indicates that a dedicated output scaling |
469 | /// factor is used for each index along that dimension. The mask value of 0 |
470 | /// implies a common scaling factor for the whole output tensor. |
471 | /// @param scales Output pointer to a constant array of float scaling factors. |
472 | /// @returns #dnnl_success on success and a status describing the error |
473 | /// otherwise. |
474 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_scales( |
475 | dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, |
476 | const float **scales); |
477 | |
478 | /// Sets primitive attributes scaling factors for primitive operations for a |
479 | /// given memory argument. |
480 | /// |
481 | /// @sa dnnl_primitive_attr_set_output_scales |
482 | /// |
483 | /// |
484 | /// @param attr Primitive attributes. |
485 | /// @param arg Parameter argument index as passed to the |
486 | /// dnnl_primitive_execute() call. |
487 | /// @param count Length of the array of scaling factors @p scales. |
488 | /// @param mask Scaling factors correspondence mask that defines the |
489 | /// correspondence between the tensor dimensions and the @p scales array. |
490 | /// The set i-th bit indicates that a dedicated scaling factor is used for |
491 | /// each index along that dimension. Set the mask to 0 to use a common |
492 | /// scaling factor for the whole output tensor. |
493 | /// @param scales Constant array of float scaling factors. This array must |
494 | /// contain @p count scales and the following equality must hold: |
495 | /// \f[count = \prod\limits_{d \in mask} output.dims[d].\f] |
496 | /// @returns #dnnl_success on success and a status describing the error |
497 | /// otherwise. |
498 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales( |
499 | dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, |
500 | const float *scales); |
501 | |
502 | /// Returns @p count, correspondence zero point @p mask, and a pointer to a |
503 | /// constant int32_t array of @p zero_points for given @p attr and memory |
504 | /// argument (index), previously set by dnnl_primitive_attr_set_zero_points. |
505 | /// |
506 | /// @warning |
507 | /// The output @p zero_points array is an internal part of the primitive |
508 | /// attributes @p attr, so it is an error to modify or destroy the @p |
509 | /// zero_points array. |
510 | /// |
511 | /// @warning |
512 | /// The lifetime of @p zero_points array is the same as that of the |
513 | /// primitive attributes @p attr to which it belongs, so it is an error |
514 | /// to use @p zero_points after @p attr is destroyed. |
515 | /// |
516 | /// |
517 | /// @param attr Primitive attributes. |
518 | /// @param arg Parameter argument index as passed to the |
519 | /// dnnl_primitive_execute() call. |
520 | /// @param count Output length of the array of zero points @p zero_points. |
521 | /// @param mask Output zero points correspondence mask that defines the |
522 | /// correspondence between the output tensor dimensions and the @p |
523 | /// zero_points array. The set i-th bit indicates that a dedicated output |
524 | /// zero point is used for each index along that dimension. The mask |
525 | /// value of 0 implies a common zero point for the whole output tensor. |
526 | /// @param zero_points Output pointer to a constant array of int32_t zero |
527 | /// points. |
528 | /// @returns #dnnl_success on success and a status describing the error |
529 | /// otherwise. |
530 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_zero_points( |
531 | const_dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, |
532 | const int32_t **zero_points); |
533 | |
534 | /// Sets primitive attributes zero points for primitive operations for a given |
535 | /// memory argument. |
536 | /// |
537 | /// @sa dnnl_primitive_attr_set_output_scales |
538 | /// |
539 | /// |
540 | /// @param attr Primitive attributes. |
541 | /// @param arg Parameter argument index as passed to the |
542 | /// dnnl_primitive_execute() call. |
543 | /// @param count Length of the array of zero points @p zero_points. |
544 | /// @param mask Zero point correspondence mask that defines the |
545 | /// correspondence between the tensor dimensions and the @p |
546 | /// zero_points array. The set i-th bit indicates that a dedicated |
547 | /// zero point is used for each index along that dimension. Set the |
548 | /// mask to 0 to use a common zero point for the whole output tensor. |
549 | /// @param zero_points Constant array of int32_t zero points. If the zero |
550 | /// points are known at the time of this call, this array must contain @p |
551 | /// count zero points and the following equality must hold: |
552 | /// \f[count = \prod\limits_{d \in mask} output.dims[d].\f] |
553 | /// If the zero points are not known at the time of the call, this array |
554 | /// must contain a single #DNNL_RUNTIME_S32_VAL and the zero points must |
555 | /// be passed at execution time as an argument with index |
556 | /// #DNNL_ARG_ATTR_ZERO_POINTS. |
557 | /// @returns #dnnl_success on success and a status describing the error |
558 | /// otherwise. |
559 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points( |
560 | dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, |
561 | const int32_t *zero_points); |
562 | |
563 | /// Returns primitive attributes post-ops. |
564 | /// |
565 | /// @warning |
566 | /// The output @p post_ops points to the internal @p attr field, so it is |
567 | /// an error to modify or destroy them. The lifetime of @p post_ops is |
568 | /// the same as that of the @p attr it belongs to, so it is an error to |
569 | /// use @p post_ops after @p attr has been destroyed. |
570 | /// |
571 | /// @param attr Primitive attributes. |
572 | /// @param post_ops Output post-ops. |
573 | /// @returns #dnnl_success on success and a status describing the error |
574 | /// otherwise. |
575 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops( |
576 | const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops); |
577 | |
578 | /// Sets primitive attributes post-ops. |
579 | /// |
580 | /// @note |
581 | /// There is no way to check whether the post-ops would be supported by |
582 | /// the target primitive. Any error will be reported by the |
583 | /// dnnl_primitive_desc_create() function call. |
584 | /// |
585 | /// @param attr Primitive attributes. |
586 | /// @param post_ops Post-ops to set. |
587 | /// @returns #dnnl_success on success and a status describing the error |
588 | /// otherwise. |
589 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops( |
590 | dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops); |
591 | |
592 | /// Creates empty post-ops sequence. |
593 | /// |
594 | /// @param post_ops Output post-ops. |
595 | /// @returns #dnnl_success on success and a status describing the error |
596 | /// otherwise. |
597 | dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops); |
598 | |
599 | /// Clones post-ops primitive attribute. |
600 | /// |
601 | /// @param post_ops Output post-ops primitive attribute. |
602 | /// @param existing_post_ops Post-ops primitive attribute to clone. |
603 | /// @returns #dnnl_success on success and a status describing the error |
604 | /// otherwise. |
605 | dnnl_status_t DNNL_API dnnl_post_ops_clone( |
606 | dnnl_post_ops_t *post_ops, const_dnnl_post_ops_t existing_post_ops); |
607 | |
608 | /// Destroys post-ops. |
609 | /// |
610 | /// @param post_ops Post-ops to destroy. |
611 | /// @returns #dnnl_success on success and a status describing the error |
612 | /// otherwise. |
613 | dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops); |
614 | |
615 | /// Returns the length of post-ops. |
616 | /// |
617 | /// @param post_ops Post-ops. |
618 | /// @returns The number of post-ops entries. |
619 | int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops); |
620 | |
621 | /// Returns the kind of a post-op entry. |
622 | /// |
623 | /// @param post_ops Post-ops. |
624 | /// @param index Post-op entry index. |
625 | /// @returns The kind of the post-op with the specified index. |
626 | /// @returns #dnnl_undefined_primitive if there is no post-op at the specified |
627 | /// index. |
628 | dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind( |
629 | const_dnnl_post_ops_t post_ops, int index); |
630 | |
631 | /// Appends an accumulation (sum) to post-ops. Prior to accumulating the |
632 | /// result, the previous value is multiplied by a scale. |
633 | /// |
634 | /// The kind of this post-op is #dnnl_sum. |
635 | /// |
636 | /// This feature may improve performance for cases like residual learning |
637 | /// blocks, where the result of convolution is accumulated to the previously |
638 | /// computed activations. The parameter @p scale may be used for the |
639 | /// integer-based computations when the result and previous activations have |
640 | /// different logical scaling factors. |
641 | /// |
642 | /// In the simplest case where the accumulation is the only post-op, the |
643 | /// computations will be: |
644 | /// |
645 | /// dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...) |
646 | /// |
647 | /// @note |
648 | /// This post-op executes in-place and does not change the |
649 | /// destination layout. |
650 | /// |
651 | /// @param post_ops Post-ops. |
652 | /// @param scale Accumulation scaling factor. |
653 | /// @returns #dnnl_success on success and a status describing the error |
654 | /// otherwise. |
655 | dnnl_status_t DNNL_API dnnl_post_ops_append_sum( |
656 | dnnl_post_ops_t post_ops, float scale); |
657 | |
658 | /// Appends an accumulation v2 (sum) to post-ops. Prior to accumulating the |
659 | /// result, the previous value is multiplied by a scale. |
660 | /// |
661 | /// The kind of this post-op is #dnnl_sum. |
662 | /// |
663 | /// This feature may improve performance for cases like residual learning |
664 | /// blocks, where the result of convolution is accumulated to the previously |
665 | /// computed activations. The parameter @p scale may be used for the |
666 | /// integer-based computations when the result and previous activations have |
667 | /// different logical scaling factors. |
668 | /// |
669 | /// In the simplest case where the accumulation is the only post-op, the |
670 | /// computations will be: |
671 | /// |
672 | /// dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...) |
673 | /// |
674 | /// If @p data_type is specified, original dst tensor will be reinterpreted |
675 | /// as a tensor with provided data type. Since it is reinterpretation, |
676 | /// data_type and dst data type should have the same size. |
677 | /// As a result, computations will be: |
678 | /// |
679 | /// dst[:] <- scale * as_data_type(dst[:]) + op(...) |
680 | /// // instead of dst[:] <- op(...) |
681 | /// @note |
682 | /// This post-op executes in-place and does not change the |
683 | /// destination layout. |
684 | /// |
685 | /// @param post_ops Post-ops. |
686 | /// @param scale Accumulation scaling factor. |
687 | /// @param data_type Accumulation data_type. |
688 | /// @returns #dnnl_success on success and a status describing the error |
689 | /// otherwise. |
690 | dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v2( |
691 | dnnl_post_ops_t post_ops, float scale, dnnl_data_type_t data_type); |
692 | |
693 | /// Appends an accumulation v3 (sum) to post-ops. Prior to accumulating the |
694 | /// result, a zero point is subtracted from the previous value and is |
695 | /// multiplied by the scale. |
696 | /// |
697 | /// The kind of this post-op is #dnnl_sum. |
698 | /// |
699 | /// This feature may improve performance for cases like dequantize the |
700 | /// asymmetrically quantized sum's src1 tensor to f32 domain before performing |
701 | /// the sum operation by subtracting the @p zero_point before the scaling. |
702 | /// |
703 | /// In the simplest case where accumulation is the only post-op, the |
704 | /// computations will be: |
705 | /// |
706 | /// dst[:] <- scale * (dst[:] - zero_point) + op(...) |
707 | /// // instead of dst[:] <- op(...) |
708 | /// |
709 | /// If @p data_type is specified, original dst tensor will be reinterpreted |
710 | /// as a tensor with provided data type. Since it is reinterpretation, |
711 | /// data_type and dst data type should have the same size. |
712 | /// As a result, computations will be: |
713 | /// |
714 | /// dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...) |
715 | /// // instead of dst[:] <- op(...) |
716 | /// @note |
717 | /// This post-op executes in-place and does not change the |
718 | /// destination layout. |
719 | /// |
720 | /// @param post_ops Post-ops. |
721 | /// @param scale Accumulation scaling factor. |
722 | /// @param zero_point Single scalar int32_t value of zero point. |
723 | /// @param data_type Accumulation data_type. |
724 | /// @returns #dnnl_success on success and a status describing the error |
725 | /// otherwise. |
726 | dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v3(dnnl_post_ops_t post_ops, |
727 | float scale, int32_t zero_point, dnnl_data_type_t data_type); |
728 | |
729 | /// Returns the parameters of an accumulation (sum) post-op. |
730 | /// |
731 | /// @param post_ops Post-ops. |
732 | /// @param index Index of the sum post-op. |
733 | /// @param scale Output accumulation scaling factor. |
734 | /// @returns #dnnl_success on success and a status describing the error |
735 | /// otherwise. |
736 | /// @returns #dnnl_invalid_arguments if @p index does not refer to a sum |
737 | /// post-op. |
738 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum( |
739 | const_dnnl_post_ops_t post_ops, int index, float *scale); |
740 | |
741 | /// Returns the parameters of an accumulation (sum) post-op with |
742 | /// a data type parameter. |
743 | /// |
744 | /// @param post_ops Post-ops. |
745 | /// @param index Index of the sum post-op. |
746 | /// @param scale Output accumulation scaling factor. |
747 | /// @param data_type Data type for accumulation. |
748 | /// @returns #dnnl_success on success and a status describing the error |
749 | /// otherwise. |
750 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v2( |
751 | const_dnnl_post_ops_t post_ops, int index, float *scale, |
752 | dnnl_data_type_t *data_type); |
753 | |
754 | /// Returns the parameters of an accumulation (sum) post-op with |
755 | /// zero point and data type parameter. |
756 | /// |
757 | /// @param post_ops Post-ops. |
758 | /// @param index Index of the sum post-op. |
759 | /// @param scale Output accumulation scaling factor. |
760 | /// @param zero_point Zero point. |
761 | /// @param data_type Data type for accumulation. |
762 | /// @returns #dnnl_success on success and a status describing the error |
763 | /// otherwise. |
764 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v3( |
765 | const_dnnl_post_ops_t post_ops, int index, float *scale, |
766 | int32_t *zero_point, dnnl_data_type_t *data_type); |
767 | |
768 | /// Appends an elementwise post-op. |
769 | /// |
770 | /// The kind of this post operation is #dnnl_eltwise. |
771 | /// |
772 | /// In the simplest case when the elementwise is the only post operation, the |
773 | /// computations would be: |
774 | /// |
775 | /// dst[:] <- scale * eltwise_op (op(...)) // instead of dst[:] <- op(...) |
776 | /// |
777 | /// where eltwise_op is configured with the given parameters. |
778 | /// |
779 | /// @param post_ops Post-ops. |
780 | /// @param scale Scaling factor. |
781 | /// @param alg_kind Elementwise algorithm for the post-op. |
782 | /// @param alpha Alpha parameter for the elementwise algorithm. |
783 | /// @param beta Beta parameter for the elementwise algorithm. |
784 | /// @returns #dnnl_success on success and a status describing the error |
785 | /// otherwise. |
786 | dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops, |
787 | float scale, dnnl_alg_kind_t alg_kind, float alpha, float beta); |
788 | |
789 | /// Returns the parameters of an elementwise post-op. |
790 | /// |
791 | /// @param post_ops Post-ops. |
792 | /// @param index Index of the elementwise post-op. |
793 | /// @param scale Output scaling factor. |
794 | /// @param alg_kind Output elementwise algorithm kind. |
795 | /// @param alpha Output alpha parameter for the elementwise algorithm. |
796 | /// @param beta Output beta parameter for the elementwise algorithm. |
797 | /// @returns #dnnl_success on success and a status describing the error |
798 | /// otherwise. |
799 | /// @returns #dnnl_invalid_arguments if @p index does not refer to an |
800 | /// elementwise post-op. |
801 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise( |
802 | const_dnnl_post_ops_t post_ops, int index, float *scale, |
803 | dnnl_alg_kind_t *alg_kind, float *alpha, float *beta); |
804 | |
805 | /// Appends a depthwise post-op convolution. |
806 | /// |
807 | /// This post-op can only be fused with a 2D 1x1 convolution (convolution with |
808 | /// weights spatial dimensions equal to 1 i.e., kh=kw=1). |
809 | /// |
810 | /// The kind of this post-op is #dnnl_convolution. |
811 | /// |
812 | /// The number of outputs for primitive with fusion is one. The output spatial |
813 | /// size can be derived as below: |
814 | /// |
815 | /// output_height = ceil(output_height_1x1_convolution, stride) |
816 | /// output_width = ceil(output_width_1x1_convolution, stride) |
817 | /// |
818 | /// See @ref dev_guide_attributes_post_ops_depthwise and |
819 | /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info. |
820 | /// |
821 | /// @param post_ops Post-ops. |
822 | /// @param weights_data_type Weights data type of depthwise post-op |
823 | /// @param bias_data_type Bias data type of depthwise post-op |
824 | /// @param dst_data_type Output data type of depthwise post-op |
825 | /// @param kernel_size Size of kernel of depthwise post-op |
826 | /// @param stride_size Size of stride of depthwise post-op |
827 | /// @param padding_l_size Size of left and top paddings of depthwise post-op |
828 | /// @param count Output length of the array of scaling factors @p scales. |
829 | /// @param mask Output scaling factors correspondence mask that defines the |
830 | /// correspondence between the output tensor dimensions and the @p |
831 | /// scales array. The set i-th bit indicates that a dedicated output scaling |
832 | /// factor is used for each index along that dimension. The mask value of 0 |
833 | /// implies a common scaling factor for the whole output tensor. |
834 | /// @param scales Output pointer to a constant array of float scaling factors. |
835 | /// @returns #dnnl_success on success and a status describing the error |
836 | /// otherwise |
837 | dnnl_status_t DNNL_API dnnl_post_ops_append_dw(dnnl_post_ops_t post_ops, |
838 | dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, |
839 | dnnl_data_type_t dst_data_type, dnnl_dim_t kernel_size, |
840 | dnnl_dim_t stride_size, dnnl_dim_t padding_l_size, dnnl_dim_t count, |
841 | int mask, const float *scales); |
842 | |
843 | /// Returns the parameters of an depthwise post-op. |
844 | /// |
845 | /// @param post_ops Post-ops. |
846 | /// @param index Index of the elementwise post-op. |
847 | /// @param weights_data_type Weights data type of depthwise post-op |
848 | /// @param bias_data_type Bias data type of depthwise post-op |
849 | /// @param dst_data_type Output data type of depthwise post-op |
850 | /// @param kernel_size Size of kernel of depthwise post-op |
851 | /// @param stride_size Size of stride of depthwise post-op |
852 | /// @param padding_l_size Size of left and top paddings of depthwise post-op |
853 | /// @param count Output length of the array of scaling factors @p scales. |
854 | /// @param mask Output scaling factors correspondence mask that defines the |
855 | /// correspondence between the output tensor dimensions and the @p |
856 | /// scales array. The set i-th bit indicates that a dedicated output scaling |
857 | /// factor is used for each index along that dimension. The mask value of 0 |
858 | /// implies a common scaling factor for the whole output tensor. |
859 | /// @param scales Output pointer to a constant array of float scaling factors. |
860 | /// @returns #dnnl_success on success and a status describing the error |
861 | /// otherwise |
862 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw( |
863 | const_dnnl_post_ops_t post_ops, int index, |
864 | dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, |
865 | dnnl_data_type_t *dst_data_type, dnnl_dim_t *kernel_size, |
866 | dnnl_dim_t *stride_size, dnnl_dim_t *padding_l_size, dnnl_dim_t *count, |
867 | int *mask, const float **scales); |
868 | |
869 | /// Appends a depthwise post-op convolution with stride 1. |
870 | /// |
871 | /// This post-op can only be fused with a 2D 1x1 convolution (convolution with |
872 | /// weights spatial dimension equal to 1 i.e., kh=kw=1). |
873 | /// |
874 | /// The kind of this post-op is #dnnl_convolution. |
875 | /// |
876 | /// The number of outputs for primitive remain same as before. The output size |
877 | /// remain same as the original primitive due to stride=1. |
878 | /// |
879 | /// The Post-op can be defined as: |
880 | /// |
881 | /// dst[:] <- scales * (conv_dw(conv_1x1)) |
882 | /// |
883 | /// See @ref dev_guide_attributes_post_ops_depthwise and |
884 | /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info. |
885 | /// |
886 | /// @param post_ops Post-ops. |
887 | /// @param weights_data_type Weights data type of depthwise post-op |
888 | /// @param bias_data_type Bias data type of depthwise post-op |
889 | /// @param dst_data_type Output data type of depthwise post-op |
890 | /// @param count Output length of the array of scaling factors @p scales. |
891 | /// @param mask Output scaling factors correspondence mask that defines the |
892 | /// correspondence between the output tensor dimensions and the @p |
893 | /// scales array. The set i-th bit indicates that a dedicated output scaling |
894 | /// factor is used for each index along that dimension. The mask value of 0 |
895 | /// implies a common scaling factor for the whole output tensor. |
896 | /// @param scales Output pointer to a constant array of float scaling factors. |
897 | /// @returns #dnnl_success on success and a status describing the error |
898 | /// otherwise |
899 | dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s1p1(dnnl_post_ops_t post_ops, |
900 | dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, |
901 | dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, |
902 | const float *scales); |
903 | |
904 | /// Returns the parameters of an depthwise post-op with stride 1. |
905 | /// |
906 | /// @param post_ops Post-ops. |
907 | /// @param index Index of the elementwise post-op. |
908 | /// @param weights_data_type Weights data type of depthwise post-op |
909 | /// @param bias_data_type Bias data type of depthwise post-op |
910 | /// @param dst_data_type Output data type of depthwise post-op |
911 | /// @param count Output length of the array of scaling factors @p scales. |
912 | /// @param mask Output scaling factors correspondence mask that defines the |
913 | /// correspondence between the output tensor dimensions and the @p |
914 | /// scales array. The set i-th bit indicates that a dedicated output scaling |
915 | /// factor is used for each index along that dimension. The mask value of 0 |
916 | /// implies a common scaling factor for the whole output tensor. |
917 | /// @param scales Output pointer to a constant array of float scaling factors. |
918 | /// @returns #dnnl_success on success and a status describing the error |
919 | /// otherwise |
920 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s1p1( |
921 | const_dnnl_post_ops_t post_ops, int index, |
922 | dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, |
923 | dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, |
924 | const float **scales); |
925 | |
926 | /// Appends a depthwise post-op convolution with stride 2. |
927 | /// |
928 | /// This post-op can only be fused with a 2D 1x1 convolution (convolution with |
929 | /// weights spatial dimension equal to 1 i.e., kh=kw=1). |
930 | /// |
931 | /// The kind of this post-op is #dnnl_convolution. |
932 | /// |
933 | /// The number of outputs for primitive remain same as before. The output |
934 | /// spatial size can be derived as below: |
935 | /// |
936 | /// output_height = ceil(output_height_1x1_convolution, stride) |
937 | /// output_width = ceil(output_width_1x1_convolution, stride) |
938 | /// |
939 | /// The Post-op can be defined as: |
940 | /// |
941 | /// dst[:] <- scales * (conv_dw(conv_1x1)) |
942 | /// |
943 | /// See @ref dev_guide_attributes_post_ops_depthwise and |
944 | /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info. |
945 | /// |
946 | /// @param post_ops Post-ops. |
947 | /// @param weights_data_type Weights data type of depthwise post-op |
948 | /// @param bias_data_type Bias data type of depthwise post-op |
949 | /// @param dst_data_type Output data type of depthwise post-op |
950 | /// @param count Output length of the array of scaling factors @p scales. |
951 | /// @param mask Output scaling factors correspondence mask that defines the |
952 | /// correspondence between the output tensor dimensions and the @p |
953 | /// scales array. The set i-th bit indicates that a dedicated output scaling |
954 | /// factor is used for each index along that dimension. The mask value of 0 |
955 | /// implies a common scaling factor for the whole output tensor. |
956 | /// @param scales Output pointer to a constant array of float scaling factors. |
957 | /// @returns #dnnl_success on success and a status describing the error |
958 | /// otherwise |
959 | dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s2p1(dnnl_post_ops_t post_ops, |
960 | dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, |
961 | dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, |
962 | const float *scales); |
963 | |
964 | /// Returns the parameters of an depthwise post-op with stride 2. |
965 | /// |
966 | /// @param post_ops Post-ops. |
967 | /// @param index Index of the elementwise post-op. |
968 | /// @param weights_data_type Weights data type of depthwise post-op |
969 | /// @param bias_data_type Bias data type of depthwise post-op |
970 | /// @param dst_data_type Output data type of depthwise post-op |
971 | /// @param count Output length of the array of scaling factors @p scales. |
972 | /// @param mask Output scaling factors correspondence mask that defines the |
973 | /// correspondence between the output tensor dimensions and the @p |
974 | /// scales array. The set i-th bit indicates that a dedicated output scaling |
975 | /// factor is used for each index along that dimension. The mask value of 0 |
976 | /// implies a common scaling factor for the whole output tensor. |
977 | /// @param scales Output pointer to a constant array of float scaling factors. |
978 | /// @returns #dnnl_success on success and a status describing the error |
979 | /// otherwise |
980 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s2p1( |
981 | const_dnnl_post_ops_t post_ops, int index, |
982 | dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, |
983 | dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, |
984 | const float **scales); |
985 | |
986 | /// Appends a binary post-op. |
987 | /// |
988 | /// The kind of this post operation is #dnnl_binary. |
989 | /// |
990 | /// In the simplest case when the binary is the only post operation, the |
991 | /// computations would be: |
992 | /// |
993 | /// dst[:] <- binary_op (dst[:], another_input[:]) |
994 | /// |
995 | /// where binary_op is configured with the given parameters. binary_op supports |
996 | /// broadcast semantics for a second operand. |
997 | /// |
998 | /// @param post_ops Post-ops. |
999 | /// @param alg_kind Binary algorithm for the post-op. |
1000 | /// @param src1_desc Memory descriptor of a second operand. |
1001 | /// @returns #dnnl_success on success and a status describing the error |
1002 | /// otherwise. |
1003 | dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops, |
1004 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src1_desc); |
1005 | |
1006 | /// Returns the parameters of a binary post-op. |
1007 | /// |
1008 | /// @param post_ops Post-ops. |
1009 | /// @param index Index of the binary post-op. |
1010 | /// @param alg_kind Output binary algorithm kind. |
1011 | /// @param src1_desc Output memory descriptor of a second operand. |
1012 | /// @returns #dnnl_success on success and a status describing the error |
1013 | /// otherwise. |
1014 | /// @returns #dnnl_invalid_arguments if @p index does not refer to a binary |
1015 | /// post-op. |
1016 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary( |
1017 | const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind, |
1018 | const dnnl_memory_desc_t **src1_desc); |
1019 | |
1020 | /// Appends a prelu forward post-op. |
1021 | /// |
1022 | /// The kind of this post-op is #dnnl::primitive::kind::prelu. |
1023 | /// |
1024 | /// The post-op can be defined as: |
1025 | /// |
1026 | /// dst[:] <- prelu(dst[:], weights[:]) |
1027 | /// prelu: |
1028 | /// dst[:] <- dst[:] if dst[:] > 0 |
1029 | /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0 |
1030 | /// |
1031 | /// |
1032 | /// @note |
1033 | /// The order of dimensions does not depend on how elements are laid |
1034 | /// out in memory. For example: |
1035 | /// - for a 2D CNN activations tensor the order is always (n, c) |
1036 | /// - for a 4D CNN activations tensor the order is always (n, c, h, w) |
1037 | /// - for a 5D CNN weights tensor the order is always |
1038 | /// (g, oc, ic, kh, kw) |
1039 | /// |
1040 | /// Prelu weights tensor is passed in runtime execution phase. Prelu |
1041 | /// weights tensor data type is implicitly assumed as f32 using plain |
1042 | /// layout (a, ab, acb, acdb, acdeb) |
1043 | /// |
1044 | /// @param post_ops Post-ops. |
1045 | /// @param mask Defines the correspondence between the output tensor |
1046 | /// dimensions and the prelu weights tensor. The set i-th bit indicates |
1047 | /// that a dedicated weights value is used for each index along that |
1048 | /// dimension. Set the mask to 0 to use a common weights value |
1049 | /// for the whole output tensor. |
1050 | /// @returns #dnnl_success on success and a status describing the error |
1051 | /// otherwise. |
1052 | dnnl_status_t DNNL_API dnnl_post_ops_append_prelu( |
1053 | dnnl_post_ops_t post_ops, int mask); |
1054 | |
1055 | /// Returns the parameters of a prelu post-op. |
1056 | /// |
1057 | /// @param post_ops Post-ops. |
1058 | /// @param index Index of the prelu post-op. |
1059 | /// @param mask Mask of the prelu post-op. |
1060 | /// @returns #dnnl_success on success and a status describing the error |
1061 | /// otherwise. |
1062 | dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu( |
1063 | const_dnnl_post_ops_t post_ops, int index, int *mask); |
1064 | |
1065 | /// @} dnnl_api_attributes |
1066 | |
1067 | /// @} dnnl_api_primitives |
1068 | |
1069 | /// @addtogroup dnnl_api_memory |
1070 | /// @{ |
1071 | |
1072 | /// Initializes a memory descriptor using dimensions and strides. |
1073 | /// |
1074 | /// @note |
1075 | /// As always, the logical order of dimensions corresponds to the `abc...` |
1076 | /// format tag, and the physical meaning of the dimensions depends on both |
1077 | /// the primitive that consumes the memory and the context of that |
1078 | /// consumption. |
1079 | /// |
1080 | /// @param memory_desc Output memory descriptor. |
1081 | /// @param ndims Number of dimensions |
1082 | /// @param dims Array of dimensions. |
1083 | /// @param data_type Elements data type. |
1084 | /// @param strides Strides in each dimension. |
1085 | /// @returns #dnnl_success on success and a status describing the error |
1086 | /// otherwise. |
1087 | dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides( |
1088 | dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, |
1089 | dnnl_data_type_t data_type, const dnnl_dims_t strides); |
1090 | |
1091 | /// Initializes a memory descriptor using dimensions and memory format tag. |
1092 | /// |
1093 | /// @note |
1094 | /// As always, the logical order of dimensions corresponds to the `abc...` |
1095 | /// format tag, and the physical meaning of the dimensions depends on both |
1096 | /// the primitive that consumes the memory and the context of that |
1097 | /// consumption. |
1098 | /// |
1099 | /// @param memory_desc Output memory descriptor. |
1100 | /// @param ndims Number of dimensions |
1101 | /// @param dims Array of dimensions. |
1102 | /// @param data_type Elements data type. |
1103 | /// @param tag Memory format tag. Can be #dnnl_format_tag_any which would |
1104 | /// allow a primitive to chose the final memory format. In this case the |
1105 | /// format_kind field of the memory descriptor would be set to |
1106 | /// #dnnl_format_kind_any. |
1107 | /// @returns #dnnl_success on success and a status describing the error |
1108 | /// otherwise. |
1109 | dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag( |
1110 | dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, |
1111 | dnnl_data_type_t data_type, dnnl_format_tag_t tag); |
1112 | |
1113 | /// Initializes a memory descriptor for a region inside an area |
1114 | /// described by an existing memory descriptor. |
1115 | /// |
1116 | /// @warning |
1117 | /// Some combinations of physical memory layout and/or offsets or dims may |
1118 | /// result in a failure to create a submemory. |
1119 | // |
1120 | /// @param memory_desc Output memory descriptor. |
1121 | /// @param parent_memory_desc An existing memory descriptor. |
1122 | /// @param dims Sizes of the region. |
1123 | /// @param offsets Offsets to the region from the encompassing |
1124 | /// memory object in each dimension |
1125 | /// @returns #dnnl_success on success and a status describing the error |
1126 | /// otherwise. |
1127 | dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory( |
1128 | dnnl_memory_desc_t *memory_desc, |
1129 | const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims, |
1130 | const dnnl_dims_t offsets); |
1131 | |
1132 | /// Initializes a memory descriptor by reshaping an existing one. The new |
1133 | /// memory descriptor inherits the data type. This operation is valid only for |
1134 | /// memory descriptors that have format_kind set to #dnnl_blocked or |
1135 | /// #dnnl_format_kind_any. |
1136 | /// |
1137 | /// The operation ensures the transformation of the physical memory format |
1138 | /// corresponds to the transformation of the logical dimensions. If such |
1139 | /// transformation is impossible, the function returns #dnnl_invalid_arguments. |
1140 | /// |
1141 | /// The reshape operation can be described as a combination of the following |
1142 | /// basic operations: |
1143 | /// 1. Add a dimension of size `1`. This is always possible. |
1144 | /// 2. Remove a dimension of size `1`. This is possible only if the dimension |
1145 | /// has no padding (i.e. `padded_dims[dim] == dims[dim] && dims[dim] == 1`). |
1146 | /// 3. Split a dimension into multiple ones. This is possible only if the size |
1147 | /// of the dimension is exactly equal to the product of the split ones and |
1148 | /// the dimension does not have padding (i.e. |
1149 | /// `padded_dims[dim] = dims[dim]`). |
1150 | /// 4. Joining multiple consecutive dimensions into a single one. As in the |
1151 | /// cases above, this requires that the dimensions do not have padding and |
1152 | /// that the memory format is such that in physical memory these dimensions |
1153 | /// are dense and have the same order as their logical counterparts. This |
1154 | /// also assumes that these dimensions are not blocked. |
1155 | /// - Here, dense means: |
1156 | /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`; |
1157 | /// - And same order means: |
1158 | /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`. |
1159 | /// |
1160 | /// @warning |
1161 | /// Some combinations of physical memory layout and/or offsets or |
1162 | /// dimensions may result in a failure to make a reshape. |
1163 | /// |
1164 | /// @param out_memory_desc Output memory descriptor. |
1165 | /// @param in_memory_desc An existing memory descriptor. Must have format_kind |
1166 | /// set to #dnnl_blocked or #dnnl_format_kind_any. |
1167 | /// @param ndims Number of dimensions for the output memory descriptor. |
1168 | /// @param dims Dimensions for the output memory descriptor. |
1169 | /// @returns #dnnl_success on success and a status describing the error |
1170 | /// otherwise. |
1171 | dnnl_status_t DNNL_API dnnl_memory_desc_reshape( |
1172 | dnnl_memory_desc_t *out_memory_desc, |
1173 | const dnnl_memory_desc_t *in_memory_desc, int ndims, |
1174 | const dnnl_dims_t dims); |
1175 | |
1176 | /// Initializes a memory descriptor by permuting axes in an existing one. |
1177 | /// |
1178 | /// The physical memory layout representation is adjusted accordingly to |
1179 | /// maintain the consistency between the logical and physical parts of the |
1180 | /// memory descriptor. |
1181 | /// |
1182 | /// The new memory descriptor inherits the data type. This operation is valid |
1183 | /// only for memory descriptors that have format_kind set to #dnnl_blocked or |
1184 | /// #dnnl_format_kind_any. |
1185 | /// |
1186 | /// The logical axes will be permuted in the following manner: |
1187 | /// ``` |
1188 | /// for (i: 0 .. in_memory_desc->ndims) |
1189 | /// out_memory_desc->dims[permutation[i]] = in_memory_desc->dims[i]; |
1190 | /// ``` |
1191 | /// |
1192 | /// Example: |
1193 | /// @code |
1194 | /// dnnl_memory_desc_t in_md, out_md, expect_out_md; |
1195 | /// |
1196 | /// const int permutation[] = {1, 0}; // swap the first and the second axes |
1197 | /// |
1198 | /// dnnl_dims_t in_dims = {2, 3}, out_dims = {3, 2}; |
1199 | /// dnnl_format_tag_t in_tag = dnnl_ab, out_tag = dnnl_ba; |
1200 | /// |
1201 | /// dnnl_memory_desc_init_by_tag( |
1202 | /// &in_md, 2, in_dims, data_type, in_tag); |
1203 | /// dnnl_memory_desc_init_by_tag( |
1204 | /// &expect_out_md, 2, out_dims, data_type, out_tag); |
1205 | /// |
1206 | /// dnnl_memory_desc_permute_axes(&out_md, in_md, permutation); |
1207 | /// assert(dnnl_memory_desc_equal(&out_md, &expect_out_md)); |
1208 | /// @endcode |
1209 | /// |
1210 | /// @param out_memory_desc Output memory descriptor. |
1211 | /// @param in_memory_desc An existing memory descriptor. Must have format_kind |
1212 | /// set to #dnnl_blocked or #dnnl_format_kind_any. |
1213 | /// @param permutation Axes permutation (of size `in_memory_desc->ndims`). |
1214 | /// @returns #dnnl_success on success and a status describing the error |
1215 | /// otherwise. |
1216 | dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes( |
1217 | dnnl_memory_desc_t *out_memory_desc, |
1218 | const dnnl_memory_desc_t *in_memory_desc, const int *permutation); |
1219 | |
1220 | /// Compares two memory descriptors. |
1221 | /// |
1222 | /// Use this function to identify whether a reorder is required between the |
1223 | /// two memories |
1224 | /// |
1225 | /// @param lhs Left-hand side of the comparison. |
1226 | /// @param rhs Right-hand side of the comparison. |
1227 | /// @returns 1 if the descriptors are the same. |
1228 | /// @returns 0 if the descriptors are different. |
1229 | int DNNL_API dnnl_memory_desc_equal( |
1230 | const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs); |
1231 | |
1232 | /// Returns the size of a memory descriptor. |
1233 | /// |
1234 | /// @param memory_desc Memory descriptor. |
1235 | /// @returns The number of bytes required for memory described by a memory |
1236 | /// descriptor. |
1237 | size_t DNNL_API dnnl_memory_desc_get_size( |
1238 | const dnnl_memory_desc_t *memory_desc); |
1239 | |
1240 | /// Returns the size of data type. |
1241 | /// |
1242 | /// @param data_type Data type. |
1243 | /// @returns The number of bytes occupied by data type. |
1244 | size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type); |
1245 | |
1246 | /// Creates a memory object. |
1247 | /// |
1248 | /// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory |
1249 | /// object will have the underlying buffer set. In this case, the buffer will |
1250 | /// be initialized as if dnnl_memory_set_data_handle() had been called. |
1251 | /// |
1252 | /// @sa dnnl_memory_set_data_handle() |
1253 | /// |
1254 | /// @param memory Output memory object. |
1255 | /// @param memory_desc Memory descriptor. |
1256 | /// @param engine Engine to use. |
1257 | /// @param handle Handle of the memory buffer to use as an underlying storage. |
1258 | /// - A pointer to the user-allocated buffer. In this case the library |
1259 | /// doesn't own the buffer. |
1260 | /// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to |
1261 | /// allocate the buffer for the memory object. In this case the library |
1262 | /// owns the buffer. |
1263 | /// - DNNL_MEMORY_NONE to create dnnl_memory without an underlying buffer. |
1264 | /// @returns #dnnl_success on success and a status describing the error |
1265 | /// otherwise. |
1266 | dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory, |
1267 | const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine, |
1268 | void *handle); |
1269 | |
1270 | /// Returns the memory descriptor for a memory object. |
1271 | /// |
1272 | /// @param memory Memory object. |
1273 | /// @param memory_desc Output memory descriptor (a copy). |
1274 | /// @returns #dnnl_success on success and a status describing the error |
1275 | /// otherwise. |
1276 | dnnl_status_t DNNL_API dnnl_memory_get_memory_desc( |
1277 | const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc); |
1278 | |
1279 | /// Returns the engine of a memory object. |
1280 | /// |
1281 | /// @param memory Memory object. |
1282 | /// @param engine Output engine on which the memory is located. |
1283 | /// @returns #dnnl_success on success and a status describing the error |
1284 | /// otherwise. |
1285 | dnnl_status_t DNNL_API dnnl_memory_get_engine( |
1286 | const_dnnl_memory_t memory, dnnl_engine_t *engine); |
1287 | |
1288 | /// Maps a memory object and returns a host-side pointer to a memory buffer |
1289 | /// with a copy of its contents. |
1290 | /// |
1291 | /// Mapping enables explicit direct access to memory contents for the engines |
1292 | /// that do not support it implicitly. |
1293 | /// |
1294 | /// Mapping is an exclusive operation - a memory object cannot be used in |
1295 | /// other operations until this memory object is unmapped. |
1296 | /// |
1297 | /// @note |
1298 | /// Any primitives working with @p memory should be completed before |
1299 | /// the memory is mapped. Use dnnl_stream_wait to synchronize the |
1300 | /// corresponding execution stream. |
1301 | /// |
1302 | /// @note |
1303 | /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are |
1304 | /// mainly provided for debug and testing purposes, and their performance |
1305 | /// may be suboptimal. |
1306 | /// |
1307 | /// @param memory Memory object. |
1308 | /// @param mapped_ptr Output pointer to the mapped buffer. |
1309 | /// @returns #dnnl_success on success and a status describing the error |
1310 | /// otherwise. |
1311 | dnnl_status_t DNNL_API dnnl_memory_map_data( |
1312 | const_dnnl_memory_t memory, void **mapped_ptr); |
1313 | |
1314 | /// Unmaps a memory object and writes back any changes made to the previously |
1315 | /// mapped memory buffer. The pointer to the mapped buffer must be obtained |
1316 | /// via the dnnl_memory_map_data() call. |
1317 | /// |
1318 | /// @note |
1319 | /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are |
1320 | /// mainly provided for debug and testing purposes, and their performance |
1321 | /// may be suboptimal. |
1322 | /// |
1323 | /// @param memory Memory object. |
1324 | /// @param mapped_ptr Pointer to the mapped buffer that must have been |
1325 | /// obtained using the dnnl_memory_map_data() function. |
1326 | /// @returns #dnnl_success on success and a status describing the error |
1327 | /// otherwise. |
1328 | dnnl_status_t DNNL_API dnnl_memory_unmap_data( |
1329 | const_dnnl_memory_t memory, void *mapped_ptr); |
1330 | |
1331 | /// Returns memory object's data handle. |
1332 | /// |
1333 | /// @param memory Memory object. |
1334 | /// @param handle Output data handle. For the CPU engine, the data handle is a |
1335 | /// pointer to the actual data. For OpenCL it is a cl_mem. |
1336 | /// @returns #dnnl_success on success and a status describing the error |
1337 | /// otherwise. |
1338 | dnnl_status_t DNNL_API dnnl_memory_get_data_handle( |
1339 | const_dnnl_memory_t memory, void **handle); |
1340 | |
1341 | /// Sets the underlying memory buffer. |
1342 | /// |
1343 | /// See the description of dnnl_memory_set_data_handle_v2() for more details. |
1344 | /// |
1345 | /// @param memory Memory object. |
1346 | /// @param handle Data handle. For the CPU engine, the data handle is a |
1347 | /// pointer to the actual data. For OpenCL it is a `cl_mem`. |
1348 | /// @returns #dnnl_success on success and a status describing the error |
1349 | /// otherwise. |
1350 | dnnl_status_t DNNL_API dnnl_memory_set_data_handle( |
1351 | dnnl_memory_t memory, void *handle); |
1352 | |
1353 | /// Sets the underlying memory buffer. |
1354 | /// |
1355 | /// @param memory Memory object. |
1356 | /// @param handle Data handle. For the CPU engine, the data handle is a |
1357 | /// pointer to the actual data. For OpenCL it is a `cl_mem`. |
1358 | /// @param stream Stream to use to execute padding in. |
1359 | /// @returns #dnnl_success on success and a status describing the error |
1360 | /// otherwise. |
1361 | dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2( |
1362 | dnnl_memory_t memory, void *handle, dnnl_stream_t stream); |
1363 | |
1364 | /// Destroys a memory object. |
1365 | /// |
1366 | /// @param memory Memory object to destroy. |
1367 | /// @returns #dnnl_success on success and a status describing the error |
1368 | /// otherwise. |
1369 | dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory); |
1370 | |
1371 | /// @} dnnl_api_memory |
1372 | |
1373 | /// @addtogroup dnnl_api_primitives |
1374 | /// @{ |
1375 | |
1376 | /// @addtogroup dnnl_api_reorder |
1377 | /// @{ |
1378 | |
1379 | /// Creates a primitive descriptor for a reorder primitive. |
1380 | /// |
1381 | /// @param reorder_primitive_desc Output primitive descriptor. |
1382 | /// @param src_desc Source memory descriptor. |
1383 | /// @param src_engine Engine on which the source memory object will be |
1384 | /// located. |
1385 | /// @param dst_desc Destination memory descriptor. |
1386 | /// @param dst_engine Engine on which the destination memory object |
1387 | /// will be located. |
1388 | /// @param attr Primitive attributes to use (can be NULL). |
1389 | /// @returns #dnnl_success on success and a status describing the error |
1390 | /// otherwise. |
1391 | dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create( |
1392 | dnnl_primitive_desc_t *reorder_primitive_desc, |
1393 | const dnnl_memory_desc_t *src_desc, dnnl_engine_t src_engine, |
1394 | const dnnl_memory_desc_t *dst_desc, dnnl_engine_t dst_engine, |
1395 | const_dnnl_primitive_attr_t attr); |
1396 | |
1397 | /// @} dnnl_api_reorder |
1398 | |
1399 | /// @addtogroup dnnl_api_concat |
1400 | /// @{ |
1401 | |
1402 | /// Creates a primitive descriptor for an out-of-place concatenation |
1403 | /// primitive. |
1404 | /// |
1405 | /// @param concat_primitive_desc Output primitive descriptor. |
1406 | /// @param dst_desc Destination memory descriptor. |
1407 | /// @param n Number of source parameters. |
1408 | /// @param concat_dimension Source tensors will be concatenated over |
1409 | /// dimension with this index. Note that order of dimensions does |
1410 | /// not depend on memory format. |
1411 | /// @param src_descs Array of source memory descriptors with @p n elements. |
1412 | /// @param attr Primitive attributes to use (can be NULL). |
1413 | /// @param engine Engine to use. |
1414 | /// @returns #dnnl_success on success and a status describing the error |
1415 | /// otherwise. |
1416 | dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create( |
1417 | dnnl_primitive_desc_t *concat_primitive_desc, |
1418 | const dnnl_memory_desc_t *dst_desc, int n, int concat_dimension, |
1419 | const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, |
1420 | dnnl_engine_t engine); |
1421 | |
1422 | /// @} dnnl_api_concat |
1423 | |
1424 | /// @addtogroup dnnl_api_sum |
1425 | /// @{ |
1426 | |
1427 | /// Creates a primitive descriptor for an (out-of-place) sum primitive. |
1428 | /// |
1429 | /// @param sum_primitive_desc Output primitive descriptor. |
1430 | /// @param dst_desc Destination memory descriptor. |
1431 | /// @param n Number of source parameters. |
1432 | /// @param scales Vector of scales to multiply data in each source |
1433 | /// memory by. |
1434 | /// @param src_descs Array of source memory descriptors having @p n elements. |
1435 | /// @param attr Primitive attributes to use (can be NULL). |
1436 | /// @param engine Engine to use. |
1437 | /// @returns #dnnl_success on success and a status describing the error |
1438 | /// otherwise. |
1439 | dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create( |
1440 | dnnl_primitive_desc_t *sum_primitive_desc, |
1441 | const dnnl_memory_desc_t *dst_desc, int n, const float *scales, |
1442 | const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, |
1443 | dnnl_engine_t engine); |
1444 | |
1445 | /// @} dnnl_api_sum |
1446 | |
1447 | /// @addtogroup dnnl_api_binary |
1448 | /// @{ |
1449 | |
1450 | /// Initializes a descriptor for a binary primitive. |
1451 | /// |
1452 | /// @note |
1453 | /// Memory descriptor @p dst_desc is allowed to be initialized with |
1454 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1455 | /// |
1456 | /// @note |
1457 | /// Both memory descriptors must have the same number of dimensions. |
1458 | /// Element broadcasting is supported for memory descriptor @p src1_desc |
1459 | /// and are applied to @ src1_desc dimensions that have size equal to 1. |
1460 | /// |
1461 | /// @param binary_desc Output descriptor for a binary primitive. |
1462 | /// @param alg_kind Algorithm kind. Valid values are #dnnl_binary_add, |
1463 | /// #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, #dnnl_binary_div, |
1464 | /// #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, #dnnl_binary_le, |
1465 | /// #dnnl_binary_lt, #dnnl_binary_eq and #dnnl_binary_ne. |
1466 | /// @param src0_desc Source 0 memory descriptor. |
1467 | /// @param src1_desc Source 1 memory descriptor. |
1468 | /// @param dst_desc Destination memory descriptor. |
1469 | /// @returns #dnnl_success on success and a status describing the error |
1470 | /// otherwise. |
1471 | dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc, |
1472 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc, |
1473 | const dnnl_memory_desc_t *src1_desc, |
1474 | const dnnl_memory_desc_t *dst_desc); |
1475 | |
1476 | /// @} dnnl_api_binary |
1477 | |
1478 | /// @addtogroup dnnl_api_convolution |
1479 | /// @{ |
1480 | |
1481 | /// Initializes a descriptor for a convolution forward propagation primitive. |
1482 | /// |
1483 | /// @note |
1484 | /// Memory descriptors can be initialized with |
1485 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1486 | /// |
1487 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values for |
1488 | /// spatial dimensions only and hence must have the same number of elements as |
1489 | /// there are spatial dimensions. The order of values is the same as in the |
1490 | /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
1491 | /// |
1492 | /// @param conv_desc Output descriptor for a convolution primitive. |
1493 | /// @param prop_kind Propagation kind. Possible values are |
1494 | /// #dnnl_forward_training and #dnnl_forward_inference. |
1495 | /// @param alg_kind Convolution algorithm. Possible values are |
1496 | /// #dnnl_convolution_direct, #dnnl_convolution_winograd, |
1497 | /// #dnnl_convolution_auto. |
1498 | /// @param src_desc Source memory descriptor. |
1499 | /// @param weights_desc Weights memory descriptor. |
1500 | /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory |
1501 | /// descriptor, or a memory descriptor with format_kind set to |
1502 | /// #dnnl_format_kind_undef disables the bias term. |
1503 | /// @param dst_desc Destination memory descriptor. |
1504 | /// @param strides Array of strides for spatial dimension. |
1505 | /// @param padding_l Array of padding values for low indices for each spatial |
1506 | /// dimension `([[front,] top,] left)`. |
1507 | /// @param padding_r Array of padding values for high indices for each spatial |
1508 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1509 | /// padding is assumed to be symmetrical. |
1510 | /// @returns #dnnl_success on success and a status describing the error |
1511 | /// otherwise. |
1512 | dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init( |
1513 | dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, |
1514 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, |
1515 | const dnnl_memory_desc_t *weights_desc, |
1516 | const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, |
1517 | const dnnl_dims_t strides, const dnnl_dims_t padding_l, |
1518 | const dnnl_dims_t padding_r); |
1519 | |
1520 | /// Initializes a descriptor for a dilated convolution forward propagation |
1521 | /// primitive. |
1522 | /// |
1523 | /// @note |
1524 | /// Memory descriptors can be initialized with |
1525 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1526 | /// |
1527 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain |
1528 | /// values for spatial dimensions only and hence must have the same number of |
1529 | /// elements as there are spatial dimensions. The order of values is the same |
1530 | /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), |
1531 | /// and width. |
1532 | /// |
1533 | /// @param conv_desc Output descriptor for a convolution primitive. |
1534 | /// @param prop_kind Propagation kind. Possible values are |
1535 | /// #dnnl_forward_training and #dnnl_forward_inference. |
1536 | /// @param alg_kind Convolution algorithm. Possible values are |
1537 | /// #dnnl_convolution_direct, #dnnl_convolution_winograd, |
1538 | /// #dnnl_convolution_auto. |
1539 | /// @param src_desc Source memory descriptor. |
1540 | /// @param weights_desc Weights memory descriptor. |
1541 | /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory |
1542 | /// descriptor, or a memory descriptor with format_kind set to |
1543 | /// #dnnl_format_kind_undef disables the bias term. |
1544 | /// @param dst_desc Destination memory descriptor. |
1545 | /// @param strides Array of strides for spatial dimension. |
1546 | /// @param dilates Array of dilations for spatial dimension. A zero value |
1547 | /// means no dilation in the corresponding dimension. |
1548 | /// @param padding_l Array of padding values for low indices for each spatial |
1549 | /// dimension `([[front,] top,] left)`. |
1550 | /// @param padding_r Array of padding values for high indices for each spatial |
1551 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1552 | /// padding is considered to be symmetrical. |
1553 | /// @returns #dnnl_success on success and a status describing the error |
1554 | /// otherwise. |
1555 | dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init( |
1556 | dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, |
1557 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, |
1558 | const dnnl_memory_desc_t *weights_desc, |
1559 | const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, |
1560 | const dnnl_dims_t strides, const dnnl_dims_t dilates, |
1561 | const dnnl_dims_t padding_l, const dnnl_dims_t padding_r); |
1562 | |
1563 | /// Initializes a descriptor for a convolution backward propagation primitive. |
1564 | /// |
1565 | /// @note |
1566 | /// Memory descriptors can be initialized with |
1567 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1568 | /// |
1569 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values for |
1570 | /// spatial dimensions only and hence must have the same number of elements as |
1571 | /// there are spatial dimensions. The order of values is the same as in the |
1572 | /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
1573 | /// |
1574 | /// @param conv_desc Output descriptor for a convolution primitive. |
1575 | /// @param alg_kind Convolution algorithm. Possible values are |
1576 | /// #dnnl_convolution_direct, #dnnl_convolution_winograd, |
1577 | /// #dnnl_convolution_auto. |
1578 | /// @param diff_src_desc Diff source memory descriptor. |
1579 | /// @param weights_desc Weights memory descriptor. |
1580 | /// @param diff_dst_desc Diff destination memory descriptor. |
1581 | /// @param strides Array of strides for spatial dimension. |
1582 | /// @param padding_l Array of padding values for low indices for each spatial |
1583 | /// dimension `([[front,] top,] left)`. |
1584 | /// @param padding_r Array of padding values for high indices for each spatial |
1585 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1586 | /// padding is assumed to be symmetrical. |
1587 | /// @returns #dnnl_success on success and a status describing the error |
1588 | /// otherwise. |
1589 | dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init( |
1590 | dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, |
1591 | const dnnl_memory_desc_t *diff_src_desc, |
1592 | const dnnl_memory_desc_t *weights_desc, |
1593 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
1594 | const dnnl_dims_t padding_l, const dnnl_dims_t padding_r); |
1595 | |
1596 | /// Initializes a descriptor for a dilated convolution backward propagation |
1597 | /// primitive. |
1598 | /// |
1599 | /// @note |
1600 | /// Memory descriptors can be initialized with |
1601 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1602 | /// |
1603 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain |
1604 | /// values for spatial dimensions only and hence must have the same number of |
1605 | /// elements as there are spatial dimensions. The order of values is the same |
1606 | /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), |
1607 | /// and width. |
1608 | /// |
1609 | /// @param conv_desc Output descriptor for a convolution primitive. |
1610 | /// @param alg_kind Convolution algorithm. Possible values are |
1611 | /// #dnnl_convolution_direct, #dnnl_convolution_winograd, |
1612 | /// #dnnl_convolution_auto. |
1613 | /// @param diff_src_desc Diff source memory descriptor. |
1614 | /// @param weights_desc Weights memory descriptor. |
1615 | /// @param diff_dst_desc Diff destination memory descriptor. |
1616 | /// @param strides Array of strides for spatial dimension. |
1617 | /// @param dilates Array of dilations for spatial dimension. A zero value |
1618 | /// means no dilation in the corresponding dimension. |
1619 | /// @param padding_l Array of padding values for low indices for each spatial |
1620 | /// dimension `([[front,] top,] left)`. |
1621 | /// @param padding_r Array of padding values for high indices for each spatial |
1622 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1623 | /// padding is considered to be symmetrical. |
1624 | /// @returns #dnnl_success on success and a status describing the error |
1625 | /// otherwise. |
1626 | dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init( |
1627 | dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, |
1628 | const dnnl_memory_desc_t *diff_src_desc, |
1629 | const dnnl_memory_desc_t *weights_desc, |
1630 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
1631 | const dnnl_dims_t dilates, const dnnl_dims_t padding_l, |
1632 | const dnnl_dims_t padding_r); |
1633 | |
1634 | /// Initializes a descriptor for a convolution weights gradient primitive. |
1635 | /// |
1636 | /// @note |
1637 | /// Memory descriptors can be initialized with |
1638 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1639 | /// |
1640 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values for |
1641 | /// spatial dimensions only and hence must have the same number of elements as |
1642 | /// there are spatial dimensions. The order of values is the same as in the |
1643 | /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
1644 | /// |
1645 | /// @param conv_desc Output descriptor for a convolution primitive. |
1646 | /// @param alg_kind Convolution algorithm. Possible values are |
1647 | /// #dnnl_convolution_direct, #dnnl_convolution_winograd, |
1648 | /// #dnnl_convolution_auto. |
1649 | /// @param src_desc Source memory descriptor. |
1650 | /// @param diff_weights_desc Diff weights memory descriptor. |
1651 | /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero |
1652 | /// memory descriptor, or a memory descriptor with format_kind set to |
1653 | /// #dnnl_format_kind_undef disables the bias term. |
1654 | /// @param diff_dst_desc Diff destination memory descriptor. |
1655 | /// @param strides Array of strides for spatial dimension. |
1656 | /// @param padding_l Array of padding values for low indices for each spatial |
1657 | /// dimension `([[front,] top,] left)`. |
1658 | /// @param padding_r Array of padding values for high indices for each spatial |
1659 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1660 | /// padding is considered to be symmetrical. |
1661 | /// @returns #dnnl_success on success and a status describing the error |
1662 | /// otherwise. |
1663 | dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init( |
1664 | dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, |
1665 | const dnnl_memory_desc_t *src_desc, |
1666 | const dnnl_memory_desc_t *diff_weights_desc, |
1667 | const dnnl_memory_desc_t *diff_bias_desc, |
1668 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
1669 | const dnnl_dims_t padding_l, const dnnl_dims_t padding_r); |
1670 | |
1671 | /// Initializes a descriptor for a dilated convolution weights gradient |
1672 | /// primitive. |
1673 | /// |
1674 | /// @note |
1675 | /// Memory descriptors can be initialized with |
1676 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1677 | /// |
1678 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain |
1679 | /// values for spatial dimensions only and hence must have the same number of |
1680 | /// elements as there are spatial dimensions. The order of values is the same |
1681 | /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), |
1682 | /// and width. |
1683 | /// |
1684 | /// @param conv_desc Output descriptor for a convolution primitive. |
1685 | /// @param alg_kind Convolution algorithm. Possible values are |
1686 | /// #dnnl_convolution_direct, #dnnl_convolution_winograd, |
1687 | /// #dnnl_convolution_auto. |
1688 | /// @param src_desc Source memory descriptor. |
1689 | /// @param diff_weights_desc Diff weights memory descriptor. |
1690 | /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero |
1691 | /// memory descriptor, or a memory descriptor with format_kind set to |
1692 | /// #dnnl_format_kind_undef disables the bias term. |
1693 | /// @param diff_dst_desc Diff destination memory descriptor. |
1694 | /// @param strides Array of strides for spatial dimension. |
1695 | /// @param dilates Array of dilations for spatial dimension. A zero value |
1696 | /// means no dilation in the corresponding dimension. |
1697 | /// @param padding_l Array of padding values for low indices for each spatial |
1698 | /// dimension `([[front,] top,] left)`. |
1699 | /// @param padding_r Array of padding values for high indices for each spatial |
1700 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1701 | /// padding is considered to be symmetrical. |
1702 | /// @returns #dnnl_success on success and a status describing the error |
1703 | /// otherwise. |
1704 | dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init( |
1705 | dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, |
1706 | const dnnl_memory_desc_t *src_desc, |
1707 | const dnnl_memory_desc_t *diff_weights_desc, |
1708 | const dnnl_memory_desc_t *diff_bias_desc, |
1709 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
1710 | const dnnl_dims_t dilates, const dnnl_dims_t padding_l, |
1711 | const dnnl_dims_t padding_r); |
1712 | |
1713 | /// @} dnnl_api_convolution |
1714 | |
1715 | /// @addtogroup dnnl_api_deconvolution |
1716 | /// @{ |
1717 | |
1718 | /// Initializes a descriptor for a deconvolution forward propagation primitive. |
1719 | /// |
1720 | /// @note |
1721 | /// Memory descriptors can be initialized with |
1722 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1723 | /// |
1724 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values for |
1725 | /// spatial dimensions only and hence must have the same number of elements as |
1726 | /// there are spatial dimensions. The order of values is the same as in the |
1727 | /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
1728 | /// |
1729 | /// @param deconv_desc Output descriptor for a deconvolution primitive. |
1730 | /// @param prop_kind Propagation kind. Possible values are |
1731 | /// #dnnl_forward_training and #dnnl_forward_inference. |
1732 | /// @param alg_kind Deconvolution algorithm. Possible values are |
1733 | /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. |
1734 | /// @param src_desc Source memory descriptor. |
1735 | /// @param weights_desc Weights memory descriptor. |
1736 | /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory |
1737 | /// descriptor, or a memory descriptor with format_kind set to |
1738 | /// #dnnl_format_kind_undef disables the bias term. |
1739 | /// @param dst_desc Destination memory descriptor. |
1740 | /// @param strides Array of strides for spatial dimension. |
1741 | /// @param padding_l Array of padding values for low indices for each spatial |
1742 | /// dimension `([[front,] top,] left)`. |
1743 | /// @param padding_r Array of padding values for high indices for each spatial |
1744 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1745 | /// padding is considered to be symmetrical. |
1746 | /// @returns #dnnl_success on success and a status describing the error |
1747 | /// otherwise. |
1748 | dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init( |
1749 | dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, |
1750 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, |
1751 | const dnnl_memory_desc_t *weights_desc, |
1752 | const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, |
1753 | const dnnl_dims_t strides, const dnnl_dims_t padding_l, |
1754 | const dnnl_dims_t padding_r); |
1755 | |
1756 | /// Initializes a descriptor for a dilated deconvolution forward propagation |
1757 | /// primitive. |
1758 | /// |
1759 | /// @note |
1760 | /// Memory descriptors can be initialized with |
1761 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1762 | /// |
1763 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain |
1764 | /// values for spatial dimensions only and hence must have the same number of |
1765 | /// elements as there are spatial dimensions. The order of values is the same |
1766 | /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), |
1767 | /// and width. |
1768 | /// |
1769 | /// @param deconv_desc Output descriptor for a deconvolution primitive. |
1770 | /// @param prop_kind Propagation kind. Possible values are |
1771 | /// #dnnl_forward_training and #dnnl_forward_inference. |
1772 | /// @param alg_kind Deconvolution algorithm. Possible values are |
1773 | /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. |
1774 | /// @param src_desc Source memory descriptor. |
1775 | /// @param weights_desc Weights memory descriptor. |
1776 | /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory |
1777 | /// descriptor, or a memory descriptor with format_kind set to |
1778 | /// #dnnl_format_kind_undef disables the bias term. |
1779 | /// @param dst_desc Destination memory descriptor. |
1780 | /// @param strides Array of strides for spatial dimension. |
1781 | /// @param dilates Array of dilations for spatial dimension. A zero value |
1782 | /// means no dilation in the corresponding dimension. |
1783 | /// @param padding_l Array of padding values for low indices for each spatial |
1784 | /// dimension `([[front,] top,] left)`. |
1785 | /// @param padding_r Array of padding values for high indices for each spatial |
1786 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1787 | /// padding is considered to be symmetrical. |
1788 | /// @returns #dnnl_success on success and a status describing the error |
1789 | /// otherwise. |
1790 | dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init( |
1791 | dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, |
1792 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, |
1793 | const dnnl_memory_desc_t *weights_desc, |
1794 | const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, |
1795 | const dnnl_dims_t strides, const dnnl_dims_t dilates, |
1796 | const dnnl_dims_t padding_l, const dnnl_dims_t padding_r); |
1797 | |
1798 | /// Initializes a descriptor for a deconvolution backward propagation primitive. |
1799 | /// |
1800 | /// @note |
1801 | /// Memory descriptors can be initialized with |
1802 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1803 | /// |
1804 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values for |
1805 | /// spatial dimensions only and hence must have the same number of elements as |
1806 | /// there are spatial dimensions. The order of values is the same as in the |
1807 | /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
1808 | /// |
1809 | /// @param deconv_desc Output descriptor for a deconvolution primitive. |
1810 | /// @param alg_kind Deconvolution algorithm. Possible values are |
1811 | /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. |
1812 | /// @param diff_src_desc Diff source memory descriptor. |
1813 | /// @param weights_desc Weights memory descriptor. |
1814 | /// @param diff_dst_desc Diff destination memory descriptor. |
1815 | /// @param strides Array of strides for spatial dimension. |
1816 | /// @param padding_l Array of padding values for low indices for each spatial |
1817 | /// dimension `([[front,] top,] left)`. |
1818 | /// @param padding_r Array of padding values for high indices for each spatial |
1819 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1820 | /// padding is considered to be symmetrical. |
1821 | /// @returns #dnnl_success on success and a status describing the error |
1822 | /// otherwise. |
1823 | dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init( |
1824 | dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, |
1825 | const dnnl_memory_desc_t *diff_src_desc, |
1826 | const dnnl_memory_desc_t *weights_desc, |
1827 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
1828 | const dnnl_dims_t padding_l, const dnnl_dims_t padding_r); |
1829 | |
1830 | /// Initializes a descriptor for a dilated deconvolution backward propagation |
1831 | /// primitive. |
1832 | /// |
1833 | /// @note |
1834 | /// Memory descriptors can be initialized with |
1835 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1836 | /// |
1837 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain |
1838 | /// values for spatial dimensions only and hence must have the same number of |
1839 | /// elements as there are spatial dimensions. The order of values is the same |
1840 | /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), |
1841 | /// and width. |
1842 | /// |
1843 | /// @param deconv_desc Output descriptor for a deconvolution primitive. |
1844 | /// @param alg_kind Deconvolution algorithm. Possible values are |
1845 | /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. |
1846 | /// @param diff_src_desc Diff source memory descriptor. |
1847 | /// @param weights_desc Weights memory descriptor. |
1848 | /// @param diff_dst_desc Diff destination memory descriptor. |
1849 | /// @param strides Array of strides for spatial dimension. |
1850 | /// @param dilates Array of dilations for spatial dimension. A zero value |
1851 | /// means no dilation in the corresponding dimension. |
1852 | /// @param padding_l Array of padding values for low indices for each spatial |
1853 | /// dimension `([[front,] top,] left)`. |
1854 | /// @param padding_r Array of padding values for high indices for each spatial |
1855 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1856 | /// padding is considered to be symmetrical. |
1857 | /// @returns #dnnl_success on success and a status describing the error |
1858 | /// otherwise. |
1859 | dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init( |
1860 | dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, |
1861 | const dnnl_memory_desc_t *diff_src_desc, |
1862 | const dnnl_memory_desc_t *weights_desc, |
1863 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
1864 | const dnnl_dims_t dilates, const dnnl_dims_t padding_l, |
1865 | const dnnl_dims_t padding_r); |
1866 | |
1867 | /// Initializes a descriptor for a deconvolution weights gradient primitive. |
1868 | /// |
1869 | /// @note |
1870 | /// Memory descriptors can be initialized with |
1871 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1872 | /// |
1873 | /// Arrays @p strides, @p padding_l, and @p padding_r contain values for |
1874 | /// spatial dimensions only and hence must have the same number of elements as |
1875 | /// there are spatial dimensions. The order of values is the same as in the |
1876 | /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
1877 | /// |
1878 | /// @param deconv_desc Output descriptor for a deconvolution primitive. |
1879 | /// @param alg_kind Deconvolution algorithm. Possible values are |
1880 | /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. |
1881 | /// @param src_desc Source memory descriptor. |
1882 | /// @param diff_weights_desc Diff weights memory descriptor. |
1883 | /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero |
1884 | /// memory descriptor, or a memory descriptor with format_kind set to |
1885 | /// #dnnl_format_kind_undef disables the bias term. |
1886 | /// @param diff_dst_desc Diff destination memory descriptor. |
1887 | /// @param strides Array of strides for spatial dimension. |
1888 | /// @param padding_l Array of padding values for low indices for each spatial |
1889 | /// dimension `([[front,] top,] left)`. |
1890 | /// @param padding_r Array of padding values for high indices for each spatial |
1891 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1892 | /// padding is considered to be symmetrical. |
1893 | /// @returns #dnnl_success on success and a status describing the error |
1894 | /// otherwise. |
1895 | dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init( |
1896 | dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, |
1897 | const dnnl_memory_desc_t *src_desc, |
1898 | const dnnl_memory_desc_t *diff_weights_desc, |
1899 | const dnnl_memory_desc_t *diff_bias_desc, |
1900 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
1901 | const dnnl_dims_t padding_l, const dnnl_dims_t padding_r); |
1902 | |
1903 | /// Initializes a descriptor for a dilated deconvolution weights gradient |
1904 | /// primitive. |
1905 | /// |
1906 | /// @note |
1907 | /// Memory descriptors can be initialized with |
1908 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
1909 | /// |
1910 | /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain |
1911 | /// values for spatial dimensions only and hence must have the same number of |
1912 | /// elements as there are spatial dimensions. The order of values is the same |
1913 | /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), |
1914 | /// and width. |
1915 | /// |
1916 | /// @param deconv_desc Output descriptor for a deconvolution primitive. |
1917 | /// @param alg_kind Deconvolution algorithm. Possible values are |
1918 | /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. |
1919 | /// @param src_desc Source memory descriptor. |
1920 | /// @param diff_weights_desc Diff weights memory descriptor. |
1921 | /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero |
1922 | /// memory descriptor, or a memory descriptor with format_kind set to |
1923 | /// #dnnl_format_kind_undef disables the bias term. |
1924 | /// @param diff_dst_desc Diff destination memory descriptor. |
1925 | /// @param strides Array of strides for spatial dimension. |
1926 | /// @param dilates Array of dilations for spatial dimension. A zero value |
1927 | /// means no dilation in the corresponding dimension. |
1928 | /// @param padding_l Array of padding values for low indices for each spatial |
1929 | /// dimension `([[front,] top,] left)`. |
1930 | /// @param padding_r Array of padding values for high indices for each spatial |
1931 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
1932 | /// padding is considered to be symmetrical. |
1933 | /// @returns #dnnl_success on success and a status describing the error |
1934 | /// otherwise. |
1935 | dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init( |
1936 | dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, |
1937 | const dnnl_memory_desc_t *src_desc, |
1938 | const dnnl_memory_desc_t *diff_weights_desc, |
1939 | const dnnl_memory_desc_t *diff_bias_desc, |
1940 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
1941 | const dnnl_dims_t dilates, const dnnl_dims_t padding_l, |
1942 | const dnnl_dims_t padding_r); |
1943 | |
1944 | /// @} dnnl_api_deconvolution |
1945 | |
1946 | /// @addtogroup dnnl_api_shuffle |
1947 | /// @{ |
1948 | |
1949 | /// Initializes a descriptor for shuffle forward propagation primitive. |
1950 | /// |
1951 | /// @param shuffle_desc Output descriptor for a shuffle primitive. |
1952 | /// @param prop_kind Propagation kind. Possible values are |
1953 | /// #dnnl_forward_training and #dnnl_forward_inference. |
1954 | /// @param data_desc Source and destination memory descriptor. |
1955 | /// @param axis The axis along which the data is shuffled. |
1956 | /// @param group_size Shuffle group size. |
1957 | /// @returns #dnnl_success on success and a status describing the error |
1958 | /// otherwise. |
1959 | dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init( |
1960 | dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind, |
1961 | const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size); |
1962 | |
1963 | /// Initializes a descriptor for shuffle backward propagation primitive. |
1964 | /// |
1965 | /// @param shuffle_desc Output descriptor for a shuffle primitive. |
1966 | /// @param diff_data_desc Diff source and diff destination memory descriptor. |
1967 | /// @param axis The axis along which the data is shuffled. |
1968 | /// @param group_size Shuffle group size. |
1969 | /// @returns #dnnl_success on success and a status describing the error |
1970 | /// otherwise. |
1971 | dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init( |
1972 | dnnl_shuffle_desc_t *shuffle_desc, |
1973 | const dnnl_memory_desc_t *diff_data_desc, int axis, |
1974 | dnnl_dim_t group_size); |
1975 | |
1976 | /// @} dnnl_api_shuffle |
1977 | |
1978 | /// @addtogroup dnnl_api_eltwise |
1979 | /// @{ |
1980 | |
1981 | /// Initializes a descriptor for eltwise forward propagation primitive. |
1982 | /// |
1983 | /// @param eltwise_desc Output descriptor for an eltwise primitive. |
1984 | /// @param prop_kind Propagation kind. Possible values are |
1985 | /// #dnnl_forward_training and #dnnl_forward_inference. |
1986 | /// @param alg_kind Elementwise algorithm kind. |
1987 | /// @param data_desc Source and destination memory descriptor. |
1988 | /// @param alpha The alpha parameter for the elementwise operation. Specific |
1989 | /// meaning depends on the algorithm. |
1990 | /// @param beta The beta parameter for the elementwise operation. Specific |
1991 | /// meaning depends on the algorithm. |
1992 | /// @returns #dnnl_success on success and a status describing the error |
1993 | /// otherwise. |
1994 | dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init( |
1995 | dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind, |
1996 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, |
1997 | float alpha, float beta); |
1998 | |
1999 | /// Initializes a descriptor for eltwise backward propagation primitive. |
2000 | /// |
2001 | /// @param eltwise_desc Output descriptor for an eltwise primitive. |
2002 | /// @param alg_kind Elementwise algorithm kind. |
2003 | /// @param diff_data_desc Diff source and diff destination memory descriptors. |
2004 | /// @param data_desc Source and destination memory descriptor. |
2005 | /// @param alpha The alpha parameter for the elementwise operation. Specific |
2006 | /// meaning depends on the algorithm. |
2007 | /// @param beta The beta parameter for the elementwise operation. Specific |
2008 | /// meaning depends on the algorithm. |
2009 | /// @returns #dnnl_success on success and a status describing the error |
2010 | /// otherwise. |
2011 | dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init( |
2012 | dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind, |
2013 | const dnnl_memory_desc_t *diff_data_desc, |
2014 | const dnnl_memory_desc_t *data_desc, float alpha, float beta); |
2015 | |
2016 | /// @} dnnl_api_eltwise |
2017 | |
2018 | /// @addtogroup dnnl_api_softmax |
2019 | /// @{ |
2020 | |
2021 | /// Initializes a descriptor for softmax forward propagation primitive. |
2022 | /// |
2023 | /// @param softmax_desc Output descriptor for a softmax primitive. |
2024 | /// @param prop_kind Propagation kind. Possible values are |
2025 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2026 | /// @param data_desc Source and destination memory descriptor. |
2027 | /// @param softmax_axis Axis over which softmax is computed. |
2028 | /// @returns #dnnl_success on success and a status describing the error |
2029 | /// otherwise. |
2030 | dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init( |
2031 | dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind, |
2032 | const dnnl_memory_desc_t *data_desc, int softmax_axis); |
2033 | |
2034 | /// Initializes a descriptor for softmax backward propagation primitive. |
2035 | /// |
2036 | /// @param softmax_desc Output descriptor for a softmax primitive. |
2037 | /// @param diff_data_desc Diff source and diff destination memory descriptors. |
2038 | /// @param data_desc Destination memory descriptor. |
2039 | /// @param softmax_axis Axis over which softmax is computed. |
2040 | /// @returns #dnnl_success on success and a status describing the error |
2041 | /// otherwise. |
2042 | dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init( |
2043 | dnnl_softmax_desc_t *softmax_desc, |
2044 | const dnnl_memory_desc_t *diff_data_desc, |
2045 | const dnnl_memory_desc_t *data_desc, int softmax_axis); |
2046 | |
2047 | /// @} dnnl_api_softmax |
2048 | |
2049 | /// @addtogroup dnnl_api_softmax_v2 |
2050 | /// @{ |
2051 | |
2052 | /// Initializes a descriptor for softmax v2 forward propagation primitive. |
2053 | /// |
2054 | /// @param softmax_desc Output descriptor for a softmax primitive. |
2055 | /// @param prop_kind Propagation kind. Possible values are |
2056 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2057 | /// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or |
2058 | /// #dnnl_softmax_log. |
2059 | /// @param src_desc Source memory descriptor. |
2060 | /// @param dst_desc Destination memory descriptor. |
2061 | /// @param softmax_axis Axis over which softmax is computed. |
2062 | /// @returns #dnnl_success on success and a status describing the error |
2063 | /// otherwise. |
2064 | dnnl_status_t DNNL_API dnnl_softmax_v2_forward_desc_init( |
2065 | dnnl_softmax_v2_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind, |
2066 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, |
2067 | const dnnl_memory_desc_t *dst_desc, int softmax_axis); |
2068 | |
2069 | /// Initializes a descriptor for softmax v2 backward propagation primitive. |
2070 | /// |
2071 | /// @param softmax_desc Output descriptor for a softmax primitive. |
2072 | /// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or |
2073 | /// #dnnl_softmax_log. |
2074 | /// @param diff_src_desc Diff source memory descriptor. |
2075 | /// @param diff_dst_desc Diff destination memory descriptor. |
2076 | /// @param dst_desc Destination memory descriptor. |
2077 | /// @param softmax_axis Axis over which softmax is computed. |
2078 | /// @returns #dnnl_success on success and a status describing the error |
2079 | /// otherwise. |
2080 | dnnl_status_t DNNL_API dnnl_softmax_v2_backward_desc_init( |
2081 | dnnl_softmax_v2_desc_t *softmax_desc, dnnl_alg_kind_t alg_kind, |
2082 | const dnnl_memory_desc_t *diff_src_desc, |
2083 | const dnnl_memory_desc_t *diff_dst_desc, |
2084 | const dnnl_memory_desc_t *dst_desc, int softmax_axis); |
2085 | |
2086 | /// @} dnnl_api_softmax_v2 |
2087 | |
2088 | /// @addtogroup dnnl_api_logsoftmax |
2089 | /// @{ |
2090 | |
2091 | /// Initializes a descriptor for logsoftmax forward propagation primitive. |
2092 | /// |
2093 | /// @param logsoftmax_desc Output descriptor for a logsoftmax primitive. |
2094 | /// @param prop_kind Propagation kind. Possible values are |
2095 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2096 | /// @param data_desc Source and destination memory descriptor. |
2097 | /// @param logsoftmax_axis Axis over which logsoftmax is computed. |
2098 | /// @returns #dnnl_success on success and a status describing the error |
2099 | /// otherwise. |
2100 | dnnl_status_t DNNL_API dnnl_logsoftmax_forward_desc_init( |
2101 | dnnl_logsoftmax_desc_t *logsoftmax_desc, dnnl_prop_kind_t prop_kind, |
2102 | const dnnl_memory_desc_t *data_desc, int logsoftmax_axis); |
2103 | |
2104 | /// Initializes a descriptor for logsoftmax backward propagation primitive. |
2105 | /// |
2106 | /// @param logsoftmax_desc Output descriptor for a logsoftmax primitive. |
2107 | /// @param diff_data_desc Diff source and diff destination memory descriptors. |
2108 | /// @param data_desc Destination memory descriptor. |
2109 | /// @param logsoftmax_axis Axis over which softmax is computed. |
2110 | /// @returns #dnnl_success on success and a status describing the error |
2111 | /// otherwise. |
2112 | dnnl_status_t DNNL_API dnnl_logsoftmax_backward_desc_init( |
2113 | dnnl_logsoftmax_desc_t *logsoftmax_desc, |
2114 | const dnnl_memory_desc_t *diff_data_desc, |
2115 | const dnnl_memory_desc_t *data_desc, int logsoftmax_axis); |
2116 | |
2117 | /// @} dnnl_api_logsoftmax |
2118 | |
2119 | /// @addtogroup dnnl_api_pooling |
2120 | /// @{ |
2121 | |
2122 | /// Initializes a descriptor for pooling forward propagation primitive. |
2123 | /// |
2124 | /// Arrays @p strides, @p kernel, @p padding_l, and @p padding_r contain values |
2125 | /// for spatial dimensions only and hence must have the same number of elements |
2126 | /// as there are spatial dimensions. The order of values is the same as in the |
2127 | /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
2128 | /// |
2129 | /// @param pool_desc Output descriptor for a pooling primitive. |
2130 | /// @param prop_kind Propagation kind. Possible values are |
2131 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2132 | /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max, |
2133 | /// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as |
2134 | /// #dnnl_pooling_avg_exclude_padding). |
2135 | /// @param src_desc Source memory descriptor. |
2136 | /// @param dst_desc Destination memory descriptor. |
2137 | /// @param strides Array of strides for spatial dimension. |
2138 | /// @param kernel Array of kernel spatial dimensions. |
2139 | /// @param padding_l Array of padding values for low indices for each spatial |
2140 | /// dimension `([[front,] top,] left)`. |
2141 | /// @param padding_r Array of padding values for high indices for each spatial |
2142 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
2143 | /// padding is considered to be symmetrical. |
2144 | /// @returns #dnnl_success on success and a status describing the error |
2145 | /// otherwise. |
2146 | dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init( |
2147 | dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, |
2148 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, |
2149 | const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, |
2150 | const dnnl_dims_t kernel, const dnnl_dims_t padding_l, |
2151 | const dnnl_dims_t padding_r); |
2152 | |
2153 | /// Initializes a descriptor for pooling backward propagation primitive. |
2154 | /// |
2155 | /// Arrays @p strides, @p kernel, @p padding_l, and @p padding_r contain values |
2156 | /// for spatial dimensions only and hence must have the same number of elements |
2157 | /// as there are spatial dimensions. The order of values is the same as in the |
2158 | /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width. |
2159 | /// |
2160 | /// @param pool_desc Output descriptor for a pooling primitive. |
2161 | /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max, |
2162 | /// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as |
2163 | /// #dnnl_pooling_avg_exclude_padding). |
2164 | /// @param diff_src_desc Diff source memory descriptor. |
2165 | /// @param diff_dst_desc Diff destination memory descriptor. |
2166 | /// @param strides Array of strides for spatial dimension. |
2167 | /// @param kernel Array of kernel spatial dimensions. |
2168 | /// @param padding_l Array of padding values for low indices for each spatial |
2169 | /// dimension `([[front,] top,] left)`. |
2170 | /// @param padding_r Array of padding values for high indices for each spatial |
2171 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
2172 | /// padding is considered to be symmetrical. |
2173 | /// @returns #dnnl_success on success and a status describing the error |
2174 | /// otherwise. |
2175 | dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init( |
2176 | dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, |
2177 | const dnnl_memory_desc_t *diff_src_desc, |
2178 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
2179 | const dnnl_dims_t kernel, const dnnl_dims_t padding_l, |
2180 | const dnnl_dims_t padding_r); |
2181 | |
2182 | /// @} dnnl_api_pooling |
2183 | |
2184 | /// @addtogroup dnnl_api_pooling_v2 |
2185 | /// @{ |
2186 | |
2187 | /// Initializes a descriptor for pooling v2 (pooling with dilation support) |
2188 | /// forward propagation primitive. |
2189 | /// |
2190 | /// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r |
2191 | /// contain values for spatial dimensions only and hence must have the same |
2192 | /// number of elements as there are spatial dimensions. The order of values |
2193 | /// is the same as in the tensor: depth (for 3D tensors), |
2194 | /// height (for 3D and 2D tensors), and width. |
2195 | /// |
2196 | /// @param pool_desc Output descriptor for a pooling primitive. |
2197 | /// @param prop_kind Propagation kind. Possible values are |
2198 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2199 | /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max, |
2200 | /// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as |
2201 | /// #dnnl_pooling_avg_exclude_padding). |
2202 | /// @param src_desc Source memory descriptor. |
2203 | /// @param dst_desc Destination memory descriptor. |
2204 | /// @param strides Array of strides for spatial dimension. |
2205 | /// @param kernel Array of kernel spatial dimensions. |
2206 | /// @param dilation Array of dilations for spatial dimension. |
2207 | /// @param padding_l Array of padding values for low indices for each spatial |
2208 | /// dimension `([[front,] top,] left)`. |
2209 | /// @param padding_r Array of padding values for high indices for each spatial |
2210 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
2211 | /// padding is considered to be symmetrical. |
2212 | /// @returns #dnnl_success on success and a status describing the error |
2213 | /// otherwise. |
2214 | dnnl_status_t DNNL_API dnnl_pooling_v2_forward_desc_init( |
2215 | dnnl_pooling_v2_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, |
2216 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, |
2217 | const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, |
2218 | const dnnl_dims_t kernel, const dnnl_dims_t dilation, |
2219 | const dnnl_dims_t padding_l, const dnnl_dims_t padding_r); |
2220 | |
2221 | /// Initializes a descriptor for pooling v2 (pooling with dilation support) |
2222 | /// backward propagation primitive. |
2223 | /// |
2224 | /// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r |
2225 | /// contain values for spatial dimensions only and hence must have the same |
2226 | /// number of elements as there are spatial dimensions. The order of values |
2227 | /// is the same as in the tensor: depth (for 3D tensors), |
2228 | /// height (for 3D and 2D tensors), and width. |
2229 | /// |
2230 | /// @param pool_desc Output descriptor for a pooling primitive. |
2231 | /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max, |
2232 | /// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as |
2233 | /// #dnnl_pooling_avg_exclude_padding). |
2234 | /// @param diff_src_desc Diff source memory descriptor. |
2235 | /// @param diff_dst_desc Diff destination memory descriptor. |
2236 | /// @param strides Array of strides for spatial dimension. |
2237 | /// @param kernel Array of kernel spatial dimensions. |
2238 | /// @param dilation Array of dilations for spatial dimension. |
2239 | /// @param padding_l Array of padding values for low indices for each spatial |
2240 | /// dimension `([[front,] top,] left)`. |
2241 | /// @param padding_r Array of padding values for high indices for each spatial |
2242 | /// dimension `([[back,] bottom,] right)`. Can be NULL in which case |
2243 | /// padding is considered to be symmetrical. |
2244 | /// @returns #dnnl_success on success and a status describing the error |
2245 | /// otherwise. |
2246 | dnnl_status_t DNNL_API dnnl_pooling_v2_backward_desc_init( |
2247 | dnnl_pooling_v2_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, |
2248 | const dnnl_memory_desc_t *diff_src_desc, |
2249 | const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, |
2250 | const dnnl_dims_t kernel, const dnnl_dims_t dilation, |
2251 | const dnnl_dims_t padding_l, const dnnl_dims_t padding_r); |
2252 | |
2253 | /// @} dnnl_api_pooling_v2 |
2254 | |
2255 | /// @addtogroup dnnl_api_prelu |
2256 | /// @{ |
2257 | |
2258 | /// Initializes a descriptor for PReLU |
2259 | /// (leaky ReLU with trainable alpha parameter) |
2260 | /// forward propagation primitive. |
2261 | /// |
2262 | /// @note |
2263 | /// weights descriptor is allowed to be initialized with |
2264 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
2265 | /// |
2266 | /// @param prelu_desc Output descriptor for a prelu primitive. |
2267 | /// @param prop_kind Propagation kind. Possible values are |
2268 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2269 | /// @param data_desc Source and destination memory descriptor. |
2270 | /// @param weights_desc Alpha parameters memory descriptor. |
2271 | /// @returns #dnnl_success on success and a status describing the error |
2272 | /// otherwise. |
2273 | dnnl_status_t DNNL_API dnnl_prelu_forward_desc_init( |
2274 | dnnl_prelu_desc_t *prelu_desc, dnnl_prop_kind_t prop_kind, |
2275 | const dnnl_memory_desc_t *data_desc, |
2276 | const dnnl_memory_desc_t *weights_desc); |
2277 | |
2278 | /// Initializes a descriptor for PReLU |
2279 | /// (leaky ReLU with trainable alpha parameter) |
2280 | /// backward propagation primitive. |
2281 | /// |
2282 | /// @note |
2283 | /// weights descriptor and diff_weights descriptor are allowed |
2284 | /// to be initialized with #dnnl_format_tag_any or with format_kind |
2285 | /// set to #dnnl_format_kind_any. |
2286 | /// |
2287 | /// @param prelu_desc Output descriptor for a prelu primitive. |
2288 | /// @param data_desc Source and destination memory descriptor. |
2289 | /// @param weights_desc Alpha parameters memory descriptor. |
2290 | /// @param diff_data_desc Diff source and destination memory descriptor. |
2291 | /// @param diff_weights_desc Diff alpha parameters memory descriptor. |
2292 | /// @returns #dnnl_success on success and a status describing the error |
2293 | /// otherwise. |
2294 | dnnl_status_t DNNL_API dnnl_prelu_backward_desc_init( |
2295 | dnnl_prelu_desc_t *prelu_desc, const dnnl_memory_desc_t *data_desc, |
2296 | const dnnl_memory_desc_t *weights_desc, |
2297 | const dnnl_memory_desc_t *diff_data_desc, |
2298 | const dnnl_memory_desc_t *diff_weights_desc); |
2299 | |
2300 | /// @} dnnl_api_prelu |
2301 | |
2302 | /// @addtogroup dnnl_api_lrn |
2303 | /// @{ |
2304 | |
2305 | /// Initializes a descriptor for LRN forward propagation primitive. |
2306 | /// |
2307 | /// @param lrn_desc Output descriptor for a LRN primitive. |
2308 | /// @param prop_kind Propagation kind. Possible values are |
2309 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2310 | /// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or |
2311 | /// #dnnl_lrn_within_channel. |
2312 | /// @param data_desc Source and destination memory descriptor. |
2313 | /// @param local_size Regularization local size. |
2314 | /// @param alpha The alpha regularization parameter. |
2315 | /// @param beta The beta regularization parameter. |
2316 | /// @param k The k regularization parameter. |
2317 | /// @returns #dnnl_success on success and a status describing the error |
2318 | /// otherwise. |
2319 | dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc, |
2320 | dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, |
2321 | const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, |
2322 | float beta, float k); |
2323 | |
2324 | /// Initializes a descriptor for LRN backward propagation primitive. |
2325 | /// |
2326 | /// @param lrn_desc Output descriptor for a LRN primitive. |
2327 | /// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or |
2328 | /// #dnnl_lrn_within_channel. |
2329 | /// @param diff_data_desc Diff source and diff destination memory descriptor. |
2330 | /// @param data_desc Source memory descriptor. |
2331 | /// @param local_size Regularization local size. |
2332 | /// @param alpha The alpha regularization parameter. |
2333 | /// @param beta The beta regularization parameter. |
2334 | /// @param k The k regularization parameter. |
2335 | /// @returns #dnnl_success on success and a status describing the error |
2336 | /// otherwise. |
2337 | dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc, |
2338 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, |
2339 | const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, |
2340 | float beta, float k); |
2341 | |
2342 | /// @} dnnl_api_lrn |
2343 | |
2344 | /// @addtogroup dnnl_api_batch_normalization |
2345 | /// @{ |
2346 | |
2347 | /// Initializes a descriptor for a batch normalization forward propagation |
2348 | /// primitive. |
2349 | /// |
2350 | /// @note |
2351 | /// In-place operation is supported: the dst can refer to the same memory |
2352 | /// as the src. |
2353 | /// |
2354 | /// @param bnrm_desc Output descriptor for batch normalization primitive. |
2355 | /// @param prop_kind Propagation kind. Possible values are |
2356 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2357 | /// @param data_desc Source and destination memory descriptor. |
2358 | /// @param epsilon Batch normalization epsilon parameter. |
2359 | /// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t). |
2360 | /// @returns #dnnl_success on success and a status describing the error |
2361 | /// otherwise. |
2362 | dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init( |
2363 | dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, |
2364 | const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags); |
2365 | |
2366 | /// Initializes a descriptor for a batch normalization backward propagation |
2367 | /// primitive. |
2368 | /// |
2369 | /// @note |
2370 | /// In-place operation is supported: the diff_dst can refer to the same |
2371 | /// memory as the diff_src. |
2372 | /// |
2373 | /// @param bnrm_desc Output descriptor for batch normalization primitive. |
2374 | /// @param prop_kind Propagation kind. Possible values are |
2375 | /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are |
2376 | /// computed in this case). |
2377 | /// @param diff_data_desc Diff source and diff destination memory descriptor. |
2378 | /// @param data_desc Source memory descriptor. |
2379 | /// @param epsilon Batch normalization epsilon parameter. |
2380 | /// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t). |
2381 | /// @returns #dnnl_success on success and a status describing the error |
2382 | /// otherwise. |
2383 | dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init( |
2384 | dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, |
2385 | const dnnl_memory_desc_t *diff_data_desc, |
2386 | const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags); |
2387 | |
2388 | /// @} dnnl_api_batch_normalization |
2389 | |
2390 | /// @addtogroup dnnl_api_layer_normalization |
2391 | /// @{ |
2392 | |
2393 | /// Initializes a descriptor for layer normalization forward propagation |
2394 | /// primitive. |
2395 | /// |
2396 | /// @note |
2397 | /// In-place operation is supported: the dst can refer to the same memory |
2398 | /// as the src. |
2399 | /// |
2400 | /// @param lnrm_desc Output descriptor for layer normalization primitive. |
2401 | /// @param prop_kind Propagation kind. Possible values are |
2402 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2403 | /// @param data_desc Source and destination memory descriptor. |
2404 | /// @param stat_desc Memory descriptor for mean and variance. If this |
2405 | /// parameter is NULL, a zero memory descriptor, or a memory descriptor |
2406 | /// with format_kind set to #dnnl_format_kind_undef, then the memory |
2407 | /// descriptor for stats is derived from @p data_desc by removing the last |
2408 | /// dimension. |
2409 | /// @param epsilon Layer normalization epsilon parameter. |
2410 | /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t). |
2411 | /// @returns #dnnl_success on success and a status describing the error |
2412 | /// otherwise. |
2413 | dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init( |
2414 | dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, |
2415 | const dnnl_memory_desc_t *data_desc, |
2416 | const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags); |
2417 | |
2418 | /// Initializes a descriptor for a layer normalization backward propagation |
2419 | /// primitive. |
2420 | /// |
2421 | /// @note |
2422 | /// In-place operation is supported: the diff_dst can refer to the same |
2423 | /// memory as the diff_src. |
2424 | /// |
2425 | /// @param lnrm_desc Output descriptor for layer normalization primitive. |
2426 | /// @param prop_kind Propagation kind. Possible values are |
2427 | /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are |
2428 | /// computed in this case). |
2429 | /// @param diff_data_desc Diff source and diff destination memory descriptor. |
2430 | /// @param data_desc Source memory descriptor. |
2431 | /// @param stat_desc Memory descriptor for mean and variance. If this |
2432 | /// parameter is NULL, a zero memory descriptor, or a memory descriptor |
2433 | /// with format_kind set to #dnnl_format_kind_undef, then the memory |
2434 | /// descriptor for stats is derived from @p data_desc by removing the last |
2435 | /// dimension. |
2436 | /// @param epsilon Layer normalization epsilon parameter. |
2437 | /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t). |
2438 | /// @returns #dnnl_success on success and a status describing the error |
2439 | /// otherwise. |
2440 | dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init( |
2441 | dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, |
2442 | const dnnl_memory_desc_t *diff_data_desc, |
2443 | const dnnl_memory_desc_t *data_desc, |
2444 | const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags); |
2445 | |
2446 | /// @} dnnl_api_layer_normalization |
2447 | |
2448 | /// @addtogroup dnnl_api_layer_normalization_v2 |
2449 | /// @{ |
2450 | |
2451 | /// Initializes a descriptor for layer normalization v2 forward propagation |
2452 | /// primitive. |
2453 | /// |
2454 | /// @note |
2455 | /// In-place operation is supported: the dst can refer to the same memory |
2456 | /// as the src. |
2457 | /// |
2458 | /// @param lnrm_desc Output descriptor for layer normalization primitive. |
2459 | /// @param prop_kind Propagation kind. Possible values are |
2460 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2461 | /// @param src_desc Source memory descriptor. |
2462 | /// @param dst_desc Destination memory descriptor. |
2463 | /// @param stat_desc Memory descriptor for mean and variance. If this |
2464 | /// parameter is NULL, a zero memory descriptor, or a memory descriptor |
2465 | /// with format_kind set to #dnnl_format_kind_undef, then the memory |
2466 | /// descriptor for stats is derived from @p data_desc by removing the last |
2467 | /// dimension. |
2468 | /// @param epsilon Layer normalization epsilon parameter. |
2469 | /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t). |
2470 | /// @returns #dnnl_success on success and a status describing the error |
2471 | /// otherwise. |
2472 | dnnl_status_t DNNL_API dnnl_layer_normalization_v2_forward_desc_init( |
2473 | dnnl_layer_normalization_v2_desc_t *lnrm_desc, |
2474 | dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc, |
2475 | const dnnl_memory_desc_t *dst_desc, const dnnl_memory_desc_t *stat_desc, |
2476 | float epsilon, unsigned flags); |
2477 | |
2478 | /// Initializes a descriptor for a layer normalization v2 backward propagation |
2479 | /// primitive. |
2480 | /// |
2481 | /// @note |
2482 | /// In-place operation is supported: the diff_dst can refer to the same |
2483 | /// memory as the diff_src. |
2484 | /// |
2485 | /// @param lnrm_desc Output descriptor for layer normalization primitive. |
2486 | /// @param prop_kind Propagation kind. Possible values are |
2487 | /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are |
2488 | /// computed in this case). |
2489 | /// @param diff_src_desc Diff source memory descriptor. |
2490 | /// @param diff_dst_desc Diff destination memory descriptor. |
2491 | /// @param src_desc Source memory descriptor. |
2492 | /// @param stat_desc Memory descriptor for mean and variance. If this |
2493 | /// parameter is NULL, a zero memory descriptor, or a memory descriptor |
2494 | /// with format_kind set to #dnnl_format_kind_undef, then the memory |
2495 | /// descriptor for stats is derived from @p data_desc by removing the last |
2496 | /// dimension. |
2497 | /// @param epsilon Layer normalization epsilon parameter. |
2498 | /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t). |
2499 | /// @returns #dnnl_success on success and a status describing the error |
2500 | /// otherwise. |
2501 | dnnl_status_t DNNL_API dnnl_layer_normalization_v2_backward_desc_init( |
2502 | dnnl_layer_normalization_v2_desc_t *lnrm_desc, |
2503 | dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_src_desc, |
2504 | const dnnl_memory_desc_t *diff_dst_desc, |
2505 | const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *stat_desc, |
2506 | float epsilon, unsigned flags); |
2507 | |
2508 | /// @} dnnl_api_layer_normalization_v2 |
2509 | |
2510 | /// @addtogroup dnnl_api_inner_product |
2511 | /// @{ |
2512 | |
2513 | /// Initializes descriptor for inner product forward propagation. |
2514 | /// |
2515 | /// @note |
2516 | /// Memory descriptors can be initialized with |
2517 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
2518 | /// |
2519 | /// @param ip_desc Output descriptor for inner product primitive. |
2520 | /// @param prop_kind Propagation kind. Possible values are |
2521 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2522 | /// @param src_desc Source memory descriptor. |
2523 | /// @param weights_desc Weights memory descriptor. |
2524 | /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory |
2525 | /// descriptor, or a memory descriptor with format_kind set to |
2526 | /// #dnnl_format_kind_undef disables the bias term. |
2527 | /// @param dst_desc Destination memory descriptor. |
2528 | /// @returns #dnnl_success on success and a status describing the error |
2529 | /// otherwise. |
2530 | dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init( |
2531 | dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind, |
2532 | const dnnl_memory_desc_t *src_desc, |
2533 | const dnnl_memory_desc_t *weights_desc, |
2534 | const dnnl_memory_desc_t *bias_desc, |
2535 | const dnnl_memory_desc_t *dst_desc); |
2536 | |
2537 | /// Initializes descriptor for inner product backward propagation. |
2538 | /// |
2539 | /// @note |
2540 | /// Memory descriptors can be initialized with |
2541 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
2542 | /// |
2543 | /// @param ip_desc Output descriptor for inner product primitive. |
2544 | /// @param diff_src_desc Diff source memory descriptor. |
2545 | /// @param weights_desc Weights memory descriptor. |
2546 | /// @param diff_dst_desc Diff destination memory descriptor. |
2547 | /// @returns #dnnl_success on success and a status describing the error |
2548 | /// otherwise. |
2549 | dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init( |
2550 | dnnl_inner_product_desc_t *ip_desc, |
2551 | const dnnl_memory_desc_t *diff_src_desc, |
2552 | const dnnl_memory_desc_t *weights_desc, |
2553 | const dnnl_memory_desc_t *diff_dst_desc); |
2554 | |
2555 | /// Initializes descriptor for inner product weights gradient primitive. |
2556 | /// |
2557 | /// @note |
2558 | /// Memory descriptors can be initialized with |
2559 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
2560 | /// |
2561 | /// @param ip_desc Output descriptor for inner product primitive. |
2562 | /// @param src_desc Source memory descriptor. |
2563 | /// @param diff_weights_desc Diff weights memory descriptor. |
2564 | /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero |
2565 | /// memory descriptor, or a memory descriptor with format_kind set to |
2566 | /// #dnnl_format_kind_undef disables the bias term. |
2567 | /// @param diff_dst_desc Diff destination memory descriptor. |
2568 | /// @returns #dnnl_success on success and a status describing the error |
2569 | /// otherwise. |
2570 | dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init( |
2571 | dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc, |
2572 | const dnnl_memory_desc_t *diff_weights_desc, |
2573 | const dnnl_memory_desc_t *diff_bias_desc, |
2574 | const dnnl_memory_desc_t *diff_dst_desc); |
2575 | |
2576 | /// @} dnnl_api_inner_product |
2577 | |
2578 | /// @addtogroup dnnl_api_attributes |
2579 | /// @{ |
2580 | |
2581 | /// Set quantization scale and shift parameters for RNN data tensors. |
2582 | /// |
2583 | /// For performance reasons, the low-precision configuration of the RNN |
2584 | /// primitives expects input activations to have the unsigned 8-bit integer |
2585 | /// data type. The scale and shift parameters are used to quantize |
2586 | /// floating-point data to unsigned integer and must be passed to the RNN |
2587 | /// primitive using attributes. |
2588 | /// |
2589 | /// The quantization formula is `scale * data + shift`. |
2590 | /// |
2591 | /// @note |
2592 | /// Quantization scale and shift are common for src_layer, src_iter, |
2593 | /// dst_iter, and dst_layer. |
2594 | /// |
2595 | /// Example usage: |
2596 | /// @code |
2597 | /// // RNN parameters |
2598 | /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; |
2599 | /// // Activations quantization parameters |
2600 | /// float scale = 63.f, shift = 64.f; |
2601 | /// |
2602 | /// dnnl_primitive_attr_t rnn_attr; |
2603 | /// // Create default attributes |
2604 | /// dnnl_primitive_attr_create(&rnn_attr); |
2605 | /// |
2606 | /// // Set scale and shift for int8 quantization of activation |
2607 | /// dnnl_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift); |
2608 | /// |
2609 | /// // Create and configure rnn op_desc |
2610 | /// dnnl_rnn_desc_t rnn_d; |
2611 | /// dnnl_primitive_desc_t rnn_pd; |
2612 | /// dnnl_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); |
2613 | /// @endcode |
2614 | /// |
2615 | /// @param attr Primitive attributes. |
2616 | /// @param scale The value to scale the data by. |
2617 | /// @param shift The value to shift the data by. |
2618 | /// @returns #dnnl_success on success and a status describing the error |
2619 | /// otherwise. |
2620 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams( |
2621 | dnnl_primitive_attr_t attr, const float scale, const float shift); |
2622 | |
2623 | /// Returns the quantization scale and shift parameters for RNN data tensors. |
2624 | /// |
2625 | /// @note |
2626 | /// Quantization scale and shift are common for src_layer, src_iter, |
2627 | /// dst_iter, and dst_layer. |
2628 | /// |
2629 | /// @param attr Primitive attributes. |
2630 | /// @param scale The value to scale the data by. |
2631 | /// @param shift The value to shift the data by. |
2632 | /// @returns #dnnl_success on success and a status describing the error |
2633 | /// otherwise. |
2634 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams( |
2635 | const_dnnl_primitive_attr_t attr, float *scale, float *shift); |
2636 | |
2637 | /// Sets quantization scaling factors for RNN weights tensors. The |
2638 | /// low-precision configuration of the RNN primitives expects input weights to |
2639 | /// use the signed 8-bit integer data type. The scaling factors are used to |
2640 | /// quantize floating-point data to signed integer and must be passed to RNN |
2641 | /// primitives using attributes. |
2642 | /// |
2643 | /// @note |
2644 | /// The dimension order is always native and does not depend on the actual |
2645 | /// layout used. For example, five-dimensional weights always have (l, d, |
2646 | /// i, g, o) logical dimension ordering. |
2647 | /// |
2648 | /// @note |
2649 | /// Quantization scales are common for weights_layer and weights_iteration |
2650 | /// |
2651 | /// @param attr Primitive attributes. |
2652 | /// @param count Number of elements in the @p scales array. |
2653 | /// @param mask Scaling factors correspondence mask that defines the |
2654 | /// correspondence between the output tensor dimensions and the @p |
2655 | /// scales vector. The set i-th bit indicates that a dedicated scaling |
2656 | /// factor should be used for each index along that dimension. Set the |
2657 | /// mask to 0 to use a common scaling factor for the whole output |
2658 | /// tensor. |
2659 | /// @param scales Array of output scaling factors that must contain @p count |
2660 | /// values and the following equality must hold: |
2661 | /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f] |
2662 | /// Violations can only be detected when the attributes are used to create |
2663 | /// a primitive descriptor. |
2664 | /// @returns #dnnl_success on success and a status describing the error |
2665 | /// otherwise. |
2666 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams( |
2667 | dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, |
2668 | const float *scales); |
2669 | |
2670 | /// Returns the quantization scaling factors for RNN weights tensors. |
2671 | /// |
2672 | /// @param attr Primitive attributes. |
2673 | /// @param count Number of elements in the @p scales array. |
2674 | /// @param mask Scaling factors correspondence mask that defines the |
2675 | /// correspondence between the output tensor dimensions and the @p |
2676 | /// scales vector. The set i-th bit indicates that a dedicated scaling |
2677 | /// factor should be used for each index along that dimension. Set the |
2678 | /// mask to 0 to use a common scaling factor for the whole output |
2679 | /// tensor. |
2680 | /// @param scales Array of output scaling factors that contain @p count |
2681 | /// values and the following equality must hold: |
2682 | /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f] |
2683 | /// @returns #dnnl_success on success and a status describing the error |
2684 | /// otherwise. |
2685 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams( |
2686 | const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, |
2687 | const float **scales); |
2688 | |
2689 | /// Sets quantization scaling factors for RNN projection weights tensors. The |
2690 | /// low-precision configuration of the RNN primitives expects input weights to |
2691 | /// use the signed 8-bit integer data type. The scaling factors are used to |
2692 | /// quantize floating-point data to signed integer and must be passed to RNN |
2693 | /// primitives using attributes. |
2694 | /// |
2695 | /// @note |
2696 | /// The dimension order is always native and does not depend on the actual |
2697 | /// layout used. For example, five-dimensional weights always have (l, d, |
2698 | /// i, g, o) logical dimension ordering. |
2699 | /// |
2700 | /// @param attr Primitive attributes. |
2701 | /// @param count Number of elements in the @p scales array. |
2702 | /// @param mask Scaling factors correspondence mask that defines the |
2703 | /// correspondence between the output tensor dimensions and the @p |
2704 | /// scales vector. The set i-th bit indicates that a dedicated scaling |
2705 | /// factor should be used for each index along that dimension. Set the |
2706 | /// mask to 0 to use a common scaling factor for the whole output |
2707 | /// tensor. |
2708 | /// @param scales Array of output scaling factors that must contain @p count |
2709 | /// values and the following equality must hold: |
2710 | /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f] |
2711 | /// Violations can only be detected when the attributes are used to create |
2712 | /// a primitive descriptor. |
2713 | /// @returns #dnnl_success on success and a status describing the error |
2714 | /// otherwise. |
2715 | dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams( |
2716 | dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, |
2717 | const float *scales); |
2718 | |
2719 | /// Returns the quantization scaling factors for RNN projection weights tensors. |
2720 | /// |
2721 | /// @param attr Primitive attributes. |
2722 | /// @param count Number of elements in the @p scales array. |
2723 | /// @param mask Scaling factors correspondence mask that defines the |
2724 | /// correspondence between the output tensor dimensions and the @p |
2725 | /// scales vector. The set i-th bit indicates that a dedicated scaling |
2726 | /// factor should be used for each index along that dimension. Set the |
2727 | /// mask to 0 to use a common scaling factor for the whole output |
2728 | /// tensor. |
2729 | /// @param scales Array of output scaling factors that contain @p count |
2730 | /// values and the following equality must hold: |
2731 | /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f] |
2732 | /// @returns #dnnl_success on success and a status describing the error |
2733 | /// otherwise. |
2734 | dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams( |
2735 | const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, |
2736 | const float **scales); |
2737 | |
2738 | /// @} dnnl_api_attributes |
2739 | |
2740 | /// @addtogroup dnnl_api_rnn |
2741 | /// @{ |
2742 | |
2743 | /// Initializes a descriptor for vanilla RNN forward propagation primitive. |
2744 | /// |
2745 | /// The following arguments may either be @c NULL or point to a zero memory |
2746 | /// descriptor: |
2747 | /// - @p src_iter_desc, |
2748 | /// - @p bias_desc, |
2749 | /// - @p dst_iter_desc. |
2750 | /// |
2751 | /// This would then indicate that the RNN forward propagation primitive should |
2752 | /// not use them and should default to zero values instead. |
2753 | /// |
2754 | /// @note |
2755 | /// All memory descriptors can be initialized with |
2756 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
2757 | /// |
2758 | /// @param rnn_desc Output descriptor for vanilla RNN primitive. |
2759 | /// @param prop_kind Propagation kind. Possible values are |
2760 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2761 | /// @param activation Activation kind. Possible values are #dnnl_eltwise_relu, |
2762 | /// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic. |
2763 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
2764 | /// info. |
2765 | /// @param src_layer_desc Memory descriptor for the input vector. |
2766 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
2767 | /// state vector. |
2768 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
2769 | /// layer input. |
2770 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
2771 | /// recurrent input. |
2772 | /// @param bias_desc Bias memory descriptor. |
2773 | /// @param dst_layer_desc Memory descriptor for the output vector. |
2774 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
2775 | /// state vector. |
2776 | /// @param flags Unused. |
2777 | /// @param alpha Negative slope if activation is #dnnl_eltwise_relu. |
2778 | /// @param beta Unused. |
2779 | /// @returns #dnnl_success on success and a status describing the error |
2780 | /// otherwise. |
2781 | dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init( |
2782 | dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, |
2783 | const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, |
2784 | const dnnl_memory_desc_t *src_layer_desc, |
2785 | const dnnl_memory_desc_t *src_iter_desc, |
2786 | const dnnl_memory_desc_t *weights_layer_desc, |
2787 | const dnnl_memory_desc_t *weights_iter_desc, |
2788 | const dnnl_memory_desc_t *bias_desc, |
2789 | const dnnl_memory_desc_t *dst_layer_desc, |
2790 | const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha, |
2791 | float beta); |
2792 | |
2793 | /// Initializes a descriptor for vanilla RNN backward propagation primitive. |
2794 | /// |
2795 | /// The following arguments may either be @c NULL or point to a zero memory |
2796 | /// descriptor: |
2797 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
2798 | /// - @p bias_desc together with @p diff_bias_desc, |
2799 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
2800 | /// |
2801 | /// This would then indicate that the RNN backward propagation primitive should |
2802 | /// not use the respective data and should use zero values instead. |
2803 | /// |
2804 | /// @note |
2805 | /// All memory descriptors can be initialized with |
2806 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
2807 | /// |
2808 | /// @param rnn_desc Output descriptor for vanilla RNN primitive. |
2809 | /// @param prop_kind Propagation kind. Must be #dnnl_backward. |
2810 | /// @param activation Activation kind. Possible values are #dnnl_eltwise_relu, |
2811 | /// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic. |
2812 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
2813 | /// info. |
2814 | /// @param src_layer_desc Memory descriptor for the input vector. |
2815 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
2816 | /// state vector. |
2817 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
2818 | /// layer input. |
2819 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
2820 | /// recurrent input. |
2821 | /// @param bias_desc Bias memory descriptor. |
2822 | /// @param dst_layer_desc Memory descriptor for the output vector. |
2823 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
2824 | /// state vector. |
2825 | /// @param diff_src_layer_desc Memory descriptor for the diff of input vector. |
2826 | /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent |
2827 | /// hidden state vector. |
2828 | /// @param diff_weights_layer_desc Memory descriptor for the diff of weights |
2829 | /// applied to the layer input. |
2830 | /// @param diff_weights_iter_desc Memory descriptor for the diff of weights |
2831 | /// applied to the recurrent input. |
2832 | /// @param diff_bias_desc Diff bias memory descriptor. |
2833 | /// @param diff_dst_layer_desc Memory descriptor for the diff of output |
2834 | /// vector. |
2835 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
2836 | /// recurrent hidden state vector. |
2837 | /// @param flags Unused. |
2838 | /// @param alpha Negative slope if activation is #dnnl_eltwise_relu. |
2839 | /// @param beta Unused. |
2840 | /// @returns #dnnl_success on success and a status describing the error |
2841 | /// otherwise. |
2842 | dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init( |
2843 | dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, |
2844 | const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, |
2845 | const dnnl_memory_desc_t *src_layer_desc, |
2846 | const dnnl_memory_desc_t *src_iter_desc, |
2847 | const dnnl_memory_desc_t *weights_layer_desc, |
2848 | const dnnl_memory_desc_t *weights_iter_desc, |
2849 | const dnnl_memory_desc_t *bias_desc, |
2850 | const dnnl_memory_desc_t *dst_layer_desc, |
2851 | const dnnl_memory_desc_t *dst_iter_desc, |
2852 | const dnnl_memory_desc_t *diff_src_layer_desc, |
2853 | const dnnl_memory_desc_t *diff_src_iter_desc, |
2854 | const dnnl_memory_desc_t *diff_weights_layer_desc, |
2855 | const dnnl_memory_desc_t *diff_weights_iter_desc, |
2856 | const dnnl_memory_desc_t *diff_bias_desc, |
2857 | const dnnl_memory_desc_t *diff_dst_layer_desc, |
2858 | const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags, |
2859 | float alpha, float beta); |
2860 | |
2861 | /// Initializes a descriptor for LSTM forward propagation primitive. |
2862 | /// |
2863 | /// The following arguments may either be @c NULL or point to a zero memory |
2864 | /// descriptor: |
2865 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
2866 | /// - @p bias_desc, |
2867 | /// - @p dst_iter_desc together with @p dst_iter_c_desc. |
2868 | /// |
2869 | /// This would then indicate that the LSTM forward propagation primitive should |
2870 | /// not use them and should default to zero values instead. |
2871 | /// |
2872 | /// @note |
2873 | /// All memory descriptors can be initialized with |
2874 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
2875 | /// |
2876 | /// @sa dnnl_lstm_forward_desc_init_v2 to initialize forward LSTM with and |
2877 | /// without peephole |
2878 | /// @sa dnnl_lstm_forward_desc_init_v3 to initialize forward LSTM with and |
2879 | /// without peephole / recurrent projection layer |
2880 | /// |
2881 | /// @param rnn_desc Output descriptor for LSTM primitive. |
2882 | /// @param prop_kind Propagation kind. Possible values are |
2883 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2884 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
2885 | /// info. |
2886 | /// @param src_layer_desc Memory descriptor for the input vector. |
2887 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
2888 | /// state vector. |
2889 | /// @param src_iter_c_desc Memory descriptor for the input recurrent cell |
2890 | /// state vector. |
2891 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
2892 | /// layer input. |
2893 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
2894 | /// recurrent input. |
2895 | /// @param bias_desc Bias memory descriptor. |
2896 | /// @param dst_layer_desc Memory descriptor for the output vector. |
2897 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
2898 | /// state vector. |
2899 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell |
2900 | /// state vector. |
2901 | /// @param flags Unused. |
2902 | /// @returns #dnnl_success on success and a status describing the error |
2903 | /// otherwise. |
2904 | dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, |
2905 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
2906 | const dnnl_memory_desc_t *src_layer_desc, |
2907 | const dnnl_memory_desc_t *src_iter_desc, |
2908 | const dnnl_memory_desc_t *src_iter_c_desc, |
2909 | const dnnl_memory_desc_t *weights_layer_desc, |
2910 | const dnnl_memory_desc_t *weights_iter_desc, |
2911 | const dnnl_memory_desc_t *bias_desc, |
2912 | const dnnl_memory_desc_t *dst_layer_desc, |
2913 | const dnnl_memory_desc_t *dst_iter_desc, |
2914 | const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags); |
2915 | |
2916 | /// Initializes a descriptor for an LSTM (with or without peephole) forward |
2917 | /// propagation primitive. |
2918 | /// |
2919 | /// The following arguments may either be @c NULL or point to a zero memory |
2920 | /// descriptor: |
2921 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
2922 | /// - @p weights_peephole_desc, |
2923 | /// - @p bias_desc, |
2924 | /// - @p dst_iter_desc together with @p dst_iter_c_desc. |
2925 | /// |
2926 | /// This would then indicate that the LSTM forward propagation primitive should |
2927 | /// not use them and should default to zero values instead. |
2928 | /// |
2929 | /// @note |
2930 | /// All memory descriptors can be initialized with #dnnl_format_tag_any or |
2931 | /// with format_kind set to #dnnl_format_kind_any. |
2932 | /// |
2933 | /// @sa dnnl_lstm_forward_desc_init_v3 to initialize forward LSTM with and |
2934 | /// without peephole / recurrent projection layer |
2935 | /// |
2936 | /// @param rnn_desc Output descriptor for LSTM primitive. |
2937 | /// @param prop_kind Propagation kind. Possible values are |
2938 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2939 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
2940 | /// info. |
2941 | /// @param src_layer_desc Memory descriptor for the input vector. |
2942 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
2943 | /// state vector. |
2944 | /// @param src_iter_c_desc Memory descriptor for the input recurrent cell |
2945 | /// state vector. |
2946 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
2947 | /// layer input. |
2948 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
2949 | /// recurrent input. |
2950 | /// @param weights_peephole_desc Memory descriptor for the weights applied to |
2951 | /// the cell states (according to the Peephole LSTM formula). |
2952 | /// @param bias_desc Bias memory descriptor. |
2953 | /// @param dst_layer_desc Memory descriptor for the output vector. |
2954 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
2955 | /// state vector. |
2956 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell |
2957 | /// state vector. |
2958 | /// @param flags Unused. |
2959 | /// @returns #dnnl_success on success and a status describing the error |
2960 | /// otherwise. |
2961 | dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, |
2962 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
2963 | const dnnl_memory_desc_t *src_layer_desc, |
2964 | const dnnl_memory_desc_t *src_iter_desc, |
2965 | const dnnl_memory_desc_t *src_iter_c_desc, |
2966 | const dnnl_memory_desc_t *weights_layer_desc, |
2967 | const dnnl_memory_desc_t *weights_iter_desc, |
2968 | const dnnl_memory_desc_t *weights_peephole_desc, |
2969 | const dnnl_memory_desc_t *bias_desc, |
2970 | const dnnl_memory_desc_t *dst_layer_desc, |
2971 | const dnnl_memory_desc_t *dst_iter_desc, |
2972 | const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags); |
2973 | |
2974 | /// Initializes a descriptor for an LSTM (with or without peephole and with |
2975 | /// or without recurrent projection layer) forward propagation primitive. |
2976 | /// |
2977 | /// The following arguments may either be @c NULL or point to a zero memory |
2978 | /// descriptor: |
2979 | /// - @p src_iter_desc together with @p src_iter_c_desc, |
2980 | /// - @p weights_peephole_desc, |
2981 | /// - @p bias_desc, |
2982 | /// - @p dst_iter_desc together with @p dst_iter_c_desc. |
2983 | /// |
2984 | /// This would then indicate that the LSTM forward propagation primitive should |
2985 | /// not use them and should default to zero values instead. |
2986 | /// |
2987 | /// The @p weights_projection_desc could either be @c NULL or point to a zero |
2988 | /// memory descriptor. This would then indicate that the LSTM doesn't have |
2989 | /// recurrent projection layer. |
2990 | /// |
2991 | /// @note |
2992 | /// All memory descriptors can be initialized with #dnnl_format_tag_any or |
2993 | /// with format_kind set to #dnnl_format_kind_any. |
2994 | /// |
2995 | /// @param rnn_desc Output descriptor for LSTM primitive. |
2996 | /// @param prop_kind Propagation kind. Possible values are |
2997 | /// #dnnl_forward_training and #dnnl_forward_inference. |
2998 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
2999 | /// info. |
3000 | /// @param src_layer_desc Memory descriptor for the input vector. |
3001 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3002 | /// state vector. |
3003 | /// @param src_iter_c_desc Memory descriptor for the input recurrent cell |
3004 | /// state vector. |
3005 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3006 | /// layer input. |
3007 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3008 | /// recurrent input. |
3009 | /// @param weights_peephole_desc Memory descriptor for the weights applied to |
3010 | /// the cell states (according to the Peephole LSTM formula). |
3011 | /// @param weights_projection_desc Memory descriptor for the weights applied to |
3012 | /// the hidden states to get the recurrent projection (according to the |
3013 | /// Projection LSTM formula). |
3014 | /// @param bias_desc Bias memory descriptor. |
3015 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3016 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3017 | /// state vector. |
3018 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell |
3019 | /// state vector. |
3020 | /// @param flags Unused. |
3021 | /// @returns #dnnl_success on success and a status describing the error |
3022 | /// otherwise. |
3023 | dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, |
3024 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
3025 | const dnnl_memory_desc_t *src_layer_desc, |
3026 | const dnnl_memory_desc_t *src_iter_desc, |
3027 | const dnnl_memory_desc_t *src_iter_c_desc, |
3028 | const dnnl_memory_desc_t *weights_layer_desc, |
3029 | const dnnl_memory_desc_t *weights_iter_desc, |
3030 | const dnnl_memory_desc_t *weights_peephole_desc, |
3031 | const dnnl_memory_desc_t *weights_projection_desc, |
3032 | const dnnl_memory_desc_t *bias_desc, |
3033 | const dnnl_memory_desc_t *dst_layer_desc, |
3034 | const dnnl_memory_desc_t *dst_iter_desc, |
3035 | const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags); |
3036 | |
3037 | /// Initializes a descriptor for an LSTM backward propagation primitive. |
3038 | /// |
3039 | /// The following arguments may either be @c NULL or point to a zero memory |
3040 | /// descriptor: |
3041 | /// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc, |
3042 | /// and @p diff_src_iter_c_desc, |
3043 | /// - @p bias_desc together with @p diff_bias_desc, |
3044 | /// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc, |
3045 | /// and @p diff_dst_iter_c_desc. |
3046 | /// |
3047 | /// This would then indicate that the LSTM backward propagation primitive |
3048 | /// should not use them and should default to zero values instead. |
3049 | /// |
3050 | /// @note |
3051 | /// All memory descriptors can be initialized with |
3052 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3053 | /// |
3054 | /// @sa dnnl_lstm_backward_desc_init_v2 to initialize backward LSTM with and |
3055 | /// without peephole |
3056 | /// @sa dnnl_lstm_backward_desc_init_v3 to initialize backward LSTM with and |
3057 | /// without peephole / recurrent projection layer |
3058 | /// |
3059 | /// @param rnn_desc Output descriptor for LSTM primitive. |
3060 | /// @param prop_kind Propagation kind. Must be #dnnl_backward. |
3061 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3062 | /// info. |
3063 | /// @param src_layer_desc Memory descriptor for the input vector. |
3064 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3065 | /// state vector. |
3066 | /// @param src_iter_c_desc Memory descriptor for the input recurrent cell |
3067 | /// state vector. |
3068 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3069 | /// layer input. |
3070 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3071 | /// recurrent input. |
3072 | /// @param bias_desc Bias memory descriptor. |
3073 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3074 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3075 | /// state vector. |
3076 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell |
3077 | /// state vector. |
3078 | /// @param diff_src_layer_desc Memory descriptor for the diff of input vector. |
3079 | /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent |
3080 | /// hidden state vector. |
3081 | /// @param diff_src_iter_c_desc Memory descriptor for the diff of input |
3082 | /// recurrent cell state vector. |
3083 | /// @param diff_weights_layer_desc Memory descriptor for the diff of weights |
3084 | /// applied to the layer input. |
3085 | /// @param diff_weights_iter_desc Memory descriptor for the diff of weights |
3086 | /// applied to the recurrent input. |
3087 | /// @param diff_bias_desc Diff bias memory descriptor. |
3088 | /// @param diff_dst_layer_desc Memory descriptor for the diff of output |
3089 | /// vector. |
3090 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
3091 | /// recurrent hidden state vector. |
3092 | /// @param diff_dst_iter_c_desc Memory descriptor for the diff of output |
3093 | /// recurrent cell state vector. |
3094 | /// @param flags Unused. |
3095 | /// @returns #dnnl_success on success and a status describing the error |
3096 | /// otherwise. |
3097 | dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, |
3098 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
3099 | const dnnl_memory_desc_t *src_layer_desc, |
3100 | const dnnl_memory_desc_t *src_iter_desc, |
3101 | const dnnl_memory_desc_t *src_iter_c_desc, |
3102 | const dnnl_memory_desc_t *weights_layer_desc, |
3103 | const dnnl_memory_desc_t *weights_iter_desc, |
3104 | const dnnl_memory_desc_t *bias_desc, |
3105 | const dnnl_memory_desc_t *dst_layer_desc, |
3106 | const dnnl_memory_desc_t *dst_iter_desc, |
3107 | const dnnl_memory_desc_t *dst_iter_c_desc, |
3108 | const dnnl_memory_desc_t *diff_src_layer_desc, |
3109 | const dnnl_memory_desc_t *diff_src_iter_desc, |
3110 | const dnnl_memory_desc_t *diff_src_iter_c_desc, |
3111 | const dnnl_memory_desc_t *diff_weights_layer_desc, |
3112 | const dnnl_memory_desc_t *diff_weights_iter_desc, |
3113 | const dnnl_memory_desc_t *diff_bias_desc, |
3114 | const dnnl_memory_desc_t *diff_dst_layer_desc, |
3115 | const dnnl_memory_desc_t *diff_dst_iter_desc, |
3116 | const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags); |
3117 | |
3118 | /// Initializes a descriptor for an LSTM (with or without peephole) backward |
3119 | /// propagation primitive. |
3120 | /// |
3121 | /// The following arguments may either be @c NULL or point to a zero memory |
3122 | /// descriptor: |
3123 | /// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc, |
3124 | /// and @p diff_src_iter_c_desc, |
3125 | /// - @p weights_peephole_desc together with @p diff_weights_peephole_desc, |
3126 | /// - @p bias_desc together with @p diff_bias_desc, |
3127 | /// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc, |
3128 | /// and @p diff_dst_iter_c_desc. |
3129 | /// |
3130 | /// This would then indicate that the LSTM backward propagation primitive |
3131 | /// should not use them and should default to zero values instead. |
3132 | /// |
3133 | /// @note |
3134 | /// All memory descriptors can be initialized with #dnnl_format_tag_any or |
3135 | /// with format_kind set to #dnnl_format_kind_any. |
3136 | /// |
3137 | /// @sa dnnl_lstm_backward_desc_init_v3 to initialize backward LSTM with and |
3138 | /// without peephole / recurrent projection layer |
3139 | /// |
3140 | /// @param rnn_desc Output descriptor for LSTM primitive. |
3141 | /// @param prop_kind Propagation kind. Must be #dnnl_backward. |
3142 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3143 | /// info. |
3144 | /// @param src_layer_desc Memory descriptor for the input vector. |
3145 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3146 | /// state vector. |
3147 | /// @param src_iter_c_desc Memory descriptor for the input recurrent cell |
3148 | /// state vector. |
3149 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3150 | /// layer input. |
3151 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3152 | /// recurrent input. |
3153 | /// @param weights_peephole_desc Memory descriptor for the weights applied to |
3154 | /// the cell states (according to the Peephole LSTM formula). |
3155 | /// @param bias_desc Bias memory descriptor. |
3156 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3157 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3158 | /// state vector. |
3159 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell |
3160 | /// state vector. |
3161 | /// @param diff_src_layer_desc Memory descriptor for the diff of input vector. |
3162 | /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent |
3163 | /// hidden state vector. |
3164 | /// @param diff_src_iter_c_desc Memory descriptor for the diff of input |
3165 | /// recurrent cell state vector. |
3166 | /// @param diff_weights_layer_desc Memory descriptor for the diff of weights |
3167 | /// applied to the layer input. |
3168 | /// @param diff_weights_iter_desc Memory descriptor for the diff of weights |
3169 | /// applied to the recurrent input. |
3170 | /// @param diff_weights_peephole_desc Memory descriptor for the diff of weights |
3171 | /// applied to the cell states (according to the Peephole LSTM formula). |
3172 | /// @param diff_bias_desc Diff bias memory descriptor. |
3173 | /// @param diff_dst_layer_desc Memory descriptor for the diff of output |
3174 | /// vector. |
3175 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
3176 | /// recurrent hidden state vector. |
3177 | /// @param diff_dst_iter_c_desc Memory descriptor for the diff of output |
3178 | /// recurrent cell state vector. |
3179 | /// @param flags Unused. |
3180 | /// @returns #dnnl_success on success and a status describing the error |
3181 | /// otherwise. |
3182 | dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v2( |
3183 | dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, |
3184 | dnnl_rnn_direction_t direction, |
3185 | const dnnl_memory_desc_t *src_layer_desc, |
3186 | const dnnl_memory_desc_t *src_iter_desc, |
3187 | const dnnl_memory_desc_t *src_iter_c_desc, |
3188 | const dnnl_memory_desc_t *weights_layer_desc, |
3189 | const dnnl_memory_desc_t *weights_iter_desc, |
3190 | const dnnl_memory_desc_t *weights_peephole_desc, |
3191 | const dnnl_memory_desc_t *bias_desc, |
3192 | const dnnl_memory_desc_t *dst_layer_desc, |
3193 | const dnnl_memory_desc_t *dst_iter_desc, |
3194 | const dnnl_memory_desc_t *dst_iter_c_desc, |
3195 | const dnnl_memory_desc_t *diff_src_layer_desc, |
3196 | const dnnl_memory_desc_t *diff_src_iter_desc, |
3197 | const dnnl_memory_desc_t *diff_src_iter_c_desc, |
3198 | const dnnl_memory_desc_t *diff_weights_layer_desc, |
3199 | const dnnl_memory_desc_t *diff_weights_iter_desc, |
3200 | const dnnl_memory_desc_t *diff_weights_peephole_desc, |
3201 | const dnnl_memory_desc_t *diff_bias_desc, |
3202 | const dnnl_memory_desc_t *diff_dst_layer_desc, |
3203 | const dnnl_memory_desc_t *diff_dst_iter_desc, |
3204 | const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags); |
3205 | |
3206 | /// Initializes a descriptor for an LSTM (with or without peephole and with or |
3207 | /// with out recurrent projection layer) backward propagation primitive. |
3208 | /// |
3209 | /// The following arguments may either be @c NULL or point to a zero memory |
3210 | /// descriptor: |
3211 | /// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc, |
3212 | /// and @p diff_src_iter_c_desc, |
3213 | /// - @p weights_peephole_desc together with @p diff_weights_peephole_desc, |
3214 | /// - @p bias_desc together with @p diff_bias_desc, |
3215 | /// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc, |
3216 | /// and @p diff_dst_iter_c_desc. |
3217 | /// |
3218 | /// This would then indicate that the LSTM backward propagation primitive |
3219 | /// should not use them and should default to zero values instead. |
3220 | /// |
3221 | /// The @p weights_projection_desc together with @p |
3222 | /// diff_weights_projection_desc could either be @c NULL or point to a zero |
3223 | /// memory descriptor. This would then indicate that the LSTM doesn't have |
3224 | /// recurrent projection layer. |
3225 | /// |
3226 | /// @note |
3227 | /// All memory descriptors can be initialized with #dnnl_format_tag_any or |
3228 | /// with format_kind set to #dnnl_format_kind_any. |
3229 | /// |
3230 | /// @param rnn_desc Output descriptor for LSTM primitive. |
3231 | /// @param prop_kind Propagation kind. Must be #dnnl_backward. |
3232 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3233 | /// info. |
3234 | /// @param src_layer_desc Memory descriptor for the input vector. |
3235 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3236 | /// state vector. |
3237 | /// @param src_iter_c_desc Memory descriptor for the input recurrent cell |
3238 | /// state vector. |
3239 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3240 | /// layer input. |
3241 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3242 | /// recurrent input. |
3243 | /// @param weights_peephole_desc Memory descriptor for the weights applied to |
3244 | /// the cell states (according to the Peephole LSTM formula). |
3245 | /// @param weights_projection_desc Memory descriptor for the weights applied to |
3246 | /// the hidden states to get the recurrent projection (according to the |
3247 | /// Projection LSTM formula). |
3248 | /// @param bias_desc Bias memory descriptor. |
3249 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3250 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3251 | /// state vector. |
3252 | /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell |
3253 | /// state vector. |
3254 | /// @param diff_src_layer_desc Memory descriptor for the diff of input vector. |
3255 | /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent |
3256 | /// hidden state vector. |
3257 | /// @param diff_src_iter_c_desc Memory descriptor for the diff of input |
3258 | /// recurrent cell state vector. |
3259 | /// @param diff_weights_layer_desc Memory descriptor for the diff of weights |
3260 | /// applied to the layer input. |
3261 | /// @param diff_weights_iter_desc Memory descriptor for the diff of weights |
3262 | /// applied to the recurrent input. |
3263 | /// @param diff_weights_peephole_desc Memory descriptor for the diff of weights |
3264 | /// applied to the cell states (according to the Peephole LSTM formula). |
3265 | /// @param diff_weights_projection_desc Memory descriptor for the diff of |
3266 | /// weights applied to the hidden states to get the recurrent projection |
3267 | /// (according to the Projection LSTM formula). |
3268 | /// @param diff_bias_desc Diff bias memory descriptor. |
3269 | /// @param diff_dst_layer_desc Memory descriptor for the diff of output |
3270 | /// vector. |
3271 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
3272 | /// recurrent hidden state vector. |
3273 | /// @param diff_dst_iter_c_desc Memory descriptor for the diff of output |
3274 | /// recurrent cell state vector. |
3275 | /// @param flags Unused. |
3276 | /// @returns #dnnl_success on success and a status describing the error |
3277 | /// otherwise. |
3278 | dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v3( |
3279 | dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, |
3280 | dnnl_rnn_direction_t direction, |
3281 | const dnnl_memory_desc_t *src_layer_desc, |
3282 | const dnnl_memory_desc_t *src_iter_desc, |
3283 | const dnnl_memory_desc_t *src_iter_c_desc, |
3284 | const dnnl_memory_desc_t *weights_layer_desc, |
3285 | const dnnl_memory_desc_t *weights_iter_desc, |
3286 | const dnnl_memory_desc_t *weights_peephole_desc, |
3287 | const dnnl_memory_desc_t *weights_projection_desc, |
3288 | const dnnl_memory_desc_t *bias_desc, |
3289 | const dnnl_memory_desc_t *dst_layer_desc, |
3290 | const dnnl_memory_desc_t *dst_iter_desc, |
3291 | const dnnl_memory_desc_t *dst_iter_c_desc, |
3292 | const dnnl_memory_desc_t *diff_src_layer_desc, |
3293 | const dnnl_memory_desc_t *diff_src_iter_desc, |
3294 | const dnnl_memory_desc_t *diff_src_iter_c_desc, |
3295 | const dnnl_memory_desc_t *diff_weights_layer_desc, |
3296 | const dnnl_memory_desc_t *diff_weights_iter_desc, |
3297 | const dnnl_memory_desc_t *diff_weights_peephole_desc, |
3298 | const dnnl_memory_desc_t *diff_weights_projection_desc, |
3299 | const dnnl_memory_desc_t *diff_bias_desc, |
3300 | const dnnl_memory_desc_t *diff_dst_layer_desc, |
3301 | const dnnl_memory_desc_t *diff_dst_iter_desc, |
3302 | const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags); |
3303 | |
3304 | /// Initializes a descriptor for GRU forward propagation primitive. |
3305 | /// |
3306 | /// The following arguments may either be @c NULL or point to a zero memory |
3307 | /// descriptor: |
3308 | /// - @p src_iter_desc, |
3309 | /// - @p bias_desc, |
3310 | /// - @p dst_iter_desc. |
3311 | /// |
3312 | /// This would then indicate that the GRU forward propagation primitive should |
3313 | /// not use them and should default to zero values instead. |
3314 | /// |
3315 | /// @note |
3316 | /// All memory descriptors can be initialized with |
3317 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3318 | /// |
3319 | /// @param rnn_desc Output descriptor for GRU primitive. |
3320 | /// @param prop_kind Propagation kind. Possible values are |
3321 | /// #dnnl_forward_training and #dnnl_forward_inference. |
3322 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3323 | /// info. |
3324 | /// @param src_layer_desc Memory descriptor for the input vector. |
3325 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3326 | /// state vector. |
3327 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3328 | /// layer input. |
3329 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3330 | /// recurrent input. |
3331 | /// @param bias_desc Bias memory descriptor. |
3332 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3333 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3334 | /// state vector. |
3335 | /// @param flags Unused. |
3336 | /// @returns #dnnl_success on success and a status describing the error |
3337 | /// otherwise. |
3338 | dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, |
3339 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
3340 | const dnnl_memory_desc_t *src_layer_desc, |
3341 | const dnnl_memory_desc_t *src_iter_desc, |
3342 | const dnnl_memory_desc_t *weights_layer_desc, |
3343 | const dnnl_memory_desc_t *weights_iter_desc, |
3344 | const dnnl_memory_desc_t *bias_desc, |
3345 | const dnnl_memory_desc_t *dst_layer_desc, |
3346 | const dnnl_memory_desc_t *dst_iter_desc, unsigned flags); |
3347 | |
3348 | /// Initializes a descriptor for GRU backward propagation primitive. |
3349 | /// |
3350 | /// The following arguments may either be @c NULL or point to a zero memory |
3351 | /// descriptor: |
3352 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
3353 | /// - @p bias_desc together with @p diff_bias_desc, |
3354 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
3355 | /// |
3356 | /// This would then indicate that the GRU backward propagation primitive |
3357 | /// should not use them and should default to zero values instead. |
3358 | /// |
3359 | /// @note |
3360 | /// All memory descriptors can be initialized with |
3361 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3362 | /// |
3363 | /// @param rnn_desc Output descriptor for GRU primitive. |
3364 | /// @param prop_kind Propagation kind. Must be #dnnl_backward. |
3365 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3366 | /// info. |
3367 | /// @param src_layer_desc Memory descriptor for the input vector. |
3368 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3369 | /// state vector. |
3370 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3371 | /// layer input. |
3372 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3373 | /// recurrent input. |
3374 | /// @param bias_desc Bias memory descriptor. |
3375 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3376 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3377 | /// state vector. |
3378 | /// @param diff_src_layer_desc Memory descriptor for the diff of input vector. |
3379 | /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent |
3380 | /// hidden state vector. |
3381 | /// @param diff_weights_layer_desc Memory descriptor for the diff of weights |
3382 | /// applied to the layer input. |
3383 | /// @param diff_weights_iter_desc Memory descriptor for the diff of weights |
3384 | /// applied to the recurrent input. |
3385 | /// @param diff_bias_desc Diff bias memory descriptor. |
3386 | /// @param diff_dst_layer_desc Memory descriptor for the diff of output |
3387 | /// vector. |
3388 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
3389 | /// recurrent hidden state vector. |
3390 | /// @param flags Unused. |
3391 | /// @returns #dnnl_success on success and a status describing the error |
3392 | /// otherwise. |
3393 | dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, |
3394 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
3395 | const dnnl_memory_desc_t *src_layer_desc, |
3396 | const dnnl_memory_desc_t *src_iter_desc, |
3397 | const dnnl_memory_desc_t *weights_layer_desc, |
3398 | const dnnl_memory_desc_t *weights_iter_desc, |
3399 | const dnnl_memory_desc_t *bias_desc, |
3400 | const dnnl_memory_desc_t *dst_layer_desc, |
3401 | const dnnl_memory_desc_t *dst_iter_desc, |
3402 | const dnnl_memory_desc_t *diff_src_layer_desc, |
3403 | const dnnl_memory_desc_t *diff_src_iter_desc, |
3404 | const dnnl_memory_desc_t *diff_weights_layer_desc, |
3405 | const dnnl_memory_desc_t *diff_weights_iter_desc, |
3406 | const dnnl_memory_desc_t *diff_bias_desc, |
3407 | const dnnl_memory_desc_t *diff_dst_layer_desc, |
3408 | const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags); |
3409 | |
3410 | /// Initializes a descriptor for LBR GRU forward propagation primitive. |
3411 | /// |
3412 | /// The following arguments may either be @c NULL or point to a zero memory |
3413 | /// descriptor: |
3414 | /// - @p src_iter_desc, |
3415 | /// - @p bias_desc, |
3416 | /// - @p dst_iter_desc. |
3417 | /// |
3418 | /// This would then indicate that the LBR GRU forward propagation primitive |
3419 | /// should not use them and should default to zero values instead. |
3420 | /// |
3421 | /// @param rnn_desc Output descriptor for LBR GRU primitive. |
3422 | /// @param prop_kind Propagation kind. Possible values are |
3423 | /// #dnnl_forward_training and #dnnl_forward_inference. |
3424 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3425 | /// info. |
3426 | /// @param src_layer_desc Memory descriptor for the input vector. |
3427 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3428 | /// state vector. |
3429 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3430 | /// layer input. |
3431 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3432 | /// recurrent input. |
3433 | /// @param bias_desc Bias memory descriptor. |
3434 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3435 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3436 | /// state vector. |
3437 | /// @param flags Unused. |
3438 | /// @returns #dnnl_success on success and a status describing the error |
3439 | /// otherwise. |
3440 | dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, |
3441 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
3442 | const dnnl_memory_desc_t *src_layer_desc, |
3443 | const dnnl_memory_desc_t *src_iter_desc, |
3444 | const dnnl_memory_desc_t *weights_layer_desc, |
3445 | const dnnl_memory_desc_t *weights_iter_desc, |
3446 | const dnnl_memory_desc_t *bias_desc, |
3447 | const dnnl_memory_desc_t *dst_layer_desc, |
3448 | const dnnl_memory_desc_t *dst_iter_desc, unsigned flags); |
3449 | |
3450 | /// Initializes a descriptor for LBR GRU backward propagation primitive. |
3451 | /// |
3452 | /// The following arguments may either be @c NULL or point to a zero memory |
3453 | /// descriptor: |
3454 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
3455 | /// - @p bias_desc together with @p diff_bias_desc, |
3456 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
3457 | /// |
3458 | /// This would then indicate that the LBR GRU backward propagation primitive |
3459 | /// should not use them and should default to zero values instead. |
3460 | /// |
3461 | /// @note |
3462 | /// All memory descriptors can be initialized with |
3463 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3464 | /// |
3465 | /// @param rnn_desc Output descriptor for LBR GRU primitive. |
3466 | /// @param prop_kind Propagation kind. Must be #dnnl_backward. |
3467 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3468 | /// info. |
3469 | /// @param src_layer_desc Memory descriptor for the input vector. |
3470 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3471 | /// state vector. |
3472 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3473 | /// layer input. |
3474 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3475 | /// recurrent input. |
3476 | /// @param bias_desc Bias memory descriptor. |
3477 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3478 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3479 | /// state vector. |
3480 | /// @param diff_src_layer_desc Memory descriptor for the diff of input vector. |
3481 | /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent |
3482 | /// hidden state vector. |
3483 | /// @param diff_weights_layer_desc Memory descriptor for the diff of weights |
3484 | /// applied to the layer input. |
3485 | /// @param diff_weights_iter_desc Memory descriptor for the diff of weights |
3486 | /// applied to the recurrent input. |
3487 | /// @param diff_bias_desc Diff bias memory descriptor. |
3488 | /// @param diff_dst_layer_desc Memory descriptor for the diff of output |
3489 | /// vector. |
3490 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
3491 | /// recurrent hidden state vector. |
3492 | /// @param flags Unused. |
3493 | /// @returns #dnnl_success on success and a status describing the error |
3494 | /// otherwise. |
3495 | dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init( |
3496 | dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, |
3497 | dnnl_rnn_direction_t direction, |
3498 | const dnnl_memory_desc_t *src_layer_desc, |
3499 | const dnnl_memory_desc_t *src_iter_desc, |
3500 | const dnnl_memory_desc_t *weights_layer_desc, |
3501 | const dnnl_memory_desc_t *weights_iter_desc, |
3502 | const dnnl_memory_desc_t *bias_desc, |
3503 | const dnnl_memory_desc_t *dst_layer_desc, |
3504 | const dnnl_memory_desc_t *dst_iter_desc, |
3505 | const dnnl_memory_desc_t *diff_src_layer_desc, |
3506 | const dnnl_memory_desc_t *diff_src_iter_desc, |
3507 | const dnnl_memory_desc_t *diff_weights_layer_desc, |
3508 | const dnnl_memory_desc_t *diff_weights_iter_desc, |
3509 | const dnnl_memory_desc_t *diff_bias_desc, |
3510 | const dnnl_memory_desc_t *diff_dst_layer_desc, |
3511 | const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags); |
3512 | |
3513 | /// Initializes a descriptor for AUGRU forward propagation primitive. |
3514 | /// |
3515 | /// The following arguments may either be @c NULL or point to a zero memory |
3516 | /// descriptor: |
3517 | /// - @p src_iter_desc, |
3518 | /// - @p bias_desc, |
3519 | /// - @p dst_iter_desc. |
3520 | /// |
3521 | /// This would then indicate that the AUGRU forward propagation primitive should |
3522 | /// not use them and should default to zero values instead. |
3523 | /// |
3524 | /// @note |
3525 | /// All memory descriptors can be initialized with |
3526 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3527 | /// |
3528 | /// @param rnn_desc Output descriptor for AUGRU primitive. |
3529 | /// @param prop_kind Propagation kind. Possible values are |
3530 | /// #dnnl_forward_training and #dnnl_forward_inference. |
3531 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3532 | /// info. |
3533 | /// @param src_layer_desc Memory descriptor for the input vector. |
3534 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3535 | /// state vector. |
3536 | /// @param attention_desc Memory descriptor for the attention vector. |
3537 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3538 | /// layer input. |
3539 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3540 | /// recurrent input. |
3541 | /// @param bias_desc Bias memory descriptor. |
3542 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3543 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3544 | /// state vector. |
3545 | /// @param flags Unused. |
3546 | /// @returns #dnnl_success on success and a status describing the error |
3547 | /// otherwise. |
3548 | dnnl_status_t DNNL_API dnnl_augru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, |
3549 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
3550 | const dnnl_memory_desc_t *src_layer_desc, |
3551 | const dnnl_memory_desc_t *src_iter_desc, |
3552 | const dnnl_memory_desc_t *attention_desc, |
3553 | const dnnl_memory_desc_t *weights_layer_desc, |
3554 | const dnnl_memory_desc_t *weights_iter_desc, |
3555 | const dnnl_memory_desc_t *bias_desc, |
3556 | const dnnl_memory_desc_t *dst_layer_desc, |
3557 | const dnnl_memory_desc_t *dst_iter_desc, unsigned flags); |
3558 | |
3559 | /// Initializes a descriptor for AUGRU backward propagation primitive. |
3560 | /// |
3561 | /// The following arguments may either be @c NULL or point to a zero memory |
3562 | /// descriptor: |
3563 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
3564 | /// - @p bias_desc together with @p diff_bias_desc, |
3565 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
3566 | /// |
3567 | /// This would then indicate that the AUGRU backward propagation primitive |
3568 | /// should not use them and should default to zero values instead. |
3569 | /// |
3570 | /// @note |
3571 | /// All memory descriptors can be initialized with |
3572 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3573 | /// |
3574 | /// @param rnn_desc Output descriptor for AUGRU primitive. |
3575 | /// @param prop_kind Propagation kind. Must be #dnnl_backward. |
3576 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3577 | /// info. |
3578 | /// @param src_layer_desc Memory descriptor for the input vector. |
3579 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3580 | /// state vector. |
3581 | /// @param attention_desc Memory descriptor for the attention vector. |
3582 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3583 | /// layer input. |
3584 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3585 | /// recurrent input. |
3586 | /// @param bias_desc Bias memory descriptor. |
3587 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3588 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3589 | /// state vector. |
3590 | /// @param diff_src_layer_desc Memory descriptor for the diff of input vector. |
3591 | /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent |
3592 | /// hidden state vector. |
3593 | /// @param diff_attention_desc Memory descriptor for the diff of attention vector. |
3594 | /// @param diff_weights_layer_desc Memory descriptor for the diff of weights |
3595 | /// applied to the layer input. |
3596 | /// @param diff_weights_iter_desc Memory descriptor for the diff of weights |
3597 | /// applied to the recurrent input. |
3598 | /// @param diff_bias_desc Diff bias memory descriptor. |
3599 | /// @param diff_dst_layer_desc Memory descriptor for the diff of output |
3600 | /// vector. |
3601 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
3602 | /// recurrent hidden state vector. |
3603 | /// @param flags Unused. |
3604 | /// @returns #dnnl_success on success and a status describing the error |
3605 | /// otherwise. |
3606 | dnnl_status_t DNNL_API dnnl_augru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, |
3607 | dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, |
3608 | const dnnl_memory_desc_t *src_layer_desc, |
3609 | const dnnl_memory_desc_t *src_iter_desc, |
3610 | const dnnl_memory_desc_t *attention_desc, |
3611 | const dnnl_memory_desc_t *weights_layer_desc, |
3612 | const dnnl_memory_desc_t *weights_iter_desc, |
3613 | const dnnl_memory_desc_t *bias_desc, |
3614 | const dnnl_memory_desc_t *dst_layer_desc, |
3615 | const dnnl_memory_desc_t *dst_iter_desc, |
3616 | const dnnl_memory_desc_t *diff_src_layer_desc, |
3617 | const dnnl_memory_desc_t *diff_src_iter_desc, |
3618 | const dnnl_memory_desc_t *diff_attention_desc, |
3619 | const dnnl_memory_desc_t *diff_weights_layer_desc, |
3620 | const dnnl_memory_desc_t *diff_weights_iter_desc, |
3621 | const dnnl_memory_desc_t *diff_bias_desc, |
3622 | const dnnl_memory_desc_t *diff_dst_layer_desc, |
3623 | const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags); |
3624 | |
3625 | /// Initializes a descriptor for LBR AUGRU forward propagation primitive. |
3626 | /// |
3627 | /// The following arguments may either be @c NULL or point to a zero memory |
3628 | /// descriptor: |
3629 | /// - @p src_iter_desc, |
3630 | /// - @p bias_desc, |
3631 | /// - @p dst_iter_desc. |
3632 | /// |
3633 | /// This would then indicate that the LBR AUGRU forward propagation primitive |
3634 | /// should not use them and should default to zero values instead. |
3635 | /// |
3636 | /// @param rnn_desc Output descriptor for LBR AUGRU primitive. |
3637 | /// @param prop_kind Propagation kind. Possible values are |
3638 | /// #dnnl_forward_training and #dnnl_forward_inference. |
3639 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3640 | /// info. |
3641 | /// @param src_layer_desc Memory descriptor for the input vector. |
3642 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3643 | /// state vector. |
3644 | /// @param attention_desc Memory descriptor for the attention vector. |
3645 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3646 | /// layer input. |
3647 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3648 | /// recurrent input. |
3649 | /// @param bias_desc Bias memory descriptor. |
3650 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3651 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3652 | /// state vector. |
3653 | /// @param flags Unused. |
3654 | /// @returns #dnnl_success on success and a status describing the error |
3655 | /// otherwise. |
3656 | dnnl_status_t DNNL_API dnnl_lbr_augru_forward_desc_init( |
3657 | dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, |
3658 | dnnl_rnn_direction_t direction, |
3659 | const dnnl_memory_desc_t *src_layer_desc, |
3660 | const dnnl_memory_desc_t *src_iter_desc, |
3661 | const dnnl_memory_desc_t *attention_desc, |
3662 | const dnnl_memory_desc_t *weights_layer_desc, |
3663 | const dnnl_memory_desc_t *weights_iter_desc, |
3664 | const dnnl_memory_desc_t *bias_desc, |
3665 | const dnnl_memory_desc_t *dst_layer_desc, |
3666 | const dnnl_memory_desc_t *dst_iter_desc, unsigned flags); |
3667 | |
3668 | /// Initializes a descriptor for LBR AUGRU backward propagation primitive. |
3669 | /// |
3670 | /// The following arguments may either be @c NULL or point to a zero memory |
3671 | /// descriptor: |
3672 | /// - @p src_iter_desc together with @p diff_src_iter_desc, |
3673 | /// - @p bias_desc together with @p diff_bias_desc, |
3674 | /// - @p dst_iter_desc together with @p diff_dst_iter_desc. |
3675 | /// |
3676 | /// This would then indicate that the LBR AUGRU backward propagation primitive |
3677 | /// should not use them and should default to zero values instead. |
3678 | /// |
3679 | /// @note |
3680 | /// All memory descriptors can be initialized with |
3681 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3682 | /// |
3683 | /// @param rnn_desc Output descriptor for LBR AUGRU primitive. |
3684 | /// @param prop_kind Propagation kind. Must be #dnnl_backward. |
3685 | /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more |
3686 | /// info. |
3687 | /// @param src_layer_desc Memory descriptor for the input vector. |
3688 | /// @param src_iter_desc Memory descriptor for the input recurrent hidden |
3689 | /// state vector. |
3690 | /// @param attention_desc Memory descriptor for the attention vector. |
3691 | /// @param weights_layer_desc Memory descriptor for the weights applied to the |
3692 | /// layer input. |
3693 | /// @param weights_iter_desc Memory descriptor for the weights applied to the |
3694 | /// recurrent input. |
3695 | /// @param bias_desc Bias memory descriptor. |
3696 | /// @param dst_layer_desc Memory descriptor for the output vector. |
3697 | /// @param dst_iter_desc Memory descriptor for the output recurrent hidden |
3698 | /// state vector. |
3699 | /// @param diff_src_layer_desc Memory descriptor for the diff of input vector. |
3700 | /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent |
3701 | /// hidden state vector. |
3702 | /// @param diff_attention_desc Memory descriptor for the diff of attention vector. |
3703 | /// @param diff_weights_layer_desc Memory descriptor for the diff of weights |
3704 | /// applied to the layer input. |
3705 | /// @param diff_weights_iter_desc Memory descriptor for the diff of weights |
3706 | /// applied to the recurrent input. |
3707 | /// @param diff_bias_desc Diff bias memory descriptor. |
3708 | /// @param diff_dst_layer_desc Memory descriptor for the diff of output |
3709 | /// vector. |
3710 | /// @param diff_dst_iter_desc Memory descriptor for the diff of output |
3711 | /// recurrent hidden state vector. |
3712 | /// @param flags Unused. |
3713 | /// @returns #dnnl_success on success and a status describing the error |
3714 | /// otherwise. |
3715 | dnnl_status_t DNNL_API dnnl_lbr_augru_backward_desc_init( |
3716 | dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, |
3717 | dnnl_rnn_direction_t direction, |
3718 | const dnnl_memory_desc_t *src_layer_desc, |
3719 | const dnnl_memory_desc_t *src_iter_desc, |
3720 | const dnnl_memory_desc_t *attention_desc, |
3721 | const dnnl_memory_desc_t *weights_layer_desc, |
3722 | const dnnl_memory_desc_t *weights_iter_desc, |
3723 | const dnnl_memory_desc_t *bias_desc, |
3724 | const dnnl_memory_desc_t *dst_layer_desc, |
3725 | const dnnl_memory_desc_t *dst_iter_desc, |
3726 | const dnnl_memory_desc_t *diff_src_layer_desc, |
3727 | const dnnl_memory_desc_t *diff_src_iter_desc, |
3728 | const dnnl_memory_desc_t *diff_attention_desc, |
3729 | const dnnl_memory_desc_t *diff_weights_layer_desc, |
3730 | const dnnl_memory_desc_t *diff_weights_iter_desc, |
3731 | const dnnl_memory_desc_t *diff_bias_desc, |
3732 | const dnnl_memory_desc_t *diff_dst_layer_desc, |
3733 | const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags); |
3734 | |
3735 | /// @} dnnl_api_rnn |
3736 | |
3737 | /// @addtogroup dnnl_api_matmul |
3738 | /// @{ |
3739 | |
3740 | /// Initializes a matrix multiplication descriptor. |
3741 | /// |
3742 | /// @param matmul_desc Output descriptor for matmul primitive. |
3743 | /// @param src_desc Source memory descriptor (matrix A) |
3744 | /// @param weights_desc Weights memory descriptor (matrix B) |
3745 | /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory |
3746 | /// descriptor, or a memory descriptor with format_kind set to |
3747 | /// #dnnl_format_kind_undef disables the bias term. |
3748 | /// @param dst_desc Destination memory descriptor (matrix C). |
3749 | /// @returns #dnnl_success on success and a status describing the error |
3750 | /// otherwise. |
3751 | dnnl_status_t DNNL_API dnnl_matmul_desc_init(dnnl_matmul_desc_t *matmul_desc, |
3752 | const dnnl_memory_desc_t *src_desc, |
3753 | const dnnl_memory_desc_t *weights_desc, |
3754 | const dnnl_memory_desc_t *bias_desc, |
3755 | const dnnl_memory_desc_t *dst_desc); |
3756 | |
3757 | /// @} dnnl_api_matmul |
3758 | |
3759 | /// @addtogroup dnnl_api_resampling Resampling |
3760 | /// @{ |
3761 | |
3762 | /// Initializes a descriptor for a resampling forward propagation primitive. |
3763 | /// |
3764 | /// @note |
3765 | /// Destination memory descriptor is allowed to be initialized with |
3766 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3767 | /// |
3768 | /// |
3769 | /// @param resampling_desc Output descriptor for a resampling primitive. |
3770 | /// @param prop_kind Propagation kind. Possible values are |
3771 | /// #dnnl_forward_training and #dnnl_forward_inference. |
3772 | /// @param alg_kind resampling algorithm kind: either #dnnl_resampling_nearest, |
3773 | /// or #dnnl_resampling_linear. |
3774 | /// @param factors Array of scaling factors for spatial dimension. |
3775 | /// @param src_desc Source memory descriptor. |
3776 | /// @param dst_desc Destination memory descriptor. |
3777 | /// @returns #dnnl_success on success and a status describing the error |
3778 | /// otherwise. |
3779 | dnnl_status_t DNNL_API dnnl_resampling_forward_desc_init( |
3780 | dnnl_resampling_desc_t *resampling_desc, dnnl_prop_kind_t prop_kind, |
3781 | dnnl_alg_kind_t alg_kind, const float *factors, |
3782 | const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc); |
3783 | |
3784 | /// Initializes a descriptor for resampling backward propagation primitive. |
3785 | /// |
3786 | /// @param resampling_desc Output descriptor for a resampling primitive. |
3787 | /// @param alg_kind resamplinging algorithm kind: either |
3788 | /// #dnnl_resampling_nearest, or #dnnl_resampling_linear. |
3789 | /// @param diff_src_desc Diff source memory descriptor. |
3790 | /// @param diff_dst_desc Diff destination memory descriptor. |
3791 | /// @param factors Array of scaling factors for spatial dimension. |
3792 | /// @returns #dnnl_success on success and a status describing the error |
3793 | /// otherwise. |
3794 | /// |
3795 | dnnl_status_t DNNL_API dnnl_resampling_backward_desc_init( |
3796 | dnnl_resampling_desc_t *resampling_desc, dnnl_alg_kind_t alg_kind, |
3797 | const float *factors, const dnnl_memory_desc_t *diff_src_desc, |
3798 | const dnnl_memory_desc_t *diff_dst_desc); |
3799 | |
3800 | /// @} dnnl_api_resampling |
3801 | |
3802 | /// @addtogroup dnnl_api_reduction Reduction |
3803 | /// @{ |
3804 | |
3805 | /// Initializes a descriptor for a reduction primitive. |
3806 | /// |
3807 | /// @note |
3808 | /// Destination memory descriptor is allowed to be initialized with |
3809 | /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. |
3810 | /// |
3811 | /// |
3812 | /// @param desc Output descriptor for a reduction primitive. |
3813 | /// @param alg_kind reduction algorithm kind. Possible values: |
3814 | /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum, |
3815 | /// #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max, |
3816 | /// #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max, |
3817 | /// #dnnl_reduction_norm_lp_power_p_sum. |
3818 | /// @param p Algorithm specific parameter. |
3819 | /// @param eps Algorithm specific parameter. |
3820 | /// @param src_desc Source memory descriptor. |
3821 | /// @param dst_desc Destination memory descriptor. |
3822 | /// @returns #dnnl_success on success and a status describing the error |
3823 | /// otherwise. |
3824 | /// |
3825 | dnnl_status_t DNNL_API dnnl_reduction_desc_init(dnnl_reduction_desc_t *desc, |
3826 | dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, |
3827 | const dnnl_memory_desc_t *dst_desc, float p, float eps); |
3828 | |
3829 | /// @} dnnl_api_reduction |
3830 | |
3831 | /// @} dnnl_api_primitives |
3832 | |
3833 | /// @addtogroup dnnl_api_engine |
3834 | /// @{ |
3835 | |
3836 | /// Returns the number of engines of a particular kind. |
3837 | /// |
3838 | /// @param kind Kind of engines to count. |
3839 | /// @returns Count of the engines. |
3840 | size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind); |
3841 | |
3842 | /// Creates an engine. |
3843 | /// |
3844 | /// @param engine Output engine. |
3845 | /// @param kind Engine kind. |
3846 | /// @param index Engine index that should be between 0 and the count of |
3847 | /// engines of the requested kind. |
3848 | /// @returns #dnnl_success on success and a status describing the error |
3849 | /// otherwise. |
3850 | dnnl_status_t DNNL_API dnnl_engine_create( |
3851 | dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index); |
3852 | |
3853 | /// Returns the kind of an engine. |
3854 | /// |
3855 | /// @param engine Engine to query. |
3856 | /// @param kind Output engine kind. |
3857 | /// @returns #dnnl_success on success and a status describing the error |
3858 | /// otherwise. |
3859 | dnnl_status_t DNNL_API dnnl_engine_get_kind( |
3860 | dnnl_engine_t engine, dnnl_engine_kind_t *kind); |
3861 | |
3862 | /// Destroys an engine. |
3863 | /// |
3864 | /// @param engine Engine to destroy. |
3865 | /// @returns #dnnl_success on success and a status describing the error |
3866 | /// otherwise. |
3867 | dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine); |
3868 | |
3869 | /// @} dnnl_api_engine |
3870 | |
3871 | /// @addtogroup dnnl_api_stream |
3872 | /// @{ |
3873 | |
3874 | /// Creates an execution stream. |
3875 | /// |
3876 | /// @param stream Output execution stream. |
3877 | /// @param engine Engine to create the execution stream on. |
3878 | /// @param flags Stream behavior flags (@sa dnnl_stream_flags_t). |
3879 | /// @returns #dnnl_success on success and a status describing the error |
3880 | /// otherwise. |
3881 | dnnl_status_t DNNL_API dnnl_stream_create( |
3882 | dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags); |
3883 | |
3884 | /// Returns the engine of a stream object. |
3885 | /// |
3886 | /// @param stream Stream object. |
3887 | /// @param engine Output engine on which the stream is created. |
3888 | /// @returns #dnnl_success on success and a status describing the error |
3889 | /// otherwise. |
3890 | dnnl_status_t DNNL_API dnnl_stream_get_engine( |
3891 | const_dnnl_stream_t stream, dnnl_engine_t *engine); |
3892 | |
3893 | /// Waits for all primitives in the execution stream to finish computations. |
3894 | /// |
3895 | /// @param stream Execution stream. |
3896 | /// @returns #dnnl_success on success and a status describing the error |
3897 | /// otherwise. |
3898 | dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream); |
3899 | |
3900 | /// Destroys an execution stream. |
3901 | /// |
3902 | /// @param stream Execution stream to destroy. |
3903 | /// @returns #dnnl_success on success and a status describing the error |
3904 | /// otherwise. |
3905 | dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream); |
3906 | |
3907 | /// @} dnnl_api_stream |
3908 | |
3909 | /// @addtogroup dnnl_api_primitive_cache |
3910 | /// @{ |
3911 | |
3912 | /// Returns the number of primitives that can be held in the primitive cache |
3913 | /// at the same time. |
3914 | /// |
3915 | /// @param capacity Primitive cache capacity to query. Concurrently |
3916 | /// accessing @p capacity is safe. |
3917 | /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the |
3918 | /// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on |
3919 | /// success. |
3920 | dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity); |
3921 | |
3922 | /// Sets a number of primitives that can be held in the primitive cache |
3923 | /// at a time. |
3924 | /// |
3925 | /// @param capacity Primitive cache capacity to set. If a new @p capacity is |
3926 | /// less than a number of primitives that the primitive cache already has |
3927 | /// then the excess entries will be evicted. Setting the @p capacity to 0 |
3928 | /// clears the primitive cache and disables it. Concurrently modifying |
3929 | /// @p capacity is safe. |
3930 | /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the |
3931 | /// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on |
3932 | /// success. |
3933 | dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity); |
3934 | |
3935 | /// @} dnnl_api_primitive_cache |
3936 | |
3937 | /// @addtogroup dnnl_api_mathmode Floating-point Math Mode |
3938 | /// @{ |
3939 | |
3940 | /// Returns the floating-point math mode that will be used by default |
3941 | /// for all subsequently created primitives. |
3942 | /// |
3943 | /// @param mode Output FP math mode. |
3944 | /// @returns #dnnl_success on success and a status describing the error |
3945 | /// otherwise. |
3946 | dnnl_status_t DNNL_API dnnl_get_default_fpmath_mode(dnnl_fpmath_mode_t *mode); |
3947 | |
3948 | /// Sets the floating-point math mode that will be used by default |
3949 | /// for all subsequently created primitives. |
3950 | /// |
3951 | /// @param mode FP math mode. The possible values are: |
3952 | /// #dnnl_fpmath_mode_strict, |
3953 | /// #dnnl_fpmath_mode_bf16, |
3954 | /// #dnnl_fpmath_mode_f16, |
3955 | /// #dnnl_fpmath_mode_tf32, |
3956 | /// #dnnl_fpmath_mode_any. |
3957 | /// @returns #dnnl_success on success and a status describing the error |
3958 | /// otherwise. |
3959 | dnnl_status_t DNNL_API dnnl_set_default_fpmath_mode(dnnl_fpmath_mode_t mode); |
3960 | |
3961 | /// @} dnnl_api_mathmode |
3962 | |
3963 | /// @addtogroup dnnl_api_service |
3964 | /// @{ |
3965 | |
3966 | /// Configures verbose output to stdout. |
3967 | /// |
3968 | /// @note |
3969 | /// Enabling verbose output affects performance. |
3970 | /// This setting overrides the ONEDNN_VERBOSE environment variable. |
3971 | /// |
3972 | /// @param level Verbosity level: |
3973 | /// - 0: no verbose output (default), |
3974 | /// - 1: primitive information at execution, |
3975 | /// - 2: primitive information at creation and execution. |
3976 | /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the |
3977 | /// @p level value is invalid, and #dnnl_success/#dnnl::status::success on |
3978 | /// success. |
3979 | dnnl_status_t DNNL_API dnnl_set_verbose(int level); |
3980 | |
3981 | /// Configures dumping of JIT-generated code. |
3982 | /// |
3983 | /// @note |
3984 | /// This setting overrides the DNNL_JIT_DUMP environment variable. |
3985 | /// |
3986 | /// @param enable Flag value. Set to 0 to disable and set to 1 to enable. |
3987 | /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the |
3988 | /// @p flag value is invalid, and #dnnl_success/#dnnl::status::success on |
3989 | /// success. |
3990 | dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable); |
3991 | |
3992 | /// Returns library version information. |
3993 | /// @returns Pointer to a constant structure containing |
3994 | /// - major: major version number, |
3995 | /// - minor: minor version number, |
3996 | /// - patch: patch release number, |
3997 | /// - hash: git commit hash. |
3998 | const dnnl_version_t DNNL_API *dnnl_version(void); |
3999 | |
4000 | /// Sets library profiling flags. The flags define which profilers are |
4001 | /// supported. |
4002 | /// |
4003 | /// @note |
4004 | /// This setting overrides DNNL_JIT_PROFILE environment variable. |
4005 | /// |
4006 | /// @sa @ref dev_guide_profilers |
4007 | /// |
4008 | /// @param flags Profiling flags that can contain the following bits: |
4009 | /// - @ref DNNL_JIT_PROFILE_VTUNE -- integration with VTune Amplifier |
4010 | /// (on by default) |
4011 | /// - @ref DNNL_JIT_PROFILE_LINUX_JITDUMP -- produce Linux-specific |
4012 | /// jit-pid.dump output (off by default). The location of the output |
4013 | /// is controlled via JITDUMPDIR environment variable or via |
4014 | /// dnnl_set_jit_profiling_jitdumpdir() function. |
4015 | /// - @ref DNNL_JIT_PROFILE_LINUX_PERFMAP -- produce Linux-specific |
4016 | /// perf-pid.map output (off by default). The output is always placed |
4017 | /// into /tmp. |
4018 | /// |
4019 | /// Passing @ref DNNL_JIT_PROFILE_NONE disables profiling completely. |
4020 | /// |
4021 | /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the |
4022 | /// @p flags value is invalid, and #dnnl_success/#dnnl::status::success on |
4023 | /// success. |
4024 | dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags); |
4025 | |
4026 | /// Sets JIT dump output path. Only applicable to Linux and is only |
4027 | /// used when profiling flags have DNNL_JIT_PROFILE_LINUX_PERF bit set. |
4028 | /// |
4029 | /// After the first JIT kernel is generated, the jitdump output will be placed |
4030 | /// into temporary directory created using the mkdtemp template |
4031 | /// 'dir/.debug/jit/dnnl.XXXXXX'. |
4032 | /// |
4033 | /// @sa @ref dev_guide_profilers |
4034 | /// |
4035 | /// @note |
4036 | /// This setting overrides JITDUMPDIR environment variable. If |
4037 | /// JITDUMPDIR is not set, and this function is never called, the path |
4038 | /// defaults to HOME. Passing NULL reverts the value to default. |
4039 | /// |
4040 | /// @note |
4041 | /// The directory is accessed only when the first JIT kernel is being |
4042 | /// created. JIT profiling will be disabled in case of any errors |
4043 | /// accessing or creating this directory. |
4044 | /// |
4045 | /// @param dir JIT dump output path. |
4046 | /// @returns #dnnl_success/#dnnl::status::success if the |
4047 | /// output directory was set correctly and an error status otherwise. |
4048 | /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented on Windows. |
4049 | dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir); |
4050 | |
4051 | /// Sets the maximal ISA the library can dispatch to on the CPU. See |
4052 | /// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values accepted by |
4053 | /// the C and C++ API functions respectively. |
4054 | /// |
4055 | /// This function has effect only once, and returns an error on subsequent |
4056 | /// calls. It should also be invoked before any other oneDNN API call, otherwise |
4057 | /// it may return an error. |
4058 | /// |
4059 | /// This function overrides the DNNL_MAX_CPU_ISA environment variable. The |
4060 | /// environment variable can be set to the desired maximal ISA name in upper |
4061 | /// case and with dnnl_cpu_isa prefix removed. For example: |
4062 | /// `DNNL_MAX_CPU_ISA=AVX2`. |
4063 | /// |
4064 | /// @note |
4065 | /// The ISAs are only partially ordered: |
4066 | /// - SSE41 < AVX < AVX2, |
4067 | /// - AVX2 < AVX512_CORE < AVX512_CORE_VNNI < AVX512_CORE_BF16 |
4068 | /// < AVX512_CORE_FP16 < AVX512_CORE_AMX, |
4069 | /// - AVX2 < AVX2_VNNI. |
4070 | /// |
4071 | /// @sa @ref dev_guide_cpu_dispatcher_control for more details |
4072 | /// |
4073 | /// @param isa Maximal ISA the library should dispatch to. Pass |
4074 | /// #dnnl_cpu_isa_all/#dnnl::cpu_isa::all to remove ISA restrictions |
4075 | /// (except for ISAs with initial support in the library). |
4076 | /// @returns #dnnl_success/#dnnl::status::success on success and a |
4077 | /// #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the @p isa |
4078 | /// parameter is invalid or the ISA cannot be changed at this time. |
4079 | /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature |
4080 | /// was disabled at build time (see @ref dev_guide_build_options for more |
4081 | /// details). |
4082 | dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa); |
4083 | |
4084 | /// Gets the maximal ISA the library can dispatch to on the CPU. See |
4085 | /// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values returned by |
4086 | /// the C and C++ API functions respectively. |
4087 | /// |
4088 | /// @sa @ref dev_guide_cpu_dispatcher_control for more details |
4089 | /// |
4090 | /// @returns #dnnl_cpu_isa_t value reflecting the maximal ISA the library may |
4091 | /// dispatch to. |
4092 | dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void); |
4093 | |
4094 | /// Sets the hints flag for the CPU ISA. See #dnnl_cpu_isa_hints_t and |
4095 | /// #dnnl::cpu_isa_hints for the list of the values accepted by the C and C++ |
4096 | /// API functions respectively. |
4097 | /// |
4098 | /// This function has effect only once, and returns an error on subsequent |
4099 | /// calls. It should also be invoked before any other oneDNN API call, otherwise |
4100 | /// it may return an error. |
4101 | /// |
4102 | /// This function overrides the DNNL_CPU_ISA_HINTS environment variable. |
4103 | /// @sa @ref dev_guide_cpu_isa_hints for more details |
4104 | /// |
4105 | /// @param isa_hints CPU ISA hints to be passed over to the implementation. |
4106 | /// Pass #dnnl_cpu_isa_no_hints/#dnnl::cpu_isa_hints::no_hints to use |
4107 | /// default features i.e. no hints. |
4108 | /// @returns #dnnl_success/#dnnl::status::success on success and a |
4109 | /// #dnnl_runtime_error/#dnnl::status::runtime_error if the ISA hints cannot |
4110 | /// be specified at the current time. |
4111 | /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature |
4112 | /// was disabled at build time (see @ref dev_guide_build_options for more |
4113 | /// details). |
4114 | dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints); |
4115 | |
4116 | /// Gets the ISA specific hints that library can follow. See |
4117 | /// #dnnl_cpu_isa_hints_t and #dnnl::cpu_isa_hints for the list of the values |
4118 | /// returned by the C and C++ API functions respectively. |
4119 | /// |
4120 | /// @sa @ref dev_guide_cpu_isa_hints for more details |
4121 | /// |
4122 | /// @returns #dnnl_cpu_isa_hints_t value reflecting the ISA specific hints the |
4123 | /// library can follow. |
4124 | dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void); |
4125 | |
4126 | /// @} dnnl_api_service |
4127 | |
4128 | /// @addtogroup dnnl_api_blas |
4129 | /// @{ |
4130 | |
4131 | /// Performs single-precision matrix-matrix multiply. |
4132 | /// |
4133 | /// The operation is defined as: |
4134 | /// |
4135 | /// `C := alpha * op( A ) * op( B ) + beta * C` |
4136 | /// |
4137 | /// where |
4138 | /// - `op( X ) = X` or `op( X ) = X**T`, |
4139 | /// - `alpha` and `beta` are scalars, and |
4140 | /// - `A`, `B`, and `C` are matrices: |
4141 | /// - `op( A )` is an `MxK` matrix, |
4142 | /// - `op( B )` is an `KxN` matrix, |
4143 | /// - `C` is an `MxN` matrix. |
4144 | /// |
4145 | /// The matrices are assumed to be stored in row-major order (the elements in |
4146 | /// each of the matrix rows are contiguous in memory). |
4147 | /// |
4148 | /// @note |
4149 | /// This API does not support XERBLA. Instead, unlike the standard BLAS |
4150 | /// functions, this one returns a dnnl_status_t value to allow error |
4151 | /// handling. |
4152 | /// |
4153 | /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not |
4154 | /// transposed, and 'T' or 't' means that A is transposed. |
4155 | /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not |
4156 | /// transposed, and 'T' or 't' means that B is transposed. |
4157 | /// @param M The M dimension. |
4158 | /// @param N The N dimension. |
4159 | /// @param K The K dimension. |
4160 | /// @param alpha The alpha parameter that is used to scale the product of |
4161 | /// matrices A and B. |
4162 | /// @param A A pointer to the A matrix data. |
4163 | /// @param lda The leading dimension for the matrix A. |
4164 | /// @param B A pointer to the B matrix data. |
4165 | /// @param ldb The leading dimension for the matrix B. |
4166 | /// @param beta The beta parameter that is used to scale the matrix C. |
4167 | /// @param C A pointer to the C matrix data. |
4168 | /// @param ldc The leading dimension for the matrix C. |
4169 | /// @returns #dnnl_success/#dnnl::status::success on success and a status |
4170 | /// describing the error otherwise. |
4171 | dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M, |
4172 | dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, |
4173 | const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc); |
4174 | |
4175 | /// Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit |
4176 | /// signed matrix B, and 32-bit signed resulting matrix C. |
4177 | /// |
4178 | /// The operation is defined as: |
4179 | /// |
4180 | /// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset` |
4181 | /// |
4182 | /// where |
4183 | /// - `op( X ) = X` or `op( X ) = X**T`, |
4184 | /// - `alpha` and `beta` are scalars, and |
4185 | /// - `A`, `B`, and `C` are matrices: |
4186 | /// - `op( A )` is an `MxK` matrix, |
4187 | /// - `op( B )` is an `KxN` matrix, |
4188 | /// - `C` is an `MxN` matrix. |
4189 | /// - `A_offset` is an `MxK` matrix with every element equal the `ao` value, |
4190 | /// - `B_offset` is an `KxN` matrix with every element equal the `bo` value, |
4191 | /// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`: |
4192 | /// - if `offsetc = F`: the `len` must be at least `1`, |
4193 | /// - if `offsetc = C`: the `len` must be at least `max(1, m)`, |
4194 | /// - if `offsetc = R`: the `len` must be at least `max(1, n)`, |
4195 | /// |
4196 | /// The matrices are assumed to be stored in row-major order (the elements in |
4197 | /// each of the matrix rows are contiguous in memory). |
4198 | /// |
4199 | /// @note |
4200 | /// This API does not support XERBLA. Instead, unlike the standard BLAS |
4201 | /// functions, this one returns a dnnl_status_t value to allow error |
4202 | /// handling. |
4203 | /// |
4204 | /// @warning |
4205 | /// On some architectures saturation may happen during intermediate |
4206 | /// computations, which would lead to unexpected results. For more |
4207 | /// details, refer to @ref dev_guide_int8_computations. |
4208 | /// |
4209 | /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not |
4210 | /// transposed, and 'T' or 't' means that A is transposed. |
4211 | /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not |
4212 | /// transposed, and 'T' or 't' means that B is transposed. |
4213 | /// @param offsetc Flag specifying how offsets should be applied to matrix C: |
4214 | /// - 'F' means that the same offset will be applied to each element of |
4215 | /// the matrix C, |
4216 | /// - 'C' means that individual offset will be applied to each element |
4217 | /// within each column, |
4218 | /// - 'R' means that individual offset will be applied to each element |
4219 | /// within each row. |
4220 | /// @param M The M dimension. |
4221 | /// @param N The N dimension. |
4222 | /// @param K The K dimension. |
4223 | /// @param alpha The alpha parameter that is used to scale the product of |
4224 | /// matrices A and B. |
4225 | /// @param A A pointer to the A matrix data. |
4226 | /// @param lda The leading dimension for the matrix A. |
4227 | /// @param ao The offset value for the matrix A. |
4228 | /// @param B A pointer to the B matrix data. |
4229 | /// @param ldb The leading dimension for the matrix B. |
4230 | /// @param bo The offset value for the matrix B. |
4231 | /// @param beta The beta parameter that is used to scale the matrix C. |
4232 | /// @param C A pointer to the C matrix data. |
4233 | /// @param ldc The leading dimension for the matrix C. |
4234 | /// @param co An array of offset values for the matrix C. The number of |
4235 | /// elements in the array depends on the value of @p offsetc. |
4236 | /// @returns #dnnl_success/#dnnl::status::success on success and a status |
4237 | /// describing the error otherwise. |
4238 | dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, |
4239 | dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, |
4240 | dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, |
4241 | float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co); |
4242 | |
4243 | /// Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit |
4244 | /// signed matrix B, and 32-bit signed resulting matrix C. |
4245 | /// |
4246 | /// The operation is defined as: |
4247 | /// |
4248 | /// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset` |
4249 | /// |
4250 | /// where |
4251 | /// - `op( X ) = X` or `op( X ) = X**T`, |
4252 | /// - `alpha` and `beta` are scalars, and |
4253 | /// - `A`, `B`, and `C` are matrices: |
4254 | /// - `op( A )` is an `MxK` matrix, |
4255 | /// - `op( B )` is an `KxN` matrix, |
4256 | /// - `C` is an `MxN` matrix. |
4257 | /// - `A_offset` is an `MxK` matrix with every element equal the `ao` value, |
4258 | /// - `B_offset` is an `KxN` matrix with every element equal the `bo` value, |
4259 | /// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`: |
4260 | /// - if `offsetc = F`: the `len` must be at least `1`, |
4261 | /// - if `offsetc = C`: the `len` must be at least `max(1, m)`, |
4262 | /// - if `offsetc = R`: the `len` must be at least `max(1, n)`, |
4263 | /// |
4264 | /// The matrices are assumed to be stored in row-major order (the elements in |
4265 | /// each of the matrix rows are contiguous in memory). |
4266 | /// |
4267 | /// @note |
4268 | /// This API does not support XERBLA. Instead, unlike the standard BLAS |
4269 | /// functions, this one returns a dnnl_status_t value to allow error |
4270 | /// handling. |
4271 | /// |
4272 | /// @warning |
4273 | /// On some architectures saturation may happen during intermediate |
4274 | /// computations, which would lead to unexpected results. For more |
4275 | /// details, refer to @ref dev_guide_int8_computations. |
4276 | /// |
4277 | /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not |
4278 | /// transposed, and 'T' or 't' means that A is transposed. |
4279 | /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not |
4280 | /// transposed, and 'T' or 't' means that B is transposed. |
4281 | /// @param offsetc Flag specifying how offsets should be applied to matrix C: |
4282 | /// - 'F' means that the same offset will be applied to each element of |
4283 | /// the matrix C, |
4284 | /// - 'C' means that individual offset will be applied to each element |
4285 | /// within each column, |
4286 | /// - 'R' means that individual offset will be applied to each element |
4287 | /// within each row. |
4288 | /// @param M The M dimension. |
4289 | /// @param N The N dimension. |
4290 | /// @param K The K dimension. |
4291 | /// @param alpha The alpha parameter that is used to scale the product of |
4292 | /// matrices A and B. |
4293 | /// @param A A pointer to the A matrix data. |
4294 | /// @param lda The leading dimension for the matrix A. |
4295 | /// @param ao The offset value for the matrix A. |
4296 | /// @param B A pointer to the B matrix data. |
4297 | /// @param ldb The leading dimension for the matrix B. |
4298 | /// @param bo The offset value for the matrix B. |
4299 | /// @param beta The beta parameter that is used to scale the matrix C. |
4300 | /// @param C A pointer to the C matrix data. |
4301 | /// @param ldc The leading dimension for the matrix C. |
4302 | /// @param co An array of offset values for the matrix C. The number of |
4303 | /// elements in the array depends on the value of @p offsetc. |
4304 | /// @returns #dnnl_success/#dnnl::status::success on success and a status |
4305 | /// describing the error otherwise. |
4306 | dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, |
4307 | dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, |
4308 | dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, |
4309 | float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co); |
4310 | |
4311 | /// @} dnnl_api_blas |
4312 | |
4313 | /// @} dnnl_api |
4314 | |
4315 | #ifdef __cplusplus |
4316 | } |
4317 | #endif |
4318 | |
4319 | #endif /* ONEAPI_DNNL_DNNL_H */ |
4320 | |