1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl.h" |
18 | |
19 | #include "c_types_map.hpp" |
20 | #include "primitive_attr.hpp" |
21 | #include "type_helpers.hpp" |
22 | #include "utils.hpp" |
23 | |
24 | using namespace dnnl::impl; |
25 | using namespace dnnl::impl::status; |
26 | using namespace dnnl::impl::utils; |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | |
31 | const primitive_attr_t &default_attr() { |
32 | static const primitive_attr_t default_attr_instance; |
33 | return default_attr_instance; |
34 | } |
35 | |
36 | status_t scales_t::set(dim_t count, int mask, const float *scales) { |
37 | cleanup(); |
38 | |
39 | count_ = count; |
40 | mask_ = mask; |
41 | |
42 | if (is_runtime_value(*scales)) { |
43 | scales_ = scales_buf_; |
44 | scales_[0] = *scales; |
45 | } else if (count_ == 1) { |
46 | scales_ = scales_buf_; |
47 | utils::array_set(scales_, scales[0], scales_buf_size); |
48 | } else { |
49 | scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64); |
50 | if (scales_ == nullptr) return status::out_of_memory; |
51 | |
52 | for (dim_t c = 0; c < count_; ++c) |
53 | scales_[c] = scales[c]; |
54 | } |
55 | |
56 | return status::success; |
57 | } |
58 | |
59 | status_t zero_points_t::get(int arg, int *mask) const { |
60 | if (mask) *mask = get_mask(arg); |
61 | return status::success; |
62 | } |
63 | |
64 | status_t zero_points_t::set(int arg, int mask) { |
65 | const bool supported_arg |
66 | = utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST); |
67 | if (!supported_arg) return status::unimplemented; |
68 | |
69 | switch (arg) { |
70 | case DNNL_ARG_SRC: |
71 | is_set_src = true; |
72 | mask_src = mask; |
73 | break; |
74 | case DNNL_ARG_WEIGHTS: |
75 | is_set_wei = true; |
76 | mask_wei = mask; |
77 | break; |
78 | case DNNL_ARG_DST: |
79 | is_set_dst = true; |
80 | mask_dst = mask; |
81 | break; |
82 | } |
83 | return status::success; |
84 | } |
85 | |
86 | } // namespace impl |
87 | } // namespace dnnl |
88 | |
89 | bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask, |
90 | dnnl::impl::data_type_t dst_dt) const { |
91 | using smask_t = skip_mask_t; |
92 | // prepare mask for runtime-parameters check |
93 | smask_t defined_mask = smask_t::none; |
94 | if ((mask & smask_t::oscale_runtime) == smask_t::oscale_runtime) |
95 | defined_mask |= smask_t::oscale; |
96 | if ((mask & smask_t::scales_runtime) == smask_t::scales_runtime) |
97 | defined_mask |= smask_t::scales; |
98 | if ((mask & smask_t::zero_points_runtime) == smask_t::zero_points_runtime) |
99 | defined_mask |= smask_t::zero_points; |
100 | bool ok = true; |
101 | |
102 | #define CHECK_ARG(x) ok = ok && (x) |
103 | #define CHECK_MASK(mask_name, mask_field) \ |
104 | CHECK_ARG(IMPLICATION( \ |
105 | (bool)(~mask & (mask_name)), (mask_field).has_default_values())) |
106 | CHECK_MASK(smask_t::oscale_runtime, output_scales_); |
107 | CHECK_MASK(smask_t::scales, scales_); |
108 | CHECK_MASK(smask_t::zero_points, zero_points_); |
109 | CHECK_MASK(smask_t::post_ops, post_ops_); |
110 | CHECK_MASK(smask_t::rnn_data_qparams, rnn_data_qparams_); |
111 | CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_); |
112 | CHECK_MASK(smask_t::rnn_weights_projection_qparams, |
113 | rnn_weights_projection_qparams_); |
114 | CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::sum_dt), |
115 | post_ops_.sum_with_default_dt(dst_dt))); |
116 | bool gpu_attr_ok = IMPLICATION((bool)(~mask & smask_t::gpu_attr), |
117 | !gpu_attr_ || gpu_attr_->has_default_values()); |
118 | CHECK_ARG(gpu_attr_ok); |
119 | CHECK_ARG(this->defined(defined_mask)); |
120 | return ok; |
121 | #undef CHECK_MASK |
122 | #undef CHECK_ARG |
123 | } |
124 | |
125 | bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const { |
126 | using smask_t = skip_mask_t; |
127 | bool ok = true; |
128 | #define CHECK_ARG(x) ok = ok && (x) |
129 | #define CHECK_MASK(mask_name, mask_field) \ |
130 | CHECK_ARG(IMPLICATION((bool)(~mask & (mask_name)), (mask_field).defined())) |
131 | CHECK_MASK(smask_t::oscale, output_scales_); |
132 | CHECK_MASK(smask_t::scales, scales_); |
133 | CHECK_MASK(smask_t::zero_points, zero_points_); |
134 | CHECK_MASK(smask_t::post_ops, post_ops_); |
135 | CHECK_MASK(smask_t::rnn_data_qparams, rnn_data_qparams_); |
136 | CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_); |
137 | CHECK_MASK(smask_t::rnn_weights_projection_qparams, |
138 | rnn_weights_projection_qparams_); |
139 | return ok; |
140 | #undef CHECK_MASK |
141 | #undef CHECK_ARG |
142 | } |
143 | |
144 | status_t post_ops_t::append_sum( |
145 | float scale, int32_t zero_point, data_type_t dt) { |
146 | if (len() == post_ops_limit) return out_of_memory; |
147 | entry_.emplace_back(); |
148 | auto &e = entry_.back(); |
149 | e.kind = primitive_kind::sum; |
150 | e.sum.scale = scale; |
151 | e.sum.zero_point = zero_point; |
152 | e.sum.dt = dt; |
153 | return success; |
154 | } |
155 | |
156 | status_t post_ops_t::append_eltwise( |
157 | float scale, alg_kind_t alg, float alpha, float beta) { |
158 | if (len() == post_ops_limit) return out_of_memory; |
159 | if (!math::is_eltwise_ok(data_type::f32, alg, alpha, beta)) |
160 | return invalid_arguments; |
161 | |
162 | entry_.emplace_back(); |
163 | auto &e = entry_.back(); |
164 | e.kind = primitive_kind::eltwise; |
165 | e.eltwise.scale = scale; |
166 | e.eltwise.alg = alg; |
167 | e.eltwise.alpha = alpha; |
168 | e.eltwise.beta = beta; |
169 | return success; |
170 | } |
171 | |
172 | status_t post_ops_t::append_dw(data_type_t wei_dt, data_type_t bias_dt, |
173 | data_type_t dst_dt, dim_t kernel_size, dim_t stride_size, |
174 | dim_t padding_l_size) { |
175 | if (len() == post_ops_limit) return out_of_memory; |
176 | bool ok = wei_dt != data_type::undef && dst_dt != data_type::undef; |
177 | if (!ok) return invalid_arguments; |
178 | |
179 | ok = ok && kernel_size > 0 && stride_size > 0; |
180 | if (!ok) return invalid_arguments; |
181 | |
182 | // Avoiding cases when kernel in pad area |
183 | ok = ok && (padding_l_size + 1) <= kernel_size; |
184 | if (!ok) return invalid_arguments; |
185 | |
186 | entry_.emplace_back(); |
187 | auto &e = entry_.back(); |
188 | e.kind = primitive_kind::convolution; |
189 | auto &d = e.depthwise_conv; |
190 | d.kernel = kernel_size; |
191 | d.stride = stride_size; |
192 | d.padding = padding_l_size; |
193 | d.wei_dt = wei_dt; |
194 | d.bias_dt = bias_dt; |
195 | d.dst_dt = dst_dt; |
196 | |
197 | return success; |
198 | } |
199 | |
200 | status_t post_ops_t::validate_binary( |
201 | alg_kind_t alg, const memory_desc_t *user_src1_desc) const { |
202 | |
203 | if (len() == post_ops_limit) return out_of_memory; |
204 | using namespace alg_kind; |
205 | bool alg_ok = one_of(alg, binary_add, binary_mul, binary_max, binary_min, |
206 | binary_div, binary_sub, binary_ge, binary_gt, binary_le, binary_lt, |
207 | binary_eq, binary_ne); |
208 | if (!alg_ok) return invalid_arguments; |
209 | if (!memory_desc_sanity_check(*user_src1_desc)) return invalid_arguments; |
210 | |
211 | // Additional check to restrict run-time dimension usage until supported. |
212 | for (int d = 0; d < user_src1_desc->ndims; ++d) { |
213 | if (user_src1_desc->dims[d] == DNNL_RUNTIME_DIM_VAL) |
214 | return invalid_arguments; |
215 | } |
216 | |
217 | return success; |
218 | } |
219 | |
220 | status_t post_ops_t::append_binary( |
221 | alg_kind_t alg, const memory_desc_t *user_src1_desc) { |
222 | auto status = validate_binary(alg, user_src1_desc); |
223 | if (status != success) return status; |
224 | |
225 | entry_.emplace_back(); |
226 | auto &e = entry_.back(); |
227 | e.kind = primitive_kind::binary; |
228 | e.binary.alg = alg; |
229 | e.binary.user_src1_desc = *user_src1_desc; |
230 | e.binary.src1_desc = *user_src1_desc; |
231 | return success; |
232 | } |
233 | |
234 | status_t post_ops_t::prepend_binary( |
235 | alg_kind_t alg, const memory_desc_t *user_src1_desc) { |
236 | auto status = validate_binary(alg, user_src1_desc); |
237 | if (status != success) return status; |
238 | |
239 | entry_.emplace(entry_.begin()); |
240 | auto &e = entry_[0]; |
241 | e.kind = primitive_kind::binary; |
242 | e.binary.alg = alg; |
243 | e.binary.user_src1_desc = *user_src1_desc; |
244 | e.binary.src1_desc = *user_src1_desc; |
245 | return success; |
246 | } |
247 | |
248 | status_t post_ops_t::append_prelu(int mask) { |
249 | if (len() == post_ops_limit) return out_of_memory; |
250 | |
251 | auto it_entry = entry_.emplace(entry_.end()); |
252 | it_entry->kind = primitive_kind::prelu; |
253 | it_entry->prelu.mask = mask; |
254 | |
255 | return success; |
256 | } |
257 | |
258 | bool post_ops_t::defined() const { |
259 | for (int idx = 0; idx < len(); ++idx) { |
260 | auto kind = entry_[idx].kind; |
261 | if (kind == primitive_kind::sum) { |
262 | if (is_runtime_value(entry_[idx].sum.scale)) return false; |
263 | } else if (kind == primitive_kind::eltwise) { |
264 | const auto &e = entry_[idx].eltwise; |
265 | if (is_runtime_value(e.scale) || is_runtime_value(e.alpha) |
266 | || is_runtime_value(e.beta)) |
267 | return false; |
268 | } else if (utils::one_of(kind, primitive_kind::binary, |
269 | primitive_kind::prelu, |
270 | primitive_kind::convolution)) { |
271 | // binary is always defined |
272 | } else { |
273 | assert(!"unreachable" ); |
274 | } |
275 | } |
276 | return true; |
277 | } |
278 | |
279 | status_t post_ops_t::set_default_formats(const memory_desc_t *dst_md) { |
280 | for (int idx = 0; idx < len(); ++idx) { |
281 | if (!contain(primitive_kind::binary, idx)) continue; |
282 | |
283 | auto &src1_md = entry_[idx].binary.src1_desc; |
284 | const memory_desc_wrapper src1_mdw(src1_md); |
285 | if (!src1_mdw.format_any()) continue; |
286 | |
287 | const memory_desc_wrapper dst_mdw(dst_md); |
288 | assert(!dst_mdw.format_any()); |
289 | |
290 | // 1D tensors should be plain abx. |
291 | if (src1_mdw.count_non_unit_dims(1)) |
292 | CHECK(memory_desc_init_by_strides(src1_md, nullptr)); |
293 | else |
294 | CHECK(memory_desc_init_by_blocking_desc( |
295 | src1_md, dst_mdw.blocking_desc())); |
296 | } |
297 | |
298 | return status::success; |
299 | } |
300 | |
301 | bool post_ops_t::check_sum_consistent_dt( |
302 | const data_type_t dst_dt, const bool diverse_sum_dt_allowed) const { |
303 | int sum_ind = find(dnnl::impl::primitive_kind::sum); |
304 | if (sum_ind == -1) return true; |
305 | const auto sum_dt = entry_[sum_ind].sum.dt; |
306 | |
307 | // sum dt and dst dt must have the same size |
308 | const bool compatible_dt_size = IMPLICATION( |
309 | !utils::one_of(dnnl_data_type_undef, sum_dt, dst_dt), |
310 | types::data_type_size(dst_dt) == types::data_type_size(sum_dt)); |
311 | if (!compatible_dt_size) return false; |
312 | if (diverse_sum_dt_allowed) return true; |
313 | |
314 | bool ok = true; |
315 | while ((sum_ind = find(dnnl::impl::primitive_kind::sum, sum_ind + 1)) != -1) |
316 | ok = ok && entry_[sum_ind].sum.dt == sum_dt; |
317 | return ok; |
318 | } |
319 | |
320 | status_t primitive_attr_t::set_fpmath_mode(fpmath_mode_t fpmath_mode) { |
321 | auto st = check_fpmath_mode(fpmath_mode); |
322 | if (st == success) fpmath_mode_ = fpmath_mode; |
323 | return st; |
324 | } |
325 | |
326 | status_t primitive_attr_t::set_scratchpad_mode( |
327 | scratchpad_mode_t scratchpad_mode) { |
328 | const bool ok = one_of( |
329 | scratchpad_mode, scratchpad_mode::library, scratchpad_mode::user); |
330 | if (!ok) return invalid_arguments; |
331 | |
332 | scratchpad_mode_ = scratchpad_mode; |
333 | return success; |
334 | } |
335 | |
336 | status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) { |
337 | return post_ops_.copy_from(post_ops); |
338 | } |
339 | |
340 | status_t primitive_attr_t::set_default_formats(const memory_desc_t *dst_md) { |
341 | return post_ops_.set_default_formats(dst_md); |
342 | } |
343 | |
344 | status_t primitive_attr_t::set_gpu_attr(const primitive_attr_item_t &gpu_attr) { |
345 | gpu_attr_ = gpu_attr.clone(); |
346 | return status::success; |
347 | } |
348 | |
349 | /* Public C API */ |
350 | |
351 | status_t dnnl_primitive_attr_create(primitive_attr_t **attr) { |
352 | if (attr == nullptr) return invalid_arguments; |
353 | |
354 | return safe_ptr_assign(*attr, new dnnl_primitive_attr); |
355 | } |
356 | |
357 | status_t dnnl_primitive_attr_clone( |
358 | primitive_attr_t **attr, const primitive_attr_t *existing_attr) { |
359 | if (any_null(attr, existing_attr)) return invalid_arguments; |
360 | |
361 | auto new_attr = utils::make_unique<primitive_attr_t>(*existing_attr); |
362 | if (!new_attr->is_initialized()) return out_of_memory; |
363 | |
364 | return safe_ptr_assign(*attr, new_attr.release()); |
365 | } |
366 | |
367 | status_t dnnl_primitive_attr_destroy(primitive_attr_t *attr) { |
368 | delete attr; |
369 | |
370 | return success; |
371 | } |
372 | |
373 | status_t dnnl_primitive_attr_get_fpmath_mode( |
374 | const primitive_attr_t *attr, fpmath_mode_t *mode) { |
375 | if (any_null(attr, mode)) return invalid_arguments; |
376 | *mode = attr->fpmath_mode_; |
377 | return success; |
378 | } |
379 | |
380 | status_t dnnl_primitive_attr_set_fpmath_mode( |
381 | primitive_attr_t *attr, fpmath_mode_t mode) { |
382 | if (any_null(attr)) return invalid_arguments; |
383 | return attr->set_fpmath_mode(mode); |
384 | } |
385 | |
386 | status_t dnnl_primitive_attr_get_scratchpad_mode( |
387 | const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) { |
388 | if (any_null(attr, scratchpad_mode)) return invalid_arguments; |
389 | |
390 | *scratchpad_mode = attr->scratchpad_mode_; |
391 | |
392 | return success; |
393 | } |
394 | |
395 | status_t dnnl_primitive_attr_set_scratchpad_mode( |
396 | primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) { |
397 | if (any_null(attr)) return invalid_arguments; |
398 | |
399 | return attr->set_scratchpad_mode(scratchpad_mode); |
400 | } |
401 | |
402 | status_t dnnl_primitive_attr_set_scales_mask( |
403 | primitive_attr_t *attr, int arg, int mask) { |
404 | bool ok = attr && mask >= 0 && arg >= 0 |
405 | && attr->output_scales_.has_default_values(); |
406 | if (!ok) return invalid_arguments; |
407 | return attr->scales_.set(arg, mask); |
408 | } |
409 | |
410 | status_t dnnl_primitive_attr_set_zero_points_mask( |
411 | primitive_attr_t *attr, int arg, int mask) { |
412 | bool ok = attr && mask >= 0; |
413 | if (!ok) return invalid_arguments; |
414 | |
415 | return attr->zero_points_.set(arg, mask); |
416 | } |
417 | |
418 | status_t dnnl_primitive_attr_get_post_ops( |
419 | const primitive_attr_t *attr, const post_ops_t **post_ops) { |
420 | if (any_null(attr, post_ops)) return invalid_arguments; |
421 | |
422 | *post_ops = &attr->post_ops_; |
423 | return success; |
424 | } |
425 | |
426 | status_t dnnl_primitive_attr_set_post_ops( |
427 | primitive_attr_t *attr, const post_ops_t *post_ops) { |
428 | if (any_null(attr, post_ops)) return invalid_arguments; |
429 | |
430 | return attr->set_post_ops(*post_ops); |
431 | } |
432 | |
433 | status_t dnnl_post_ops_create(post_ops_t **post_ops) { |
434 | if (post_ops == nullptr) return invalid_arguments; |
435 | |
436 | return safe_ptr_assign(*post_ops, new dnnl_post_ops); |
437 | } |
438 | |
439 | status_t dnnl_post_ops_clone( |
440 | post_ops_t **post_ops, const post_ops_t *existing_post_ops) { |
441 | if (any_null(post_ops, existing_post_ops)) return invalid_arguments; |
442 | |
443 | auto new_post_ops = utils::make_unique<post_ops_t>(*existing_post_ops); |
444 | if (!new_post_ops->is_initialized()) return out_of_memory; |
445 | |
446 | return safe_ptr_assign(*post_ops, new_post_ops.release()); |
447 | } |
448 | |
449 | status_t dnnl_post_ops_destroy(post_ops_t *post_ops) { |
450 | delete post_ops; |
451 | |
452 | return success; |
453 | } |
454 | |
455 | int dnnl_post_ops_len(const post_ops_t *post_ops) { |
456 | if (post_ops) return post_ops->len(); |
457 | |
458 | return 0; |
459 | } |
460 | |
461 | primitive_kind_t dnnl_post_ops_get_kind(const post_ops_t *post_ops, int index) { |
462 | bool ok = post_ops && 0 <= index && index < post_ops->len(); |
463 | if (!ok) return primitive_kind::undefined; |
464 | |
465 | return post_ops->entry_[index].kind; |
466 | } |
467 | |
468 | status_t dnnl_post_ops_append_sum( |
469 | post_ops_t *post_ops, float scale, int32_t zero_point, data_type_t dt) { |
470 | if (post_ops == nullptr) return invalid_arguments; |
471 | |
472 | return post_ops->append_sum(scale, zero_point, dt); |
473 | } |
474 | |
475 | namespace { |
476 | bool simple_get_params_check( |
477 | const post_ops_t *post_ops, int index, primitive_kind_t kind) { |
478 | bool ok = true && post_ops != nullptr && 0 <= index |
479 | && index < post_ops->len() && post_ops->entry_[index].kind == kind; |
480 | return ok; |
481 | } |
482 | } // namespace |
483 | |
484 | status_t dnnl_post_ops_get_params_sum(const post_ops_t *post_ops, int index, |
485 | float *scale, int32_t *zero_point, data_type_t *dt) { |
486 | bool ok = true |
487 | && simple_get_params_check(post_ops, index, primitive_kind::sum); |
488 | if (!ok) return invalid_arguments; |
489 | |
490 | if (scale) *scale = post_ops->entry_[index].sum.scale; |
491 | if (zero_point) *zero_point = post_ops->entry_[index].sum.zero_point; |
492 | if (dt) *dt = post_ops->entry_[index].sum.dt; |
493 | return success; |
494 | } |
495 | |
496 | status_t dnnl_post_ops_append_eltwise( |
497 | post_ops_t *post_ops, alg_kind_t kind, float alpha, float beta) { |
498 | if (post_ops == nullptr) return invalid_arguments; |
499 | |
500 | return post_ops->append_eltwise(1.0f, kind, alpha, beta); |
501 | } |
502 | |
503 | status_t dnnl_post_ops_get_params_eltwise(const post_ops_t *post_ops, int index, |
504 | alg_kind_t *alg, float *alpha, float *beta) { |
505 | bool ok = true |
506 | && simple_get_params_check(post_ops, index, primitive_kind::eltwise) |
507 | && !any_null(alpha, beta); |
508 | if (!ok) return invalid_arguments; |
509 | |
510 | const auto &e = post_ops->entry_[index].eltwise; |
511 | *alg = e.alg; |
512 | *alpha = e.alpha; |
513 | *beta = e.beta; |
514 | |
515 | return success; |
516 | } |
517 | |
518 | status_t dnnl_post_ops_append_dw(post_ops_t *post_ops, data_type_t wei_dt, |
519 | data_type_t bias_dt, data_type_t dst_dt, dim_t kernel_size, |
520 | dim_t stride_size, dim_t padding_l_size) { |
521 | if (post_ops == nullptr) return invalid_arguments; |
522 | |
523 | return post_ops->append_dw( |
524 | wei_dt, bias_dt, dst_dt, kernel_size, stride_size, padding_l_size); |
525 | } |
526 | |
527 | status_t dnnl_post_ops_get_params_dw(const post_ops_t *post_ops, int index, |
528 | data_type_t *wei_dt, data_type_t *bias_dt, data_type_t *dst_dt, |
529 | dim_t *kernel, dim_t *stride, dim_t *padding) { |
530 | |
531 | if (!simple_get_params_check(post_ops, index, primitive_kind::convolution)) |
532 | return invalid_arguments; |
533 | |
534 | const auto &d = post_ops->entry_[index].depthwise_conv; |
535 | if (wei_dt) *wei_dt = d.wei_dt; |
536 | if (bias_dt) *bias_dt = d.bias_dt; |
537 | if (dst_dt) *dst_dt = d.dst_dt; |
538 | if (kernel) *kernel = d.kernel; |
539 | if (stride) *stride = d.stride; |
540 | if (padding) *padding = d.padding; |
541 | |
542 | return success; |
543 | } |
544 | |
545 | status_t dnnl_post_ops_append_binary(post_ops_t *post_ops, alg_kind_t alg_kind, |
546 | const memory_desc_t *user_src1_desc) { |
547 | if (post_ops == nullptr) return invalid_arguments; |
548 | |
549 | return post_ops->append_binary(alg_kind, user_src1_desc); |
550 | } |
551 | |
552 | status_t dnnl_post_ops_get_params_binary(const post_ops_t *post_ops, int index, |
553 | alg_kind_t *alg_kind, const memory_desc_t **user_src1_desc) { |
554 | if (!simple_get_params_check(post_ops, index, primitive_kind::binary)) |
555 | return invalid_arguments; |
556 | |
557 | const auto &b = post_ops->entry_[index].binary; |
558 | if (alg_kind) *alg_kind = b.alg; |
559 | if (user_src1_desc) *user_src1_desc = &b.user_src1_desc; |
560 | |
561 | return success; |
562 | } |
563 | |
564 | status_t dnnl_post_ops_append_prelu(post_ops_t *post_ops, int mask) { |
565 | if (post_ops == nullptr) return invalid_arguments; |
566 | |
567 | return post_ops->append_prelu(mask); |
568 | } |
569 | |
570 | status_t dnnl_post_ops_get_params_prelu( |
571 | const post_ops_t *post_ops, int index, int *mask) { |
572 | if (post_ops == nullptr || index >= post_ops->len()) |
573 | return invalid_arguments; |
574 | |
575 | const auto &prelu_entry = post_ops->entry_[index].prelu; |
576 | if (mask) *mask = prelu_entry.mask; |
577 | |
578 | return success; |
579 | } |
580 | |
581 | status_t dnnl_primitive_attr_set_rnn_data_qparams( |
582 | primitive_attr_t *attr, const float scale, const float shift) { |
583 | if (attr == nullptr) return invalid_arguments; |
584 | |
585 | return attr->rnn_data_qparams_.set(scale, shift); |
586 | } |
587 | |
588 | status_t dnnl_primitive_attr_get_rnn_data_qparams( |
589 | const primitive_attr_t *attr, float *scale, float *shift) { |
590 | if (attr == nullptr) return invalid_arguments; |
591 | |
592 | const auto qparams = attr->rnn_data_qparams_; |
593 | if (scale) *scale = qparams.scale_; |
594 | if (shift) *shift = qparams.shift_; |
595 | |
596 | return success; |
597 | } |
598 | |
599 | status_t dnnl_primitive_attr_set_rnn_weights_qparams( |
600 | primitive_attr_t *attr, dim_t count, int mask, const float *scales) { |
601 | bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; |
602 | if (!ok) return invalid_arguments; |
603 | |
604 | return attr->rnn_weights_qparams_.set(count, mask, scales); |
605 | } |
606 | |
607 | status_t dnnl_primitive_attr_get_rnn_weights_qparams( |
608 | const primitive_attr_t *attr, dim_t *count, int *mask, |
609 | const float **scales) { |
610 | if (attr == nullptr) return invalid_arguments; |
611 | |
612 | const auto &qparams = attr->rnn_weights_qparams_; |
613 | if (count) *count = qparams.count_; |
614 | if (mask) *mask = qparams.mask_; |
615 | if (scales) *scales = qparams.scales_; |
616 | |
617 | return success; |
618 | } |
619 | |
620 | status_t dnnl_primitive_attr_set_rnn_weights_projection_qparams( |
621 | primitive_attr_t *attr, dim_t count, int mask, const float *scales) { |
622 | bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; |
623 | if (!ok) return invalid_arguments; |
624 | |
625 | return attr->rnn_weights_projection_qparams_.set(count, mask, scales); |
626 | } |
627 | |
628 | status_t dnnl_primitive_attr_get_rnn_weights_projection_qparams( |
629 | const primitive_attr_t *attr, dim_t *count, int *mask, |
630 | const float **scales) { |
631 | if (attr == nullptr) return invalid_arguments; |
632 | |
633 | const auto &qparams = attr->rnn_weights_projection_qparams_; |
634 | if (count) *count = qparams.count_; |
635 | if (mask) *mask = qparams.mask_; |
636 | if (scales) *scales = qparams.scales_; |
637 | |
638 | return success; |
639 | } |
640 | |
641 | status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams( |
642 | dnnl_primitive_attr_t attr, bool mode, dim_t ngates, |
643 | const float *scales, float cscale) { |
644 | if (attr == nullptr) return invalid_arguments; |
645 | |
646 | return attr->rnn_tparams_.set(mode, ngates, scales, cscale); |
647 | } |
648 | |