1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_IMPL_LIST_ITEM_HPP |
18 | #define COMMON_IMPL_LIST_ITEM_HPP |
19 | |
20 | #include "c_types_map.hpp" |
21 | #include "primitive_desc.hpp" |
22 | #include "utils.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | |
27 | // This key takes prop_kind and correspondent data_type for src, wei and dst. |
28 | struct pk_dt_impl_key_t { |
29 | prop_kind_t kind; |
30 | data_type_t src_dt, wei_dt, dst_dt; |
31 | |
32 | bool operator<(const pk_dt_impl_key_t &rhs) const { |
33 | return value() < rhs.value(); |
34 | } |
35 | |
36 | private: |
37 | enum { MAX_DT_NUM = 10 }; |
38 | size_t value() const { |
39 | return (((size_t)kind * MAX_DT_NUM + (size_t)src_dt) * MAX_DT_NUM |
40 | + (size_t)wei_dt) |
41 | * MAX_DT_NUM |
42 | + (size_t)dst_dt; |
43 | } |
44 | }; |
45 | |
46 | // This is a simpler version of key to use only prop_kind. |
47 | struct pk_impl_key_t { |
48 | prop_kind_t kind; |
49 | |
50 | bool operator<(const pk_impl_key_t &rhs) const { |
51 | return value() < rhs.value(); |
52 | } |
53 | |
54 | private: |
55 | size_t value() const { return (size_t)kind; } |
56 | }; |
57 | |
58 | struct impl_list_item_t { |
59 | constexpr impl_list_item_t() = default; |
60 | constexpr impl_list_item_t(const impl_list_item_t &other) = default; |
61 | constexpr impl_list_item_t(impl_list_item_t &&other) = default; |
62 | impl_list_item_t &operator=(const impl_list_item_t &other) = default; |
63 | impl_list_item_t &operator=(impl_list_item_t &&other) = default; |
64 | |
65 | constexpr impl_list_item_t(std::nullptr_t) {} |
66 | |
67 | template <typename pd_t> |
68 | struct type_deduction_helper_t { |
69 | using type = pd_t; |
70 | constexpr type_deduction_helper_t() { |
71 | static_assert(std::is_base_of<primitive_desc_t, pd_t>::value, |
72 | "type_deduction_helper_t is expected to be used for " |
73 | "primitive descriptor classes only." ); |
74 | } |
75 | }; |
76 | |
77 | template <typename pd_t> |
78 | struct concat_type_deduction_helper_t |
79 | : public type_deduction_helper_t<pd_t> { |
80 | constexpr concat_type_deduction_helper_t() = default; |
81 | }; |
82 | |
83 | template <typename pd_t> |
84 | struct sum_type_deduction_helper_t : public type_deduction_helper_t<pd_t> { |
85 | }; |
86 | |
87 | template <typename pd_t> |
88 | struct reorder_type_deduction_helper_t |
89 | : public type_deduction_helper_t<pd_t> {}; |
90 | |
91 | template <typename pd_t> |
92 | constexpr impl_list_item_t(type_deduction_helper_t<pd_t>) |
93 | : create_pd_func_(&primitive_desc_t::create< |
94 | typename type_deduction_helper_t<pd_t>::type>) {} |
95 | |
96 | template <typename pd_t> |
97 | constexpr impl_list_item_t(concat_type_deduction_helper_t<pd_t>) |
98 | : create_concat_pd_func_( |
99 | concat_type_deduction_helper_t<pd_t>::type::create) {} |
100 | |
101 | template <typename pd_t> |
102 | constexpr impl_list_item_t(sum_type_deduction_helper_t<pd_t>) |
103 | : create_sum_pd_func_(sum_type_deduction_helper_t<pd_t>::type::create) { |
104 | } |
105 | |
106 | template <typename pd_t> |
107 | constexpr impl_list_item_t(reorder_type_deduction_helper_t<pd_t>) |
108 | : create_reorder_pd_func_( |
109 | reorder_type_deduction_helper_t<pd_t>::type::create) {} |
110 | |
111 | explicit operator bool() const { |
112 | return !utils::everyone_is(nullptr, create_pd_func_, |
113 | create_concat_pd_func_, create_sum_pd_func_, |
114 | create_reorder_pd_func_); |
115 | } |
116 | |
117 | // Currently, this only supports iterator friendly primitives. Can be |
118 | // extended to sum, concat and reorder if needed. |
119 | template <typename pd_t> |
120 | static int find(const impl_list_item_t *list) { |
121 | int idx = 0; |
122 | for (const impl_list_item_t *cur = list; *cur; cur++) { |
123 | if (cur->create_pd_func_ == &primitive_desc_t::create<pd_t>) |
124 | return idx; |
125 | idx++; |
126 | } |
127 | return -1; |
128 | } |
129 | |
130 | private: |
131 | status_t operator()(primitive_desc_t **pd, const op_desc_t *adesc, |
132 | const primitive_attr_t *attr, engine_t *engine, |
133 | const primitive_desc_t *hint_fwd, int pd_iterator_offset) const { |
134 | assert(create_pd_func_); |
135 | if (!create_pd_func_) return status::runtime_error; |
136 | auto status = create_pd_func_(pd, adesc, attr, engine, hint_fwd); |
137 | if (status == status::success) |
138 | (*pd)->init_pd_iterator_offset(pd_iterator_offset); |
139 | return status; |
140 | } |
141 | |
142 | status_t operator()(concat_pd_t **concat_pd, engine_t *engine, |
143 | const primitive_attr_t *attr, const memory_desc_t *dst_md, int n, |
144 | int concat_dim, const memory_desc_t *const *src_mds) const { |
145 | assert(create_concat_pd_func_); |
146 | if (!create_concat_pd_func_) return status::runtime_error; |
147 | return create_concat_pd_func_( |
148 | concat_pd, engine, attr, dst_md, n, concat_dim, src_mds); |
149 | } |
150 | |
151 | status_t operator()(sum_pd_t **sum_pd, engine_t *engine, |
152 | const primitive_attr_t *attr, const memory_desc_t *dst_md, int n, |
153 | const float *scales, const memory_desc_t *const *src_mds) const { |
154 | assert(create_sum_pd_func_); |
155 | if (!create_sum_pd_func_) return status::runtime_error; |
156 | return create_sum_pd_func_( |
157 | sum_pd, engine, attr, dst_md, n, scales, src_mds); |
158 | } |
159 | |
160 | status_t operator()(reorder_pd_t **reorder_pd, engine_t *engine, |
161 | const primitive_attr_t *attr, engine_t *src_engine, |
162 | const memory_desc_t *src_md, engine_t *dst_engine, |
163 | const memory_desc_t *dst_md) const { |
164 | if (!create_reorder_pd_func_) return status::runtime_error; |
165 | return create_reorder_pd_func_(reorder_pd, engine, attr, src_engine, |
166 | src_md, dst_engine, dst_md); |
167 | } |
168 | |
169 | using create_pd_func_t = status_t (*)(primitive_desc_t **, |
170 | const op_desc_t *, const primitive_attr_t *, engine_t *, |
171 | const primitive_desc_t *); |
172 | |
173 | using create_concat_pd_func_t = status_t (*)(concat_pd_t **, engine_t *, |
174 | const primitive_attr_t *, const memory_desc_t *, int, int, |
175 | const memory_desc_t *const *); |
176 | |
177 | using create_sum_pd_func_t = status_t (*)(sum_pd_t **, engine_t *, |
178 | const primitive_attr_t *, const memory_desc_t *, int, const float *, |
179 | const memory_desc_t *const *); |
180 | |
181 | using create_reorder_pd_func_t = status_t (*)(reorder_pd_t **, engine_t *, |
182 | const primitive_attr_t *, engine_t *, const memory_desc_t *, |
183 | engine_t *, const memory_desc_t *); |
184 | |
185 | create_pd_func_t create_pd_func_ = nullptr; |
186 | create_concat_pd_func_t create_concat_pd_func_ = nullptr; |
187 | create_sum_pd_func_t create_sum_pd_func_ = nullptr; |
188 | create_reorder_pd_func_t create_reorder_pd_func_ = nullptr; |
189 | |
190 | // List of functions/classes that have permissions to create primitive |
191 | // descriptors. |
192 | friend struct primitive_desc_iterator_t; |
193 | friend status_t concat_primitive_desc_create( |
194 | std::shared_ptr<primitive_desc_t> &, engine_t *, |
195 | const memory_desc_t *, int, int, const memory_desc_t *const *, |
196 | const primitive_attr_t *); |
197 | friend status_t sum_primitive_desc_create(primitive_desc_iface_t **, |
198 | const memory_desc_t *, int, const float *, |
199 | const memory_desc_t *const *, const primitive_attr_t *, engine_t *); |
200 | friend status_t reorder_primitive_desc_create( |
201 | std::shared_ptr<primitive_desc_t> &, engine_t *, |
202 | const memory_desc_t *, engine_t *, const memory_desc_t *, |
203 | engine_t *, const primitive_attr_t *); |
204 | }; |
205 | |
206 | } // namespace impl |
207 | } // namespace dnnl |
208 | |
209 | #endif |
210 | |