1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gemm_inner_product.hpp"
18
19#include "gpu/gemm/gpu_gemm.hpp"
20#include "gpu/ocl/ocl_stream.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace ocl {
26
27status_t gemm_inner_product_fwd_t::execute_forward(
28 const exec_ctx_t &ctx) const {
29 using namespace memory_tracking::names;
30
31 gemm_exec_args_t gemm_args;
32 gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC);
33 gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
34 gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DST);
35 gemm_args.bias = &CTX_IN_STORAGE(DNNL_ARG_BIAS);
36 memory_storage_t *a0
37 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
38
39 memory_storage_t *b0
40 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
41
42 memory_storage_t *c0
43 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
44
45 gemm_args.a_zero_point = b0;
46 gemm_args.b_zero_point = a0;
47 gemm_args.c_zero_point = c0;
48 gemm_args.a_scales
49 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
50 gemm_args.b_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
51 gemm_args.c_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
52 gemm_args.exec_args = ctx.args();
53
54 gemm_exec_ctx_t gemm_ctx(ctx, gemm_args);
55
56 nested_scratchpad_t ns(ctx, key_nested, gemm_);
57 gemm_ctx.set_scratchpad_grantor(ns.grantor());
58
59 status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx);
60
61 if (gemm_exec_status != status::success) return gemm_exec_status;
62
63 return status::success;
64}
65
66status_t gemm_inner_product_bwd_data_t::execute_backward_data(
67 const exec_ctx_t &ctx) const {
68 using namespace memory_tracking::names;
69
70 gemm_exec_args_t gemm_args;
71 gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
72 gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
73 gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
74
75 gemm_exec_ctx_t gemm_ctx(ctx, gemm_args);
76
77 nested_scratchpad_t ns(ctx, key_nested, gemm_);
78 gemm_ctx.set_scratchpad_grantor(ns.grantor());
79
80 status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx);
81 if (gemm_exec_status != status::success) return gemm_exec_status;
82
83 return status::success;
84}
85
86status_t gemm_inner_product_bwd_weights_t::execute_backward_weights(
87 const exec_ctx_t &ctx) const {
88 using namespace memory_tracking::names;
89
90 gemm_exec_args_t gemm_args;
91 if (pd()->wei_tr()) {
92 gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC);
93 gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
94 } else {
95 gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
96 gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_SRC);
97 }
98 gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DIFF_WEIGHTS);
99 if (!pd()->reduction_pd_)
100 gemm_args.sum_ab = &CTX_OUT_STORAGE(DNNL_ARG_DIFF_BIAS);
101 gemm_exec_ctx_t gemm_ctx(ctx, gemm_args);
102
103 nested_scratchpad_t ns(ctx, key_nested_multiple, gemm_);
104 gemm_ctx.set_scratchpad_grantor(ns.grantor());
105
106 status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx);
107 if (gemm_exec_status != status::success) return gemm_exec_status;
108
109 if (pd()->with_bias() && pd()->reduction_pd_) {
110 auto diff_dst = ctx.input(DNNL_ARG_DIFF_DST);
111 auto diff_bia = ctx.output(DNNL_ARG_DIFF_BIAS);
112 exec_args_t r_args;
113 r_args[DNNL_ARG_SRC] = memory_arg_t {diff_dst, true};
114 r_args[DNNL_ARG_DST] = memory_arg_t {diff_bia, false};
115 exec_ctx_t r_ctx(ctx, std::move(r_args));
116 nested_scratchpad_t ns(ctx, key_nested_multiple + 1, reduction_);
117 r_ctx.set_scratchpad_grantor(ns.grantor());
118 reduction_->execute(r_ctx);
119 }
120
121 return status::success;
122}
123
124} // namespace ocl
125} // namespace gpu
126} // namespace impl
127} // namespace dnnl
128