1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_PRIMITIVE_DESC_HPP
18#define COMMON_PRIMITIVE_DESC_HPP
19
20#include <typeindex>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "c_types_map.hpp"
25#include "cache_blob.hpp"
26#include "cache_blob_id.hpp"
27#include "memory_tracking.hpp"
28#include "nstl.hpp"
29#include "opdesc.hpp"
30#include "primitive_attr.hpp"
31#include "primitive_cache.hpp"
32#include "type_helpers.hpp"
33#include "verbose.hpp"
34
35namespace dnnl {
36namespace impl {
37
38static int po_inputs(const post_ops_t &post_ops, const primitive_kind_t kind) {
39 int n_inputs = 0;
40 for (int idx = 0; idx < post_ops.len(); ++idx) {
41 if (post_ops.contain(kind, idx)) n_inputs++;
42 }
43 return n_inputs;
44}
45
46struct impl_list_item_t;
47struct primitive_t;
48// Primitive descriptor implementation
49struct primitive_desc_t : public c_compatible {
50 primitive_desc_t(const primitive_attr_t *attr, primitive_kind_t kind)
51 : attr_(*attr), kind_(kind), pd_iterator_offset_(0) {
52 is_initialized_ = is_initialized_ && attr_.is_initialized();
53 }
54
55 primitive_desc_t(primitive_kind_t kind) : kind_(kind) {}
56
57 bool is_initialized() const { return is_initialized_; }
58
59 virtual ~primitive_desc_t() = default;
60 virtual primitive_desc_t *clone() const = 0;
61
62 const primitive_attr_t *attr() const { return &attr_; }
63 primitive_kind_t kind() const { return kind_; }
64
65 const char *info(engine_t *engine) const {
66 if (!info_.is_initialized()) info_.init(engine, this);
67 return info_.c_str();
68 }
69
70 memory_tracking::registry_t &scratchpad_registry() {
71 return scratchpad_registry_;
72 }
73 const memory_tracking::registry_t &scratchpad_registry() const {
74 return scratchpad_registry_;
75 }
76
77 virtual const op_desc_t *op_desc() const { return nullptr; }
78
79 const std::vector<uint8_t> &get_cache_blob_id(engine_t *engine) const {
80 return cache_blob_id_.get(engine, this);
81 }
82
83 static bool post_op_has_proper_input(const primitive_attr_t *attr,
84 const primitive_kind_t prim, const int idx, const int arg,
85 const int src_mnemonic) {
86 return (attr->post_ops_.contain(prim, idx)
87 && arg == (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | src_mnemonic));
88 }
89
90 enum class arg_usage_t { unused, input, output };
91 virtual arg_usage_t arg_usage(int arg) const {
92 using types::is_zero_md;
93 if (arg == DNNL_ARG_ATTR_OUTPUT_SCALES
94 && !attr()->output_scales_.defined())
95 return arg_usage_t::input;
96 if (arg & DNNL_ARG_ATTR_ZERO_POINTS) {
97 int zp_arg = arg & ~DNNL_ARG_ATTR_ZERO_POINTS;
98 if (!attr()->zero_points_.defined(zp_arg))
99 return arg_usage_t::input;
100 }
101 if (arg & DNNL_ARG_ATTR_SCALES) {
102 int scale_arg = arg & ~DNNL_ARG_ATTR_SCALES;
103 if (!attr()->scales_.get(scale_arg).defined())
104 return arg_usage_t::input;
105 }
106 if ((arg == (DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0))
107 && !attr()->scales_.get(DNNL_ARG_SRC_0).defined())
108 return arg_usage_t::input;
109 if ((arg == (DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1))
110 && !attr()->scales_.get(DNNL_ARG_SRC_1).defined())
111 return arg_usage_t::input;
112 if (arg == DNNL_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
113 return arg_usage_t::output;
114 for (int idx = 0; idx < attr()->post_ops_.len(); ++idx) {
115 using namespace primitive_kind;
116 if (post_op_has_proper_input(
117 attr(), binary, idx, arg, DNNL_ARG_SRC_1)
118 || post_op_has_proper_input(
119 attr(), prelu, idx, arg, DNNL_ARG_WEIGHTS))
120 return arg_usage_t::input;
121 }
122
123 return arg_usage_t::unused;
124 }
125
126 virtual const memory_desc_t *arg_md(int arg) const {
127 // Separate binary post-ops sections due to inability to express inside
128 // switch statement.
129 if (arg >= DNNL_ARG_ATTR_MULTIPLE_POST_OP(0)
130 && arg < DNNL_ARG_ATTR_MULTIPLE_POST_OP(
131 post_ops_t::post_ops_limit)) {
132 const auto &po = attr()->post_ops_;
133 for (int idx = 0; idx < po.len(); ++idx) {
134 if (arg
135 != (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
136 | DNNL_ARG_SRC_1))
137 continue;
138
139 return &po.entry_[idx].binary.src1_desc;
140 }
141 }
142
143 switch (arg) {
144 case DNNL_ARG_WORKSPACE: return workspace_md(0);
145 case DNNL_ARG_SCRATCHPAD: return scratchpad_md(0);
146 default: return &glob_zero_md;
147 }
148 }
149
150#define DECLARE_MD_STUB(stub) \
151 virtual const memory_desc_t *stub(int idx = 0) const { \
152 return &glob_zero_md; \
153 }
154
155 DECLARE_MD_STUB(input_md);
156 DECLARE_MD_STUB(output_md);
157 DECLARE_MD_STUB(src_md);
158 DECLARE_MD_STUB(diff_src_md);
159 DECLARE_MD_STUB(dst_md);
160 DECLARE_MD_STUB(diff_dst_md);
161 DECLARE_MD_STUB(weights_md);
162 DECLARE_MD_STUB(diff_weights_md);
163 DECLARE_MD_STUB(workspace_md);
164#undef DECLARE_MD_STUB
165
166 const memory_desc_t *scratchpad_md(int idx = 0) const {
167 return idx == 0 ? &scratchpad_md_ : &glob_zero_md;
168 }
169
170 void init_scratchpad_md() {
171 auto size = scratchpad_size(scratchpad_mode::user);
172 dims_t dims = {size};
173 memory_desc_init_by_tag(
174 scratchpad_md_, size ? 1 : 0, dims, data_type::u8, dnnl_x);
175 }
176
177 /** returns the scratchpad size for the given scratchpad mode. */
178 dim_t scratchpad_size(scratchpad_mode_t mode) const {
179 if (mode != attr_.scratchpad_mode_) return 0;
180 return scratchpad_registry().size();
181 }
182
183 virtual status_t query(query_t what, int idx, void *result) const {
184 auto safe_ret_md = [&](const memory_desc_t *_) {
185 if (_ == nullptr) return status::not_required;
186 *(const memory_desc_t **)result = _;
187 return status::success;
188 };
189
190 switch (what) {
191 case query::primitive_kind:
192 *(primitive_kind_t *)result = kind();
193 break;
194
195 case query::memory_consumption_s64:
196 *(dim_t *)result = scratchpad_size(scratchpad_mode::library);
197 break;
198
199 case query::exec_arg_md: return safe_ret_md(arg_md(idx));
200 case query::src_md: return safe_ret_md(src_md(idx));
201 case query::diff_src_md: return safe_ret_md(diff_src_md(idx));
202 case query::dst_md: return safe_ret_md(dst_md(idx));
203 case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx));
204 case query::weights_md: return safe_ret_md(weights_md(idx));
205 case query::diff_weights_md:
206 return safe_ret_md(diff_weights_md(idx));
207 case query::workspace_md:
208 if (idx != 0) return status::invalid_arguments;
209 return safe_ret_md(workspace_md(idx));
210 case query::scratchpad_md:
211 if (idx != 0) return status::invalid_arguments;
212 return safe_ret_md(scratchpad_md(idx));
213
214 case query::num_of_inputs_s32: *(int *)result = n_inputs(); break;
215 case query::num_of_outputs_s32: *(int *)result = n_outputs(); break;
216
217 case query::impl_info_str: *(const char **)result = name(); break;
218
219 default: return status::unimplemented;
220 }
221 return status::success;
222 }
223
224 virtual int n_inputs() const { return 0; }
225 virtual int n_outputs() const { return 0; }
226 int n_binary_po_inputs() const {
227 return po_inputs(attr()->post_ops_, primitive_kind::binary);
228 }
229
230 int n_prelu_po_inputs() const {
231 return po_inputs(attr()->post_ops_, primitive_kind::prelu);
232 }
233 // The `hint_mds(bool is_hint)` returns a vector of memory descriptors
234 // that might affect the equality of primitive descriptors for backward pass.
235 //
236 // This function is used for creating a key to fetch primitive or primitive
237 // descriptor from cache.
238 //
239 // 1. When creating a primitive descriptor for backward pass there may be
240 // a forward primitive descriptor hint that can be used to obtain the
241 // memory descriptors. In this case the `is_hint` argument must be `true`.
242 // 2. When creating a primitive this function is called for a primitive
243 // descriptor that can be either forward or backward. In this case
244 // the `is_hint` argument must be `false`.
245 // - For forward it will return an empty vector.
246 // - For backward it will return a vector of memory descriptors if
247 // the implementation depends on a forward primitive descriptor.
248 //
249 // The current cases are:
250 // - pooling
251 // - shuffle
252 //
253 // Later the list of primitives can be extended. For instance, currently
254 // there is no convolution on the list because nthrs + op_desc
255 // (even with format=`any`) + attributes fully define a particular
256 // implementation.
257 virtual std::vector<memory_desc_t> hint_mds(bool is_hint) const {
258 UNUSED(is_hint);
259 return {};
260 }
261
262 virtual status_t create_primitive(
263 std::pair<std::shared_ptr<primitive_t>, bool> &primitive,
264 engine_t *engine, const cache_blob_t &cache_blob) const = 0;
265
266 // This is a proxy interface that is used for creating nested primitives.
267 // It ignores the bool value that indicates whether the requested primitive
268 // was taken from cache.
269 status_t create_primitive(std::shared_ptr<primitive_t> &primitive,
270 engine_t *engine,
271 const cache_blob_t &cache_blob = cache_blob_t()) const {
272 std::pair<std::shared_ptr<primitive_t>, bool> p;
273 CHECK(create_primitive(p, engine, cache_blob));
274 primitive = p.first;
275 return status::success;
276 }
277
278 virtual const char *name() const = 0;
279
280 int pd_iterator_offset() const { return pd_iterator_offset_; }
281
282protected:
283 primitive_attr_t attr_;
284 primitive_kind_t kind_;
285 int pd_iterator_offset_;
286
287 memory_desc_t scratchpad_md_;
288
289 mutable pd_info_t info_;
290 mutable cache_blob_id_t cache_blob_id_;
291
292 memory_tracking::registry_t scratchpad_registry_;
293
294protected:
295 void init_pd_iterator_offset(int offset) { pd_iterator_offset_ = offset; }
296
297 /** compares ws between fwd_pd and this (make sense to use for bwd_pd)
298 * Expectation: this already set workspace, and this workspace should
299 * exactly match the one from fwd_pd */
300 bool compare_ws(const primitive_desc_t *fwd_pd) const {
301 if (!workspace_md()) return true; // the impl lives fine w/o workspace
302 return fwd_pd && fwd_pd->workspace_md()
303 && *fwd_pd->workspace_md() == *workspace_md();
304 }
305
306 primitive_desc_t &operator=(const primitive_desc_t &other) = delete;
307
308 /* static magic */
309
310 template <typename pd_t>
311 static status_t create(primitive_desc_t **pd, const op_desc_t *adesc,
312 const primitive_attr_t *attr, engine_t *engine,
313 const primitive_desc_t *hint_fwd) {
314 using namespace dnnl::impl::status;
315 using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
316 if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
317 assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
318 auto hint
319 = reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
320 auto _pd = new pd_t((const pd_op_desc_t *)adesc, attr, hint);
321 if (_pd == nullptr) return out_of_memory;
322 if (!_pd->is_initialized()) {
323 delete _pd;
324 return out_of_memory;
325 }
326 if (_pd->init(engine) != success) {
327 delete _pd;
328 return unimplemented;
329 }
330
331 _pd->init_scratchpad_md();
332 *pd = _pd;
333 return success;
334 }
335
336 friend struct dnnl::impl::impl_list_item_t;
337};
338
339} // namespace impl
340} // namespace dnnl
341
342#define DECLARE_COMMON_PD_t(impl_name, impl_type, use_global_scratchpad) \
343 pd_t *clone() const override { \
344 auto new_pd = utils::make_unique<pd_t>(*this); \
345 if (!new_pd->is_initialized()) return nullptr; \
346 return new_pd.release(); \
347 } \
348 status_t create_primitive( \
349 std::pair<std::shared_ptr<primitive_t>, bool> &primitive, \
350 engine_t *engine, const cache_blob_t &cache_blob) const override { \
351 return primitive_t::create_primitive_common<impl_type, pd_t>( \
352 primitive, this, engine, use_global_scratchpad, cache_blob); \
353 } \
354 const char *name() const override { return impl_name; } \
355 template <typename pd_t> \
356 friend status_t primitive_desc_t::create(primitive_desc_t **pd, \
357 const op_desc_t *adesc, const primitive_attr_t *attr, \
358 engine_t *engine, const primitive_desc_t *hint_fwd);
359
360#define DECLARE_COMMON_PD_T_USE_GLOBAL_SCRATCHPAD(impl_name, impl_type) \
361 DECLARE_COMMON_PD_t(impl_name, impl_type, true)
362
363#define DECLARE_COMMON_PD_T_(impl_name, impl_type) \
364 DECLARE_COMMON_PD_t(impl_name, impl_type, false)
365
366#define DECLARE_COMMON_PD_T(impl_name, impl_type, ...) \
367 DECLARE_COMMON_PD_T_##__VA_ARGS__(impl_name, impl_type)
368
369#endif
370
371// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
372