1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_PRIMITIVE_HASHING_HPP
18#define COMMON_PRIMITIVE_HASHING_HPP
19
20#include <typeindex>
21#include <type_traits>
22
23#include "c_types_map.hpp"
24#include "engine_id.hpp"
25#include "oneapi/dnnl/dnnl.h"
26#include "primitive_attr.hpp"
27#include "type_helpers.hpp"
28
29namespace dnnl {
30namespace impl {
31
32struct primitive_desc_t;
33namespace primitive_hashing {
34
35struct key_t {
36 key_t(const engine_t *engine, const op_desc_t *op_desc,
37 const primitive_attr_t *attr, int pd_iterator_offset,
38 const std::vector<memory_desc_t> &hint_mds);
39
40 key_t(const primitive_desc_t *pd, const engine_t *engine);
41
42 bool operator==(const key_t &other) const;
43 const std::thread::id &thread_id() const { return thread_id_; }
44
45 primitive_kind_t primitive_kind_;
46 // Make these data fields mutable to be able to update them without removing
47 // and adding a key (extract is available in C++17 only).
48 mutable const op_desc_t *op_desc_;
49 mutable const primitive_attr_t *attr_;
50 int pd_iterator_offset_;
51 int impl_nthr_;
52 std::vector<memory_desc_t> hint_mds_;
53 engine_id_t engine_id_;
54
55private:
56 template <typename desc_t>
57 static const desc_t &cast_to_desc(const void *p) {
58 return *(reinterpret_cast<const desc_t *>(p));
59 }
60
61 static primitive_kind_t get_pkind(primitive_kind_t pkind);
62
63 // Thread ID is not used as part of the key, it's only used to get
64 // information about what thread inserted the key and the corresponding
65 // primitive to handle some multithreaded scenarios.
66 std::thread::id thread_id_;
67};
68
69size_t get_md_hash(const memory_desc_t &md);
70size_t get_attr_hash(const primitive_attr_t &attr);
71size_t get_desc_hash(const concat_desc_t &desc);
72size_t get_desc_hash(const batch_normalization_desc_t &desc);
73size_t get_desc_hash(const binary_desc_t &desc);
74size_t get_desc_hash(const convolution_desc_t &desc);
75size_t get_desc_hash(const eltwise_desc_t &desc);
76size_t get_desc_hash(const gemm_desc_t &desc);
77size_t get_desc_hash(const inner_product_desc_t &desc);
78size_t get_desc_hash(const layer_normalization_desc_t &desc);
79size_t get_desc_hash(const lrn_desc_t &desc);
80size_t get_desc_hash(const matmul_desc_t &desc);
81size_t get_desc_hash(const pooling_desc_t &desc);
82size_t get_desc_hash(const prelu_desc_t &desc);
83size_t get_desc_hash(const reduction_desc_t &desc);
84size_t get_desc_hash(const reorder_desc_t &desc);
85size_t get_desc_hash(const resampling_desc_t &desc);
86size_t get_desc_hash(const rnn_desc_t &desc);
87size_t get_desc_hash(const shuffle_desc_t &desc);
88size_t get_desc_hash(const softmax_desc_t &desc);
89size_t get_desc_hash(const sum_desc_t &desc);
90size_t get_desc_hash(const zero_pad_desc_t &desc);
91
92template <typename T>
93size_t get_array_hash(size_t seed, const T *v, int size) {
94 for (int i = 0; i < size; i++) {
95 seed = hash_combine(seed, v[i]);
96 }
97 return seed;
98}
99
100template <>
101inline size_t get_array_hash<memory_desc_t>(
102 size_t seed, const memory_desc_t *v, int size) {
103 for (int i = 0; i < size; i++) {
104 seed = hash_combine(seed, get_md_hash(v[i]));
105 }
106 return seed;
107}
108
109inline size_t get_array_hash(
110 size_t seed, const std::vector<const memory_desc_t *> &mds) {
111 for (const auto *md : mds)
112 seed = hash_combine(seed, get_md_hash(*md));
113 return seed;
114}
115
116} // namespace primitive_hashing
117} // namespace impl
118} // namespace dnnl
119
120// inject a specialization of std::hash for key_t in std namespace
121namespace std {
122template <>
123struct hash<dnnl::impl::primitive_hashing::key_t> {
124 using argument_type = dnnl::impl::primitive_hashing::key_t;
125 using result_type = std::size_t;
126 result_type operator()(const argument_type &key) const {
127 using namespace dnnl::impl;
128 using namespace dnnl::impl::primitive_hashing;
129 size_t seed = 0;
130 // Compute hash for primitive_kind_, attr_, impl_id_ and impl_nthr_
131 seed = hash_combine(seed,
132 hash_combine(0, static_cast<size_t>(key.primitive_kind_)));
133 seed = hash_combine(seed, get_attr_hash(*key.attr_));
134 seed = hash_combine(seed, hash_combine(0, key.pd_iterator_offset_));
135 seed = hash_combine(seed, hash_combine(0, key.impl_nthr_));
136
137 seed = hash_combine(seed, key.engine_id_.hash());
138 // Combine hash for op_desc with the computed hash
139#define CASE(pkind) \
140 case primitive_kind::pkind: \
141 seed = hash_combine( \
142 seed, get_desc_hash(*(pkind##_desc_t *)key.op_desc_)); \
143 break;
144
145 // clang-format off
146 switch ((int)key.primitive_kind_) {
147 CASE(batch_normalization)
148 CASE(binary)
149 CASE(concat)
150 CASE(convolution)
151 CASE(deconvolution)
152 CASE(eltwise)
153 CASE(gemm)
154 CASE(inner_product)
155 CASE(layer_normalization)
156 CASE(lrn)
157 CASE(matmul)
158 CASE(pooling)
159 CASE(prelu)
160 CASE(reduction)
161 CASE(reorder)
162 CASE(resampling)
163 CASE(rnn)
164 CASE(shuffle)
165 CASE(softmax)
166 CASE(sum)
167 CASE(zero_pad)
168 default: assert(!"unknown primitive_kind");
169 }
170 // clang-format on
171#undef CASE
172 seed = get_array_hash(
173 seed, key.hint_mds_.data(), (int)key.hint_mds_.size());
174
175 return seed;
176 }
177};
178
179} // namespace std
180
181#endif
182