1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_GEMM_GPU_GEMM_HPP
18#define GPU_GEMM_GPU_GEMM_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/gemm/gpu_gemm_exec_types.hpp"
23#include "gpu/gpu_primitive.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28
29struct gpu_gemm_t : public gpu_primitive_t {
30 using gpu_primitive_t::gpu_primitive_t;
31 virtual status_t execute(const gemm_exec_ctx_t &ctx) const = 0;
32 status_t execute(const exec_ctx_t &ctx) const override {
33 gemm_exec_args_t gemm_args;
34 // TODO: we have to swap a and b because
35 // - gemm primitive is created with row major desc,
36 // - parameters to gemm are passed as row major
37 // - but gemm implementation assumes column major
38 gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_B);
39 gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_A);
40 gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_C);
41 gemm_args.a_zero_point
42 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_B);
43 gemm_args.b_zero_point
44 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_A);
45 gemm_args.c_zero_point
46 = &CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_C);
47 gemm_args.a_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_B);
48 gemm_args.b_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_A);
49 gemm_args.c_scales = &CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_C);
50 gemm_exec_ctx_t gemm_ctx(ctx, gemm_args);
51 return execute(gemm_ctx);
52 }
53};
54
55inline const gpu_gemm_t *gpu_gemm(const std::shared_ptr<primitive_t> &p) {
56 return utils::downcast<gpu_gemm_t *>(p.get());
57}
58
59} // namespace gpu
60} // namespace impl
61} // namespace dnnl
62
63#endif
64