1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl.h"
18
19#include "c_types_map.hpp"
20#include "primitive_attr.hpp"
21#include "type_helpers.hpp"
22#include "utils.hpp"
23
24using namespace dnnl::impl;
25using namespace dnnl::impl::status;
26using namespace dnnl::impl::utils;
27
28namespace dnnl {
29namespace impl {
30
31const primitive_attr_t &default_attr() {
32 static const primitive_attr_t default_attr_instance;
33 return default_attr_instance;
34}
35
36status_t scales_t::set(dim_t count, int mask, const float *scales) {
37 cleanup();
38
39 count_ = count;
40 mask_ = mask;
41
42 if (is_runtime_value(*scales)) {
43 scales_ = scales_buf_;
44 scales_[0] = *scales;
45 } else if (count_ == 1) {
46 scales_ = scales_buf_;
47 utils::array_set(scales_, scales[0], scales_buf_size);
48 } else {
49 scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
50 if (scales_ == nullptr) return status::out_of_memory;
51
52 for (dim_t c = 0; c < count_; ++c)
53 scales_[c] = scales[c];
54 }
55
56 return status::success;
57}
58
59status_t zero_points_t::get(int arg, int *mask) const {
60 if (mask) *mask = get_mask(arg);
61 return status::success;
62}
63
64status_t zero_points_t::set(int arg, int mask) {
65 const bool supported_arg
66 = utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST);
67 if (!supported_arg) return status::unimplemented;
68
69 switch (arg) {
70 case DNNL_ARG_SRC:
71 is_set_src = true;
72 mask_src = mask;
73 break;
74 case DNNL_ARG_WEIGHTS:
75 is_set_wei = true;
76 mask_wei = mask;
77 break;
78 case DNNL_ARG_DST:
79 is_set_dst = true;
80 mask_dst = mask;
81 break;
82 }
83 return status::success;
84}
85
86} // namespace impl
87} // namespace dnnl
88
89bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
90 dnnl::impl::data_type_t dst_dt) const {
91 using smask_t = skip_mask_t;
92 // prepare mask for runtime-parameters check
93 smask_t defined_mask = smask_t::none;
94 if ((mask & smask_t::oscale_runtime) == smask_t::oscale_runtime)
95 defined_mask |= smask_t::oscale;
96 if ((mask & smask_t::scales_runtime) == smask_t::scales_runtime)
97 defined_mask |= smask_t::scales;
98 if ((mask & smask_t::zero_points_runtime) == smask_t::zero_points_runtime)
99 defined_mask |= smask_t::zero_points;
100 bool ok = true;
101
102#define CHECK_ARG(x) ok = ok && (x)
103#define CHECK_MASK(mask_name, mask_field) \
104 CHECK_ARG(IMPLICATION( \
105 (bool)(~mask & (mask_name)), (mask_field).has_default_values()))
106 CHECK_MASK(smask_t::oscale_runtime, output_scales_);
107 CHECK_MASK(smask_t::scales, scales_);
108 CHECK_MASK(smask_t::zero_points, zero_points_);
109 CHECK_MASK(smask_t::post_ops, post_ops_);
110 CHECK_MASK(smask_t::rnn_data_qparams, rnn_data_qparams_);
111 CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_);
112 CHECK_MASK(smask_t::rnn_weights_projection_qparams,
113 rnn_weights_projection_qparams_);
114 CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::sum_dt),
115 post_ops_.sum_with_default_dt(dst_dt)));
116 bool gpu_attr_ok = IMPLICATION((bool)(~mask & smask_t::gpu_attr),
117 !gpu_attr_ || gpu_attr_->has_default_values());
118 CHECK_ARG(gpu_attr_ok);
119 CHECK_ARG(this->defined(defined_mask));
120 return ok;
121#undef CHECK_MASK
122#undef CHECK_ARG
123}
124
125bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
126 using smask_t = skip_mask_t;
127 bool ok = true;
128#define CHECK_ARG(x) ok = ok && (x)
129#define CHECK_MASK(mask_name, mask_field) \
130 CHECK_ARG(IMPLICATION((bool)(~mask & (mask_name)), (mask_field).defined()))
131 CHECK_MASK(smask_t::oscale, output_scales_);
132 CHECK_MASK(smask_t::scales, scales_);
133 CHECK_MASK(smask_t::zero_points, zero_points_);
134 CHECK_MASK(smask_t::post_ops, post_ops_);
135 CHECK_MASK(smask_t::rnn_data_qparams, rnn_data_qparams_);
136 CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_);
137 CHECK_MASK(smask_t::rnn_weights_projection_qparams,
138 rnn_weights_projection_qparams_);
139 return ok;
140#undef CHECK_MASK
141#undef CHECK_ARG
142}
143
144status_t post_ops_t::append_sum(
145 float scale, int32_t zero_point, data_type_t dt) {
146 if (len() == post_ops_limit) return out_of_memory;
147 entry_.emplace_back();
148 auto &e = entry_.back();
149 e.kind = primitive_kind::sum;
150 e.sum.scale = scale;
151 e.sum.zero_point = zero_point;
152 e.sum.dt = dt;
153 return success;
154}
155
156status_t post_ops_t::append_eltwise(
157 float scale, alg_kind_t alg, float alpha, float beta) {
158 if (len() == post_ops_limit) return out_of_memory;
159 if (!math::is_eltwise_ok(data_type::f32, alg, alpha, beta))
160 return invalid_arguments;
161
162 entry_.emplace_back();
163 auto &e = entry_.back();
164 e.kind = primitive_kind::eltwise;
165 e.eltwise.scale = scale;
166 e.eltwise.alg = alg;
167 e.eltwise.alpha = alpha;
168 e.eltwise.beta = beta;
169 return success;
170}
171
172status_t post_ops_t::append_dw(data_type_t wei_dt, data_type_t bias_dt,
173 data_type_t dst_dt, dim_t kernel_size, dim_t stride_size,
174 dim_t padding_l_size) {
175 if (len() == post_ops_limit) return out_of_memory;
176 bool ok = wei_dt != data_type::undef && dst_dt != data_type::undef;
177 if (!ok) return invalid_arguments;
178
179 ok = ok && kernel_size > 0 && stride_size > 0;
180 if (!ok) return invalid_arguments;
181
182 // Avoiding cases when kernel in pad area
183 ok = ok && (padding_l_size + 1) <= kernel_size;
184 if (!ok) return invalid_arguments;
185
186 entry_.emplace_back();
187 auto &e = entry_.back();
188 e.kind = primitive_kind::convolution;
189 auto &d = e.depthwise_conv;
190 d.kernel = kernel_size;
191 d.stride = stride_size;
192 d.padding = padding_l_size;
193 d.wei_dt = wei_dt;
194 d.bias_dt = bias_dt;
195 d.dst_dt = dst_dt;
196
197 return success;
198}
199
200status_t post_ops_t::validate_binary(
201 alg_kind_t alg, const memory_desc_t *user_src1_desc) const {
202
203 if (len() == post_ops_limit) return out_of_memory;
204 using namespace alg_kind;
205 bool alg_ok = one_of(alg, binary_add, binary_mul, binary_max, binary_min,
206 binary_div, binary_sub, binary_ge, binary_gt, binary_le, binary_lt,
207 binary_eq, binary_ne);
208 if (!alg_ok) return invalid_arguments;
209 if (!memory_desc_sanity_check(*user_src1_desc)) return invalid_arguments;
210
211 // Additional check to restrict run-time dimension usage until supported.
212 for (int d = 0; d < user_src1_desc->ndims; ++d) {
213 if (user_src1_desc->dims[d] == DNNL_RUNTIME_DIM_VAL)
214 return invalid_arguments;
215 }
216
217 return success;
218}
219
220status_t post_ops_t::append_binary(
221 alg_kind_t alg, const memory_desc_t *user_src1_desc) {
222 auto status = validate_binary(alg, user_src1_desc);
223 if (status != success) return status;
224
225 entry_.emplace_back();
226 auto &e = entry_.back();
227 e.kind = primitive_kind::binary;
228 e.binary.alg = alg;
229 e.binary.user_src1_desc = *user_src1_desc;
230 e.binary.src1_desc = *user_src1_desc;
231 return success;
232}
233
234status_t post_ops_t::prepend_binary(
235 alg_kind_t alg, const memory_desc_t *user_src1_desc) {
236 auto status = validate_binary(alg, user_src1_desc);
237 if (status != success) return status;
238
239 entry_.emplace(entry_.begin());
240 auto &e = entry_[0];
241 e.kind = primitive_kind::binary;
242 e.binary.alg = alg;
243 e.binary.user_src1_desc = *user_src1_desc;
244 e.binary.src1_desc = *user_src1_desc;
245 return success;
246}
247
248status_t post_ops_t::append_prelu(int mask) {
249 if (len() == post_ops_limit) return out_of_memory;
250
251 auto it_entry = entry_.emplace(entry_.end());
252 it_entry->kind = primitive_kind::prelu;
253 it_entry->prelu.mask = mask;
254
255 return success;
256}
257
258bool post_ops_t::defined() const {
259 for (int idx = 0; idx < len(); ++idx) {
260 auto kind = entry_[idx].kind;
261 if (kind == primitive_kind::sum) {
262 if (is_runtime_value(entry_[idx].sum.scale)) return false;
263 } else if (kind == primitive_kind::eltwise) {
264 const auto &e = entry_[idx].eltwise;
265 if (is_runtime_value(e.scale) || is_runtime_value(e.alpha)
266 || is_runtime_value(e.beta))
267 return false;
268 } else if (utils::one_of(kind, primitive_kind::binary,
269 primitive_kind::prelu,
270 primitive_kind::convolution)) {
271 // binary is always defined
272 } else {
273 assert(!"unreachable");
274 }
275 }
276 return true;
277}
278
279status_t post_ops_t::set_default_formats(const memory_desc_t *dst_md) {
280 for (int idx = 0; idx < len(); ++idx) {
281 if (!contain(primitive_kind::binary, idx)) continue;
282
283 auto &src1_md = entry_[idx].binary.src1_desc;
284 const memory_desc_wrapper src1_mdw(src1_md);
285 if (!src1_mdw.format_any()) continue;
286
287 const memory_desc_wrapper dst_mdw(dst_md);
288 assert(!dst_mdw.format_any());
289
290 // 1D tensors should be plain abx.
291 if (src1_mdw.count_non_unit_dims(1))
292 CHECK(memory_desc_init_by_strides(src1_md, nullptr));
293 else
294 CHECK(memory_desc_init_by_blocking_desc(
295 src1_md, dst_mdw.blocking_desc()));
296 }
297
298 return status::success;
299}
300
301bool post_ops_t::check_sum_consistent_dt(
302 const data_type_t dst_dt, const bool diverse_sum_dt_allowed) const {
303 int sum_ind = find(dnnl::impl::primitive_kind::sum);
304 if (sum_ind == -1) return true;
305 const auto sum_dt = entry_[sum_ind].sum.dt;
306
307 // sum dt and dst dt must have the same size
308 const bool compatible_dt_size = IMPLICATION(
309 !utils::one_of(dnnl_data_type_undef, sum_dt, dst_dt),
310 types::data_type_size(dst_dt) == types::data_type_size(sum_dt));
311 if (!compatible_dt_size) return false;
312 if (diverse_sum_dt_allowed) return true;
313
314 bool ok = true;
315 while ((sum_ind = find(dnnl::impl::primitive_kind::sum, sum_ind + 1)) != -1)
316 ok = ok && entry_[sum_ind].sum.dt == sum_dt;
317 return ok;
318}
319
320status_t primitive_attr_t::set_fpmath_mode(fpmath_mode_t fpmath_mode) {
321 auto st = check_fpmath_mode(fpmath_mode);
322 if (st == success) fpmath_mode_ = fpmath_mode;
323 return st;
324}
325
326status_t primitive_attr_t::set_scratchpad_mode(
327 scratchpad_mode_t scratchpad_mode) {
328 const bool ok = one_of(
329 scratchpad_mode, scratchpad_mode::library, scratchpad_mode::user);
330 if (!ok) return invalid_arguments;
331
332 scratchpad_mode_ = scratchpad_mode;
333 return success;
334}
335
336status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
337 return post_ops_.copy_from(post_ops);
338}
339
340status_t primitive_attr_t::set_default_formats(const memory_desc_t *dst_md) {
341 return post_ops_.set_default_formats(dst_md);
342}
343
344status_t primitive_attr_t::set_gpu_attr(const primitive_attr_item_t &gpu_attr) {
345 gpu_attr_ = gpu_attr.clone();
346 return status::success;
347}
348
349/* Public C API */
350
351status_t dnnl_primitive_attr_create(primitive_attr_t **attr) {
352 if (attr == nullptr) return invalid_arguments;
353
354 return safe_ptr_assign(*attr, new dnnl_primitive_attr);
355}
356
357status_t dnnl_primitive_attr_clone(
358 primitive_attr_t **attr, const primitive_attr_t *existing_attr) {
359 if (any_null(attr, existing_attr)) return invalid_arguments;
360
361 auto new_attr = utils::make_unique<primitive_attr_t>(*existing_attr);
362 if (!new_attr->is_initialized()) return out_of_memory;
363
364 return safe_ptr_assign(*attr, new_attr.release());
365}
366
367status_t dnnl_primitive_attr_destroy(primitive_attr_t *attr) {
368 delete attr;
369
370 return success;
371}
372
373status_t dnnl_primitive_attr_get_fpmath_mode(
374 const primitive_attr_t *attr, fpmath_mode_t *mode) {
375 if (any_null(attr, mode)) return invalid_arguments;
376 *mode = attr->fpmath_mode_;
377 return success;
378}
379
380status_t dnnl_primitive_attr_set_fpmath_mode(
381 primitive_attr_t *attr, fpmath_mode_t mode) {
382 if (any_null(attr)) return invalid_arguments;
383 return attr->set_fpmath_mode(mode);
384}
385
386status_t dnnl_primitive_attr_get_scratchpad_mode(
387 const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
388 if (any_null(attr, scratchpad_mode)) return invalid_arguments;
389
390 *scratchpad_mode = attr->scratchpad_mode_;
391
392 return success;
393}
394
395status_t dnnl_primitive_attr_set_scratchpad_mode(
396 primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
397 if (any_null(attr)) return invalid_arguments;
398
399 return attr->set_scratchpad_mode(scratchpad_mode);
400}
401
402status_t dnnl_primitive_attr_set_scales_mask(
403 primitive_attr_t *attr, int arg, int mask) {
404 bool ok = attr && mask >= 0 && arg >= 0
405 && attr->output_scales_.has_default_values();
406 if (!ok) return invalid_arguments;
407 return attr->scales_.set(arg, mask);
408}
409
410status_t dnnl_primitive_attr_set_zero_points_mask(
411 primitive_attr_t *attr, int arg, int mask) {
412 bool ok = attr && mask >= 0;
413 if (!ok) return invalid_arguments;
414
415 return attr->zero_points_.set(arg, mask);
416}
417
418status_t dnnl_primitive_attr_get_post_ops(
419 const primitive_attr_t *attr, const post_ops_t **post_ops) {
420 if (any_null(attr, post_ops)) return invalid_arguments;
421
422 *post_ops = &attr->post_ops_;
423 return success;
424}
425
426status_t dnnl_primitive_attr_set_post_ops(
427 primitive_attr_t *attr, const post_ops_t *post_ops) {
428 if (any_null(attr, post_ops)) return invalid_arguments;
429
430 return attr->set_post_ops(*post_ops);
431}
432
433status_t dnnl_post_ops_create(post_ops_t **post_ops) {
434 if (post_ops == nullptr) return invalid_arguments;
435
436 return safe_ptr_assign(*post_ops, new dnnl_post_ops);
437}
438
439status_t dnnl_post_ops_clone(
440 post_ops_t **post_ops, const post_ops_t *existing_post_ops) {
441 if (any_null(post_ops, existing_post_ops)) return invalid_arguments;
442
443 auto new_post_ops = utils::make_unique<post_ops_t>(*existing_post_ops);
444 if (!new_post_ops->is_initialized()) return out_of_memory;
445
446 return safe_ptr_assign(*post_ops, new_post_ops.release());
447}
448
449status_t dnnl_post_ops_destroy(post_ops_t *post_ops) {
450 delete post_ops;
451
452 return success;
453}
454
455int dnnl_post_ops_len(const post_ops_t *post_ops) {
456 if (post_ops) return post_ops->len();
457
458 return 0;
459}
460
461primitive_kind_t dnnl_post_ops_get_kind(const post_ops_t *post_ops, int index) {
462 bool ok = post_ops && 0 <= index && index < post_ops->len();
463 if (!ok) return primitive_kind::undefined;
464
465 return post_ops->entry_[index].kind;
466}
467
468status_t dnnl_post_ops_append_sum(
469 post_ops_t *post_ops, float scale, int32_t zero_point, data_type_t dt) {
470 if (post_ops == nullptr) return invalid_arguments;
471
472 return post_ops->append_sum(scale, zero_point, dt);
473}
474
475namespace {
476bool simple_get_params_check(
477 const post_ops_t *post_ops, int index, primitive_kind_t kind) {
478 bool ok = true && post_ops != nullptr && 0 <= index
479 && index < post_ops->len() && post_ops->entry_[index].kind == kind;
480 return ok;
481}
482} // namespace
483
484status_t dnnl_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
485 float *scale, int32_t *zero_point, data_type_t *dt) {
486 bool ok = true
487 && simple_get_params_check(post_ops, index, primitive_kind::sum);
488 if (!ok) return invalid_arguments;
489
490 if (scale) *scale = post_ops->entry_[index].sum.scale;
491 if (zero_point) *zero_point = post_ops->entry_[index].sum.zero_point;
492 if (dt) *dt = post_ops->entry_[index].sum.dt;
493 return success;
494}
495
496status_t dnnl_post_ops_append_eltwise(
497 post_ops_t *post_ops, alg_kind_t kind, float alpha, float beta) {
498 if (post_ops == nullptr) return invalid_arguments;
499
500 return post_ops->append_eltwise(1.0f, kind, alpha, beta);
501}
502
503status_t dnnl_post_ops_get_params_eltwise(const post_ops_t *post_ops, int index,
504 alg_kind_t *alg, float *alpha, float *beta) {
505 bool ok = true
506 && simple_get_params_check(post_ops, index, primitive_kind::eltwise)
507 && !any_null(alpha, beta);
508 if (!ok) return invalid_arguments;
509
510 const auto &e = post_ops->entry_[index].eltwise;
511 *alg = e.alg;
512 *alpha = e.alpha;
513 *beta = e.beta;
514
515 return success;
516}
517
518status_t dnnl_post_ops_append_dw(post_ops_t *post_ops, data_type_t wei_dt,
519 data_type_t bias_dt, data_type_t dst_dt, dim_t kernel_size,
520 dim_t stride_size, dim_t padding_l_size) {
521 if (post_ops == nullptr) return invalid_arguments;
522
523 return post_ops->append_dw(
524 wei_dt, bias_dt, dst_dt, kernel_size, stride_size, padding_l_size);
525}
526
527status_t dnnl_post_ops_get_params_dw(const post_ops_t *post_ops, int index,
528 data_type_t *wei_dt, data_type_t *bias_dt, data_type_t *dst_dt,
529 dim_t *kernel, dim_t *stride, dim_t *padding) {
530
531 if (!simple_get_params_check(post_ops, index, primitive_kind::convolution))
532 return invalid_arguments;
533
534 const auto &d = post_ops->entry_[index].depthwise_conv;
535 if (wei_dt) *wei_dt = d.wei_dt;
536 if (bias_dt) *bias_dt = d.bias_dt;
537 if (dst_dt) *dst_dt = d.dst_dt;
538 if (kernel) *kernel = d.kernel;
539 if (stride) *stride = d.stride;
540 if (padding) *padding = d.padding;
541
542 return success;
543}
544
545status_t dnnl_post_ops_append_binary(post_ops_t *post_ops, alg_kind_t alg_kind,
546 const memory_desc_t *user_src1_desc) {
547 if (post_ops == nullptr) return invalid_arguments;
548
549 return post_ops->append_binary(alg_kind, user_src1_desc);
550}
551
552status_t dnnl_post_ops_get_params_binary(const post_ops_t *post_ops, int index,
553 alg_kind_t *alg_kind, const memory_desc_t **user_src1_desc) {
554 if (!simple_get_params_check(post_ops, index, primitive_kind::binary))
555 return invalid_arguments;
556
557 const auto &b = post_ops->entry_[index].binary;
558 if (alg_kind) *alg_kind = b.alg;
559 if (user_src1_desc) *user_src1_desc = &b.user_src1_desc;
560
561 return success;
562}
563
564status_t dnnl_post_ops_append_prelu(post_ops_t *post_ops, int mask) {
565 if (post_ops == nullptr) return invalid_arguments;
566
567 return post_ops->append_prelu(mask);
568}
569
570status_t dnnl_post_ops_get_params_prelu(
571 const post_ops_t *post_ops, int index, int *mask) {
572 if (post_ops == nullptr || index >= post_ops->len())
573 return invalid_arguments;
574
575 const auto &prelu_entry = post_ops->entry_[index].prelu;
576 if (mask) *mask = prelu_entry.mask;
577
578 return success;
579}
580
581status_t dnnl_primitive_attr_set_rnn_data_qparams(
582 primitive_attr_t *attr, const float scale, const float shift) {
583 if (attr == nullptr) return invalid_arguments;
584
585 return attr->rnn_data_qparams_.set(scale, shift);
586}
587
588status_t dnnl_primitive_attr_get_rnn_data_qparams(
589 const primitive_attr_t *attr, float *scale, float *shift) {
590 if (attr == nullptr) return invalid_arguments;
591
592 const auto qparams = attr->rnn_data_qparams_;
593 if (scale) *scale = qparams.scale_;
594 if (shift) *shift = qparams.shift_;
595
596 return success;
597}
598
599status_t dnnl_primitive_attr_set_rnn_weights_qparams(
600 primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
601 bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
602 if (!ok) return invalid_arguments;
603
604 return attr->rnn_weights_qparams_.set(count, mask, scales);
605}
606
607status_t dnnl_primitive_attr_get_rnn_weights_qparams(
608 const primitive_attr_t *attr, dim_t *count, int *mask,
609 const float **scales) {
610 if (attr == nullptr) return invalid_arguments;
611
612 const auto &qparams = attr->rnn_weights_qparams_;
613 if (count) *count = qparams.count_;
614 if (mask) *mask = qparams.mask_;
615 if (scales) *scales = qparams.scales_;
616
617 return success;
618}
619
620status_t dnnl_primitive_attr_set_rnn_weights_projection_qparams(
621 primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
622 bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
623 if (!ok) return invalid_arguments;
624
625 return attr->rnn_weights_projection_qparams_.set(count, mask, scales);
626}
627
628status_t dnnl_primitive_attr_get_rnn_weights_projection_qparams(
629 const primitive_attr_t *attr, dim_t *count, int *mask,
630 const float **scales) {
631 if (attr == nullptr) return invalid_arguments;
632
633 const auto &qparams = attr->rnn_weights_projection_qparams_;
634 if (count) *count = qparams.count_;
635 if (mask) *mask = qparams.mask_;
636 if (scales) *scales = qparams.scales_;
637
638 return success;
639}
640
641status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams(
642 dnnl_primitive_attr_t attr, bool mode, dim_t ngates,
643 const float *scales, float cscale) {
644 if (attr == nullptr) return invalid_arguments;
645
646 return attr->rnn_tparams_.set(mode, ngates, scales, cscale);
647}
648