1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gemm_inner_product.hpp" |
18 | |
19 | #include "gpu/gemm/gpu_gemm.hpp" |
20 | #include "gpu/ocl/ocl_stream.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace ocl { |
26 | |
27 | status_t gemm_inner_product_fwd_t::execute_forward( |
28 | const exec_ctx_t &ctx) const { |
29 | using namespace memory_tracking::names; |
30 | |
31 | gemm_exec_args_t gemm_args; |
32 | gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC); |
33 | gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
34 | gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DST); |
35 | gemm_args.bias = &CTX_IN_STORAGE(DNNL_ARG_BIAS); |
36 | memory_storage_t *a0 |
37 | = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); |
38 | |
39 | memory_storage_t *b0 |
40 | = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS); |
41 | |
42 | memory_storage_t *c0 |
43 | = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); |
44 | |
45 | gemm_args.a_zero_point = b0; |
46 | gemm_args.b_zero_point = a0; |
47 | gemm_args.c_zero_point = c0; |
48 | gemm_args.a_scales |
49 | = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); |
50 | gemm_args.b_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
51 | gemm_args.c_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
52 | gemm_args.exec_args = ctx.args(); |
53 | |
54 | gemm_exec_ctx_t gemm_ctx(ctx, gemm_args); |
55 | |
56 | nested_scratchpad_t ns(ctx, key_nested, gemm_); |
57 | gemm_ctx.set_scratchpad_grantor(ns.grantor()); |
58 | |
59 | status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx); |
60 | |
61 | if (gemm_exec_status != status::success) return gemm_exec_status; |
62 | |
63 | return status::success; |
64 | } |
65 | |
66 | status_t gemm_inner_product_bwd_data_t::execute_backward_data( |
67 | const exec_ctx_t &ctx) const { |
68 | using namespace memory_tracking::names; |
69 | |
70 | gemm_exec_args_t gemm_args; |
71 | gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
72 | gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
73 | gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
74 | |
75 | gemm_exec_ctx_t gemm_ctx(ctx, gemm_args); |
76 | |
77 | nested_scratchpad_t ns(ctx, key_nested, gemm_); |
78 | gemm_ctx.set_scratchpad_grantor(ns.grantor()); |
79 | |
80 | status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx); |
81 | if (gemm_exec_status != status::success) return gemm_exec_status; |
82 | |
83 | return status::success; |
84 | } |
85 | |
86 | status_t gemm_inner_product_bwd_weights_t::execute_backward_weights( |
87 | const exec_ctx_t &ctx) const { |
88 | using namespace memory_tracking::names; |
89 | |
90 | gemm_exec_args_t gemm_args; |
91 | if (pd()->wei_tr()) { |
92 | gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC); |
93 | gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
94 | } else { |
95 | gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
96 | gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_SRC); |
97 | } |
98 | gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DIFF_WEIGHTS); |
99 | if (!pd()->reduction_pd_) |
100 | gemm_args.sum_ab = &CTX_OUT_STORAGE(DNNL_ARG_DIFF_BIAS); |
101 | gemm_exec_ctx_t gemm_ctx(ctx, gemm_args); |
102 | |
103 | nested_scratchpad_t ns(ctx, key_nested_multiple, gemm_); |
104 | gemm_ctx.set_scratchpad_grantor(ns.grantor()); |
105 | |
106 | status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx); |
107 | if (gemm_exec_status != status::success) return gemm_exec_status; |
108 | |
109 | if (pd()->with_bias() && pd()->reduction_pd_) { |
110 | auto diff_dst = ctx.input(DNNL_ARG_DIFF_DST); |
111 | auto diff_bia = ctx.output(DNNL_ARG_DIFF_BIAS); |
112 | exec_args_t r_args; |
113 | r_args[DNNL_ARG_SRC] = memory_arg_t {diff_dst, true}; |
114 | r_args[DNNL_ARG_DST] = memory_arg_t {diff_bia, false}; |
115 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
116 | nested_scratchpad_t ns(ctx, key_nested_multiple + 1, reduction_); |
117 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
118 | reduction_->execute(r_ctx); |
119 | } |
120 | |
121 | return status::success; |
122 | } |
123 | |
124 | } // namespace ocl |
125 | } // namespace gpu |
126 | } // namespace impl |
127 | } // namespace dnnl |
128 | |