1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_GEMM_GPU_GEMM_EXEC_TYPES_HPP |
18 | #define GPU_GEMM_GPU_GEMM_EXEC_TYPES_HPP |
19 | |
20 | #include "common/memory_storage.hpp" |
21 | #include "common/stream.hpp" |
22 | |
23 | #define DNNL_ARG_A DNNL_ARG_SRC |
24 | #define DNNL_ARG_B DNNL_ARG_WEIGHTS |
25 | #define DNNL_ARG_C DNNL_ARG_DST |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | |
31 | #define GEMM_CTX_ARG_STORAGE(argument) \ |
32 | (ctx.args().argument ? *(ctx.args().argument) \ |
33 | : dnnl::impl::memory_storage_t::empty_storage()) |
34 | |
35 | struct gemm_exec_args_t { |
36 | memory_storage_t *a = nullptr; |
37 | memory_storage_t *b = nullptr; |
38 | memory_storage_t *c = nullptr; |
39 | memory_storage_t *a_zero_point = nullptr; |
40 | memory_storage_t *b_zero_point = nullptr; |
41 | memory_storage_t *c_zero_point = nullptr; |
42 | memory_storage_t *bias = nullptr; |
43 | memory_storage_t *a_scales = nullptr; |
44 | memory_storage_t *b_scales = nullptr; |
45 | memory_storage_t *c_scales = nullptr; |
46 | memory_storage_t *sum_ab = nullptr; |
47 | exec_args_t exec_args; |
48 | }; |
49 | |
50 | struct gemm_exec_ctx_t { |
51 | gemm_exec_ctx_t(stream_t *stream, const gemm_exec_args_t &args, |
52 | const gemm_desc_t *gemm_desc = nullptr) |
53 | : stream_(stream), args_(args), gemm_desc_(gemm_desc) {} |
54 | gemm_exec_ctx_t(const exec_ctx_t &other, const gemm_exec_args_t &args, |
55 | const gemm_desc_t *gemm_desc = nullptr) |
56 | : stream_(other.stream()) |
57 | , args_(args) |
58 | , gemm_desc_(gemm_desc) |
59 | , resource_mapper_(other.get_resource_mapper()) |
60 | , scratchpad_grantor_(other.grantor_handle()) {} |
61 | |
62 | stream_t *stream() const { return stream_; } |
63 | const gemm_exec_args_t &args() const { return args_; } |
64 | const gemm_desc_t *desc() const { return gemm_desc_; } |
65 | |
66 | void set_scratchpad_grantor( |
67 | const memory_tracking::grantor_t *scratchpad_grantor) { |
68 | scratchpad_grantor_ = scratchpad_grantor; |
69 | } |
70 | |
71 | const memory_tracking::grantor_t &get_scratchpad_grantor() const { |
72 | assert(scratchpad_grantor_); |
73 | return *scratchpad_grantor_; |
74 | } |
75 | |
76 | const resource_mapper_t *get_resource_mapper() const { |
77 | assert(resource_mapper_); |
78 | return resource_mapper_; |
79 | } |
80 | |
81 | void set_resource_mapper(const resource_mapper_t *resource_mapper) { |
82 | resource_mapper_ = resource_mapper; |
83 | } |
84 | |
85 | private: |
86 | stream_t *stream_; |
87 | gemm_exec_args_t args_; |
88 | const gemm_desc_t *gemm_desc_ = nullptr; |
89 | const resource_mapper_t *resource_mapper_ = nullptr; |
90 | const memory_tracking::grantor_t *scratchpad_grantor_ = nullptr; |
91 | }; |
92 | |
93 | } // namespace gpu |
94 | } // namespace impl |
95 | } // namespace dnnl |
96 | |
97 | #endif |
98 | |