1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gemm_matmul.hpp"
18
19#include "gpu/gemm/gpu_gemm.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace ocl {
25
26status_t gemm_matmul_t::execute(const exec_ctx_t &ctx) const {
27 using namespace memory_tracking::names;
28
29 const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC);
30 const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS);
31 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST);
32 const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS);
33
34 memory_storage_t *a0
35 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
36
37 memory_storage_t *b0
38 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
39
40 memory_storage_t *c0
41 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
42
43 gemm_exec_args_t gemm_args;
44 gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC);
45 gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
46 gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DST);
47 gemm_args.bias = &CTX_IN_STORAGE(DNNL_ARG_BIAS);
48
49 // Note: we have to swap `a` and `b` zero-point arguments because,
50 // - gemm primitive is created with row major desc,
51 // - parameters to gemm are passed as row major
52 // - but gemm implementation assumes column major
53 gemm_args.a_zero_point = b0;
54 gemm_args.b_zero_point = a0;
55 gemm_args.c_zero_point = c0;
56 gemm_args.a_scales
57 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
58 gemm_args.b_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
59 gemm_args.c_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
60 gemm_args.exec_args = ctx.args();
61 auto gemm_desc = create_gemm_desc(src_d.md_, weights_d.md_, dst_d.md_,
62 bia_d.md_, pd()->desc()->accum_data_type, ctx.stream()->engine());
63
64 gemm_exec_ctx_t gemm_ctx(ctx, gemm_args, &gemm_desc);
65
66 nested_scratchpad_t ns(ctx, key_nested, gemm_);
67 gemm_ctx.set_scratchpad_grantor(ns.grantor());
68
69 status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx);
70 if (gemm_exec_status != status::success) return gemm_exec_status;
71
72 return status::success;
73}
74
75} // namespace ocl
76} // namespace gpu
77} // namespace impl
78} // namespace dnnl
79