1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_PRIMITIVE_ATTR_HPP |
18 | #define COMMON_PRIMITIVE_ATTR_HPP |
19 | |
20 | #include <map> |
21 | #include <initializer_list> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "c_types_map.hpp" |
26 | #include "nstl.hpp" |
27 | #include "type_helpers.hpp" |
28 | #include "utils.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | |
33 | const primitive_attr_t &default_attr(); |
34 | |
35 | struct rnn_data_qparams_t : public c_compatible { |
36 | rnn_data_qparams_t() : scale_(1.), shift_(0.) {} |
37 | bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); } |
38 | bool defined() const { |
39 | return !is_runtime_value(scale_) && !is_runtime_value(shift_); |
40 | } |
41 | |
42 | status_t set(float scale, float shift) { |
43 | scale_ = scale; |
44 | shift_ = shift; |
45 | return status::success; |
46 | } |
47 | |
48 | bool operator==(const rnn_data_qparams_t &rhs) const { |
49 | using namespace utils; |
50 | return equal_with_nan(scale_, rhs.scale_) |
51 | && equal_with_nan(shift_, rhs.shift_); |
52 | } |
53 | |
54 | float scale_; |
55 | float shift_; |
56 | }; |
57 | |
58 | struct rnn_tparams_t : public c_compatible { |
59 | rnn_tparams_t() |
60 | : test_mode_(false), scales_(nullptr), ngates_(0), cscale_(0.0f) {} |
61 | |
62 | ~rnn_tparams_t() { |
63 | test_mode_ = false; |
64 | if (scales_ != nullptr) impl::free(scales_); |
65 | scales_ = nullptr; |
66 | ngates_ = 0; |
67 | cscale_ = 0.0f; |
68 | } |
69 | |
70 | bool operator==(const rnn_tparams_t &rhs) const { |
71 | using namespace utils; |
72 | |
73 | bool ret = test_mode_ == rhs.test_mode_ && ngates_ == rhs.ngates_ |
74 | && equal_with_nan(cscale_, rhs.cscale_); |
75 | |
76 | if (!ret) return ret; |
77 | |
78 | if (scales_) { |
79 | if (std::memcmp(scales_, rhs.scales_, sizeof(float) * ngates_)) |
80 | return false; |
81 | } |
82 | return true; |
83 | } |
84 | |
85 | bool has_default_values() const { |
86 | return (test_mode_ == false && scales_ == nullptr && ngates_ == 0 |
87 | && cscale_ == 0.0f); |
88 | } |
89 | |
90 | status_t set(bool mode, dim_t ngates, const float *scales, float cscale) { |
91 | test_mode_ = mode; |
92 | ngates_ = ngates; |
93 | scales_ = nullptr; |
94 | if (scales != nullptr) { |
95 | scales_ = (float *)impl::malloc(ngates_ * sizeof(*scales_), 64); |
96 | if (scales_ == nullptr) return status::out_of_memory; |
97 | utils::array_copy(scales_, scales, ngates_); |
98 | } |
99 | |
100 | cscale_ = cscale; |
101 | |
102 | return status::success; |
103 | } |
104 | |
105 | // copy_from() functions are used for each attribute member instead of |
106 | // operator= in order to return a status. |
107 | // TODO: consider replacing copy_from() functions with copy-constructors and |
108 | // std::move, since there are only a few places in the library that actually |
109 | // use them. |
110 | status_t copy_from(const rnn_tparams_t &other) { |
111 | return set( |
112 | other.test_mode_, other.ngates_, other.scales_, other.cscale_); |
113 | } |
114 | |
115 | bool test_mode_; /* we could also use scale_ == nullptr as a test to check test_mode*/ |
116 | float *scales_; |
117 | dim_t ngates_; /* ngates is equel to the number of scales */ |
118 | float cscale_; /* =0.0f if no c state */ |
119 | |
120 | private: |
121 | DNNL_DISALLOW_COPY_AND_ASSIGN(rnn_tparams_t); |
122 | }; |
123 | |
124 | // Note: keep for RNN quantization |
125 | struct scales_t : public c_compatible { |
126 | scales_t() : count_(1), mask_(0), scales_(scales_buf_) { set(1.); } |
127 | scales_t(dim_t count, int mask, const float *scales) |
128 | : scales_(scales_buf_) { |
129 | set(count, mask, scales); |
130 | } |
131 | |
132 | ~scales_t() { cleanup(); } |
133 | |
134 | bool operator==(const scales_t &rhs) const { |
135 | bool ret = count_ == rhs.count_ && mask_ == rhs.mask_ |
136 | && !utils::any_null(scales_, rhs.scales_) |
137 | && defined() == rhs.defined() |
138 | && IMPLICATION(defined(), |
139 | !std::memcmp( |
140 | scales_, rhs.scales_, sizeof(float) * count_)); |
141 | return ret; |
142 | } |
143 | |
144 | bool has_default_values() const { |
145 | for (dim_t c = 0; c < count_; ++c) { |
146 | if (scales_[c] != 1.) return false; |
147 | } |
148 | return true; |
149 | } |
150 | |
151 | bool defined() const { return !is_runtime_value(scales_[0]); } |
152 | |
153 | status_t set(dim_t count, int mask, const float *scales); |
154 | status_t set(float single_scale) { return this->set(1, 0, &single_scale); } |
155 | |
156 | status_t copy_from(const scales_t &other) { |
157 | return set(other.count_, other.mask_, other.scales_); |
158 | } |
159 | |
160 | dim_t count_; |
161 | int mask_; |
162 | float *scales_; |
163 | |
164 | private: |
165 | enum { scales_buf_size = 16 }; |
166 | float scales_buf_[scales_buf_size]; |
167 | |
168 | void cleanup() { |
169 | if (scales_ != scales_buf_ && scales_ != nullptr) impl::free(scales_); |
170 | |
171 | count_ = 1; |
172 | mask_ = 0; |
173 | scales_ = scales_buf_; |
174 | } |
175 | |
176 | DNNL_DISALLOW_COPY_AND_ASSIGN(scales_t); |
177 | }; |
178 | |
179 | struct runtime_scales_t : public c_compatible { |
180 | // Clang-3.8.1 raises an error for a default initialization of a const |
181 | // object. Const runtime_scales_t object is used as default_scales. |
182 | // runtime_scales_t() = default; |
183 | runtime_scales_t() {} |
184 | |
185 | status_t set(int mask) { |
186 | mask_ = mask; |
187 | is_set_ = true; |
188 | return status::success; |
189 | } |
190 | |
191 | bool operator==(const runtime_scales_t &rhs) const { |
192 | return mask_ == rhs.mask_ && is_set_ == rhs.is_set_; |
193 | } |
194 | |
195 | bool has_default_values() const { return !is_set_; } |
196 | |
197 | bool defined() const { return has_default_values(); } |
198 | |
199 | void reset() { |
200 | mask_ = 0; |
201 | is_set_ = false; |
202 | } |
203 | |
204 | // TODO: replace with `-1` to remove `is_set_`. |
205 | // Hide `mask_` under `private:` to force interface usage. |
206 | int mask_ = 0; |
207 | bool is_set_ = false; |
208 | }; |
209 | |
210 | struct arg_scales_t : public c_compatible { |
211 | arg_scales_t() = default; |
212 | |
213 | const runtime_scales_t &get(int arg) const { |
214 | static const runtime_scales_t default_scales; |
215 | const auto it = scales_.find(arg); |
216 | if (it == scales_.end()) return default_scales; |
217 | return it->second; |
218 | } |
219 | |
220 | bool operator==(const arg_scales_t &rhs) const { |
221 | return scales_ == rhs.scales_; |
222 | } |
223 | |
224 | bool has_default_values(const std::vector<int> &skip_args = {}) const { |
225 | for (const auto &s : scales_) { |
226 | if (!s.second.has_default_values()) { |
227 | bool skip = false; |
228 | for (const auto &skip_a : skip_args) |
229 | if (s.first == skip_a) { |
230 | skip = true; |
231 | break; |
232 | } |
233 | if (skip) continue; |
234 | return false; |
235 | } |
236 | } |
237 | return true; |
238 | } |
239 | |
240 | status_t set(int arg, int mask) { |
241 | if (!check_arg(arg)) return status::invalid_arguments; |
242 | return scales_[arg].set(mask); |
243 | } |
244 | |
245 | status_t get(int arg, int *mask, bool *is_set) const { |
246 | if (!check_arg(arg)) return status::invalid_arguments; |
247 | const auto &s = get(arg); |
248 | if (mask) *mask = s.mask_; |
249 | if (is_set) *is_set = s.is_set_; |
250 | return status::success; |
251 | } |
252 | |
253 | status_t reset(int arg) { |
254 | if (!check_arg(arg)) return status::invalid_arguments; |
255 | const auto it = scales_.find(arg); |
256 | if (it != scales_.end()) scales_.erase(it); |
257 | return status::success; |
258 | } |
259 | |
260 | bool defined() const { return has_default_values(); } |
261 | |
262 | status_t copy_from(const arg_scales_t &other) { |
263 | for (auto it = other.scales_.begin(); it != other.scales_.end(); ++it) { |
264 | // Find an entry that can match the arguments without constructing a |
265 | // new object. |
266 | if (scales_.count(it->first) == 1) { |
267 | auto &entry = scales_[it->first]; |
268 | bool exists = entry.mask_ == it->second.mask_; |
269 | if (exists) continue; |
270 | } |
271 | |
272 | CHECK(set(it->first, it->second.mask_)); |
273 | } |
274 | return status::success; |
275 | } |
276 | |
277 | std::map<int, runtime_scales_t> scales_; |
278 | |
279 | private: |
280 | bool check_arg(int arg) const { |
281 | // binary |
282 | for (const auto &sa : {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}) { |
283 | if (arg == sa) return true; |
284 | } |
285 | // concat |
286 | if (arg & DNNL_ARG_MULTIPLE_SRC) return true; |
287 | // convolution |
288 | for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { |
289 | if (arg == sa) return true; |
290 | } |
291 | // depth-wise convolution post op |
292 | for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { |
293 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | sa)) return true; |
294 | } |
295 | return false; |
296 | } |
297 | }; |
298 | |
299 | struct zero_points_t : public c_compatible { |
300 | bool operator==(const zero_points_t &rhs) const { |
301 | return mask_src == rhs.mask_src && mask_wei == rhs.mask_wei |
302 | && mask_dst == rhs.mask_dst && is_set_src == rhs.is_set_src |
303 | && is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst; |
304 | } |
305 | |
306 | // arg-specific checks |
307 | bool common(int arg) const { return get_mask(arg) == 0; } |
308 | bool defined(int arg) const { return has_default_values(arg); } |
309 | bool has_default_values(int arg) const { return is_set(arg) == false; } |
310 | |
311 | // same checks but for all supported arguments at once |
312 | bool common() const { return check_all(&zero_points_t::common); } |
313 | bool defined() const { return has_default_values(); } |
314 | bool has_default_values() const { |
315 | return check_all(&zero_points_t::has_default_values); |
316 | } |
317 | |
318 | status_t get(int arg, int *mask) const; |
319 | |
320 | status_t set(int arg, int mask); |
321 | status_t set(int arg) { return set(arg, 0); } |
322 | |
323 | private: |
324 | bool is_set_src = false, is_set_wei = false, is_set_dst = false; |
325 | int mask_src = 0, mask_wei = 0, mask_dst = 0; |
326 | |
327 | int get_mask(int arg) const { |
328 | int mask = 0; |
329 | switch (arg) { |
330 | case DNNL_ARG_SRC: mask = mask_src; break; |
331 | case DNNL_ARG_WEIGHTS: mask = mask_wei; break; |
332 | case DNNL_ARG_DST: mask = mask_dst; break; |
333 | default: mask = 0; |
334 | } |
335 | return mask; |
336 | } |
337 | |
338 | bool is_set(int arg) const { |
339 | bool arg_is_set = false; |
340 | switch (arg) { |
341 | case DNNL_ARG_SRC: arg_is_set = is_set_src; break; |
342 | case DNNL_ARG_WEIGHTS: arg_is_set = is_set_wei; break; |
343 | case DNNL_ARG_DST: arg_is_set = is_set_dst; break; |
344 | default: arg_is_set = 0; |
345 | } |
346 | return arg_is_set; |
347 | } |
348 | |
349 | bool check_all(bool (zero_points_t::*f)(int) const) const { |
350 | for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) |
351 | if (!(this->*f)(arg)) return false; |
352 | return true; |
353 | } |
354 | }; |
355 | |
356 | struct serialization_stream_t; |
357 | |
358 | struct primitive_attr_item_t { |
359 | virtual std::unique_ptr<primitive_attr_item_t> clone() const = 0; |
360 | virtual bool has_default_values() const = 0; |
361 | virtual bool is_equal(const primitive_attr_item_t &other) const = 0; |
362 | virtual size_t get_hash() const = 0; |
363 | virtual void serialize(serialization_stream_t &stream) const = 0; |
364 | virtual ~primitive_attr_item_t() = default; |
365 | }; |
366 | |
367 | } // namespace impl |
368 | } // namespace dnnl |
369 | |
370 | struct dnnl_post_ops : public dnnl::impl::c_compatible { |
371 | struct entry_t { |
372 | entry_t() : kind(dnnl::impl::primitive_kind::undefined) {} |
373 | entry_t(const entry_t &other) { copy_from(other); } |
374 | |
375 | dnnl::impl::status_t copy_from(const entry_t &other) { |
376 | return set(other); |
377 | } |
378 | |
379 | // TODO: This operator has to be deleted, and its usage has to be |
380 | // replaced with copy_from() or copy/move constructors in order to |
381 | // extract a status. |
382 | entry_t &operator=(const entry_t &other) { |
383 | DNNL_SHORT_CIRCUIT_SELF_ASSIGN(other); |
384 | set(other); |
385 | return *this; |
386 | } |
387 | |
388 | struct eltwise_t { |
389 | dnnl::impl::alg_kind_t alg; |
390 | float scale, alpha, beta; |
391 | }; |
392 | |
393 | struct depthwise_conv_t { |
394 | dnnl::impl::dim_t kernel; |
395 | dnnl::impl::dim_t stride; |
396 | dnnl::impl::dim_t padding; |
397 | dnnl::impl::data_type_t wei_dt; |
398 | dnnl::impl::data_type_t bias_dt; |
399 | dnnl::impl::data_type_t dst_dt; |
400 | }; |
401 | |
402 | struct binary_t { |
403 | dnnl::impl::alg_kind_t alg; |
404 | // This is an unmodifiable user copy of attributes which is used in |
405 | // caching mechanism. Not to be used internally. |
406 | dnnl::impl::memory_desc_t user_src1_desc; |
407 | // This is a modifiable copy of memory desc. It changes format kind |
408 | // and tag of md in case user passed format_kind::any. To be used |
409 | // everywhere internally. |
410 | dnnl::impl::memory_desc_t src1_desc; |
411 | }; |
412 | |
413 | struct prelu_t { |
414 | int mask; |
415 | }; |
416 | |
417 | dnnl::impl::primitive_kind_t kind |
418 | = dnnl::impl::primitive_kind::undefined; |
419 | union { |
420 | struct { |
421 | float scale; |
422 | int32_t zero_point; |
423 | dnnl::impl::data_type_t dt; |
424 | } sum; |
425 | eltwise_t eltwise; |
426 | depthwise_conv_t depthwise_conv; |
427 | binary_t binary; |
428 | prelu_t prelu; |
429 | }; |
430 | |
431 | bool is_eltwise(bool require_scale_one = false) const { |
432 | using namespace dnnl::impl; |
433 | return kind == primitive_kind::eltwise |
434 | && IMPLICATION(require_scale_one, eltwise.scale == 1.f); |
435 | } |
436 | |
437 | bool is_relu(bool require_scale_one = true, |
438 | bool require_nslope_zero = true) const { |
439 | using namespace dnnl::impl; |
440 | return is_eltwise(require_scale_one) |
441 | && eltwise.alg == alg_kind::eltwise_relu |
442 | && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f); |
443 | } |
444 | |
445 | bool is_sum(bool require_scale_one = true, |
446 | bool require_zp_zero = true) const { |
447 | using namespace dnnl::impl; |
448 | return kind == primitive_kind::sum |
449 | && IMPLICATION(require_scale_one, sum.scale == 1.f) |
450 | && IMPLICATION(require_zp_zero, sum.zero_point == 0); |
451 | } |
452 | |
453 | bool is_convolution() const { |
454 | using namespace dnnl::impl; |
455 | return kind == primitive_kind::convolution; |
456 | } |
457 | |
458 | bool is_binary() const { |
459 | return kind == dnnl::impl::primitive_kind::binary; |
460 | } |
461 | |
462 | bool is_prelu() const { |
463 | return kind == dnnl::impl::primitive_kind::prelu; |
464 | } |
465 | |
466 | dnnl::impl::status_t set_depthwise_scales(const float *scales); |
467 | |
468 | bool operator==(const entry_t &rhs) const { |
469 | using namespace dnnl::impl; |
470 | using namespace dnnl::impl::utils; |
471 | if (kind != rhs.kind) { return false; } |
472 | bool ret = true; |
473 | switch (kind) { |
474 | case primitive_kind::eltwise: |
475 | ret = eltwise.alg == rhs.eltwise.alg |
476 | && equal_with_nan(eltwise.scale, rhs.eltwise.scale) |
477 | && equal_with_nan(eltwise.alpha, rhs.eltwise.alpha) |
478 | && equal_with_nan(eltwise.beta, rhs.eltwise.beta); |
479 | break; |
480 | case primitive_kind::sum: |
481 | ret = equal_with_nan(sum.scale, rhs.sum.scale) |
482 | && sum.zero_point == rhs.sum.zero_point |
483 | && sum.dt == rhs.sum.dt; |
484 | break; |
485 | case primitive_kind::convolution: |
486 | // Depthwise Only |
487 | ret = depthwise_conv.kernel == rhs.depthwise_conv.kernel |
488 | && depthwise_conv.stride |
489 | == rhs.depthwise_conv.stride |
490 | && depthwise_conv.padding |
491 | == rhs.depthwise_conv.padding |
492 | && depthwise_conv.wei_dt |
493 | == rhs.depthwise_conv.wei_dt |
494 | && depthwise_conv.bias_dt |
495 | == rhs.depthwise_conv.bias_dt |
496 | && depthwise_conv.dst_dt |
497 | == rhs.depthwise_conv.dst_dt; |
498 | break; |
499 | case primitive_kind::binary: |
500 | ret = binary.alg == rhs.binary.alg |
501 | && binary.user_src1_desc |
502 | == rhs.binary.user_src1_desc; |
503 | break; |
504 | case primitive_kind::prelu: |
505 | ret = prelu.mask == rhs.prelu.mask; |
506 | break; |
507 | default: assert(!"unsupported post_op" ); |
508 | } |
509 | return ret; |
510 | } |
511 | |
512 | bool operator!=(const entry_t &rhs) const { |
513 | return !this->operator==(rhs); |
514 | } |
515 | |
516 | private: |
517 | dnnl::impl::status_t set(const entry_t &other) { |
518 | // Copying by if (is_convolution()) {} else if(is_sum()) {} |
519 | // else if(is_relu()) {} seems to be unreliable. memcpying for now. |
520 | dnnl::impl::utils::array_copy( |
521 | (char *)this, (char *)&other, sizeof(*this)); |
522 | return dnnl::impl::status::success; |
523 | } |
524 | }; |
525 | |
526 | dnnl_post_ops() : entry_() {} |
527 | |
528 | dnnl_post_ops(const dnnl_post_ops &other) { |
529 | if (copy_from(other) != dnnl::impl::status::success) |
530 | is_initialized_ = false; |
531 | } |
532 | |
533 | dnnl::impl::status_t append_sum(float scale, int32_t zero_point = 0, |
534 | dnnl::impl::data_type_t dt = dnnl_data_type_undef); |
535 | dnnl::impl::status_t append_eltwise( |
536 | float scale, dnnl::impl::alg_kind_t alg, float alpha, float beta); |
537 | dnnl::impl::status_t append_dw(dnnl::impl::data_type_t wei_dt, |
538 | dnnl::impl::data_type_t bias_dt, dnnl::impl::data_type_t dst_dt, |
539 | dnnl::impl::dim_t kernel_size, dnnl::impl::dim_t stride_size, |
540 | dnnl::impl::dim_t padding_l_size); |
541 | dnnl::impl::status_t append_binary(dnnl::impl::alg_kind_t alg, |
542 | const dnnl::impl::memory_desc_t *user_src1_desc); |
543 | dnnl::impl::status_t append_prelu(int mask); |
544 | |
545 | dnnl::impl::status_t prepend_binary(dnnl::impl::alg_kind_t alg, |
546 | const dnnl::impl::memory_desc_t *user_src1_desc); |
547 | |
548 | int find(dnnl::impl::primitive_kind_t kind, int start = 0, |
549 | int stop = -1) const { |
550 | if (stop == -1) stop = len(); |
551 | stop = dnnl::impl::nstl::min(stop, len()); |
552 | for (int idx = start; idx < stop; ++idx) |
553 | if (entry_[idx].kind == kind) return idx; |
554 | return -1; |
555 | } |
556 | |
557 | dnnl::impl::data_type_t get_sum_dt( |
558 | const dnnl::impl::data_type_t dst_dt) const { |
559 | const int sum_ind = find(dnnl::impl::primitive_kind::sum); |
560 | if (sum_ind == -1) return dst_dt; |
561 | const auto sum_dt = entry_[sum_ind].sum.dt; |
562 | if (sum_dt != dnnl::impl::data_type::undef) return sum_dt; |
563 | return dst_dt; |
564 | } |
565 | |
566 | bool defined() const; |
567 | int len() const { return (int)entry_.size(); } |
568 | bool has_default_values() const { return len() == 0; } |
569 | |
570 | dnnl::impl::status_t set_default_formats( |
571 | const dnnl::impl::memory_desc_t *dst_md); |
572 | |
573 | bool check_sum_consistent_dt(const dnnl::impl::data_type_t dst_dt, |
574 | const bool diverse_sum_dt_allowed = false) const; |
575 | |
576 | bool sum_with_default_dt( |
577 | dnnl::impl::data_type_t dst_dt = dnnl_data_type_undef) const { |
578 | int sum_ind = find(dnnl::impl::primitive_kind::sum); |
579 | return sum_ind == -1 || entry_[sum_ind].sum.dt == dnnl_data_type_undef |
580 | || entry_[sum_ind].sum.dt == dst_dt; |
581 | } |
582 | |
583 | bool contain(dnnl::impl::primitive_kind_t kind, int index) const { |
584 | return find(kind, index, index + 1) == index; |
585 | } |
586 | |
587 | bool operator==(const dnnl_post_ops &rhs) const { |
588 | bool ret = len() == rhs.len(); |
589 | for (int i = 0; i < len(); ++i) |
590 | ret = ret && entry_[i] == rhs.entry_[i]; |
591 | return ret; |
592 | } |
593 | |
594 | dnnl::impl::status_t copy_from(const dnnl_post_ops &other) { |
595 | using namespace dnnl::impl; |
596 | |
597 | for (int idx = 0; idx < other.len(); ++idx) { |
598 | if (len() > idx) { |
599 | if (entry_[idx] == other.entry_[idx]) continue; |
600 | } else { |
601 | entry_.emplace_back(); |
602 | } |
603 | CHECK(entry_[idx].copy_from(other.entry_[idx])); |
604 | } |
605 | |
606 | return status::success; |
607 | } |
608 | |
609 | bool is_initialized() const { return is_initialized_; } |
610 | |
611 | std::vector<entry_t> entry_; |
612 | |
613 | // Since binary post op accepts no more than 32 memory arguments by |
614 | // design, we limit the amount of post-ops to 32. |
615 | static constexpr int post_ops_limit = 32; |
616 | |
617 | private: |
618 | dnnl::impl::status_t validate_binary(dnnl::impl::alg_kind_t alg, |
619 | const dnnl::impl::memory_desc_t *user_src1_desc) const; |
620 | }; |
621 | |
622 | struct dnnl_primitive_attr : public dnnl::impl::c_compatible { |
623 | dnnl_primitive_attr() |
624 | : scratchpad_mode_(dnnl::impl::scratchpad_mode::library) |
625 | , fpmath_mode_(dnnl::impl::get_fpmath_mode()) {} |
626 | |
627 | dnnl_primitive_attr *clone() const { |
628 | return new dnnl_primitive_attr(*this); |
629 | } |
630 | |
631 | dnnl_primitive_attr(const dnnl_primitive_attr &other) { |
632 | if (copy_from(other) != dnnl::impl::status::success) |
633 | is_initialized_ = false; |
634 | } |
635 | |
636 | dnnl::impl::status_t copy_from(const dnnl_primitive_attr &other) { |
637 | using namespace dnnl::impl; |
638 | |
639 | output_scales_ = other.output_scales_; |
640 | scales_ = other.scales_; |
641 | zero_points_ = other.zero_points_; |
642 | scratchpad_mode_ = other.scratchpad_mode_; |
643 | fpmath_mode_ = other.fpmath_mode_; |
644 | CHECK(post_ops_.copy_from(other.post_ops_)); |
645 | rnn_data_qparams_ = other.rnn_data_qparams_; |
646 | CHECK(rnn_weights_qparams_.copy_from(other.rnn_weights_qparams_)); |
647 | CHECK(rnn_weights_projection_qparams_.copy_from( |
648 | other.rnn_weights_projection_qparams_)); |
649 | CHECK(rnn_tparams_.copy_from(other.rnn_tparams_)); |
650 | if (other.gpu_attr_) gpu_attr_ = other.gpu_attr_->clone(); |
651 | |
652 | return status::success; |
653 | } |
654 | |
655 | bool is_initialized() const { return is_initialized_; } |
656 | |
657 | enum class skip_mask_t : unsigned { |
658 | none = 0, |
659 | oscale = 1u << 0, |
660 | oscale_runtime = 1u << 1, |
661 | scales = 1u << 2, |
662 | scales_runtime = (unsigned)scales | (1u << 3), |
663 | zero_points = 1u << 4, |
664 | zero_points_runtime = (unsigned)zero_points | (1u << 5), |
665 | post_ops = 1u << 6, |
666 | rnn_data_qparams = 1u << 7, |
667 | rnn_weights_qparams = 1u << 8, |
668 | rnn_tparams = 1u << 9, |
669 | sum_dt = 1u << 10, |
670 | rnn_weights_projection_qparams = 1u << 11, |
671 | gpu_attr = 1u << 12 |
672 | }; |
673 | |
674 | /** Returns true if the attributes have default values. |
675 | * |
676 | * @note The scratchpad_mode_ is not take into account */ |
677 | bool has_default_values(skip_mask_t mask = skip_mask_t::none, |
678 | dnnl::impl::data_type_t dst_dt = dnnl_data_type_undef) const; |
679 | |
680 | /** Returns true if the attributes are fully defined. */ |
681 | bool defined(skip_mask_t mask = skip_mask_t::none) const; |
682 | |
683 | bool operator==(const dnnl_primitive_attr &rhs) const { |
684 | bool ret = scratchpad_mode_ == rhs.scratchpad_mode_ |
685 | && fpmath_mode_ == rhs.fpmath_mode_ |
686 | && output_scales_ == rhs.output_scales_ |
687 | && scales_ == rhs.scales_ && zero_points_ == rhs.zero_points_ |
688 | && post_ops_ == rhs.post_ops_ |
689 | && rnn_data_qparams_ == rhs.rnn_data_qparams_ |
690 | && rnn_weights_qparams_ == rhs.rnn_weights_qparams_ |
691 | && rnn_weights_projection_qparams_ |
692 | == rhs.rnn_weights_projection_qparams_ |
693 | && rnn_tparams_ == rhs.rnn_tparams_ |
694 | && ((gpu_attr_ && rhs.gpu_attr_ |
695 | && gpu_attr_->is_equal(*rhs.gpu_attr_)) |
696 | || (!gpu_attr_ && !rhs.gpu_attr_)); |
697 | return ret; |
698 | } |
699 | |
700 | dnnl::impl::status_t set_fpmath_mode(dnnl::impl::fpmath_mode_t fpmath_mode); |
701 | dnnl::impl::status_t set_scratchpad_mode( |
702 | dnnl::impl::scratchpad_mode_t scratchpad_mode); |
703 | dnnl::impl::status_t set_post_ops(const dnnl::impl::post_ops_t &post_ops); |
704 | dnnl::impl::status_t set_gpu_attr( |
705 | const dnnl::impl::primitive_attr_item_t &gpu_attr); |
706 | dnnl::impl::status_t set_default_formats( |
707 | const dnnl::impl::memory_desc_t *dst_md); |
708 | |
709 | /* Auxiliary functions */ |
710 | bool mayidownconvert(dnnl::impl::data_type_t dt_from, |
711 | dnnl::impl::data_type_t dt_to) const { |
712 | using namespace dnnl::impl; |
713 | |
714 | bool is_compat = is_fpsubtype(dt_to, dt_from); |
715 | auto can_downconvert = [&]() { |
716 | switch (fpmath_mode_) { |
717 | case fpmath_mode::strict: return dt_from == dt_to; |
718 | case fpmath_mode::any: return true; |
719 | case fpmath_mode::bf16: |
720 | return is_fpsubtype(data_type::bf16, dt_to); |
721 | case fpmath_mode::f16: |
722 | return is_fpsubtype(data_type::f16, dt_to); |
723 | case fpmath_mode::tf32: |
724 | return is_fpsubtype(data_type::tf32, dt_to); |
725 | default: return false; |
726 | } |
727 | }; |
728 | return is_compat && can_downconvert(); |
729 | } |
730 | |
731 | // NOTE: make sure that the types below have overloaded comparison operator |
732 | dnnl::impl::runtime_scales_t output_scales_; |
733 | dnnl::impl::arg_scales_t scales_; |
734 | dnnl::impl::zero_points_t zero_points_; |
735 | dnnl::impl::scratchpad_mode_t scratchpad_mode_; |
736 | dnnl::impl::fpmath_mode_t fpmath_mode_; |
737 | dnnl::impl::post_ops_t post_ops_; |
738 | dnnl::impl::rnn_data_qparams_t rnn_data_qparams_; |
739 | dnnl::impl::scales_t rnn_weights_qparams_; |
740 | dnnl::impl::scales_t rnn_weights_projection_qparams_; |
741 | dnnl::impl::rnn_tparams_t rnn_tparams_; |
742 | |
743 | std::unique_ptr<dnnl::impl::primitive_attr_item_t> gpu_attr_; |
744 | |
745 | dnnl_primitive_attr &operator=(const dnnl_primitive_attr &other) = delete; |
746 | }; |
747 | |
748 | inline dnnl_primitive_attr::skip_mask_t operator|( |
749 | dnnl_primitive_attr::skip_mask_t lhs, |
750 | dnnl_primitive_attr::skip_mask_t rhs) { |
751 | return static_cast<dnnl_primitive_attr::skip_mask_t>( |
752 | static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); |
753 | } |
754 | inline dnnl_primitive_attr::skip_mask_t operator&( |
755 | dnnl_primitive_attr::skip_mask_t lhs, |
756 | dnnl_primitive_attr::skip_mask_t rhs) { |
757 | return static_cast<dnnl_primitive_attr::skip_mask_t>( |
758 | static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); |
759 | } |
760 | inline dnnl_primitive_attr::skip_mask_t &operator|=( |
761 | dnnl_primitive_attr::skip_mask_t &lhs, |
762 | dnnl_primitive_attr::skip_mask_t rhs) { |
763 | lhs = lhs | rhs; |
764 | return lhs; |
765 | } |
766 | inline dnnl_primitive_attr::skip_mask_t &operator&=( |
767 | dnnl_primitive_attr::skip_mask_t &lhs, |
768 | dnnl_primitive_attr::skip_mask_t rhs) { |
769 | lhs = lhs & rhs; |
770 | return lhs; |
771 | } |
772 | inline bool operator!=(dnnl_primitive_attr::skip_mask_t lhs, |
773 | dnnl_primitive_attr::skip_mask_t rhs) { |
774 | return (static_cast<unsigned>(lhs) != static_cast<unsigned>(rhs)); |
775 | } |
776 | inline dnnl_primitive_attr::skip_mask_t operator~( |
777 | dnnl_primitive_attr::skip_mask_t rhs) { |
778 | return static_cast<dnnl_primitive_attr::skip_mask_t>( |
779 | ~static_cast<unsigned>(rhs)); |
780 | } |
781 | |
782 | #endif |
783 | |