1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "primitive_desc.hpp" |
18 | #include "type_helpers.hpp" |
19 | #include "utils.hpp" |
20 | |
21 | #include "dnnl_thread.hpp" |
22 | #include "engine.hpp" |
23 | #include "primitive_hashing.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace primitive_hashing { |
28 | |
29 | key_t::key_t(const engine_t *engine, const op_desc_t *op_desc, |
30 | const primitive_attr_t *attr, int pd_iterator_offset, |
31 | const std::vector<memory_desc_t> &hint_mds) |
32 | : primitive_kind_(op_desc->kind) |
33 | , op_desc_(op_desc) |
34 | , attr_(attr) |
35 | , pd_iterator_offset_(pd_iterator_offset) |
36 | , impl_nthr_(dnnl_get_max_threads()) |
37 | , hint_mds_(hint_mds) |
38 | , engine_id_(engine->engine_id()) |
39 | , thread_id_(std::this_thread::get_id()) {} |
40 | |
41 | key_t::key_t(const primitive_desc_t *pd, const engine_t *engine) |
42 | : key_t(engine, pd->op_desc(), pd->attr(), pd->pd_iterator_offset(), |
43 | pd->hint_mds(false /* is_hint */)) {} |
44 | |
45 | bool key_t::operator==(const key_t &rhs) const { |
46 | DNNL_SHORT_CIRCUIT_SELF_COMPARISON(rhs); |
47 | // clang-format off |
48 | bool ret = true |
49 | // Less expensive comparisons come first |
50 | && primitive_kind_ == rhs.primitive_kind_ |
51 | && engine_id_ == rhs.engine_id_ |
52 | && hint_mds_.size() == rhs.hint_mds_.size() |
53 | && pd_iterator_offset_ == rhs.pd_iterator_offset_ |
54 | && impl_nthr_ == rhs.impl_nthr_ |
55 | && (*attr_) == (*rhs.attr_); |
56 | |
57 | if (!ret) return false; |
58 | |
59 | #define CASE(pkind) \ |
60 | case primitive_kind::pkind: \ |
61 | ret = cast_to_desc<pkind##_desc_t>(op_desc_) \ |
62 | == cast_to_desc<pkind##_desc_t>(rhs.op_desc_); \ |
63 | break; |
64 | |
65 | switch ((int)primitive_kind_) { |
66 | CASE(batch_normalization) |
67 | CASE(binary) |
68 | CASE(concat) |
69 | CASE(convolution) |
70 | CASE(deconvolution) |
71 | CASE(eltwise) |
72 | CASE(gemm) |
73 | CASE(inner_product) |
74 | CASE(layer_normalization) |
75 | CASE(lrn) |
76 | CASE(matmul) |
77 | CASE(pooling) |
78 | CASE(prelu) |
79 | CASE(reduction) |
80 | CASE(reorder) |
81 | CASE(resampling) |
82 | CASE(rnn) |
83 | CASE(shuffle) |
84 | CASE(softmax) |
85 | CASE(sum) |
86 | CASE(zero_pad) |
87 | default: assert(!"unknown primitive kind" ); |
88 | } |
89 | #undef CASE |
90 | // clang-format on |
91 | |
92 | if (!ret) return false; |
93 | |
94 | for (size_t i = 0; i < hint_mds_.size(); ++i) |
95 | if (hint_mds_[i] != rhs.hint_mds_[i]) return false; |
96 | |
97 | return true; |
98 | } |
99 | |
100 | // Combine hash of each memory_desc_t data member |
101 | size_t get_md_hash(const memory_desc_t &md) { |
102 | size_t seed = 0; |
103 | seed = get_array_hash(seed, md.dims, md.ndims); |
104 | seed = hash_combine(seed, static_cast<size_t>(md.data_type)); |
105 | seed = get_array_hash(seed, md.padded_dims, md.ndims); |
106 | seed = get_array_hash(seed, md.padded_offsets, md.ndims); |
107 | seed = hash_combine(seed, md.offset0); |
108 | seed = hash_combine(seed, static_cast<size_t>(md.format_kind)); |
109 | // format desc |
110 | switch ((int)md.format_kind) { |
111 | case format_kind::undef: |
112 | case format_kind::any: break; |
113 | case format_kind::blocked: |
114 | for (int i = 0; i < md.ndims; i++) { |
115 | if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue; |
116 | seed = hash_combine(seed, md.format_desc.blocking.strides[i]); |
117 | } |
118 | seed = hash_combine(seed, md.format_desc.blocking.inner_nblks); |
119 | seed = get_array_hash(seed, md.format_desc.blocking.inner_blks, |
120 | md.format_desc.blocking.inner_nblks); |
121 | seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs, |
122 | md.format_desc.blocking.inner_nblks); |
123 | break; |
124 | case format_kind::wino: |
125 | seed = hash_combine(seed, |
126 | static_cast<size_t>(md.format_desc.wino_desc.wino_format)); |
127 | seed = hash_combine(seed, md.format_desc.wino_desc.r); |
128 | seed = hash_combine(seed, md.format_desc.wino_desc.alpha); |
129 | seed = hash_combine(seed, md.format_desc.wino_desc.ic); |
130 | seed = hash_combine(seed, md.format_desc.wino_desc.oc); |
131 | seed = hash_combine(seed, md.format_desc.wino_desc.ic_block); |
132 | seed = hash_combine(seed, md.format_desc.wino_desc.oc_block); |
133 | seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block); |
134 | seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block); |
135 | seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale); |
136 | seed = hash_combine(seed, md.format_desc.wino_desc.size); |
137 | break; |
138 | case format_kind::rnn_packed: |
139 | seed = hash_combine(seed, |
140 | static_cast<size_t>(md.format_desc.rnn_packed_desc.format)); |
141 | seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts); |
142 | seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n); |
143 | seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb); |
144 | { |
145 | int n_parts = md.format_desc.rnn_packed_desc.n_parts; |
146 | seed = get_array_hash( |
147 | seed, md.format_desc.rnn_packed_desc.parts, n_parts); |
148 | seed = get_array_hash(seed, |
149 | md.format_desc.rnn_packed_desc.part_pack_size, n_parts); |
150 | seed = get_array_hash(seed, |
151 | md.format_desc.rnn_packed_desc.pack_part, n_parts); |
152 | } |
153 | seed = hash_combine( |
154 | seed, md.format_desc.rnn_packed_desc.offset_compensation); |
155 | seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size); |
156 | break; |
157 | default: assert(!"unknown format_kind" ); |
158 | } |
159 | |
160 | if (md.extra.flags != dnnl_memory_extra_flag_none) { |
161 | seed = hash_combine(seed, md.extra.flags); |
162 | if ((md.extra.flags |
163 | & (dnnl_memory_extra_flag_compensation_conv_s8s8 |
164 | | dnnl_memory_extra_flag_rnn_u8s8_compensation)) |
165 | && !types::extra_flag_rnn_s8s8_compensation_is_set( |
166 | md.extra.flags)) { |
167 | seed = hash_combine(seed, md.extra.compensation_mask); |
168 | } |
169 | |
170 | if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) { |
171 | seed = hash_combine(seed, md.extra.scale_adjust); |
172 | } |
173 | |
174 | if (md.extra.flags |
175 | & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) { |
176 | seed = hash_combine(seed, md.extra.asymm_compensation_mask); |
177 | } |
178 | } |
179 | // Combined hash for a memory descriptor |
180 | return seed; |
181 | } |
182 | |
183 | // Combine hash of each primitive_attr_t data member |
184 | size_t get_attr_hash(const primitive_attr_t &attr) { |
185 | size_t seed = 0; |
186 | // scratchpad_mode |
187 | seed = hash_combine(seed, static_cast<size_t>(attr.scratchpad_mode_)); |
188 | // fpmath_mode |
189 | seed = hash_combine(seed, static_cast<size_t>(attr.fpmath_mode_)); |
190 | |
191 | if (!attr.output_scales_.has_default_values()) { |
192 | // output_scales: mask |
193 | seed = hash_combine(seed, attr.output_scales_.mask_); |
194 | } else if (!attr.scales_.has_default_values()) { |
195 | // go through scales for all arguments |
196 | for (const auto &p : attr.scales_.scales_) { |
197 | // scales: arg |
198 | seed = hash_combine(seed, p.first); |
199 | // scales: mask |
200 | seed = hash_combine(seed, p.second.mask_); |
201 | } |
202 | } |
203 | // zero_points |
204 | for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) |
205 | if (!attr.zero_points_.has_default_values(arg)) { |
206 | // zero_points: arg |
207 | seed = hash_combine(seed, arg); |
208 | int mask = 0; |
209 | attr.zero_points_.get(arg, &mask); |
210 | // zero_points: mask |
211 | seed = hash_combine(seed, mask); |
212 | } |
213 | // post_ops: entry[:] |
214 | for (int i = 0; i < attr.post_ops_.len(); i++) { |
215 | const auto &entry = attr.post_ops_.entry_[i]; |
216 | switch (entry.kind) { |
217 | case primitive_kind::eltwise: |
218 | seed = hash_combine( |
219 | seed, static_cast<size_t>(entry.eltwise.alg)); |
220 | seed = hash_combine(seed, entry.eltwise.scale); |
221 | seed = hash_combine(seed, entry.eltwise.alpha); |
222 | seed = hash_combine(seed, entry.eltwise.beta); |
223 | break; |
224 | case primitive_kind::sum: |
225 | seed = hash_combine(seed, entry.sum.scale); |
226 | seed = hash_combine(seed, entry.sum.zero_point); |
227 | seed = hash_combine(seed, static_cast<size_t>(entry.sum.dt)); |
228 | break; |
229 | case primitive_kind::convolution: |
230 | seed = hash_combine( |
231 | seed, static_cast<size_t>(entry.depthwise_conv.kernel)); |
232 | seed = hash_combine( |
233 | seed, static_cast<size_t>(entry.depthwise_conv.stride)); |
234 | seed = hash_combine(seed, |
235 | static_cast<size_t>(entry.depthwise_conv.padding)); |
236 | seed = hash_combine( |
237 | seed, static_cast<size_t>(entry.depthwise_conv.wei_dt)); |
238 | seed = hash_combine(seed, |
239 | static_cast<size_t>(entry.depthwise_conv.bias_dt)); |
240 | seed = hash_combine( |
241 | seed, static_cast<size_t>(entry.depthwise_conv.dst_dt)); |
242 | break; |
243 | case primitive_kind::binary: |
244 | seed = hash_combine( |
245 | seed, static_cast<size_t>(entry.binary.alg)); |
246 | seed = hash_combine( |
247 | seed, get_md_hash(entry.binary.user_src1_desc)); |
248 | break; |
249 | case primitive_kind::prelu: |
250 | seed = hash_combine( |
251 | seed, static_cast<size_t>(entry.prelu.mask)); |
252 | break; |
253 | default: assert(!"unknown post_op" ); |
254 | } |
255 | } |
256 | // rnn_data_qparams: scale, shift |
257 | seed = hash_combine(seed, attr.rnn_data_qparams_.scale_); |
258 | seed = hash_combine(seed, attr.rnn_data_qparams_.shift_); |
259 | if (!attr.rnn_weights_qparams_.has_default_values()) { |
260 | // rnn_weights_qparams: mask |
261 | seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_); |
262 | // rnn_weights_qparams: count |
263 | seed = hash_combine(seed, attr.rnn_weights_qparams_.count_); |
264 | // rnn_weights_qparams: scales[:] |
265 | seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_, |
266 | attr.rnn_weights_qparams_.count_); |
267 | } |
268 | if (attr.gpu_attr_) { |
269 | seed = hash_combine(seed, attr.gpu_attr_->get_hash()); |
270 | } |
271 | // Combined hash for attributes |
272 | return seed; |
273 | } |
274 | |
275 | // Functions that compute hash for different op_descs |
276 | size_t get_desc_hash(const concat_desc_t &desc) { |
277 | size_t seed = 0; |
278 | // Kinds |
279 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
280 | // Memory descriptors |
281 | seed = hash_combine(seed, get_md_hash(*desc.dst_md)); |
282 | // N |
283 | seed = hash_combine(seed, desc.n); |
284 | // Concat dimension |
285 | seed = hash_combine(seed, desc.concat_dimension); |
286 | // Array of mds |
287 | seed = get_array_hash(seed, desc.src_mds); |
288 | // Combined hash for concat desc |
289 | return seed; |
290 | } |
291 | |
292 | size_t get_desc_hash(const batch_normalization_desc_t &desc) { |
293 | size_t seed = 0; |
294 | // Kinds |
295 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
296 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
297 | // Memory descriptors |
298 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
299 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
300 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
301 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
302 | seed = hash_combine(seed, get_md_hash(desc.scaleshift_desc)); |
303 | seed = hash_combine(seed, get_md_hash(desc.diff_scaleshift_desc)); |
304 | seed = hash_combine(seed, get_md_hash(desc.stat_desc)); |
305 | // Epsilon |
306 | seed = hash_combine(seed, desc.batch_norm_epsilon); |
307 | // Flags |
308 | seed = hash_combine(seed, desc.flags); |
309 | // Combined hash for batch normalization desc |
310 | return seed; |
311 | } |
312 | |
313 | size_t get_desc_hash(const binary_desc_t &desc) { |
314 | size_t seed = 0; |
315 | // Kinds |
316 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
317 | seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind)); |
318 | // Memory descriptors |
319 | seed = hash_combine(seed, get_md_hash(desc.src_desc[0])); |
320 | seed = hash_combine(seed, get_md_hash(desc.src_desc[1])); |
321 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
322 | // Combined hash for binary op desc |
323 | return seed; |
324 | } |
325 | |
326 | // (De-)Convolution |
327 | size_t get_desc_hash(const convolution_desc_t &desc) { |
328 | size_t seed = 0; |
329 | // Kinds |
330 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
331 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
332 | seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind)); |
333 | // Memory descriptors |
334 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
335 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
336 | seed = hash_combine(seed, get_md_hash(desc.weights_desc)); |
337 | seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc)); |
338 | seed = hash_combine(seed, get_md_hash(desc.bias_desc)); |
339 | seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc)); |
340 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
341 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
342 | // Strides, dilates, padding |
343 | seed = get_array_hash(seed, desc.strides, DNNL_MAX_NDIMS); |
344 | seed = get_array_hash(seed, desc.dilates, DNNL_MAX_NDIMS); |
345 | seed = get_array_hash(seed, desc.padding[0], DNNL_MAX_NDIMS); |
346 | seed = get_array_hash(seed, desc.padding[1], DNNL_MAX_NDIMS); |
347 | // Accumulator type |
348 | seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type)); |
349 | // Combined hash for (de-)convolution desc |
350 | return seed; |
351 | } |
352 | |
353 | // Eltwise |
354 | size_t get_desc_hash(const eltwise_desc_t &desc) { |
355 | size_t seed = 0; |
356 | // Kinds |
357 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
358 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
359 | seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind)); |
360 | // Memory descriptors |
361 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
362 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
363 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
364 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
365 | // Alpha, beta |
366 | seed = hash_combine(seed, desc.alpha); |
367 | seed = hash_combine(seed, desc.beta); |
368 | // Combined hash for eltwise desc |
369 | return seed; |
370 | } |
371 | |
372 | size_t get_desc_hash(const gemm_desc_t &desc) { |
373 | size_t seed = 0; |
374 | // Kinds |
375 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
376 | seed = hash_combine(seed, get_md_hash(desc.a_desc)); |
377 | seed = hash_combine(seed, get_md_hash(desc.b_desc)); |
378 | seed = hash_combine(seed, get_md_hash(desc.c_desc)); |
379 | seed = hash_combine(seed, get_md_hash(desc.bias_desc)); |
380 | // Accumulator type |
381 | seed = hash_combine(seed, static_cast<size_t>(desc.acc_type)); |
382 | seed = hash_combine(seed, static_cast<size_t>(desc.sum_ab)); |
383 | seed = hash_combine(seed, static_cast<size_t>(desc.sum_ab_type)); |
384 | // Combined hash for gemm desc |
385 | return seed; |
386 | } |
387 | |
388 | size_t get_desc_hash(const inner_product_desc_t &desc) { |
389 | size_t seed = 0; |
390 | // Kinds |
391 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
392 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
393 | // Memory descriptors |
394 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
395 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
396 | seed = hash_combine(seed, get_md_hash(desc.weights_desc)); |
397 | seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc)); |
398 | seed = hash_combine(seed, get_md_hash(desc.bias_desc)); |
399 | seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc)); |
400 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
401 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
402 | // Accumulator type |
403 | seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type)); |
404 | // Combined hash for inner_product desc |
405 | return seed; |
406 | } |
407 | |
408 | size_t get_desc_hash(const layer_normalization_desc_t &desc) { |
409 | size_t seed = 0; |
410 | // Kinds |
411 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
412 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
413 | // Memory descriptors |
414 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
415 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
416 | seed = hash_combine(seed, get_md_hash(desc.data_scaleshift_desc)); |
417 | seed = hash_combine(seed, get_md_hash(desc.diff_data_scaleshift_desc)); |
418 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
419 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
420 | seed = hash_combine(seed, get_md_hash(desc.stat_desc)); |
421 | // Epsilon |
422 | seed = hash_combine(seed, desc.layer_norm_epsilon); |
423 | // Flags |
424 | seed = hash_combine(seed, desc.flags); |
425 | // Combined hash for layer_normalization desc |
426 | return seed; |
427 | } |
428 | |
429 | size_t get_desc_hash(const lrn_desc_t &desc) { |
430 | size_t seed = 0; |
431 | // Kinds |
432 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
433 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
434 | seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind)); |
435 | // Memory descriptors |
436 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
437 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
438 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
439 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
440 | // Local size |
441 | seed = hash_combine(seed, desc.local_size); |
442 | // Alpha, beta |
443 | seed = hash_combine(seed, desc.lrn_alpha); |
444 | seed = hash_combine(seed, desc.lrn_beta); |
445 | // k |
446 | seed = hash_combine(seed, desc.lrn_k); |
447 | // Combined hash for lrn desc |
448 | return seed; |
449 | } |
450 | |
451 | size_t get_desc_hash(const matmul_desc_t &desc) { |
452 | size_t seed = 0; |
453 | // Kinds |
454 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
455 | // Memory descriptors |
456 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
457 | seed = hash_combine(seed, get_md_hash(desc.weights_desc)); |
458 | seed = hash_combine(seed, get_md_hash(desc.bias_desc)); |
459 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
460 | // Accumulator type |
461 | seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type)); |
462 | // Combined hash for matmul op desc |
463 | return seed; |
464 | } |
465 | |
466 | size_t get_desc_hash(const pooling_desc_t &desc) { |
467 | size_t seed = 0; |
468 | // Kinds |
469 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
470 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
471 | seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind)); |
472 | // Memory descriptors |
473 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
474 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
475 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
476 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
477 | // Strides, dilates, padding |
478 | seed = get_array_hash(seed, desc.strides, DNNL_MAX_NDIMS); |
479 | seed = get_array_hash(seed, desc.kernel, DNNL_MAX_NDIMS); |
480 | seed = get_array_hash(seed, desc.padding[0], DNNL_MAX_NDIMS); |
481 | seed = get_array_hash(seed, desc.padding[1], DNNL_MAX_NDIMS); |
482 | seed = get_array_hash(seed, desc.dilation, DNNL_MAX_NDIMS); |
483 | // Accumulator type |
484 | seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type)); |
485 | // Combined hash for pooling desc |
486 | return seed; |
487 | } |
488 | |
489 | size_t get_desc_hash(const prelu_desc_t &desc) { |
490 | size_t seed = 0; |
491 | // Kinds |
492 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
493 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
494 | // Memory descriptors |
495 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
496 | seed = hash_combine(seed, get_md_hash(desc.weights_desc)); |
497 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
498 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
499 | seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc)); |
500 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
501 | // Combined hash for prelu desc |
502 | return seed; |
503 | } |
504 | |
505 | size_t get_desc_hash(const reduction_desc_t &desc) { |
506 | size_t seed = 0; |
507 | // Kinds |
508 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
509 | seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind)); |
510 | // Memory descriptors |
511 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
512 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
513 | // P, eps |
514 | seed = hash_combine(seed, desc.p); |
515 | seed = hash_combine(seed, desc.eps); |
516 | // Combined hash for reduction desc |
517 | return seed; |
518 | } |
519 | |
520 | size_t get_desc_hash(const reorder_desc_t &desc) { |
521 | size_t seed = 0; |
522 | // Kinds |
523 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
524 | // Memory descriptors |
525 | seed = hash_combine(seed, get_md_hash(*desc.src_md)); |
526 | seed = hash_combine(seed, get_md_hash(*desc.dst_md)); |
527 | // Kinds of source and destination engines |
528 | seed = hash_combine(seed, static_cast<size_t>(desc.src_engine_kind)); |
529 | seed = hash_combine(seed, static_cast<size_t>(desc.dst_engine_kind)); |
530 | seed = hash_combine(seed, desc.is_cross_engine); |
531 | // Combined hash for reorder desc |
532 | return seed; |
533 | } |
534 | |
535 | size_t get_desc_hash(const resampling_desc_t &desc) { |
536 | size_t seed = 0; |
537 | // Kinds |
538 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
539 | seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind)); |
540 | // Memory descriptors |
541 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
542 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
543 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
544 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
545 | // Factors |
546 | seed = get_array_hash(seed, desc.factors, DNNL_MAX_NDIMS); |
547 | // Combined hash for resampling op desc |
548 | return seed; |
549 | } |
550 | |
551 | size_t get_desc_hash(const rnn_desc_t &desc) { |
552 | size_t seed = 0; |
553 | // Kinds |
554 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
555 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
556 | seed = hash_combine(seed, static_cast<size_t>(desc.cell_kind)); |
557 | seed = hash_combine(seed, static_cast<size_t>(desc.direction)); |
558 | // Memory descriptors |
559 | seed = hash_combine(seed, get_md_hash(desc.src_layer_desc)); |
560 | seed = hash_combine(seed, get_md_hash(desc.src_iter_desc)); |
561 | seed = hash_combine(seed, get_md_hash(desc.src_iter_c_desc)); |
562 | seed = hash_combine(seed, get_md_hash(desc.weights_layer_desc)); |
563 | seed = hash_combine(seed, get_md_hash(desc.weights_iter_desc)); |
564 | seed = hash_combine(seed, get_md_hash(desc.bias_desc)); |
565 | seed = hash_combine(seed, get_md_hash(desc.dst_layer_desc)); |
566 | seed = hash_combine(seed, get_md_hash(desc.dst_iter_desc)); |
567 | seed = hash_combine(seed, get_md_hash(desc.dst_iter_c_desc)); |
568 | seed = hash_combine(seed, get_md_hash(desc.weights_peephole_desc)); |
569 | seed = hash_combine(seed, get_md_hash(desc.weights_projection_desc)); |
570 | seed = hash_combine(seed, get_md_hash(desc.diff_src_layer_desc)); |
571 | seed = hash_combine(seed, get_md_hash(desc.diff_src_iter_desc)); |
572 | seed = hash_combine(seed, get_md_hash(desc.diff_src_iter_c_desc)); |
573 | seed = hash_combine(seed, get_md_hash(desc.diff_weights_layer_desc)); |
574 | seed = hash_combine(seed, get_md_hash(desc.diff_weights_iter_desc)); |
575 | seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc)); |
576 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_layer_desc)); |
577 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_iter_desc)); |
578 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_iter_c_desc)); |
579 | seed = hash_combine(seed, get_md_hash(desc.diff_weights_peephole_desc)); |
580 | seed = hash_combine(seed, get_md_hash(desc.diff_weights_projection_desc)); |
581 | // Flags |
582 | seed = hash_combine(seed, desc.flags); |
583 | // Activation kind |
584 | seed = hash_combine(seed, static_cast<size_t>(desc.activation_kind)); |
585 | // Alpha, beta |
586 | seed = hash_combine(seed, desc.alpha); |
587 | seed = hash_combine(seed, desc.beta); |
588 | // Combined hash for rnn desc |
589 | return seed; |
590 | } |
591 | |
592 | // Shuffle |
593 | size_t get_desc_hash(const shuffle_desc_t &desc) { |
594 | size_t seed = 0; |
595 | // Kinds |
596 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
597 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
598 | // Memory descriptors |
599 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
600 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
601 | // Axis |
602 | seed = hash_combine(seed, desc.axis); |
603 | // Groupe size |
604 | seed = hash_combine(seed, desc.group_size); |
605 | // Combined hash for shuffle desc |
606 | return seed; |
607 | } |
608 | |
609 | size_t get_desc_hash(const softmax_desc_t &desc) { |
610 | size_t seed = 0; |
611 | // Kinds |
612 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
613 | seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind)); |
614 | seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind)); |
615 | // Memory descriptors |
616 | seed = hash_combine(seed, get_md_hash(desc.src_desc)); |
617 | seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); |
618 | seed = hash_combine(seed, get_md_hash(desc.dst_desc)); |
619 | seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); |
620 | // Axis |
621 | seed = hash_combine(seed, desc.softmax_axis); |
622 | // Combined hash for softmax desc |
623 | return seed; |
624 | } |
625 | |
626 | size_t get_desc_hash(const sum_desc_t &desc) { |
627 | size_t seed = 0; |
628 | // Kinds |
629 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
630 | // Memory descriptors |
631 | seed = hash_combine(seed, get_md_hash(*desc.dst_md)); |
632 | // N |
633 | seed = hash_combine(seed, desc.n); |
634 | // Scales |
635 | if (desc.scales) { seed = get_array_hash(seed, desc.scales, desc.n); } |
636 | // Array of mds |
637 | seed = get_array_hash(seed, desc.src_mds); |
638 | // Combined hash for sum desc |
639 | return seed; |
640 | } |
641 | |
642 | size_t get_desc_hash(const zero_pad_desc_t &desc) { |
643 | size_t seed = 0; |
644 | // Kinds |
645 | seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind)); |
646 | return seed; |
647 | } |
648 | |
649 | } // namespace primitive_hashing |
650 | } // namespace impl |
651 | } // namespace dnnl |
652 | |