1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_PRIMITIVE_ATTR_HPP
18#define COMMON_PRIMITIVE_ATTR_HPP
19
20#include <map>
21#include <initializer_list>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "c_types_map.hpp"
26#include "nstl.hpp"
27#include "type_helpers.hpp"
28#include "utils.hpp"
29
30namespace dnnl {
31namespace impl {
32
33const primitive_attr_t &default_attr();
34
35struct rnn_data_qparams_t : public c_compatible {
36 rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
37 bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
38 bool defined() const {
39 return !is_runtime_value(scale_) && !is_runtime_value(shift_);
40 }
41
42 status_t set(float scale, float shift) {
43 scale_ = scale;
44 shift_ = shift;
45 return status::success;
46 }
47
48 bool operator==(const rnn_data_qparams_t &rhs) const {
49 using namespace utils;
50 return equal_with_nan(scale_, rhs.scale_)
51 && equal_with_nan(shift_, rhs.shift_);
52 }
53
54 float scale_;
55 float shift_;
56};
57
58struct rnn_tparams_t : public c_compatible {
59 rnn_tparams_t()
60 : test_mode_(false), scales_(nullptr), ngates_(0), cscale_(0.0f) {}
61
62 ~rnn_tparams_t() {
63 test_mode_ = false;
64 if (scales_ != nullptr) impl::free(scales_);
65 scales_ = nullptr;
66 ngates_ = 0;
67 cscale_ = 0.0f;
68 }
69
70 bool operator==(const rnn_tparams_t &rhs) const {
71 using namespace utils;
72
73 bool ret = test_mode_ == rhs.test_mode_ && ngates_ == rhs.ngates_
74 && equal_with_nan(cscale_, rhs.cscale_);
75
76 if (!ret) return ret;
77
78 if (scales_) {
79 if (std::memcmp(scales_, rhs.scales_, sizeof(float) * ngates_))
80 return false;
81 }
82 return true;
83 }
84
85 bool has_default_values() const {
86 return (test_mode_ == false && scales_ == nullptr && ngates_ == 0
87 && cscale_ == 0.0f);
88 }
89
90 status_t set(bool mode, dim_t ngates, const float *scales, float cscale) {
91 test_mode_ = mode;
92 ngates_ = ngates;
93 scales_ = nullptr;
94 if (scales != nullptr) {
95 scales_ = (float *)impl::malloc(ngates_ * sizeof(*scales_), 64);
96 if (scales_ == nullptr) return status::out_of_memory;
97 utils::array_copy(scales_, scales, ngates_);
98 }
99
100 cscale_ = cscale;
101
102 return status::success;
103 }
104
105 // copy_from() functions are used for each attribute member instead of
106 // operator= in order to return a status.
107 // TODO: consider replacing copy_from() functions with copy-constructors and
108 // std::move, since there are only a few places in the library that actually
109 // use them.
110 status_t copy_from(const rnn_tparams_t &other) {
111 return set(
112 other.test_mode_, other.ngates_, other.scales_, other.cscale_);
113 }
114
115 bool test_mode_; /* we could also use scale_ == nullptr as a test to check test_mode*/
116 float *scales_;
117 dim_t ngates_; /* ngates is equel to the number of scales */
118 float cscale_; /* =0.0f if no c state */
119
120private:
121 DNNL_DISALLOW_COPY_AND_ASSIGN(rnn_tparams_t);
122};
123
124// Note: keep for RNN quantization
125struct scales_t : public c_compatible {
126 scales_t() : count_(1), mask_(0), scales_(scales_buf_) { set(1.); }
127 scales_t(dim_t count, int mask, const float *scales)
128 : scales_(scales_buf_) {
129 set(count, mask, scales);
130 }
131
132 ~scales_t() { cleanup(); }
133
134 bool operator==(const scales_t &rhs) const {
135 bool ret = count_ == rhs.count_ && mask_ == rhs.mask_
136 && !utils::any_null(scales_, rhs.scales_)
137 && defined() == rhs.defined()
138 && IMPLICATION(defined(),
139 !std::memcmp(
140 scales_, rhs.scales_, sizeof(float) * count_));
141 return ret;
142 }
143
144 bool has_default_values() const {
145 for (dim_t c = 0; c < count_; ++c) {
146 if (scales_[c] != 1.) return false;
147 }
148 return true;
149 }
150
151 bool defined() const { return !is_runtime_value(scales_[0]); }
152
153 status_t set(dim_t count, int mask, const float *scales);
154 status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
155
156 status_t copy_from(const scales_t &other) {
157 return set(other.count_, other.mask_, other.scales_);
158 }
159
160 dim_t count_;
161 int mask_;
162 float *scales_;
163
164private:
165 enum { scales_buf_size = 16 };
166 float scales_buf_[scales_buf_size];
167
168 void cleanup() {
169 if (scales_ != scales_buf_ && scales_ != nullptr) impl::free(scales_);
170
171 count_ = 1;
172 mask_ = 0;
173 scales_ = scales_buf_;
174 }
175
176 DNNL_DISALLOW_COPY_AND_ASSIGN(scales_t);
177};
178
179struct runtime_scales_t : public c_compatible {
180 // Clang-3.8.1 raises an error for a default initialization of a const
181 // object. Const runtime_scales_t object is used as default_scales.
182 // runtime_scales_t() = default;
183 runtime_scales_t() {}
184
185 status_t set(int mask) {
186 mask_ = mask;
187 is_set_ = true;
188 return status::success;
189 }
190
191 bool operator==(const runtime_scales_t &rhs) const {
192 return mask_ == rhs.mask_ && is_set_ == rhs.is_set_;
193 }
194
195 bool has_default_values() const { return !is_set_; }
196
197 bool defined() const { return has_default_values(); }
198
199 void reset() {
200 mask_ = 0;
201 is_set_ = false;
202 }
203
204 // TODO: replace with `-1` to remove `is_set_`.
205 // Hide `mask_` under `private:` to force interface usage.
206 int mask_ = 0;
207 bool is_set_ = false;
208};
209
210struct arg_scales_t : public c_compatible {
211 arg_scales_t() = default;
212
213 const runtime_scales_t &get(int arg) const {
214 static const runtime_scales_t default_scales;
215 const auto it = scales_.find(arg);
216 if (it == scales_.end()) return default_scales;
217 return it->second;
218 }
219
220 bool operator==(const arg_scales_t &rhs) const {
221 return scales_ == rhs.scales_;
222 }
223
224 bool has_default_values(const std::vector<int> &skip_args = {}) const {
225 for (const auto &s : scales_) {
226 if (!s.second.has_default_values()) {
227 bool skip = false;
228 for (const auto &skip_a : skip_args)
229 if (s.first == skip_a) {
230 skip = true;
231 break;
232 }
233 if (skip) continue;
234 return false;
235 }
236 }
237 return true;
238 }
239
240 status_t set(int arg, int mask) {
241 if (!check_arg(arg)) return status::invalid_arguments;
242 return scales_[arg].set(mask);
243 }
244
245 status_t get(int arg, int *mask, bool *is_set) const {
246 if (!check_arg(arg)) return status::invalid_arguments;
247 const auto &s = get(arg);
248 if (mask) *mask = s.mask_;
249 if (is_set) *is_set = s.is_set_;
250 return status::success;
251 }
252
253 status_t reset(int arg) {
254 if (!check_arg(arg)) return status::invalid_arguments;
255 const auto it = scales_.find(arg);
256 if (it != scales_.end()) scales_.erase(it);
257 return status::success;
258 }
259
260 bool defined() const { return has_default_values(); }
261
262 status_t copy_from(const arg_scales_t &other) {
263 for (auto it = other.scales_.begin(); it != other.scales_.end(); ++it) {
264 // Find an entry that can match the arguments without constructing a
265 // new object.
266 if (scales_.count(it->first) == 1) {
267 auto &entry = scales_[it->first];
268 bool exists = entry.mask_ == it->second.mask_;
269 if (exists) continue;
270 }
271
272 CHECK(set(it->first, it->second.mask_));
273 }
274 return status::success;
275 }
276
277 std::map<int, runtime_scales_t> scales_;
278
279private:
280 bool check_arg(int arg) const {
281 // binary
282 for (const auto &sa : {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}) {
283 if (arg == sa) return true;
284 }
285 // concat
286 if (arg & DNNL_ARG_MULTIPLE_SRC) return true;
287 // convolution
288 for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
289 if (arg == sa) return true;
290 }
291 // depth-wise convolution post op
292 for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
293 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | sa)) return true;
294 }
295 return false;
296 }
297};
298
299struct zero_points_t : public c_compatible {
300 bool operator==(const zero_points_t &rhs) const {
301 return mask_src == rhs.mask_src && mask_wei == rhs.mask_wei
302 && mask_dst == rhs.mask_dst && is_set_src == rhs.is_set_src
303 && is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst;
304 }
305
306 // arg-specific checks
307 bool common(int arg) const { return get_mask(arg) == 0; }
308 bool defined(int arg) const { return has_default_values(arg); }
309 bool has_default_values(int arg) const { return is_set(arg) == false; }
310
311 // same checks but for all supported arguments at once
312 bool common() const { return check_all(&zero_points_t::common); }
313 bool defined() const { return has_default_values(); }
314 bool has_default_values() const {
315 return check_all(&zero_points_t::has_default_values);
316 }
317
318 status_t get(int arg, int *mask) const;
319
320 status_t set(int arg, int mask);
321 status_t set(int arg) { return set(arg, 0); }
322
323private:
324 bool is_set_src = false, is_set_wei = false, is_set_dst = false;
325 int mask_src = 0, mask_wei = 0, mask_dst = 0;
326
327 int get_mask(int arg) const {
328 int mask = 0;
329 switch (arg) {
330 case DNNL_ARG_SRC: mask = mask_src; break;
331 case DNNL_ARG_WEIGHTS: mask = mask_wei; break;
332 case DNNL_ARG_DST: mask = mask_dst; break;
333 default: mask = 0;
334 }
335 return mask;
336 }
337
338 bool is_set(int arg) const {
339 bool arg_is_set = false;
340 switch (arg) {
341 case DNNL_ARG_SRC: arg_is_set = is_set_src; break;
342 case DNNL_ARG_WEIGHTS: arg_is_set = is_set_wei; break;
343 case DNNL_ARG_DST: arg_is_set = is_set_dst; break;
344 default: arg_is_set = 0;
345 }
346 return arg_is_set;
347 }
348
349 bool check_all(bool (zero_points_t::*f)(int) const) const {
350 for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
351 if (!(this->*f)(arg)) return false;
352 return true;
353 }
354};
355
356struct serialization_stream_t;
357
358struct primitive_attr_item_t {
359 virtual std::unique_ptr<primitive_attr_item_t> clone() const = 0;
360 virtual bool has_default_values() const = 0;
361 virtual bool is_equal(const primitive_attr_item_t &other) const = 0;
362 virtual size_t get_hash() const = 0;
363 virtual void serialize(serialization_stream_t &stream) const = 0;
364 virtual ~primitive_attr_item_t() = default;
365};
366
367} // namespace impl
368} // namespace dnnl
369
370struct dnnl_post_ops : public dnnl::impl::c_compatible {
371 struct entry_t {
372 entry_t() : kind(dnnl::impl::primitive_kind::undefined) {}
373 entry_t(const entry_t &other) { copy_from(other); }
374
375 dnnl::impl::status_t copy_from(const entry_t &other) {
376 return set(other);
377 }
378
379 // TODO: This operator has to be deleted, and its usage has to be
380 // replaced with copy_from() or copy/move constructors in order to
381 // extract a status.
382 entry_t &operator=(const entry_t &other) {
383 DNNL_SHORT_CIRCUIT_SELF_ASSIGN(other);
384 set(other);
385 return *this;
386 }
387
388 struct eltwise_t {
389 dnnl::impl::alg_kind_t alg;
390 float scale, alpha, beta;
391 };
392
393 struct depthwise_conv_t {
394 dnnl::impl::dim_t kernel;
395 dnnl::impl::dim_t stride;
396 dnnl::impl::dim_t padding;
397 dnnl::impl::data_type_t wei_dt;
398 dnnl::impl::data_type_t bias_dt;
399 dnnl::impl::data_type_t dst_dt;
400 };
401
402 struct binary_t {
403 dnnl::impl::alg_kind_t alg;
404 // This is an unmodifiable user copy of attributes which is used in
405 // caching mechanism. Not to be used internally.
406 dnnl::impl::memory_desc_t user_src1_desc;
407 // This is a modifiable copy of memory desc. It changes format kind
408 // and tag of md in case user passed format_kind::any. To be used
409 // everywhere internally.
410 dnnl::impl::memory_desc_t src1_desc;
411 };
412
413 struct prelu_t {
414 int mask;
415 };
416
417 dnnl::impl::primitive_kind_t kind
418 = dnnl::impl::primitive_kind::undefined;
419 union {
420 struct {
421 float scale;
422 int32_t zero_point;
423 dnnl::impl::data_type_t dt;
424 } sum;
425 eltwise_t eltwise;
426 depthwise_conv_t depthwise_conv;
427 binary_t binary;
428 prelu_t prelu;
429 };
430
431 bool is_eltwise(bool require_scale_one = false) const {
432 using namespace dnnl::impl;
433 return kind == primitive_kind::eltwise
434 && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
435 }
436
437 bool is_relu(bool require_scale_one = true,
438 bool require_nslope_zero = true) const {
439 using namespace dnnl::impl;
440 return is_eltwise(require_scale_one)
441 && eltwise.alg == alg_kind::eltwise_relu
442 && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
443 }
444
445 bool is_sum(bool require_scale_one = true,
446 bool require_zp_zero = true) const {
447 using namespace dnnl::impl;
448 return kind == primitive_kind::sum
449 && IMPLICATION(require_scale_one, sum.scale == 1.f)
450 && IMPLICATION(require_zp_zero, sum.zero_point == 0);
451 }
452
453 bool is_convolution() const {
454 using namespace dnnl::impl;
455 return kind == primitive_kind::convolution;
456 }
457
458 bool is_binary() const {
459 return kind == dnnl::impl::primitive_kind::binary;
460 }
461
462 bool is_prelu() const {
463 return kind == dnnl::impl::primitive_kind::prelu;
464 }
465
466 dnnl::impl::status_t set_depthwise_scales(const float *scales);
467
468 bool operator==(const entry_t &rhs) const {
469 using namespace dnnl::impl;
470 using namespace dnnl::impl::utils;
471 if (kind != rhs.kind) { return false; }
472 bool ret = true;
473 switch (kind) {
474 case primitive_kind::eltwise:
475 ret = eltwise.alg == rhs.eltwise.alg
476 && equal_with_nan(eltwise.scale, rhs.eltwise.scale)
477 && equal_with_nan(eltwise.alpha, rhs.eltwise.alpha)
478 && equal_with_nan(eltwise.beta, rhs.eltwise.beta);
479 break;
480 case primitive_kind::sum:
481 ret = equal_with_nan(sum.scale, rhs.sum.scale)
482 && sum.zero_point == rhs.sum.zero_point
483 && sum.dt == rhs.sum.dt;
484 break;
485 case primitive_kind::convolution:
486 // Depthwise Only
487 ret = depthwise_conv.kernel == rhs.depthwise_conv.kernel
488 && depthwise_conv.stride
489 == rhs.depthwise_conv.stride
490 && depthwise_conv.padding
491 == rhs.depthwise_conv.padding
492 && depthwise_conv.wei_dt
493 == rhs.depthwise_conv.wei_dt
494 && depthwise_conv.bias_dt
495 == rhs.depthwise_conv.bias_dt
496 && depthwise_conv.dst_dt
497 == rhs.depthwise_conv.dst_dt;
498 break;
499 case primitive_kind::binary:
500 ret = binary.alg == rhs.binary.alg
501 && binary.user_src1_desc
502 == rhs.binary.user_src1_desc;
503 break;
504 case primitive_kind::prelu:
505 ret = prelu.mask == rhs.prelu.mask;
506 break;
507 default: assert(!"unsupported post_op");
508 }
509 return ret;
510 }
511
512 bool operator!=(const entry_t &rhs) const {
513 return !this->operator==(rhs);
514 }
515
516 private:
517 dnnl::impl::status_t set(const entry_t &other) {
518 // Copying by if (is_convolution()) {} else if(is_sum()) {}
519 // else if(is_relu()) {} seems to be unreliable. memcpying for now.
520 dnnl::impl::utils::array_copy(
521 (char *)this, (char *)&other, sizeof(*this));
522 return dnnl::impl::status::success;
523 }
524 };
525
526 dnnl_post_ops() : entry_() {}
527
528 dnnl_post_ops(const dnnl_post_ops &other) {
529 if (copy_from(other) != dnnl::impl::status::success)
530 is_initialized_ = false;
531 }
532
533 dnnl::impl::status_t append_sum(float scale, int32_t zero_point = 0,
534 dnnl::impl::data_type_t dt = dnnl_data_type_undef);
535 dnnl::impl::status_t append_eltwise(
536 float scale, dnnl::impl::alg_kind_t alg, float alpha, float beta);
537 dnnl::impl::status_t append_dw(dnnl::impl::data_type_t wei_dt,
538 dnnl::impl::data_type_t bias_dt, dnnl::impl::data_type_t dst_dt,
539 dnnl::impl::dim_t kernel_size, dnnl::impl::dim_t stride_size,
540 dnnl::impl::dim_t padding_l_size);
541 dnnl::impl::status_t append_binary(dnnl::impl::alg_kind_t alg,
542 const dnnl::impl::memory_desc_t *user_src1_desc);
543 dnnl::impl::status_t append_prelu(int mask);
544
545 dnnl::impl::status_t prepend_binary(dnnl::impl::alg_kind_t alg,
546 const dnnl::impl::memory_desc_t *user_src1_desc);
547
548 int find(dnnl::impl::primitive_kind_t kind, int start = 0,
549 int stop = -1) const {
550 if (stop == -1) stop = len();
551 stop = dnnl::impl::nstl::min(stop, len());
552 for (int idx = start; idx < stop; ++idx)
553 if (entry_[idx].kind == kind) return idx;
554 return -1;
555 }
556
557 dnnl::impl::data_type_t get_sum_dt(
558 const dnnl::impl::data_type_t dst_dt) const {
559 const int sum_ind = find(dnnl::impl::primitive_kind::sum);
560 if (sum_ind == -1) return dst_dt;
561 const auto sum_dt = entry_[sum_ind].sum.dt;
562 if (sum_dt != dnnl::impl::data_type::undef) return sum_dt;
563 return dst_dt;
564 }
565
566 bool defined() const;
567 int len() const { return (int)entry_.size(); }
568 bool has_default_values() const { return len() == 0; }
569
570 dnnl::impl::status_t set_default_formats(
571 const dnnl::impl::memory_desc_t *dst_md);
572
573 bool check_sum_consistent_dt(const dnnl::impl::data_type_t dst_dt,
574 const bool diverse_sum_dt_allowed = false) const;
575
576 bool sum_with_default_dt(
577 dnnl::impl::data_type_t dst_dt = dnnl_data_type_undef) const {
578 int sum_ind = find(dnnl::impl::primitive_kind::sum);
579 return sum_ind == -1 || entry_[sum_ind].sum.dt == dnnl_data_type_undef
580 || entry_[sum_ind].sum.dt == dst_dt;
581 }
582
583 bool contain(dnnl::impl::primitive_kind_t kind, int index) const {
584 return find(kind, index, index + 1) == index;
585 }
586
587 bool operator==(const dnnl_post_ops &rhs) const {
588 bool ret = len() == rhs.len();
589 for (int i = 0; i < len(); ++i)
590 ret = ret && entry_[i] == rhs.entry_[i];
591 return ret;
592 }
593
594 dnnl::impl::status_t copy_from(const dnnl_post_ops &other) {
595 using namespace dnnl::impl;
596
597 for (int idx = 0; idx < other.len(); ++idx) {
598 if (len() > idx) {
599 if (entry_[idx] == other.entry_[idx]) continue;
600 } else {
601 entry_.emplace_back();
602 }
603 CHECK(entry_[idx].copy_from(other.entry_[idx]));
604 }
605
606 return status::success;
607 }
608
609 bool is_initialized() const { return is_initialized_; }
610
611 std::vector<entry_t> entry_;
612
613 // Since binary post op accepts no more than 32 memory arguments by
614 // design, we limit the amount of post-ops to 32.
615 static constexpr int post_ops_limit = 32;
616
617private:
618 dnnl::impl::status_t validate_binary(dnnl::impl::alg_kind_t alg,
619 const dnnl::impl::memory_desc_t *user_src1_desc) const;
620};
621
622struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
623 dnnl_primitive_attr()
624 : scratchpad_mode_(dnnl::impl::scratchpad_mode::library)
625 , fpmath_mode_(dnnl::impl::get_fpmath_mode()) {}
626
627 dnnl_primitive_attr *clone() const {
628 return new dnnl_primitive_attr(*this);
629 }
630
631 dnnl_primitive_attr(const dnnl_primitive_attr &other) {
632 if (copy_from(other) != dnnl::impl::status::success)
633 is_initialized_ = false;
634 }
635
636 dnnl::impl::status_t copy_from(const dnnl_primitive_attr &other) {
637 using namespace dnnl::impl;
638
639 output_scales_ = other.output_scales_;
640 scales_ = other.scales_;
641 zero_points_ = other.zero_points_;
642 scratchpad_mode_ = other.scratchpad_mode_;
643 fpmath_mode_ = other.fpmath_mode_;
644 CHECK(post_ops_.copy_from(other.post_ops_));
645 rnn_data_qparams_ = other.rnn_data_qparams_;
646 CHECK(rnn_weights_qparams_.copy_from(other.rnn_weights_qparams_));
647 CHECK(rnn_weights_projection_qparams_.copy_from(
648 other.rnn_weights_projection_qparams_));
649 CHECK(rnn_tparams_.copy_from(other.rnn_tparams_));
650 if (other.gpu_attr_) gpu_attr_ = other.gpu_attr_->clone();
651
652 return status::success;
653 }
654
655 bool is_initialized() const { return is_initialized_; }
656
657 enum class skip_mask_t : unsigned {
658 none = 0,
659 oscale = 1u << 0,
660 oscale_runtime = 1u << 1,
661 scales = 1u << 2,
662 scales_runtime = (unsigned)scales | (1u << 3),
663 zero_points = 1u << 4,
664 zero_points_runtime = (unsigned)zero_points | (1u << 5),
665 post_ops = 1u << 6,
666 rnn_data_qparams = 1u << 7,
667 rnn_weights_qparams = 1u << 8,
668 rnn_tparams = 1u << 9,
669 sum_dt = 1u << 10,
670 rnn_weights_projection_qparams = 1u << 11,
671 gpu_attr = 1u << 12
672 };
673
674 /** Returns true if the attributes have default values.
675 *
676 * @note The scratchpad_mode_ is not take into account */
677 bool has_default_values(skip_mask_t mask = skip_mask_t::none,
678 dnnl::impl::data_type_t dst_dt = dnnl_data_type_undef) const;
679
680 /** Returns true if the attributes are fully defined. */
681 bool defined(skip_mask_t mask = skip_mask_t::none) const;
682
683 bool operator==(const dnnl_primitive_attr &rhs) const {
684 bool ret = scratchpad_mode_ == rhs.scratchpad_mode_
685 && fpmath_mode_ == rhs.fpmath_mode_
686 && output_scales_ == rhs.output_scales_
687 && scales_ == rhs.scales_ && zero_points_ == rhs.zero_points_
688 && post_ops_ == rhs.post_ops_
689 && rnn_data_qparams_ == rhs.rnn_data_qparams_
690 && rnn_weights_qparams_ == rhs.rnn_weights_qparams_
691 && rnn_weights_projection_qparams_
692 == rhs.rnn_weights_projection_qparams_
693 && rnn_tparams_ == rhs.rnn_tparams_
694 && ((gpu_attr_ && rhs.gpu_attr_
695 && gpu_attr_->is_equal(*rhs.gpu_attr_))
696 || (!gpu_attr_ && !rhs.gpu_attr_));
697 return ret;
698 }
699
700 dnnl::impl::status_t set_fpmath_mode(dnnl::impl::fpmath_mode_t fpmath_mode);
701 dnnl::impl::status_t set_scratchpad_mode(
702 dnnl::impl::scratchpad_mode_t scratchpad_mode);
703 dnnl::impl::status_t set_post_ops(const dnnl::impl::post_ops_t &post_ops);
704 dnnl::impl::status_t set_gpu_attr(
705 const dnnl::impl::primitive_attr_item_t &gpu_attr);
706 dnnl::impl::status_t set_default_formats(
707 const dnnl::impl::memory_desc_t *dst_md);
708
709 /* Auxiliary functions */
710 bool mayidownconvert(dnnl::impl::data_type_t dt_from,
711 dnnl::impl::data_type_t dt_to) const {
712 using namespace dnnl::impl;
713
714 bool is_compat = is_fpsubtype(dt_to, dt_from);
715 auto can_downconvert = [&]() {
716 switch (fpmath_mode_) {
717 case fpmath_mode::strict: return dt_from == dt_to;
718 case fpmath_mode::any: return true;
719 case fpmath_mode::bf16:
720 return is_fpsubtype(data_type::bf16, dt_to);
721 case fpmath_mode::f16:
722 return is_fpsubtype(data_type::f16, dt_to);
723 case fpmath_mode::tf32:
724 return is_fpsubtype(data_type::tf32, dt_to);
725 default: return false;
726 }
727 };
728 return is_compat && can_downconvert();
729 }
730
731 // NOTE: make sure that the types below have overloaded comparison operator
732 dnnl::impl::runtime_scales_t output_scales_;
733 dnnl::impl::arg_scales_t scales_;
734 dnnl::impl::zero_points_t zero_points_;
735 dnnl::impl::scratchpad_mode_t scratchpad_mode_;
736 dnnl::impl::fpmath_mode_t fpmath_mode_;
737 dnnl::impl::post_ops_t post_ops_;
738 dnnl::impl::rnn_data_qparams_t rnn_data_qparams_;
739 dnnl::impl::scales_t rnn_weights_qparams_;
740 dnnl::impl::scales_t rnn_weights_projection_qparams_;
741 dnnl::impl::rnn_tparams_t rnn_tparams_;
742
743 std::unique_ptr<dnnl::impl::primitive_attr_item_t> gpu_attr_;
744
745 dnnl_primitive_attr &operator=(const dnnl_primitive_attr &other) = delete;
746};
747
748inline dnnl_primitive_attr::skip_mask_t operator|(
749 dnnl_primitive_attr::skip_mask_t lhs,
750 dnnl_primitive_attr::skip_mask_t rhs) {
751 return static_cast<dnnl_primitive_attr::skip_mask_t>(
752 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
753}
754inline dnnl_primitive_attr::skip_mask_t operator&(
755 dnnl_primitive_attr::skip_mask_t lhs,
756 dnnl_primitive_attr::skip_mask_t rhs) {
757 return static_cast<dnnl_primitive_attr::skip_mask_t>(
758 static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs));
759}
760inline dnnl_primitive_attr::skip_mask_t &operator|=(
761 dnnl_primitive_attr::skip_mask_t &lhs,
762 dnnl_primitive_attr::skip_mask_t rhs) {
763 lhs = lhs | rhs;
764 return lhs;
765}
766inline dnnl_primitive_attr::skip_mask_t &operator&=(
767 dnnl_primitive_attr::skip_mask_t &lhs,
768 dnnl_primitive_attr::skip_mask_t rhs) {
769 lhs = lhs & rhs;
770 return lhs;
771}
772inline bool operator!=(dnnl_primitive_attr::skip_mask_t lhs,
773 dnnl_primitive_attr::skip_mask_t rhs) {
774 return (static_cast<unsigned>(lhs) != static_cast<unsigned>(rhs));
775}
776inline dnnl_primitive_attr::skip_mask_t operator~(
777 dnnl_primitive_attr::skip_mask_t rhs) {
778 return static_cast<dnnl_primitive_attr::skip_mask_t>(
779 ~static_cast<unsigned>(rhs));
780}
781
782#endif
783