1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_GEMM_GPU_GEMM_EXEC_TYPES_HPP
18#define GPU_GEMM_GPU_GEMM_EXEC_TYPES_HPP
19
20#include "common/memory_storage.hpp"
21#include "common/stream.hpp"
22
23#define DNNL_ARG_A DNNL_ARG_SRC
24#define DNNL_ARG_B DNNL_ARG_WEIGHTS
25#define DNNL_ARG_C DNNL_ARG_DST
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30
31#define GEMM_CTX_ARG_STORAGE(argument) \
32 (ctx.args().argument ? *(ctx.args().argument) \
33 : dnnl::impl::memory_storage_t::empty_storage())
34
35struct gemm_exec_args_t {
36 memory_storage_t *a = nullptr;
37 memory_storage_t *b = nullptr;
38 memory_storage_t *c = nullptr;
39 memory_storage_t *a_zero_point = nullptr;
40 memory_storage_t *b_zero_point = nullptr;
41 memory_storage_t *c_zero_point = nullptr;
42 memory_storage_t *bias = nullptr;
43 memory_storage_t *a_scales = nullptr;
44 memory_storage_t *b_scales = nullptr;
45 memory_storage_t *c_scales = nullptr;
46 memory_storage_t *sum_ab = nullptr;
47 exec_args_t exec_args;
48};
49
50struct gemm_exec_ctx_t {
51 gemm_exec_ctx_t(stream_t *stream, const gemm_exec_args_t &args,
52 const gemm_desc_t *gemm_desc = nullptr)
53 : stream_(stream), args_(args), gemm_desc_(gemm_desc) {}
54 gemm_exec_ctx_t(const exec_ctx_t &other, const gemm_exec_args_t &args,
55 const gemm_desc_t *gemm_desc = nullptr)
56 : stream_(other.stream())
57 , args_(args)
58 , gemm_desc_(gemm_desc)
59 , resource_mapper_(other.get_resource_mapper())
60 , scratchpad_grantor_(other.grantor_handle()) {}
61
62 stream_t *stream() const { return stream_; }
63 const gemm_exec_args_t &args() const { return args_; }
64 const gemm_desc_t *desc() const { return gemm_desc_; }
65
66 void set_scratchpad_grantor(
67 const memory_tracking::grantor_t *scratchpad_grantor) {
68 scratchpad_grantor_ = scratchpad_grantor;
69 }
70
71 const memory_tracking::grantor_t &get_scratchpad_grantor() const {
72 assert(scratchpad_grantor_);
73 return *scratchpad_grantor_;
74 }
75
76 const resource_mapper_t *get_resource_mapper() const {
77 assert(resource_mapper_);
78 return resource_mapper_;
79 }
80
81 void set_resource_mapper(const resource_mapper_t *resource_mapper) {
82 resource_mapper_ = resource_mapper;
83 }
84
85private:
86 stream_t *stream_;
87 gemm_exec_args_t args_;
88 const gemm_desc_t *gemm_desc_ = nullptr;
89 const resource_mapper_t *resource_mapper_ = nullptr;
90 const memory_tracking::grantor_t *scratchpad_grantor_ = nullptr;
91};
92
93} // namespace gpu
94} // namespace impl
95} // namespace dnnl
96
97#endif
98