1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEMM_INNER_PRODUCT_HPP
18#define GPU_OCL_GEMM_INNER_PRODUCT_HPP
19
20#include <assert.h>
21#include <string>
22
23#include "common/c_types_map.hpp"
24#include "common/gemm_utils.hpp"
25#include "common/primitive.hpp"
26#include "common/primitive_desc_iterator.hpp"
27#include "gpu/compute/compute.hpp"
28#include "gpu/gemm/gpu_gemm.hpp"
29#include "gpu/gpu_inner_product_pd.hpp"
30#include "gpu/gpu_primitive.hpp"
31#include "gpu/gpu_primitive_attr.hpp"
32#include "gpu/gpu_reduction_pd.hpp"
33#include "gpu/gpu_resource.hpp"
34#include "gpu/primitive_conf.hpp"
35
36namespace dnnl {
37namespace impl {
38namespace gpu {
39namespace ocl {
40
41struct gemm_inner_product_fwd_t : public gpu_primitive_t {
42 using gpu_primitive_t::gpu_primitive_t;
43 struct pd_t : public gpu_inner_product_fwd_pd_t {
44 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
45 const inner_product_fwd_pd_t *hint_fwd_pd)
46 : gpu_inner_product_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
47 pd_t(const pd_t &rhs) = default;
48 ~pd_t() = default;
49
50 DECLARE_COMMON_PD_T(gemm_pd_->name(), gemm_inner_product_fwd_t);
51
52 status_t init(engine_t *engine) {
53 using namespace data_type;
54 using namespace prop_kind;
55 using namespace data_type;
56 assert(engine->kind() == engine_kind::gpu);
57
58 const auto attr_skip_mask
59 = primitive_attr_t::skip_mask_t::scales_runtime
60 | primitive_attr_t::skip_mask_t::post_ops;
61
62 bool ok = is_fwd() && set_default_params() == status::success
63 && !has_zero_dim_memory()
64 && dense_consistency_check(src_md(), weights_md(), dst_md())
65 && dense_gemm_consistency_check(
66 src_md(), weights_md(), dst_md())
67 && attr()->has_default_values(attr_skip_mask)
68 && post_ops_with_binary_ok(
69 attr(), desc()->dst_desc.data_type)
70 && attr_.set_default_formats(dst_md(0)) == status::success;
71 if (!ok) return status::unimplemented;
72
73 attr_info_ = attr_info_t::create(attr());
74
75 memory_desc_t a_md, b_md, c_md;
76 init_2d_desc(&a_md, src_md());
77 init_2d_desc(&b_md, weights_md(), true);
78 init_2d_desc(&c_md, dst_md());
79 primitive_attr_t gemm_attr = *attr();
80 auto wei_mask = gemm_attr.scales_.get(DNNL_ARG_WEIGHTS).mask_;
81 if (wei_mask == 1) //transpose mask for gemm
82 gemm_attr.scales_.set(DNNL_ARG_WEIGHTS, 1 << (b_md.ndims - 1));
83 else if (wei_mask != 0)
84 return status::unimplemented;
85 bool gemm_ok = status::success
86 == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md,
87 weights_md(1), desc()->accum_data_type, &gemm_attr,
88 true);
89 if (!gemm_ok) return status::unimplemented;
90
91 init_scratchpad();
92
93 return status::success;
94 }
95
96 attr_info_t attr_info_ = {};
97 std::shared_ptr<primitive_desc_t> gemm_pd_;
98
99 private:
100 void init_scratchpad() {
101 auto scratchpad = scratchpad_registry().registrar();
102 scratchpad.book(memory_tracking::names::key_nested,
103 gemm_pd_->scratchpad_registry());
104 }
105 };
106
107 status_t init(engine_t *engine) override {
108 return create_nested_primitive(gemm_, pd()->gemm_pd_, engine);
109 }
110
111 status_t execute(const exec_ctx_t &ctx) const override {
112 return execute_forward(ctx);
113 }
114
115private:
116 status_t execute_forward(const exec_ctx_t &ctx) const;
117 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
118
119 std::shared_ptr<primitive_t> gemm_;
120};
121
122struct gemm_inner_product_bwd_data_t : public gpu_primitive_t {
123 using gpu_primitive_t::gpu_primitive_t;
124 struct pd_t : public gpu_inner_product_bwd_data_pd_t {
125 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
126 const inner_product_fwd_pd_t *hint_fwd_pd)
127 : gpu_inner_product_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
128 pd_t(const pd_t &rhs) = default;
129 ~pd_t() = default;
130
131 DECLARE_COMMON_PD_T(gemm_pd_->name(), gemm_inner_product_bwd_data_t);
132
133 status_t init(engine_t *engine) {
134 using namespace prop_kind;
135 using namespace data_type;
136
137 assert(engine->kind() == engine_kind::gpu);
138
139 bool ok = this->desc()->prop_kind == backward_data
140 && set_default_params() == status::success
141 && !has_zero_dim_memory()
142 && utils::one_of(weights_md()->data_type, f32, bf16)
143 && utils::one_of(diff_src_md()->data_type, f32, bf16)
144 && utils::one_of(diff_dst_md()->data_type, f32, bf16)
145 && attr()->has_default_values()
146 && dense_consistency_check(
147 diff_src_md(), weights_md(), diff_dst_md())
148 && dense_gemm_consistency_check(
149 diff_src_md(), weights_md(), diff_dst_md());
150 if (!ok) return status::unimplemented;
151
152 memory_desc_t a_md, b_md, c_md;
153 init_2d_desc(&a_md, diff_dst_md());
154 init_2d_desc(&b_md, weights_md());
155 init_2d_desc(&c_md, diff_src_md());
156
157 bool gemm_ok = status::success
158 == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md,
159 &glob_zero_md, desc()->accum_data_type, attr(),
160 true);
161 if (!gemm_ok) return status::unimplemented;
162 init_scratchpad();
163
164 return status::success;
165 }
166
167 std::shared_ptr<primitive_desc_t> gemm_pd_;
168
169 private:
170 void init_scratchpad() {
171 auto scratchpad = scratchpad_registry().registrar();
172 scratchpad.book(memory_tracking::names::key_nested,
173 gemm_pd_->scratchpad_registry());
174 }
175 };
176
177 status_t init(engine_t *engine) override {
178 return create_nested_primitive(gemm_, pd()->gemm_pd_, engine);
179 }
180
181 status_t execute(const exec_ctx_t &ctx) const override {
182 return execute_backward_data(ctx);
183 }
184
185private:
186 status_t execute_backward_data(const exec_ctx_t &ctx) const;
187 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
188
189 std::shared_ptr<primitive_t> gemm_;
190};
191
192struct gemm_inner_product_bwd_weights_t : public gpu_primitive_t {
193 using gpu_primitive_t::gpu_primitive_t;
194 using gpu_ip_bwd_weights_pd_t = gpu_inner_product_bwd_weights_pd_t;
195 struct pd_t : public gpu_ip_bwd_weights_pd_t {
196 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
197 const inner_product_fwd_pd_t *hint_fwd_pd)
198 : gpu_ip_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
199 pd_t(const pd_t &rhs) = default;
200
201 ~pd_t() = default;
202
203 DECLARE_COMMON_PD_T(gemm_pd_->name(), gemm_inner_product_bwd_weights_t);
204
205 status_t init(engine_t *engine) {
206 using namespace prop_kind;
207 using namespace data_type;
208
209 assert(engine->kind() == engine_kind::gpu);
210
211 bool ok = this->desc()->prop_kind == backward_weights
212 && set_default_params() == status::success
213 && !has_zero_dim_memory()
214 && utils::one_of(diff_weights_md()->data_type, f32, bf16)
215 && utils::one_of(src_md()->data_type, f32, bf16)
216 && utils::one_of(diff_dst_md()->data_type, f32, bf16)
217 && attr()->has_default_values()
218 && dense_consistency_check(
219 src_md(), diff_weights_md(), diff_dst_md())
220 && dense_gemm_consistency_check(
221 src_md(), diff_weights_md(), diff_dst_md());
222 if (!ok) return status::unimplemented;
223
224 memory_desc_t a_md, b_md, c_md;
225 if (wei_tr()) {
226 init_2d_desc(&a_md, src_md(), true);
227 init_2d_desc(&b_md, diff_dst_md());
228 init_2d_desc(&c_md, diff_weights_md(), true);
229 } else {
230 init_2d_desc(&a_md, diff_dst_md(), true);
231 init_2d_desc(&b_md, src_md());
232 init_2d_desc(&c_md, diff_weights_md());
233 }
234 bool gemm_ok = false;
235 auto reduce_bias = sum_ab::sum_none;
236 if (with_bias())
237 reduce_bias = wei_tr() ? sum_ab::sum_b_col : sum_ab::sum_a_row;
238 gemm_ok = status::success
239 == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md,
240 &glob_zero_md, desc()->accum_data_type, attr(),
241 true, reduce_bias,
242 desc()->diff_bias_desc.data_type);
243
244 //fused bias reduction not supported, apply in separate kernel
245 if (with_bias() && !gemm_ok) {
246 gemm_ok = status::success
247 == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md,
248 &glob_zero_md, desc()->accum_data_type, attr());
249 if (!gemm_ok) return status::unimplemented;
250 memory_desc_t reduction_dst_md, reduction_bias_md;
251 //Set ndims to 3 in order to explicitly specify blocked format
252 //so that it will go to optimized reduction implementation.
253 reduction_bias_md.ndims = 3;
254 reduction_bias_md.dims[0] = 1;
255 reduction_bias_md.dims[1] = diff_bias_md_.dims[0];
256 reduction_bias_md.dims[2] = 1;
257 bool use_blocked = OC() % 16 == 0;
258 CHECK(memory_desc_init_by_tag(reduction_bias_md,
259 reduction_bias_md.ndims, reduction_bias_md.dims,
260 diff_bias_md_.data_type,
261 use_blocked ? format_tag::aBc16b : format_tag::abc));
262 reduction_dst_md = *diff_dst_md();
263 reduction_dst_md.ndims = 3;
264 reduction_dst_md.dims[2] = 1;
265 CHECK(memory_desc_init_by_tag(reduction_dst_md,
266 reduction_dst_md.ndims, reduction_dst_md.dims,
267 diff_dst_md_.data_type,
268 use_blocked ? format_tag::aBc16b : format_tag::abc));
269 reduction_desc_t reduction_d;
270 CHECK(reduction_desc_init(&reduction_d,
271 dnnl::impl::alg_kind::reduction_sum, &reduction_dst_md,
272 &reduction_bias_md, 0.0f, 0.0f));
273 primitive_attr_t reduction_attr;
274 int threads_per_eu;
275 auto status
276 = gemm_pd_->query(query::preferred_gpu_threads_per_eu,
277 0, &threads_per_eu);
278 if (status == status::success)
279 reduction_attr.set_gpu_attr(
280 gpu_primitive_attr_t(threads_per_eu));
281 primitive_desc_iterator_t it(engine, (op_desc_t *)&reduction_d,
282 &reduction_attr, nullptr);
283 if (!it.is_initialized()) return status::out_of_memory;
284 reduction_pd_ = *(++it);
285 if (!reduction_pd_) return status::unimplemented;
286 }
287 if (!gemm_ok) return status::unimplemented;
288 init_scratchpad();
289 return status::success;
290 }
291
292 bool wei_tr() const {
293 const auto &wmd = *this->diff_weights_md();
294 return wmd.format_desc.blocking.strides[0] == 1;
295 }
296
297 std::shared_ptr<primitive_desc_t> gemm_pd_;
298 std::shared_ptr<primitive_desc_t> reduction_pd_;
299
300 private:
301 void init_scratchpad() {
302 auto scratchpad = scratchpad_registry().registrar();
303 scratchpad.book(memory_tracking::names::key_nested_multiple,
304 gemm_pd_->scratchpad_registry());
305 if (with_bias() && reduction_pd_)
306 scratchpad.book(memory_tracking::names::key_nested_multiple + 1,
307 reduction_pd_->scratchpad_registry());
308 }
309 };
310
311 status_t init(engine_t *engine) override {
312 CHECK(create_nested_primitive(gemm_, pd()->gemm_pd_, engine));
313 if (pd()->with_bias() && pd()->reduction_pd_)
314 CHECK(create_nested_primitive(
315 reduction_, pd()->reduction_pd_, engine));
316 return status::success;
317 }
318
319 status_t execute(const exec_ctx_t &ctx) const override {
320 return execute_backward_weights(ctx);
321 }
322
323private:
324 status_t execute_backward_weights(const exec_ctx_t &ctx) const;
325 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
326 std::shared_ptr<primitive_t> gemm_;
327 std::shared_ptr<primitive_t> reduction_;
328};
329
330} // namespace ocl
331} // namespace gpu
332} // namespace impl
333} // namespace dnnl
334
335#endif
336