1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_PRIMITIVE_DESC_HPP |
18 | #define COMMON_PRIMITIVE_DESC_HPP |
19 | |
20 | #include <typeindex> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "c_types_map.hpp" |
25 | #include "cache_blob.hpp" |
26 | #include "cache_blob_id.hpp" |
27 | #include "memory_tracking.hpp" |
28 | #include "nstl.hpp" |
29 | #include "opdesc.hpp" |
30 | #include "primitive_attr.hpp" |
31 | #include "primitive_cache.hpp" |
32 | #include "type_helpers.hpp" |
33 | #include "verbose.hpp" |
34 | |
35 | namespace dnnl { |
36 | namespace impl { |
37 | |
38 | static int po_inputs(const post_ops_t &post_ops, const primitive_kind_t kind) { |
39 | int n_inputs = 0; |
40 | for (int idx = 0; idx < post_ops.len(); ++idx) { |
41 | if (post_ops.contain(kind, idx)) n_inputs++; |
42 | } |
43 | return n_inputs; |
44 | } |
45 | |
46 | struct impl_list_item_t; |
47 | struct primitive_t; |
48 | // Primitive descriptor implementation |
49 | struct primitive_desc_t : public c_compatible { |
50 | primitive_desc_t(const primitive_attr_t *attr, primitive_kind_t kind) |
51 | : attr_(*attr), kind_(kind), pd_iterator_offset_(0) { |
52 | is_initialized_ = is_initialized_ && attr_.is_initialized(); |
53 | } |
54 | |
55 | primitive_desc_t(primitive_kind_t kind) : kind_(kind) {} |
56 | |
57 | bool is_initialized() const { return is_initialized_; } |
58 | |
59 | virtual ~primitive_desc_t() = default; |
60 | virtual primitive_desc_t *clone() const = 0; |
61 | |
62 | const primitive_attr_t *attr() const { return &attr_; } |
63 | primitive_kind_t kind() const { return kind_; } |
64 | |
65 | const char *info(engine_t *engine) const { |
66 | if (!info_.is_initialized()) info_.init(engine, this); |
67 | return info_.c_str(); |
68 | } |
69 | |
70 | memory_tracking::registry_t &scratchpad_registry() { |
71 | return scratchpad_registry_; |
72 | } |
73 | const memory_tracking::registry_t &scratchpad_registry() const { |
74 | return scratchpad_registry_; |
75 | } |
76 | |
77 | virtual const op_desc_t *op_desc() const { return nullptr; } |
78 | |
79 | const std::vector<uint8_t> &get_cache_blob_id(engine_t *engine) const { |
80 | return cache_blob_id_.get(engine, this); |
81 | } |
82 | |
83 | static bool post_op_has_proper_input(const primitive_attr_t *attr, |
84 | const primitive_kind_t prim, const int idx, const int arg, |
85 | const int src_mnemonic) { |
86 | return (attr->post_ops_.contain(prim, idx) |
87 | && arg == (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | src_mnemonic)); |
88 | } |
89 | |
90 | enum class arg_usage_t { unused, input, output }; |
91 | virtual arg_usage_t arg_usage(int arg) const { |
92 | using types::is_zero_md; |
93 | if (arg == DNNL_ARG_ATTR_OUTPUT_SCALES |
94 | && !attr()->output_scales_.defined()) |
95 | return arg_usage_t::input; |
96 | if (arg & DNNL_ARG_ATTR_ZERO_POINTS) { |
97 | int zp_arg = arg & ~DNNL_ARG_ATTR_ZERO_POINTS; |
98 | if (!attr()->zero_points_.defined(zp_arg)) |
99 | return arg_usage_t::input; |
100 | } |
101 | if (arg & DNNL_ARG_ATTR_SCALES) { |
102 | int scale_arg = arg & ~DNNL_ARG_ATTR_SCALES; |
103 | if (!attr()->scales_.get(scale_arg).defined()) |
104 | return arg_usage_t::input; |
105 | } |
106 | if ((arg == (DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0)) |
107 | && !attr()->scales_.get(DNNL_ARG_SRC_0).defined()) |
108 | return arg_usage_t::input; |
109 | if ((arg == (DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1)) |
110 | && !attr()->scales_.get(DNNL_ARG_SRC_1).defined()) |
111 | return arg_usage_t::input; |
112 | if (arg == DNNL_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) |
113 | return arg_usage_t::output; |
114 | for (int idx = 0; idx < attr()->post_ops_.len(); ++idx) { |
115 | using namespace primitive_kind; |
116 | if (post_op_has_proper_input( |
117 | attr(), binary, idx, arg, DNNL_ARG_SRC_1) |
118 | || post_op_has_proper_input( |
119 | attr(), prelu, idx, arg, DNNL_ARG_WEIGHTS)) |
120 | return arg_usage_t::input; |
121 | } |
122 | |
123 | return arg_usage_t::unused; |
124 | } |
125 | |
126 | virtual const memory_desc_t *arg_md(int arg) const { |
127 | // Separate binary post-ops sections due to inability to express inside |
128 | // switch statement. |
129 | if (arg >= DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) |
130 | && arg < DNNL_ARG_ATTR_MULTIPLE_POST_OP( |
131 | post_ops_t::post_ops_limit)) { |
132 | const auto &po = attr()->post_ops_; |
133 | for (int idx = 0; idx < po.len(); ++idx) { |
134 | if (arg |
135 | != (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
136 | | DNNL_ARG_SRC_1)) |
137 | continue; |
138 | |
139 | return &po.entry_[idx].binary.src1_desc; |
140 | } |
141 | } |
142 | |
143 | switch (arg) { |
144 | case DNNL_ARG_WORKSPACE: return workspace_md(0); |
145 | case DNNL_ARG_SCRATCHPAD: return scratchpad_md(0); |
146 | default: return &glob_zero_md; |
147 | } |
148 | } |
149 | |
150 | #define DECLARE_MD_STUB(stub) \ |
151 | virtual const memory_desc_t *stub(int idx = 0) const { \ |
152 | return &glob_zero_md; \ |
153 | } |
154 | |
155 | DECLARE_MD_STUB(input_md); |
156 | DECLARE_MD_STUB(output_md); |
157 | DECLARE_MD_STUB(src_md); |
158 | DECLARE_MD_STUB(diff_src_md); |
159 | DECLARE_MD_STUB(dst_md); |
160 | DECLARE_MD_STUB(diff_dst_md); |
161 | DECLARE_MD_STUB(weights_md); |
162 | DECLARE_MD_STUB(diff_weights_md); |
163 | DECLARE_MD_STUB(workspace_md); |
164 | #undef DECLARE_MD_STUB |
165 | |
166 | const memory_desc_t *scratchpad_md(int idx = 0) const { |
167 | return idx == 0 ? &scratchpad_md_ : &glob_zero_md; |
168 | } |
169 | |
170 | void init_scratchpad_md() { |
171 | auto size = scratchpad_size(scratchpad_mode::user); |
172 | dims_t dims = {size}; |
173 | memory_desc_init_by_tag( |
174 | scratchpad_md_, size ? 1 : 0, dims, data_type::u8, dnnl_x); |
175 | } |
176 | |
177 | /** returns the scratchpad size for the given scratchpad mode. */ |
178 | dim_t scratchpad_size(scratchpad_mode_t mode) const { |
179 | if (mode != attr_.scratchpad_mode_) return 0; |
180 | return scratchpad_registry().size(); |
181 | } |
182 | |
183 | virtual status_t query(query_t what, int idx, void *result) const { |
184 | auto safe_ret_md = [&](const memory_desc_t *_) { |
185 | if (_ == nullptr) return status::not_required; |
186 | *(const memory_desc_t **)result = _; |
187 | return status::success; |
188 | }; |
189 | |
190 | switch (what) { |
191 | case query::primitive_kind: |
192 | *(primitive_kind_t *)result = kind(); |
193 | break; |
194 | |
195 | case query::memory_consumption_s64: |
196 | *(dim_t *)result = scratchpad_size(scratchpad_mode::library); |
197 | break; |
198 | |
199 | case query::exec_arg_md: return safe_ret_md(arg_md(idx)); |
200 | case query::src_md: return safe_ret_md(src_md(idx)); |
201 | case query::diff_src_md: return safe_ret_md(diff_src_md(idx)); |
202 | case query::dst_md: return safe_ret_md(dst_md(idx)); |
203 | case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx)); |
204 | case query::weights_md: return safe_ret_md(weights_md(idx)); |
205 | case query::diff_weights_md: |
206 | return safe_ret_md(diff_weights_md(idx)); |
207 | case query::workspace_md: |
208 | if (idx != 0) return status::invalid_arguments; |
209 | return safe_ret_md(workspace_md(idx)); |
210 | case query::scratchpad_md: |
211 | if (idx != 0) return status::invalid_arguments; |
212 | return safe_ret_md(scratchpad_md(idx)); |
213 | |
214 | case query::num_of_inputs_s32: *(int *)result = n_inputs(); break; |
215 | case query::num_of_outputs_s32: *(int *)result = n_outputs(); break; |
216 | |
217 | case query::impl_info_str: *(const char **)result = name(); break; |
218 | |
219 | default: return status::unimplemented; |
220 | } |
221 | return status::success; |
222 | } |
223 | |
224 | virtual int n_inputs() const { return 0; } |
225 | virtual int n_outputs() const { return 0; } |
226 | int n_binary_po_inputs() const { |
227 | return po_inputs(attr()->post_ops_, primitive_kind::binary); |
228 | } |
229 | |
230 | int n_prelu_po_inputs() const { |
231 | return po_inputs(attr()->post_ops_, primitive_kind::prelu); |
232 | } |
233 | // The `hint_mds(bool is_hint)` returns a vector of memory descriptors |
234 | // that might affect the equality of primitive descriptors for backward pass. |
235 | // |
236 | // This function is used for creating a key to fetch primitive or primitive |
237 | // descriptor from cache. |
238 | // |
239 | // 1. When creating a primitive descriptor for backward pass there may be |
240 | // a forward primitive descriptor hint that can be used to obtain the |
241 | // memory descriptors. In this case the `is_hint` argument must be `true`. |
242 | // 2. When creating a primitive this function is called for a primitive |
243 | // descriptor that can be either forward or backward. In this case |
244 | // the `is_hint` argument must be `false`. |
245 | // - For forward it will return an empty vector. |
246 | // - For backward it will return a vector of memory descriptors if |
247 | // the implementation depends on a forward primitive descriptor. |
248 | // |
249 | // The current cases are: |
250 | // - pooling |
251 | // - shuffle |
252 | // |
253 | // Later the list of primitives can be extended. For instance, currently |
254 | // there is no convolution on the list because nthrs + op_desc |
255 | // (even with format=`any`) + attributes fully define a particular |
256 | // implementation. |
257 | virtual std::vector<memory_desc_t> hint_mds(bool is_hint) const { |
258 | UNUSED(is_hint); |
259 | return {}; |
260 | } |
261 | |
262 | virtual status_t create_primitive( |
263 | std::pair<std::shared_ptr<primitive_t>, bool> &primitive, |
264 | engine_t *engine, const cache_blob_t &cache_blob) const = 0; |
265 | |
266 | // This is a proxy interface that is used for creating nested primitives. |
267 | // It ignores the bool value that indicates whether the requested primitive |
268 | // was taken from cache. |
269 | status_t create_primitive(std::shared_ptr<primitive_t> &primitive, |
270 | engine_t *engine, |
271 | const cache_blob_t &cache_blob = cache_blob_t()) const { |
272 | std::pair<std::shared_ptr<primitive_t>, bool> p; |
273 | CHECK(create_primitive(p, engine, cache_blob)); |
274 | primitive = p.first; |
275 | return status::success; |
276 | } |
277 | |
278 | virtual const char *name() const = 0; |
279 | |
280 | int pd_iterator_offset() const { return pd_iterator_offset_; } |
281 | |
282 | protected: |
283 | primitive_attr_t attr_; |
284 | primitive_kind_t kind_; |
285 | int pd_iterator_offset_; |
286 | |
287 | memory_desc_t scratchpad_md_; |
288 | |
289 | mutable pd_info_t info_; |
290 | mutable cache_blob_id_t cache_blob_id_; |
291 | |
292 | memory_tracking::registry_t scratchpad_registry_; |
293 | |
294 | protected: |
295 | void init_pd_iterator_offset(int offset) { pd_iterator_offset_ = offset; } |
296 | |
297 | /** compares ws between fwd_pd and this (make sense to use for bwd_pd) |
298 | * Expectation: this already set workspace, and this workspace should |
299 | * exactly match the one from fwd_pd */ |
300 | bool compare_ws(const primitive_desc_t *fwd_pd) const { |
301 | if (!workspace_md()) return true; // the impl lives fine w/o workspace |
302 | return fwd_pd && fwd_pd->workspace_md() |
303 | && *fwd_pd->workspace_md() == *workspace_md(); |
304 | } |
305 | |
306 | primitive_desc_t &operator=(const primitive_desc_t &other) = delete; |
307 | |
308 | /* static magic */ |
309 | |
310 | template <typename pd_t> |
311 | static status_t create(primitive_desc_t **pd, const op_desc_t *adesc, |
312 | const primitive_attr_t *attr, engine_t *engine, |
313 | const primitive_desc_t *hint_fwd) { |
314 | using namespace dnnl::impl::status; |
315 | using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type; |
316 | if (adesc->kind != pd_t::base_pkind) return invalid_arguments; |
317 | assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); |
318 | auto hint |
319 | = reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd); |
320 | auto _pd = new pd_t((const pd_op_desc_t *)adesc, attr, hint); |
321 | if (_pd == nullptr) return out_of_memory; |
322 | if (!_pd->is_initialized()) { |
323 | delete _pd; |
324 | return out_of_memory; |
325 | } |
326 | if (_pd->init(engine) != success) { |
327 | delete _pd; |
328 | return unimplemented; |
329 | } |
330 | |
331 | _pd->init_scratchpad_md(); |
332 | *pd = _pd; |
333 | return success; |
334 | } |
335 | |
336 | friend struct dnnl::impl::impl_list_item_t; |
337 | }; |
338 | |
339 | } // namespace impl |
340 | } // namespace dnnl |
341 | |
342 | #define DECLARE_COMMON_PD_t(impl_name, impl_type, use_global_scratchpad) \ |
343 | pd_t *clone() const override { \ |
344 | auto new_pd = utils::make_unique<pd_t>(*this); \ |
345 | if (!new_pd->is_initialized()) return nullptr; \ |
346 | return new_pd.release(); \ |
347 | } \ |
348 | status_t create_primitive( \ |
349 | std::pair<std::shared_ptr<primitive_t>, bool> &primitive, \ |
350 | engine_t *engine, const cache_blob_t &cache_blob) const override { \ |
351 | return primitive_t::create_primitive_common<impl_type, pd_t>( \ |
352 | primitive, this, engine, use_global_scratchpad, cache_blob); \ |
353 | } \ |
354 | const char *name() const override { return impl_name; } \ |
355 | template <typename pd_t> \ |
356 | friend status_t primitive_desc_t::create(primitive_desc_t **pd, \ |
357 | const op_desc_t *adesc, const primitive_attr_t *attr, \ |
358 | engine_t *engine, const primitive_desc_t *hint_fwd); |
359 | |
360 | #define DECLARE_COMMON_PD_T_USE_GLOBAL_SCRATCHPAD(impl_name, impl_type) \ |
361 | DECLARE_COMMON_PD_t(impl_name, impl_type, true) |
362 | |
363 | #define DECLARE_COMMON_PD_T_(impl_name, impl_type) \ |
364 | DECLARE_COMMON_PD_t(impl_name, impl_type, false) |
365 | |
366 | #define DECLARE_COMMON_PD_T(impl_name, impl_type, ...) \ |
367 | DECLARE_COMMON_PD_T_##__VA_ARGS__(impl_name, impl_type) |
368 | |
369 | #endif |
370 | |
371 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
372 | |