1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_PRIMITIVE_HPP |
18 | #define COMMON_PRIMITIVE_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <atomic> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "c_types_map.hpp" |
26 | #include "cache_blob.hpp" |
27 | #include "memory_storage.hpp" |
28 | #include "memory_tracking.hpp" |
29 | #include "primitive_desc.hpp" |
30 | #include "primitive_exec_types.hpp" |
31 | #include "rw_mutex.hpp" |
32 | #include "scratchpad.hpp" |
33 | |
34 | #include <future> |
35 | #include <type_traits> |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | |
40 | struct resource_mapper_t; |
41 | // Primitive implementation |
42 | struct primitive_t : public c_compatible { |
43 | using primitive_list_t = std::vector<const primitive_t *>; |
44 | |
45 | primitive_t(const primitive_desc_t *pd) : pd_(pd->clone()) {} |
46 | virtual ~primitive_t() = default; |
47 | |
48 | virtual status_t init(engine_t *engine) { return status::success; } |
49 | |
50 | status_t init(engine_t *engine, bool use_global_scratchpad, |
51 | const cache_blob_t &cache_blob) { |
52 | cache_blob_ = cache_blob; |
53 | CHECK(init(engine)); |
54 | CHECK(init_cached_resource(engine)); |
55 | use_global_scratchpad_ = use_global_scratchpad; |
56 | // The `cache_blob_` is no longer needed after primitive creation. |
57 | cache_blob_ = cache_blob_t(); |
58 | return status::success; |
59 | } |
60 | |
61 | const std::shared_ptr<primitive_desc_t> &pd() const { return pd_; } |
62 | primitive_kind_t kind() const { return pd_->kind(); } |
63 | virtual status_t execute(const exec_ctx_t &ctx) const = 0; |
64 | |
65 | virtual status_t get_cache_blob( |
66 | engine_t *engine, cache_blob_t &cache_blob) const { |
67 | assert(!"unexpected" ); |
68 | return status::runtime_error; |
69 | } |
70 | |
71 | virtual status_t get_cache_blob_size(size_t *size) const { |
72 | assert(!"unexpected" ); |
73 | return status::runtime_error; |
74 | } |
75 | |
76 | virtual status_t create_resource( |
77 | engine_t *engine, resource_mapper_t &mapper) const { |
78 | return status::success; |
79 | } |
80 | |
81 | // Although this function is marked as `const` it changes primitive_t state. |
82 | // The only place where this function should be used is in: |
83 | // `init(engine_t *engine, bool use_global_scratchpad)` during primitive_t |
84 | // creation in `create_primitive_common`. |
85 | // The rationale behind marking it as `const` is to simplify enabling the |
86 | // primitive cache mode for storing compiled GPU kernels instead of |
87 | // binaries and to preserve the current primitive cache implementation. |
88 | // |
89 | // The main idea is to create a resource inside the primitive_t only once |
90 | // and cache it as part of primitive_t. |
91 | // TODO: The ultimate goal is to switch completely to caching compiled |
92 | // GPU kernels therefore the code will be thrown out once it's done. |
93 | virtual status_t init_cached_resource(engine_t *engine) const { |
94 | return status::success; |
95 | } |
96 | |
97 | bool use_global_scratchpad() const { return use_global_scratchpad_; } |
98 | cache_blob_t cache_blob() const { return cache_blob_; } |
99 | |
100 | protected: |
101 | template <typename impl_type, typename pd_t> |
102 | static status_t create_primitive_common( |
103 | std::pair<std::shared_ptr<primitive_t>, bool> &primitive, |
104 | const pd_t *pd, engine_t *engine, bool use_global_scratchpad, |
105 | const cache_blob_t &cache_blob) { |
106 | |
107 | auto &global_primitive_cache = primitive_cache(); |
108 | primitive_hashing::key_t key(pd, engine); |
109 | |
110 | std::promise<primitive_cache_t::cache_value_t> p_promise; |
111 | // Try to get the shared future from the cache, if it's missing then |
112 | // a shared future with no shared state is returned and the passed |
113 | // shared future is added, otherwise a valid shared future is returned |
114 | // and no insertion is performed. |
115 | auto p_future = global_primitive_cache.get_or_add( |
116 | key, p_promise.get_future()); |
117 | |
118 | bool is_from_cache = p_future.valid(); |
119 | |
120 | auto status = status::success; |
121 | std::shared_ptr<primitive_t> p; |
122 | |
123 | if (is_from_cache) { |
124 | // The requested primitive is present in the cache or is being |
125 | // created by another thread. |
126 | p = p_future.get().primitive; |
127 | if (!p) return p_future.get().status; |
128 | } else { |
129 | // The requested primitive is NOT present in the cache therefore |
130 | // we have to create it and notify the waiting threads |
131 | // once the creation is done. |
132 | p = std::make_shared<impl_type>(pd); |
133 | status = p->init(engine, use_global_scratchpad, cache_blob); |
134 | if (status != status::success) { |
135 | // Communicate an error. |
136 | p_promise.set_value({nullptr, status}); |
137 | // Remove the shared future from the cache because it's |
138 | // invalidated. An invalidated shared future is the one that |
139 | // stores a nullptr. |
140 | global_primitive_cache.remove_if_invalidated(key); |
141 | return status; |
142 | } else { |
143 | // Store the created primitive in the shared future and notify |
144 | // the waiting threads. |
145 | p_promise.set_value({p, status}); |
146 | |
147 | // The key_t contains pointers to op_desc and attr objects that |
148 | // reside in pd. When primitive_t is created it copies the pd |
149 | // and hence contains a copy. |
150 | // Since the created primitive_t is stored in the cache with |
151 | // the corresponding key, the key must contain pointers to |
152 | // op_desc and attr that reside in the coppied pd |
153 | // in the primitive_t. |
154 | // Therefore the pointers in the key, which has already been put |
155 | // into the cache, must be updated. |
156 | global_primitive_cache.update_entry(key, p->pd().get()); |
157 | } |
158 | } |
159 | primitive = std::make_pair(p, is_from_cache); |
160 | return status; |
161 | } |
162 | |
163 | std::shared_ptr<primitive_desc_t> pd_; |
164 | bool use_global_scratchpad_; |
165 | cache_blob_t cache_blob_; |
166 | |
167 | private: |
168 | primitive_t() = delete; |
169 | DNNL_DISALLOW_COPY_AND_ASSIGN(primitive_t); |
170 | }; |
171 | |
172 | // This is a helper class which is used for forwarding a scratchpad |
173 | // from master primitive to the nested ones. |
174 | struct nested_scratchpad_t { |
175 | nested_scratchpad_t(const exec_ctx_t &master_ctx, int key, |
176 | const std::shared_ptr<primitive_t> &nested_p); |
177 | const memory_tracking::grantor_t *grantor() const { return grantor_.get(); } |
178 | |
179 | ~nested_scratchpad_t(); |
180 | |
181 | DNNL_DISALLOW_COPY_AND_ASSIGN(nested_scratchpad_t); |
182 | |
183 | private: |
184 | std::unique_ptr<memory_storage_t> scratchpad_mem_storage_; |
185 | std::unique_ptr<memory_tracking::grantor_t> grantor_; |
186 | }; |
187 | |
188 | } // namespace impl |
189 | } // namespace dnnl |
190 | |
191 | #define ARG_TYPE(t) \ |
192 | typename std::remove_cv<typename std::remove_pointer<t>::type>::type |
193 | |
194 | #define CTX_IN_MEM(type, arg) \ |
195 | static_cast<const ARG_TYPE(type) *>(ctx.host_ptr(arg)) |
196 | |
197 | // Returns destination memory which may not have been zero pad initialized. |
198 | #define CTX_OUT_MEM(type, arg) static_cast<ARG_TYPE(type) *>(ctx.host_ptr(arg)) |
199 | |
200 | // Returns destination memory which has been zero pad initialized. This macro |
201 | // may result in a failure returned via the `status` input since zero pad |
202 | // may fail. |
203 | #define CTX_OUT_CLEAN_MEM(type, arg, status) \ |
204 | static_cast<ARG_TYPE(type) *>(ctx.host_ptr(arg, true, &status)) |
205 | |
206 | #endif |
207 | |