1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "primitive_desc.hpp"
18#include "type_helpers.hpp"
19#include "utils.hpp"
20
21#include "dnnl_thread.hpp"
22#include "engine.hpp"
23#include "primitive_hashing.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace primitive_hashing {
28
29key_t::key_t(const engine_t *engine, const op_desc_t *op_desc,
30 const primitive_attr_t *attr, int pd_iterator_offset,
31 const std::vector<memory_desc_t> &hint_mds)
32 : primitive_kind_(op_desc->kind)
33 , op_desc_(op_desc)
34 , attr_(attr)
35 , pd_iterator_offset_(pd_iterator_offset)
36 , impl_nthr_(dnnl_get_max_threads())
37 , hint_mds_(hint_mds)
38 , engine_id_(engine->engine_id())
39 , thread_id_(std::this_thread::get_id()) {}
40
41key_t::key_t(const primitive_desc_t *pd, const engine_t *engine)
42 : key_t(engine, pd->op_desc(), pd->attr(), pd->pd_iterator_offset(),
43 pd->hint_mds(false /* is_hint */)) {}
44
45bool key_t::operator==(const key_t &rhs) const {
46 DNNL_SHORT_CIRCUIT_SELF_COMPARISON(rhs);
47 // clang-format off
48 bool ret = true
49 // Less expensive comparisons come first
50 && primitive_kind_ == rhs.primitive_kind_
51 && engine_id_ == rhs.engine_id_
52 && hint_mds_.size() == rhs.hint_mds_.size()
53 && pd_iterator_offset_ == rhs.pd_iterator_offset_
54 && impl_nthr_ == rhs.impl_nthr_
55 && (*attr_) == (*rhs.attr_);
56
57 if (!ret) return false;
58
59#define CASE(pkind) \
60 case primitive_kind::pkind: \
61 ret = cast_to_desc<pkind##_desc_t>(op_desc_) \
62 == cast_to_desc<pkind##_desc_t>(rhs.op_desc_); \
63 break;
64
65 switch ((int)primitive_kind_) {
66 CASE(batch_normalization)
67 CASE(binary)
68 CASE(concat)
69 CASE(convolution)
70 CASE(deconvolution)
71 CASE(eltwise)
72 CASE(gemm)
73 CASE(inner_product)
74 CASE(layer_normalization)
75 CASE(lrn)
76 CASE(matmul)
77 CASE(pooling)
78 CASE(prelu)
79 CASE(reduction)
80 CASE(reorder)
81 CASE(resampling)
82 CASE(rnn)
83 CASE(shuffle)
84 CASE(softmax)
85 CASE(sum)
86 CASE(zero_pad)
87 default: assert(!"unknown primitive kind");
88 }
89#undef CASE
90 // clang-format on
91
92 if (!ret) return false;
93
94 for (size_t i = 0; i < hint_mds_.size(); ++i)
95 if (hint_mds_[i] != rhs.hint_mds_[i]) return false;
96
97 return true;
98}
99
100// Combine hash of each memory_desc_t data member
101size_t get_md_hash(const memory_desc_t &md) {
102 size_t seed = 0;
103 seed = get_array_hash(seed, md.dims, md.ndims);
104 seed = hash_combine(seed, static_cast<size_t>(md.data_type));
105 seed = get_array_hash(seed, md.padded_dims, md.ndims);
106 seed = get_array_hash(seed, md.padded_offsets, md.ndims);
107 seed = hash_combine(seed, md.offset0);
108 seed = hash_combine(seed, static_cast<size_t>(md.format_kind));
109 // format desc
110 switch ((int)md.format_kind) {
111 case format_kind::undef:
112 case format_kind::any: break;
113 case format_kind::blocked:
114 for (int i = 0; i < md.ndims; i++) {
115 if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue;
116 seed = hash_combine(seed, md.format_desc.blocking.strides[i]);
117 }
118 seed = hash_combine(seed, md.format_desc.blocking.inner_nblks);
119 seed = get_array_hash(seed, md.format_desc.blocking.inner_blks,
120 md.format_desc.blocking.inner_nblks);
121 seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs,
122 md.format_desc.blocking.inner_nblks);
123 break;
124 case format_kind::wino:
125 seed = hash_combine(seed,
126 static_cast<size_t>(md.format_desc.wino_desc.wino_format));
127 seed = hash_combine(seed, md.format_desc.wino_desc.r);
128 seed = hash_combine(seed, md.format_desc.wino_desc.alpha);
129 seed = hash_combine(seed, md.format_desc.wino_desc.ic);
130 seed = hash_combine(seed, md.format_desc.wino_desc.oc);
131 seed = hash_combine(seed, md.format_desc.wino_desc.ic_block);
132 seed = hash_combine(seed, md.format_desc.wino_desc.oc_block);
133 seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block);
134 seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block);
135 seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale);
136 seed = hash_combine(seed, md.format_desc.wino_desc.size);
137 break;
138 case format_kind::rnn_packed:
139 seed = hash_combine(seed,
140 static_cast<size_t>(md.format_desc.rnn_packed_desc.format));
141 seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts);
142 seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n);
143 seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb);
144 {
145 int n_parts = md.format_desc.rnn_packed_desc.n_parts;
146 seed = get_array_hash(
147 seed, md.format_desc.rnn_packed_desc.parts, n_parts);
148 seed = get_array_hash(seed,
149 md.format_desc.rnn_packed_desc.part_pack_size, n_parts);
150 seed = get_array_hash(seed,
151 md.format_desc.rnn_packed_desc.pack_part, n_parts);
152 }
153 seed = hash_combine(
154 seed, md.format_desc.rnn_packed_desc.offset_compensation);
155 seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size);
156 break;
157 default: assert(!"unknown format_kind");
158 }
159
160 if (md.extra.flags != dnnl_memory_extra_flag_none) {
161 seed = hash_combine(seed, md.extra.flags);
162 if ((md.extra.flags
163 & (dnnl_memory_extra_flag_compensation_conv_s8s8
164 | dnnl_memory_extra_flag_rnn_u8s8_compensation))
165 && !types::extra_flag_rnn_s8s8_compensation_is_set(
166 md.extra.flags)) {
167 seed = hash_combine(seed, md.extra.compensation_mask);
168 }
169
170 if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
171 seed = hash_combine(seed, md.extra.scale_adjust);
172 }
173
174 if (md.extra.flags
175 & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
176 seed = hash_combine(seed, md.extra.asymm_compensation_mask);
177 }
178 }
179 // Combined hash for a memory descriptor
180 return seed;
181}
182
183// Combine hash of each primitive_attr_t data member
184size_t get_attr_hash(const primitive_attr_t &attr) {
185 size_t seed = 0;
186 // scratchpad_mode
187 seed = hash_combine(seed, static_cast<size_t>(attr.scratchpad_mode_));
188 // fpmath_mode
189 seed = hash_combine(seed, static_cast<size_t>(attr.fpmath_mode_));
190
191 if (!attr.output_scales_.has_default_values()) {
192 // output_scales: mask
193 seed = hash_combine(seed, attr.output_scales_.mask_);
194 } else if (!attr.scales_.has_default_values()) {
195 // go through scales for all arguments
196 for (const auto &p : attr.scales_.scales_) {
197 // scales: arg
198 seed = hash_combine(seed, p.first);
199 // scales: mask
200 seed = hash_combine(seed, p.second.mask_);
201 }
202 }
203 // zero_points
204 for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
205 if (!attr.zero_points_.has_default_values(arg)) {
206 // zero_points: arg
207 seed = hash_combine(seed, arg);
208 int mask = 0;
209 attr.zero_points_.get(arg, &mask);
210 // zero_points: mask
211 seed = hash_combine(seed, mask);
212 }
213 // post_ops: entry[:]
214 for (int i = 0; i < attr.post_ops_.len(); i++) {
215 const auto &entry = attr.post_ops_.entry_[i];
216 switch (entry.kind) {
217 case primitive_kind::eltwise:
218 seed = hash_combine(
219 seed, static_cast<size_t>(entry.eltwise.alg));
220 seed = hash_combine(seed, entry.eltwise.scale);
221 seed = hash_combine(seed, entry.eltwise.alpha);
222 seed = hash_combine(seed, entry.eltwise.beta);
223 break;
224 case primitive_kind::sum:
225 seed = hash_combine(seed, entry.sum.scale);
226 seed = hash_combine(seed, entry.sum.zero_point);
227 seed = hash_combine(seed, static_cast<size_t>(entry.sum.dt));
228 break;
229 case primitive_kind::convolution:
230 seed = hash_combine(
231 seed, static_cast<size_t>(entry.depthwise_conv.kernel));
232 seed = hash_combine(
233 seed, static_cast<size_t>(entry.depthwise_conv.stride));
234 seed = hash_combine(seed,
235 static_cast<size_t>(entry.depthwise_conv.padding));
236 seed = hash_combine(
237 seed, static_cast<size_t>(entry.depthwise_conv.wei_dt));
238 seed = hash_combine(seed,
239 static_cast<size_t>(entry.depthwise_conv.bias_dt));
240 seed = hash_combine(
241 seed, static_cast<size_t>(entry.depthwise_conv.dst_dt));
242 break;
243 case primitive_kind::binary:
244 seed = hash_combine(
245 seed, static_cast<size_t>(entry.binary.alg));
246 seed = hash_combine(
247 seed, get_md_hash(entry.binary.user_src1_desc));
248 break;
249 case primitive_kind::prelu:
250 seed = hash_combine(
251 seed, static_cast<size_t>(entry.prelu.mask));
252 break;
253 default: assert(!"unknown post_op");
254 }
255 }
256 // rnn_data_qparams: scale, shift
257 seed = hash_combine(seed, attr.rnn_data_qparams_.scale_);
258 seed = hash_combine(seed, attr.rnn_data_qparams_.shift_);
259 if (!attr.rnn_weights_qparams_.has_default_values()) {
260 // rnn_weights_qparams: mask
261 seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_);
262 // rnn_weights_qparams: count
263 seed = hash_combine(seed, attr.rnn_weights_qparams_.count_);
264 // rnn_weights_qparams: scales[:]
265 seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_,
266 attr.rnn_weights_qparams_.count_);
267 }
268 if (attr.gpu_attr_) {
269 seed = hash_combine(seed, attr.gpu_attr_->get_hash());
270 }
271 // Combined hash for attributes
272 return seed;
273}
274
275// Functions that compute hash for different op_descs
276size_t get_desc_hash(const concat_desc_t &desc) {
277 size_t seed = 0;
278 // Kinds
279 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
280 // Memory descriptors
281 seed = hash_combine(seed, get_md_hash(*desc.dst_md));
282 // N
283 seed = hash_combine(seed, desc.n);
284 // Concat dimension
285 seed = hash_combine(seed, desc.concat_dimension);
286 // Array of mds
287 seed = get_array_hash(seed, desc.src_mds);
288 // Combined hash for concat desc
289 return seed;
290}
291
292size_t get_desc_hash(const batch_normalization_desc_t &desc) {
293 size_t seed = 0;
294 // Kinds
295 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
296 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
297 // Memory descriptors
298 seed = hash_combine(seed, get_md_hash(desc.src_desc));
299 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
300 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
301 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
302 seed = hash_combine(seed, get_md_hash(desc.scaleshift_desc));
303 seed = hash_combine(seed, get_md_hash(desc.diff_scaleshift_desc));
304 seed = hash_combine(seed, get_md_hash(desc.stat_desc));
305 // Epsilon
306 seed = hash_combine(seed, desc.batch_norm_epsilon);
307 // Flags
308 seed = hash_combine(seed, desc.flags);
309 // Combined hash for batch normalization desc
310 return seed;
311}
312
313size_t get_desc_hash(const binary_desc_t &desc) {
314 size_t seed = 0;
315 // Kinds
316 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
317 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
318 // Memory descriptors
319 seed = hash_combine(seed, get_md_hash(desc.src_desc[0]));
320 seed = hash_combine(seed, get_md_hash(desc.src_desc[1]));
321 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
322 // Combined hash for binary op desc
323 return seed;
324}
325
326// (De-)Convolution
327size_t get_desc_hash(const convolution_desc_t &desc) {
328 size_t seed = 0;
329 // Kinds
330 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
331 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
332 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
333 // Memory descriptors
334 seed = hash_combine(seed, get_md_hash(desc.src_desc));
335 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
336 seed = hash_combine(seed, get_md_hash(desc.weights_desc));
337 seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
338 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
339 seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
340 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
341 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
342 // Strides, dilates, padding
343 seed = get_array_hash(seed, desc.strides, DNNL_MAX_NDIMS);
344 seed = get_array_hash(seed, desc.dilates, DNNL_MAX_NDIMS);
345 seed = get_array_hash(seed, desc.padding[0], DNNL_MAX_NDIMS);
346 seed = get_array_hash(seed, desc.padding[1], DNNL_MAX_NDIMS);
347 // Accumulator type
348 seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
349 // Combined hash for (de-)convolution desc
350 return seed;
351}
352
353// Eltwise
354size_t get_desc_hash(const eltwise_desc_t &desc) {
355 size_t seed = 0;
356 // Kinds
357 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
358 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
359 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
360 // Memory descriptors
361 seed = hash_combine(seed, get_md_hash(desc.src_desc));
362 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
363 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
364 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
365 // Alpha, beta
366 seed = hash_combine(seed, desc.alpha);
367 seed = hash_combine(seed, desc.beta);
368 // Combined hash for eltwise desc
369 return seed;
370}
371
372size_t get_desc_hash(const gemm_desc_t &desc) {
373 size_t seed = 0;
374 // Kinds
375 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
376 seed = hash_combine(seed, get_md_hash(desc.a_desc));
377 seed = hash_combine(seed, get_md_hash(desc.b_desc));
378 seed = hash_combine(seed, get_md_hash(desc.c_desc));
379 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
380 // Accumulator type
381 seed = hash_combine(seed, static_cast<size_t>(desc.acc_type));
382 seed = hash_combine(seed, static_cast<size_t>(desc.sum_ab));
383 seed = hash_combine(seed, static_cast<size_t>(desc.sum_ab_type));
384 // Combined hash for gemm desc
385 return seed;
386}
387
388size_t get_desc_hash(const inner_product_desc_t &desc) {
389 size_t seed = 0;
390 // Kinds
391 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
392 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
393 // Memory descriptors
394 seed = hash_combine(seed, get_md_hash(desc.src_desc));
395 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
396 seed = hash_combine(seed, get_md_hash(desc.weights_desc));
397 seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
398 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
399 seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
400 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
401 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
402 // Accumulator type
403 seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
404 // Combined hash for inner_product desc
405 return seed;
406}
407
408size_t get_desc_hash(const layer_normalization_desc_t &desc) {
409 size_t seed = 0;
410 // Kinds
411 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
412 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
413 // Memory descriptors
414 seed = hash_combine(seed, get_md_hash(desc.src_desc));
415 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
416 seed = hash_combine(seed, get_md_hash(desc.data_scaleshift_desc));
417 seed = hash_combine(seed, get_md_hash(desc.diff_data_scaleshift_desc));
418 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
419 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
420 seed = hash_combine(seed, get_md_hash(desc.stat_desc));
421 // Epsilon
422 seed = hash_combine(seed, desc.layer_norm_epsilon);
423 // Flags
424 seed = hash_combine(seed, desc.flags);
425 // Combined hash for layer_normalization desc
426 return seed;
427}
428
429size_t get_desc_hash(const lrn_desc_t &desc) {
430 size_t seed = 0;
431 // Kinds
432 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
433 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
434 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
435 // Memory descriptors
436 seed = hash_combine(seed, get_md_hash(desc.src_desc));
437 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
438 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
439 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
440 // Local size
441 seed = hash_combine(seed, desc.local_size);
442 // Alpha, beta
443 seed = hash_combine(seed, desc.lrn_alpha);
444 seed = hash_combine(seed, desc.lrn_beta);
445 // k
446 seed = hash_combine(seed, desc.lrn_k);
447 // Combined hash for lrn desc
448 return seed;
449}
450
451size_t get_desc_hash(const matmul_desc_t &desc) {
452 size_t seed = 0;
453 // Kinds
454 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
455 // Memory descriptors
456 seed = hash_combine(seed, get_md_hash(desc.src_desc));
457 seed = hash_combine(seed, get_md_hash(desc.weights_desc));
458 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
459 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
460 // Accumulator type
461 seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
462 // Combined hash for matmul op desc
463 return seed;
464}
465
466size_t get_desc_hash(const pooling_desc_t &desc) {
467 size_t seed = 0;
468 // Kinds
469 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
470 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
471 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
472 // Memory descriptors
473 seed = hash_combine(seed, get_md_hash(desc.src_desc));
474 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
475 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
476 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
477 // Strides, dilates, padding
478 seed = get_array_hash(seed, desc.strides, DNNL_MAX_NDIMS);
479 seed = get_array_hash(seed, desc.kernel, DNNL_MAX_NDIMS);
480 seed = get_array_hash(seed, desc.padding[0], DNNL_MAX_NDIMS);
481 seed = get_array_hash(seed, desc.padding[1], DNNL_MAX_NDIMS);
482 seed = get_array_hash(seed, desc.dilation, DNNL_MAX_NDIMS);
483 // Accumulator type
484 seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
485 // Combined hash for pooling desc
486 return seed;
487}
488
489size_t get_desc_hash(const prelu_desc_t &desc) {
490 size_t seed = 0;
491 // Kinds
492 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
493 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
494 // Memory descriptors
495 seed = hash_combine(seed, get_md_hash(desc.src_desc));
496 seed = hash_combine(seed, get_md_hash(desc.weights_desc));
497 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
498 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
499 seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
500 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
501 // Combined hash for prelu desc
502 return seed;
503}
504
505size_t get_desc_hash(const reduction_desc_t &desc) {
506 size_t seed = 0;
507 // Kinds
508 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
509 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
510 // Memory descriptors
511 seed = hash_combine(seed, get_md_hash(desc.src_desc));
512 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
513 // P, eps
514 seed = hash_combine(seed, desc.p);
515 seed = hash_combine(seed, desc.eps);
516 // Combined hash for reduction desc
517 return seed;
518}
519
520size_t get_desc_hash(const reorder_desc_t &desc) {
521 size_t seed = 0;
522 // Kinds
523 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
524 // Memory descriptors
525 seed = hash_combine(seed, get_md_hash(*desc.src_md));
526 seed = hash_combine(seed, get_md_hash(*desc.dst_md));
527 // Kinds of source and destination engines
528 seed = hash_combine(seed, static_cast<size_t>(desc.src_engine_kind));
529 seed = hash_combine(seed, static_cast<size_t>(desc.dst_engine_kind));
530 seed = hash_combine(seed, desc.is_cross_engine);
531 // Combined hash for reorder desc
532 return seed;
533}
534
535size_t get_desc_hash(const resampling_desc_t &desc) {
536 size_t seed = 0;
537 // Kinds
538 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
539 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
540 // Memory descriptors
541 seed = hash_combine(seed, get_md_hash(desc.src_desc));
542 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
543 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
544 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
545 // Factors
546 seed = get_array_hash(seed, desc.factors, DNNL_MAX_NDIMS);
547 // Combined hash for resampling op desc
548 return seed;
549}
550
551size_t get_desc_hash(const rnn_desc_t &desc) {
552 size_t seed = 0;
553 // Kinds
554 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
555 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
556 seed = hash_combine(seed, static_cast<size_t>(desc.cell_kind));
557 seed = hash_combine(seed, static_cast<size_t>(desc.direction));
558 // Memory descriptors
559 seed = hash_combine(seed, get_md_hash(desc.src_layer_desc));
560 seed = hash_combine(seed, get_md_hash(desc.src_iter_desc));
561 seed = hash_combine(seed, get_md_hash(desc.src_iter_c_desc));
562 seed = hash_combine(seed, get_md_hash(desc.weights_layer_desc));
563 seed = hash_combine(seed, get_md_hash(desc.weights_iter_desc));
564 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
565 seed = hash_combine(seed, get_md_hash(desc.dst_layer_desc));
566 seed = hash_combine(seed, get_md_hash(desc.dst_iter_desc));
567 seed = hash_combine(seed, get_md_hash(desc.dst_iter_c_desc));
568 seed = hash_combine(seed, get_md_hash(desc.weights_peephole_desc));
569 seed = hash_combine(seed, get_md_hash(desc.weights_projection_desc));
570 seed = hash_combine(seed, get_md_hash(desc.diff_src_layer_desc));
571 seed = hash_combine(seed, get_md_hash(desc.diff_src_iter_desc));
572 seed = hash_combine(seed, get_md_hash(desc.diff_src_iter_c_desc));
573 seed = hash_combine(seed, get_md_hash(desc.diff_weights_layer_desc));
574 seed = hash_combine(seed, get_md_hash(desc.diff_weights_iter_desc));
575 seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
576 seed = hash_combine(seed, get_md_hash(desc.diff_dst_layer_desc));
577 seed = hash_combine(seed, get_md_hash(desc.diff_dst_iter_desc));
578 seed = hash_combine(seed, get_md_hash(desc.diff_dst_iter_c_desc));
579 seed = hash_combine(seed, get_md_hash(desc.diff_weights_peephole_desc));
580 seed = hash_combine(seed, get_md_hash(desc.diff_weights_projection_desc));
581 // Flags
582 seed = hash_combine(seed, desc.flags);
583 // Activation kind
584 seed = hash_combine(seed, static_cast<size_t>(desc.activation_kind));
585 // Alpha, beta
586 seed = hash_combine(seed, desc.alpha);
587 seed = hash_combine(seed, desc.beta);
588 // Combined hash for rnn desc
589 return seed;
590}
591
592// Shuffle
593size_t get_desc_hash(const shuffle_desc_t &desc) {
594 size_t seed = 0;
595 // Kinds
596 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
597 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
598 // Memory descriptors
599 seed = hash_combine(seed, get_md_hash(desc.src_desc));
600 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
601 // Axis
602 seed = hash_combine(seed, desc.axis);
603 // Groupe size
604 seed = hash_combine(seed, desc.group_size);
605 // Combined hash for shuffle desc
606 return seed;
607}
608
609size_t get_desc_hash(const softmax_desc_t &desc) {
610 size_t seed = 0;
611 // Kinds
612 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
613 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
614 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
615 // Memory descriptors
616 seed = hash_combine(seed, get_md_hash(desc.src_desc));
617 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
618 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
619 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
620 // Axis
621 seed = hash_combine(seed, desc.softmax_axis);
622 // Combined hash for softmax desc
623 return seed;
624}
625
626size_t get_desc_hash(const sum_desc_t &desc) {
627 size_t seed = 0;
628 // Kinds
629 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
630 // Memory descriptors
631 seed = hash_combine(seed, get_md_hash(*desc.dst_md));
632 // N
633 seed = hash_combine(seed, desc.n);
634 // Scales
635 if (desc.scales) { seed = get_array_hash(seed, desc.scales, desc.n); }
636 // Array of mds
637 seed = get_array_hash(seed, desc.src_mds);
638 // Combined hash for sum desc
639 return seed;
640}
641
642size_t get_desc_hash(const zero_pad_desc_t &desc) {
643 size_t seed = 0;
644 // Kinds
645 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
646 return seed;
647}
648
649} // namespace primitive_hashing
650} // namespace impl
651} // namespace dnnl
652