1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_IMPL_LIST_ITEM_HPP
18#define COMMON_IMPL_LIST_ITEM_HPP
19
20#include "c_types_map.hpp"
21#include "primitive_desc.hpp"
22#include "utils.hpp"
23
24namespace dnnl {
25namespace impl {
26
27// This key takes prop_kind and correspondent data_type for src, wei and dst.
28struct pk_dt_impl_key_t {
29 prop_kind_t kind;
30 data_type_t src_dt, wei_dt, dst_dt;
31
32 bool operator<(const pk_dt_impl_key_t &rhs) const {
33 return value() < rhs.value();
34 }
35
36private:
37 enum { MAX_DT_NUM = 10 };
38 size_t value() const {
39 return (((size_t)kind * MAX_DT_NUM + (size_t)src_dt) * MAX_DT_NUM
40 + (size_t)wei_dt)
41 * MAX_DT_NUM
42 + (size_t)dst_dt;
43 }
44};
45
46// This is a simpler version of key to use only prop_kind.
47struct pk_impl_key_t {
48 prop_kind_t kind;
49
50 bool operator<(const pk_impl_key_t &rhs) const {
51 return value() < rhs.value();
52 }
53
54private:
55 size_t value() const { return (size_t)kind; }
56};
57
58struct impl_list_item_t {
59 constexpr impl_list_item_t() = default;
60 constexpr impl_list_item_t(const impl_list_item_t &other) = default;
61 constexpr impl_list_item_t(impl_list_item_t &&other) = default;
62 impl_list_item_t &operator=(const impl_list_item_t &other) = default;
63 impl_list_item_t &operator=(impl_list_item_t &&other) = default;
64
65 constexpr impl_list_item_t(std::nullptr_t) {}
66
67 template <typename pd_t>
68 struct type_deduction_helper_t {
69 using type = pd_t;
70 constexpr type_deduction_helper_t() {
71 static_assert(std::is_base_of<primitive_desc_t, pd_t>::value,
72 "type_deduction_helper_t is expected to be used for "
73 "primitive descriptor classes only.");
74 }
75 };
76
77 template <typename pd_t>
78 struct concat_type_deduction_helper_t
79 : public type_deduction_helper_t<pd_t> {
80 constexpr concat_type_deduction_helper_t() = default;
81 };
82
83 template <typename pd_t>
84 struct sum_type_deduction_helper_t : public type_deduction_helper_t<pd_t> {
85 };
86
87 template <typename pd_t>
88 struct reorder_type_deduction_helper_t
89 : public type_deduction_helper_t<pd_t> {};
90
91 template <typename pd_t>
92 constexpr impl_list_item_t(type_deduction_helper_t<pd_t>)
93 : create_pd_func_(&primitive_desc_t::create<
94 typename type_deduction_helper_t<pd_t>::type>) {}
95
96 template <typename pd_t>
97 constexpr impl_list_item_t(concat_type_deduction_helper_t<pd_t>)
98 : create_concat_pd_func_(
99 concat_type_deduction_helper_t<pd_t>::type::create) {}
100
101 template <typename pd_t>
102 constexpr impl_list_item_t(sum_type_deduction_helper_t<pd_t>)
103 : create_sum_pd_func_(sum_type_deduction_helper_t<pd_t>::type::create) {
104 }
105
106 template <typename pd_t>
107 constexpr impl_list_item_t(reorder_type_deduction_helper_t<pd_t>)
108 : create_reorder_pd_func_(
109 reorder_type_deduction_helper_t<pd_t>::type::create) {}
110
111 explicit operator bool() const {
112 return !utils::everyone_is(nullptr, create_pd_func_,
113 create_concat_pd_func_, create_sum_pd_func_,
114 create_reorder_pd_func_);
115 }
116
117 // Currently, this only supports iterator friendly primitives. Can be
118 // extended to sum, concat and reorder if needed.
119 template <typename pd_t>
120 static int find(const impl_list_item_t *list) {
121 int idx = 0;
122 for (const impl_list_item_t *cur = list; *cur; cur++) {
123 if (cur->create_pd_func_ == &primitive_desc_t::create<pd_t>)
124 return idx;
125 idx++;
126 }
127 return -1;
128 }
129
130private:
131 status_t operator()(primitive_desc_t **pd, const op_desc_t *adesc,
132 const primitive_attr_t *attr, engine_t *engine,
133 const primitive_desc_t *hint_fwd, int pd_iterator_offset) const {
134 assert(create_pd_func_);
135 if (!create_pd_func_) return status::runtime_error;
136 auto status = create_pd_func_(pd, adesc, attr, engine, hint_fwd);
137 if (status == status::success)
138 (*pd)->init_pd_iterator_offset(pd_iterator_offset);
139 return status;
140 }
141
142 status_t operator()(concat_pd_t **concat_pd, engine_t *engine,
143 const primitive_attr_t *attr, const memory_desc_t *dst_md, int n,
144 int concat_dim, const memory_desc_t *const *src_mds) const {
145 assert(create_concat_pd_func_);
146 if (!create_concat_pd_func_) return status::runtime_error;
147 return create_concat_pd_func_(
148 concat_pd, engine, attr, dst_md, n, concat_dim, src_mds);
149 }
150
151 status_t operator()(sum_pd_t **sum_pd, engine_t *engine,
152 const primitive_attr_t *attr, const memory_desc_t *dst_md, int n,
153 const float *scales, const memory_desc_t *const *src_mds) const {
154 assert(create_sum_pd_func_);
155 if (!create_sum_pd_func_) return status::runtime_error;
156 return create_sum_pd_func_(
157 sum_pd, engine, attr, dst_md, n, scales, src_mds);
158 }
159
160 status_t operator()(reorder_pd_t **reorder_pd, engine_t *engine,
161 const primitive_attr_t *attr, engine_t *src_engine,
162 const memory_desc_t *src_md, engine_t *dst_engine,
163 const memory_desc_t *dst_md) const {
164 if (!create_reorder_pd_func_) return status::runtime_error;
165 return create_reorder_pd_func_(reorder_pd, engine, attr, src_engine,
166 src_md, dst_engine, dst_md);
167 }
168
169 using create_pd_func_t = status_t (*)(primitive_desc_t **,
170 const op_desc_t *, const primitive_attr_t *, engine_t *,
171 const primitive_desc_t *);
172
173 using create_concat_pd_func_t = status_t (*)(concat_pd_t **, engine_t *,
174 const primitive_attr_t *, const memory_desc_t *, int, int,
175 const memory_desc_t *const *);
176
177 using create_sum_pd_func_t = status_t (*)(sum_pd_t **, engine_t *,
178 const primitive_attr_t *, const memory_desc_t *, int, const float *,
179 const memory_desc_t *const *);
180
181 using create_reorder_pd_func_t = status_t (*)(reorder_pd_t **, engine_t *,
182 const primitive_attr_t *, engine_t *, const memory_desc_t *,
183 engine_t *, const memory_desc_t *);
184
185 create_pd_func_t create_pd_func_ = nullptr;
186 create_concat_pd_func_t create_concat_pd_func_ = nullptr;
187 create_sum_pd_func_t create_sum_pd_func_ = nullptr;
188 create_reorder_pd_func_t create_reorder_pd_func_ = nullptr;
189
190 // List of functions/classes that have permissions to create primitive
191 // descriptors.
192 friend struct primitive_desc_iterator_t;
193 friend status_t concat_primitive_desc_create(
194 std::shared_ptr<primitive_desc_t> &, engine_t *,
195 const memory_desc_t *, int, int, const memory_desc_t *const *,
196 const primitive_attr_t *);
197 friend status_t sum_primitive_desc_create(primitive_desc_iface_t **,
198 const memory_desc_t *, int, const float *,
199 const memory_desc_t *const *, const primitive_attr_t *, engine_t *);
200 friend status_t reorder_primitive_desc_create(
201 std::shared_ptr<primitive_desc_t> &, engine_t *,
202 const memory_desc_t *, engine_t *, const memory_desc_t *,
203 engine_t *, const primitive_attr_t *);
204};
205
206} // namespace impl
207} // namespace dnnl
208
209#endif
210