1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_PRIMITIVE_HPP
18#define COMMON_PRIMITIVE_HPP
19
20#include <assert.h>
21#include <atomic>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "c_types_map.hpp"
26#include "cache_blob.hpp"
27#include "memory_storage.hpp"
28#include "memory_tracking.hpp"
29#include "primitive_desc.hpp"
30#include "primitive_exec_types.hpp"
31#include "rw_mutex.hpp"
32#include "scratchpad.hpp"
33
34#include <future>
35#include <type_traits>
36
37namespace dnnl {
38namespace impl {
39
40struct resource_mapper_t;
41// Primitive implementation
42struct primitive_t : public c_compatible {
43 using primitive_list_t = std::vector<const primitive_t *>;
44
45 primitive_t(const primitive_desc_t *pd) : pd_(pd->clone()) {}
46 virtual ~primitive_t() = default;
47
48 virtual status_t init(engine_t *engine) { return status::success; }
49
50 status_t init(engine_t *engine, bool use_global_scratchpad,
51 const cache_blob_t &cache_blob) {
52 cache_blob_ = cache_blob;
53 CHECK(init(engine));
54 CHECK(init_cached_resource(engine));
55 use_global_scratchpad_ = use_global_scratchpad;
56 // The `cache_blob_` is no longer needed after primitive creation.
57 cache_blob_ = cache_blob_t();
58 return status::success;
59 }
60
61 const std::shared_ptr<primitive_desc_t> &pd() const { return pd_; }
62 primitive_kind_t kind() const { return pd_->kind(); }
63 virtual status_t execute(const exec_ctx_t &ctx) const = 0;
64
65 virtual status_t get_cache_blob(
66 engine_t *engine, cache_blob_t &cache_blob) const {
67 assert(!"unexpected");
68 return status::runtime_error;
69 }
70
71 virtual status_t get_cache_blob_size(size_t *size) const {
72 assert(!"unexpected");
73 return status::runtime_error;
74 }
75
76 virtual status_t create_resource(
77 engine_t *engine, resource_mapper_t &mapper) const {
78 return status::success;
79 }
80
81 // Although this function is marked as `const` it changes primitive_t state.
82 // The only place where this function should be used is in:
83 // `init(engine_t *engine, bool use_global_scratchpad)` during primitive_t
84 // creation in `create_primitive_common`.
85 // The rationale behind marking it as `const` is to simplify enabling the
86 // primitive cache mode for storing compiled GPU kernels instead of
87 // binaries and to preserve the current primitive cache implementation.
88 //
89 // The main idea is to create a resource inside the primitive_t only once
90 // and cache it as part of primitive_t.
91 // TODO: The ultimate goal is to switch completely to caching compiled
92 // GPU kernels therefore the code will be thrown out once it's done.
93 virtual status_t init_cached_resource(engine_t *engine) const {
94 return status::success;
95 }
96
97 bool use_global_scratchpad() const { return use_global_scratchpad_; }
98 cache_blob_t cache_blob() const { return cache_blob_; }
99
100protected:
101 template <typename impl_type, typename pd_t>
102 static status_t create_primitive_common(
103 std::pair<std::shared_ptr<primitive_t>, bool> &primitive,
104 const pd_t *pd, engine_t *engine, bool use_global_scratchpad,
105 const cache_blob_t &cache_blob) {
106
107 auto &global_primitive_cache = primitive_cache();
108 primitive_hashing::key_t key(pd, engine);
109
110 std::promise<primitive_cache_t::cache_value_t> p_promise;
111 // Try to get the shared future from the cache, if it's missing then
112 // a shared future with no shared state is returned and the passed
113 // shared future is added, otherwise a valid shared future is returned
114 // and no insertion is performed.
115 auto p_future = global_primitive_cache.get_or_add(
116 key, p_promise.get_future());
117
118 bool is_from_cache = p_future.valid();
119
120 auto status = status::success;
121 std::shared_ptr<primitive_t> p;
122
123 if (is_from_cache) {
124 // The requested primitive is present in the cache or is being
125 // created by another thread.
126 p = p_future.get().primitive;
127 if (!p) return p_future.get().status;
128 } else {
129 // The requested primitive is NOT present in the cache therefore
130 // we have to create it and notify the waiting threads
131 // once the creation is done.
132 p = std::make_shared<impl_type>(pd);
133 status = p->init(engine, use_global_scratchpad, cache_blob);
134 if (status != status::success) {
135 // Communicate an error.
136 p_promise.set_value({nullptr, status});
137 // Remove the shared future from the cache because it's
138 // invalidated. An invalidated shared future is the one that
139 // stores a nullptr.
140 global_primitive_cache.remove_if_invalidated(key);
141 return status;
142 } else {
143 // Store the created primitive in the shared future and notify
144 // the waiting threads.
145 p_promise.set_value({p, status});
146
147 // The key_t contains pointers to op_desc and attr objects that
148 // reside in pd. When primitive_t is created it copies the pd
149 // and hence contains a copy.
150 // Since the created primitive_t is stored in the cache with
151 // the corresponding key, the key must contain pointers to
152 // op_desc and attr that reside in the coppied pd
153 // in the primitive_t.
154 // Therefore the pointers in the key, which has already been put
155 // into the cache, must be updated.
156 global_primitive_cache.update_entry(key, p->pd().get());
157 }
158 }
159 primitive = std::make_pair(p, is_from_cache);
160 return status;
161 }
162
163 std::shared_ptr<primitive_desc_t> pd_;
164 bool use_global_scratchpad_;
165 cache_blob_t cache_blob_;
166
167private:
168 primitive_t() = delete;
169 DNNL_DISALLOW_COPY_AND_ASSIGN(primitive_t);
170};
171
172// This is a helper class which is used for forwarding a scratchpad
173// from master primitive to the nested ones.
174struct nested_scratchpad_t {
175 nested_scratchpad_t(const exec_ctx_t &master_ctx, int key,
176 const std::shared_ptr<primitive_t> &nested_p);
177 const memory_tracking::grantor_t *grantor() const { return grantor_.get(); }
178
179 ~nested_scratchpad_t();
180
181 DNNL_DISALLOW_COPY_AND_ASSIGN(nested_scratchpad_t);
182
183private:
184 std::unique_ptr<memory_storage_t> scratchpad_mem_storage_;
185 std::unique_ptr<memory_tracking::grantor_t> grantor_;
186};
187
188} // namespace impl
189} // namespace dnnl
190
191#define ARG_TYPE(t) \
192 typename std::remove_cv<typename std::remove_pointer<t>::type>::type
193
194#define CTX_IN_MEM(type, arg) \
195 static_cast<const ARG_TYPE(type) *>(ctx.host_ptr(arg))
196
197// Returns destination memory which may not have been zero pad initialized.
198#define CTX_OUT_MEM(type, arg) static_cast<ARG_TYPE(type) *>(ctx.host_ptr(arg))
199
200// Returns destination memory which has been zero pad initialized. This macro
201// may result in a failure returned via the `status` input since zero pad
202// may fail.
203#define CTX_OUT_CLEAN_MEM(type, arg, status) \
204 static_cast<ARG_TYPE(type) *>(ctx.host_ptr(arg, true, &status))
205
206#endif
207