1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_GEMM_INNER_PRODUCT_HPP |
18 | #define GPU_OCL_GEMM_INNER_PRODUCT_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <string> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/gemm_utils.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/primitive_desc_iterator.hpp" |
27 | #include "gpu/compute/compute.hpp" |
28 | #include "gpu/gemm/gpu_gemm.hpp" |
29 | #include "gpu/gpu_inner_product_pd.hpp" |
30 | #include "gpu/gpu_primitive.hpp" |
31 | #include "gpu/gpu_primitive_attr.hpp" |
32 | #include "gpu/gpu_reduction_pd.hpp" |
33 | #include "gpu/gpu_resource.hpp" |
34 | #include "gpu/primitive_conf.hpp" |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace gpu { |
39 | namespace ocl { |
40 | |
41 | struct gemm_inner_product_fwd_t : public gpu_primitive_t { |
42 | using gpu_primitive_t::gpu_primitive_t; |
43 | struct pd_t : public gpu_inner_product_fwd_pd_t { |
44 | pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr, |
45 | const inner_product_fwd_pd_t *hint_fwd_pd) |
46 | : gpu_inner_product_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
47 | pd_t(const pd_t &rhs) = default; |
48 | ~pd_t() = default; |
49 | |
50 | DECLARE_COMMON_PD_T(gemm_pd_->name(), gemm_inner_product_fwd_t); |
51 | |
52 | status_t init(engine_t *engine) { |
53 | using namespace data_type; |
54 | using namespace prop_kind; |
55 | using namespace data_type; |
56 | assert(engine->kind() == engine_kind::gpu); |
57 | |
58 | const auto attr_skip_mask |
59 | = primitive_attr_t::skip_mask_t::scales_runtime |
60 | | primitive_attr_t::skip_mask_t::post_ops; |
61 | |
62 | bool ok = is_fwd() && set_default_params() == status::success |
63 | && !has_zero_dim_memory() |
64 | && dense_consistency_check(src_md(), weights_md(), dst_md()) |
65 | && dense_gemm_consistency_check( |
66 | src_md(), weights_md(), dst_md()) |
67 | && attr()->has_default_values(attr_skip_mask) |
68 | && post_ops_with_binary_ok( |
69 | attr(), desc()->dst_desc.data_type) |
70 | && attr_.set_default_formats(dst_md(0)) == status::success; |
71 | if (!ok) return status::unimplemented; |
72 | |
73 | attr_info_ = attr_info_t::create(attr()); |
74 | |
75 | memory_desc_t a_md, b_md, c_md; |
76 | init_2d_desc(&a_md, src_md()); |
77 | init_2d_desc(&b_md, weights_md(), true); |
78 | init_2d_desc(&c_md, dst_md()); |
79 | primitive_attr_t gemm_attr = *attr(); |
80 | auto wei_mask = gemm_attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; |
81 | if (wei_mask == 1) //transpose mask for gemm |
82 | gemm_attr.scales_.set(DNNL_ARG_WEIGHTS, 1 << (b_md.ndims - 1)); |
83 | else if (wei_mask != 0) |
84 | return status::unimplemented; |
85 | bool gemm_ok = status::success |
86 | == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md, |
87 | weights_md(1), desc()->accum_data_type, &gemm_attr, |
88 | true); |
89 | if (!gemm_ok) return status::unimplemented; |
90 | |
91 | init_scratchpad(); |
92 | |
93 | return status::success; |
94 | } |
95 | |
96 | attr_info_t attr_info_ = {}; |
97 | std::shared_ptr<primitive_desc_t> gemm_pd_; |
98 | |
99 | private: |
100 | void init_scratchpad() { |
101 | auto scratchpad = scratchpad_registry().registrar(); |
102 | scratchpad.book(memory_tracking::names::key_nested, |
103 | gemm_pd_->scratchpad_registry()); |
104 | } |
105 | }; |
106 | |
107 | status_t init(engine_t *engine) override { |
108 | return create_nested_primitive(gemm_, pd()->gemm_pd_, engine); |
109 | } |
110 | |
111 | status_t execute(const exec_ctx_t &ctx) const override { |
112 | return execute_forward(ctx); |
113 | } |
114 | |
115 | private: |
116 | status_t execute_forward(const exec_ctx_t &ctx) const; |
117 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
118 | |
119 | std::shared_ptr<primitive_t> gemm_; |
120 | }; |
121 | |
122 | struct gemm_inner_product_bwd_data_t : public gpu_primitive_t { |
123 | using gpu_primitive_t::gpu_primitive_t; |
124 | struct pd_t : public gpu_inner_product_bwd_data_pd_t { |
125 | pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr, |
126 | const inner_product_fwd_pd_t *hint_fwd_pd) |
127 | : gpu_inner_product_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {} |
128 | pd_t(const pd_t &rhs) = default; |
129 | ~pd_t() = default; |
130 | |
131 | DECLARE_COMMON_PD_T(gemm_pd_->name(), gemm_inner_product_bwd_data_t); |
132 | |
133 | status_t init(engine_t *engine) { |
134 | using namespace prop_kind; |
135 | using namespace data_type; |
136 | |
137 | assert(engine->kind() == engine_kind::gpu); |
138 | |
139 | bool ok = this->desc()->prop_kind == backward_data |
140 | && set_default_params() == status::success |
141 | && !has_zero_dim_memory() |
142 | && utils::one_of(weights_md()->data_type, f32, bf16) |
143 | && utils::one_of(diff_src_md()->data_type, f32, bf16) |
144 | && utils::one_of(diff_dst_md()->data_type, f32, bf16) |
145 | && attr()->has_default_values() |
146 | && dense_consistency_check( |
147 | diff_src_md(), weights_md(), diff_dst_md()) |
148 | && dense_gemm_consistency_check( |
149 | diff_src_md(), weights_md(), diff_dst_md()); |
150 | if (!ok) return status::unimplemented; |
151 | |
152 | memory_desc_t a_md, b_md, c_md; |
153 | init_2d_desc(&a_md, diff_dst_md()); |
154 | init_2d_desc(&b_md, weights_md()); |
155 | init_2d_desc(&c_md, diff_src_md()); |
156 | |
157 | bool gemm_ok = status::success |
158 | == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md, |
159 | &glob_zero_md, desc()->accum_data_type, attr(), |
160 | true); |
161 | if (!gemm_ok) return status::unimplemented; |
162 | init_scratchpad(); |
163 | |
164 | return status::success; |
165 | } |
166 | |
167 | std::shared_ptr<primitive_desc_t> gemm_pd_; |
168 | |
169 | private: |
170 | void init_scratchpad() { |
171 | auto scratchpad = scratchpad_registry().registrar(); |
172 | scratchpad.book(memory_tracking::names::key_nested, |
173 | gemm_pd_->scratchpad_registry()); |
174 | } |
175 | }; |
176 | |
177 | status_t init(engine_t *engine) override { |
178 | return create_nested_primitive(gemm_, pd()->gemm_pd_, engine); |
179 | } |
180 | |
181 | status_t execute(const exec_ctx_t &ctx) const override { |
182 | return execute_backward_data(ctx); |
183 | } |
184 | |
185 | private: |
186 | status_t execute_backward_data(const exec_ctx_t &ctx) const; |
187 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
188 | |
189 | std::shared_ptr<primitive_t> gemm_; |
190 | }; |
191 | |
192 | struct gemm_inner_product_bwd_weights_t : public gpu_primitive_t { |
193 | using gpu_primitive_t::gpu_primitive_t; |
194 | using gpu_ip_bwd_weights_pd_t = gpu_inner_product_bwd_weights_pd_t; |
195 | struct pd_t : public gpu_ip_bwd_weights_pd_t { |
196 | pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr, |
197 | const inner_product_fwd_pd_t *hint_fwd_pd) |
198 | : gpu_ip_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {} |
199 | pd_t(const pd_t &rhs) = default; |
200 | |
201 | ~pd_t() = default; |
202 | |
203 | DECLARE_COMMON_PD_T(gemm_pd_->name(), gemm_inner_product_bwd_weights_t); |
204 | |
205 | status_t init(engine_t *engine) { |
206 | using namespace prop_kind; |
207 | using namespace data_type; |
208 | |
209 | assert(engine->kind() == engine_kind::gpu); |
210 | |
211 | bool ok = this->desc()->prop_kind == backward_weights |
212 | && set_default_params() == status::success |
213 | && !has_zero_dim_memory() |
214 | && utils::one_of(diff_weights_md()->data_type, f32, bf16) |
215 | && utils::one_of(src_md()->data_type, f32, bf16) |
216 | && utils::one_of(diff_dst_md()->data_type, f32, bf16) |
217 | && attr()->has_default_values() |
218 | && dense_consistency_check( |
219 | src_md(), diff_weights_md(), diff_dst_md()) |
220 | && dense_gemm_consistency_check( |
221 | src_md(), diff_weights_md(), diff_dst_md()); |
222 | if (!ok) return status::unimplemented; |
223 | |
224 | memory_desc_t a_md, b_md, c_md; |
225 | if (wei_tr()) { |
226 | init_2d_desc(&a_md, src_md(), true); |
227 | init_2d_desc(&b_md, diff_dst_md()); |
228 | init_2d_desc(&c_md, diff_weights_md(), true); |
229 | } else { |
230 | init_2d_desc(&a_md, diff_dst_md(), true); |
231 | init_2d_desc(&b_md, src_md()); |
232 | init_2d_desc(&c_md, diff_weights_md()); |
233 | } |
234 | bool gemm_ok = false; |
235 | auto reduce_bias = sum_ab::sum_none; |
236 | if (with_bias()) |
237 | reduce_bias = wei_tr() ? sum_ab::sum_b_col : sum_ab::sum_a_row; |
238 | gemm_ok = status::success |
239 | == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md, |
240 | &glob_zero_md, desc()->accum_data_type, attr(), |
241 | true, reduce_bias, |
242 | desc()->diff_bias_desc.data_type); |
243 | |
244 | //fused bias reduction not supported, apply in separate kernel |
245 | if (with_bias() && !gemm_ok) { |
246 | gemm_ok = status::success |
247 | == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md, |
248 | &glob_zero_md, desc()->accum_data_type, attr()); |
249 | if (!gemm_ok) return status::unimplemented; |
250 | memory_desc_t reduction_dst_md, reduction_bias_md; |
251 | //Set ndims to 3 in order to explicitly specify blocked format |
252 | //so that it will go to optimized reduction implementation. |
253 | reduction_bias_md.ndims = 3; |
254 | reduction_bias_md.dims[0] = 1; |
255 | reduction_bias_md.dims[1] = diff_bias_md_.dims[0]; |
256 | reduction_bias_md.dims[2] = 1; |
257 | bool use_blocked = OC() % 16 == 0; |
258 | CHECK(memory_desc_init_by_tag(reduction_bias_md, |
259 | reduction_bias_md.ndims, reduction_bias_md.dims, |
260 | diff_bias_md_.data_type, |
261 | use_blocked ? format_tag::aBc16b : format_tag::abc)); |
262 | reduction_dst_md = *diff_dst_md(); |
263 | reduction_dst_md.ndims = 3; |
264 | reduction_dst_md.dims[2] = 1; |
265 | CHECK(memory_desc_init_by_tag(reduction_dst_md, |
266 | reduction_dst_md.ndims, reduction_dst_md.dims, |
267 | diff_dst_md_.data_type, |
268 | use_blocked ? format_tag::aBc16b : format_tag::abc)); |
269 | reduction_desc_t reduction_d; |
270 | CHECK(reduction_desc_init(&reduction_d, |
271 | dnnl::impl::alg_kind::reduction_sum, &reduction_dst_md, |
272 | &reduction_bias_md, 0.0f, 0.0f)); |
273 | primitive_attr_t reduction_attr; |
274 | int threads_per_eu; |
275 | auto status |
276 | = gemm_pd_->query(query::preferred_gpu_threads_per_eu, |
277 | 0, &threads_per_eu); |
278 | if (status == status::success) |
279 | reduction_attr.set_gpu_attr( |
280 | gpu_primitive_attr_t(threads_per_eu)); |
281 | primitive_desc_iterator_t it(engine, (op_desc_t *)&reduction_d, |
282 | &reduction_attr, nullptr); |
283 | if (!it.is_initialized()) return status::out_of_memory; |
284 | reduction_pd_ = *(++it); |
285 | if (!reduction_pd_) return status::unimplemented; |
286 | } |
287 | if (!gemm_ok) return status::unimplemented; |
288 | init_scratchpad(); |
289 | return status::success; |
290 | } |
291 | |
292 | bool wei_tr() const { |
293 | const auto &wmd = *this->diff_weights_md(); |
294 | return wmd.format_desc.blocking.strides[0] == 1; |
295 | } |
296 | |
297 | std::shared_ptr<primitive_desc_t> gemm_pd_; |
298 | std::shared_ptr<primitive_desc_t> reduction_pd_; |
299 | |
300 | private: |
301 | void init_scratchpad() { |
302 | auto scratchpad = scratchpad_registry().registrar(); |
303 | scratchpad.book(memory_tracking::names::key_nested_multiple, |
304 | gemm_pd_->scratchpad_registry()); |
305 | if (with_bias() && reduction_pd_) |
306 | scratchpad.book(memory_tracking::names::key_nested_multiple + 1, |
307 | reduction_pd_->scratchpad_registry()); |
308 | } |
309 | }; |
310 | |
311 | status_t init(engine_t *engine) override { |
312 | CHECK(create_nested_primitive(gemm_, pd()->gemm_pd_, engine)); |
313 | if (pd()->with_bias() && pd()->reduction_pd_) |
314 | CHECK(create_nested_primitive( |
315 | reduction_, pd()->reduction_pd_, engine)); |
316 | return status::success; |
317 | } |
318 | |
319 | status_t execute(const exec_ctx_t &ctx) const override { |
320 | return execute_backward_weights(ctx); |
321 | } |
322 | |
323 | private: |
324 | status_t execute_backward_weights(const exec_ctx_t &ctx) const; |
325 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
326 | std::shared_ptr<primitive_t> gemm_; |
327 | std::shared_ptr<primitive_t> reduction_; |
328 | }; |
329 | |
330 | } // namespace ocl |
331 | } // namespace gpu |
332 | } // namespace impl |
333 | } // namespace dnnl |
334 | |
335 | #endif |
336 | |