1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_PRIMITIVE_HASHING_HPP |
18 | #define COMMON_PRIMITIVE_HASHING_HPP |
19 | |
20 | #include <typeindex> |
21 | #include <type_traits> |
22 | |
23 | #include "c_types_map.hpp" |
24 | #include "engine_id.hpp" |
25 | #include "oneapi/dnnl/dnnl.h" |
26 | #include "primitive_attr.hpp" |
27 | #include "type_helpers.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | |
32 | struct primitive_desc_t; |
33 | namespace primitive_hashing { |
34 | |
35 | struct key_t { |
36 | key_t(const engine_t *engine, const op_desc_t *op_desc, |
37 | const primitive_attr_t *attr, int pd_iterator_offset, |
38 | const std::vector<memory_desc_t> &hint_mds); |
39 | |
40 | key_t(const primitive_desc_t *pd, const engine_t *engine); |
41 | |
42 | bool operator==(const key_t &other) const; |
43 | const std::thread::id &thread_id() const { return thread_id_; } |
44 | |
45 | primitive_kind_t primitive_kind_; |
46 | // Make these data fields mutable to be able to update them without removing |
47 | // and adding a key (extract is available in C++17 only). |
48 | mutable const op_desc_t *op_desc_; |
49 | mutable const primitive_attr_t *attr_; |
50 | int pd_iterator_offset_; |
51 | int impl_nthr_; |
52 | std::vector<memory_desc_t> hint_mds_; |
53 | engine_id_t engine_id_; |
54 | |
55 | private: |
56 | template <typename desc_t> |
57 | static const desc_t &cast_to_desc(const void *p) { |
58 | return *(reinterpret_cast<const desc_t *>(p)); |
59 | } |
60 | |
61 | static primitive_kind_t get_pkind(primitive_kind_t pkind); |
62 | |
63 | // Thread ID is not used as part of the key, it's only used to get |
64 | // information about what thread inserted the key and the corresponding |
65 | // primitive to handle some multithreaded scenarios. |
66 | std::thread::id thread_id_; |
67 | }; |
68 | |
69 | size_t get_md_hash(const memory_desc_t &md); |
70 | size_t get_attr_hash(const primitive_attr_t &attr); |
71 | size_t get_desc_hash(const concat_desc_t &desc); |
72 | size_t get_desc_hash(const batch_normalization_desc_t &desc); |
73 | size_t get_desc_hash(const binary_desc_t &desc); |
74 | size_t get_desc_hash(const convolution_desc_t &desc); |
75 | size_t get_desc_hash(const eltwise_desc_t &desc); |
76 | size_t get_desc_hash(const gemm_desc_t &desc); |
77 | size_t get_desc_hash(const inner_product_desc_t &desc); |
78 | size_t get_desc_hash(const layer_normalization_desc_t &desc); |
79 | size_t get_desc_hash(const lrn_desc_t &desc); |
80 | size_t get_desc_hash(const matmul_desc_t &desc); |
81 | size_t get_desc_hash(const pooling_desc_t &desc); |
82 | size_t get_desc_hash(const prelu_desc_t &desc); |
83 | size_t get_desc_hash(const reduction_desc_t &desc); |
84 | size_t get_desc_hash(const reorder_desc_t &desc); |
85 | size_t get_desc_hash(const resampling_desc_t &desc); |
86 | size_t get_desc_hash(const rnn_desc_t &desc); |
87 | size_t get_desc_hash(const shuffle_desc_t &desc); |
88 | size_t get_desc_hash(const softmax_desc_t &desc); |
89 | size_t get_desc_hash(const sum_desc_t &desc); |
90 | size_t get_desc_hash(const zero_pad_desc_t &desc); |
91 | |
92 | template <typename T> |
93 | size_t get_array_hash(size_t seed, const T *v, int size) { |
94 | for (int i = 0; i < size; i++) { |
95 | seed = hash_combine(seed, v[i]); |
96 | } |
97 | return seed; |
98 | } |
99 | |
100 | template <> |
101 | inline size_t get_array_hash<memory_desc_t>( |
102 | size_t seed, const memory_desc_t *v, int size) { |
103 | for (int i = 0; i < size; i++) { |
104 | seed = hash_combine(seed, get_md_hash(v[i])); |
105 | } |
106 | return seed; |
107 | } |
108 | |
109 | inline size_t get_array_hash( |
110 | size_t seed, const std::vector<const memory_desc_t *> &mds) { |
111 | for (const auto *md : mds) |
112 | seed = hash_combine(seed, get_md_hash(*md)); |
113 | return seed; |
114 | } |
115 | |
116 | } // namespace primitive_hashing |
117 | } // namespace impl |
118 | } // namespace dnnl |
119 | |
120 | // inject a specialization of std::hash for key_t in std namespace |
121 | namespace std { |
122 | template <> |
123 | struct hash<dnnl::impl::primitive_hashing::key_t> { |
124 | using argument_type = dnnl::impl::primitive_hashing::key_t; |
125 | using result_type = std::size_t; |
126 | result_type operator()(const argument_type &key) const { |
127 | using namespace dnnl::impl; |
128 | using namespace dnnl::impl::primitive_hashing; |
129 | size_t seed = 0; |
130 | // Compute hash for primitive_kind_, attr_, impl_id_ and impl_nthr_ |
131 | seed = hash_combine(seed, |
132 | hash_combine(0, static_cast<size_t>(key.primitive_kind_))); |
133 | seed = hash_combine(seed, get_attr_hash(*key.attr_)); |
134 | seed = hash_combine(seed, hash_combine(0, key.pd_iterator_offset_)); |
135 | seed = hash_combine(seed, hash_combine(0, key.impl_nthr_)); |
136 | |
137 | seed = hash_combine(seed, key.engine_id_.hash()); |
138 | // Combine hash for op_desc with the computed hash |
139 | #define CASE(pkind) \ |
140 | case primitive_kind::pkind: \ |
141 | seed = hash_combine( \ |
142 | seed, get_desc_hash(*(pkind##_desc_t *)key.op_desc_)); \ |
143 | break; |
144 | |
145 | // clang-format off |
146 | switch ((int)key.primitive_kind_) { |
147 | CASE(batch_normalization) |
148 | CASE(binary) |
149 | CASE(concat) |
150 | CASE(convolution) |
151 | CASE(deconvolution) |
152 | CASE(eltwise) |
153 | CASE(gemm) |
154 | CASE(inner_product) |
155 | CASE(layer_normalization) |
156 | CASE(lrn) |
157 | CASE(matmul) |
158 | CASE(pooling) |
159 | CASE(prelu) |
160 | CASE(reduction) |
161 | CASE(reorder) |
162 | CASE(resampling) |
163 | CASE(rnn) |
164 | CASE(shuffle) |
165 | CASE(softmax) |
166 | CASE(sum) |
167 | CASE(zero_pad) |
168 | default: assert(!"unknown primitive_kind" ); |
169 | } |
170 | // clang-format on |
171 | #undef CASE |
172 | seed = get_array_hash( |
173 | seed, key.hint_mds_.data(), (int)key.hint_mds_.size()); |
174 | |
175 | return seed; |
176 | } |
177 | }; |
178 | |
179 | } // namespace std |
180 | |
181 | #endif |
182 | |