1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gemm_matmul.hpp" |
18 | |
19 | #include "gpu/gemm/gpu_gemm.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace ocl { |
25 | |
26 | status_t gemm_matmul_t::execute(const exec_ctx_t &ctx) const { |
27 | using namespace memory_tracking::names; |
28 | |
29 | const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC); |
30 | const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS); |
31 | const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST); |
32 | const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS); |
33 | |
34 | memory_storage_t *a0 |
35 | = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); |
36 | |
37 | memory_storage_t *b0 |
38 | = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS); |
39 | |
40 | memory_storage_t *c0 |
41 | = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); |
42 | |
43 | gemm_exec_args_t gemm_args; |
44 | gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC); |
45 | gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
46 | gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DST); |
47 | gemm_args.bias = &CTX_IN_STORAGE(DNNL_ARG_BIAS); |
48 | |
49 | // Note: we have to swap `a` and `b` zero-point arguments because, |
50 | // - gemm primitive is created with row major desc, |
51 | // - parameters to gemm are passed as row major |
52 | // - but gemm implementation assumes column major |
53 | gemm_args.a_zero_point = b0; |
54 | gemm_args.b_zero_point = a0; |
55 | gemm_args.c_zero_point = c0; |
56 | gemm_args.a_scales |
57 | = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); |
58 | gemm_args.b_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
59 | gemm_args.c_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
60 | gemm_args.exec_args = ctx.args(); |
61 | auto gemm_desc = create_gemm_desc(src_d.md_, weights_d.md_, dst_d.md_, |
62 | bia_d.md_, pd()->desc()->accum_data_type, ctx.stream()->engine()); |
63 | |
64 | gemm_exec_ctx_t gemm_ctx(ctx, gemm_args, &gemm_desc); |
65 | |
66 | nested_scratchpad_t ns(ctx, key_nested, gemm_); |
67 | gemm_ctx.set_scratchpad_grantor(ns.grantor()); |
68 | |
69 | status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx); |
70 | if (gemm_exec_status != status::success) return gemm_exec_status; |
71 | |
72 | return status::success; |
73 | } |
74 | |
75 | } // namespace ocl |
76 | } // namespace gpu |
77 | } // namespace impl |
78 | } // namespace dnnl |
79 | |