1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_GEMM_PD_HPP |
18 | #define GPU_JIT_GEMM_PD_HPP |
19 | |
20 | #include <vector> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "gpu/gpu_gemm_pd.hpp" |
24 | #include "gpu/jit/gemm/gen_gemm_kernel_generator.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace gpu { |
29 | namespace jit { |
30 | |
31 | #define GEMM_MAX_PO 36 |
32 | |
33 | struct jit_gemm_pd_t : public gpu_gemm_pd_t { |
34 | using gpu_gemm_pd_t::gpu_gemm_pd_t; |
35 | |
36 | struct binary_src_t { |
37 | enum type_t { none, scales, bias, binary } type; |
38 | int index; |
39 | |
40 | binary_src_t(type_t type_, int index_) : type(type_), index(index_) {} |
41 | }; |
42 | |
43 | status_t init_post_ops() { |
44 | using namespace primitive_kind; |
45 | using namespace alg_kind; |
46 | using namespace data_type; |
47 | |
48 | const auto d = desc(); |
49 | |
50 | // Examine post-ops and remember binary srcs. |
51 | post_ops_ = attr()->post_ops_; |
52 | binary_srcs_.reserve(post_ops_.len() + 4); |
53 | |
54 | bool ok = true; |
55 | |
56 | for (int i = 0; i < post_ops_.len(); i++) { |
57 | const auto &e = post_ops_.entry_[i]; |
58 | switch (e.kind) { |
59 | case binary: |
60 | ok &= gemm_kernel_generator_t<ngen::HW::Unknown>:: |
61 | supportedBinaryOp(e.binary.alg) |
62 | && is_md_gemm_compatible_plain_format( |
63 | &e.binary.src1_desc); |
64 | binary_srcs_.push_back( |
65 | binary_src_t {binary_src_t::binary, int(i)}); |
66 | break; |
67 | case sum: |
68 | ok &= !with_sum_; |
69 | with_sum_ = true; |
70 | sum_at_begin_ = (i == 0); |
71 | binary_srcs_.push_back( |
72 | binary_src_t {binary_src_t::none, 0}); |
73 | beta_ = e.sum.scale; |
74 | break; |
75 | case eltwise: |
76 | ok &= jit_eltwise_injector_f32_is_supported(e.eltwise.alg); |
77 | binary_srcs_.push_back( |
78 | binary_src_t {binary_src_t::none, 0}); |
79 | break; |
80 | default: return status::unimplemented; |
81 | } |
82 | } |
83 | |
84 | if (!ok) return status::unimplemented; |
85 | |
86 | // If scales are present, convert them and any bias to binary post-ops. |
87 | // Also convert bias to binary post-op if dst zp are present. |
88 | const auto *wei_scales = &attr()->scales_.get(DNNL_ARG_WEIGHTS); |
89 | const auto *src_scales = &attr()->scales_.get(DNNL_ARG_SRC); |
90 | const auto *c_scales = &attr()->scales_.get(DNNL_ARG_DST); |
91 | |
92 | bias_via_binary_ = (desc()->bias_type() != data_type::undef) |
93 | && (!wei_scales->has_default_values() |
94 | || !src_scales->has_default_values() |
95 | || !attr()->zero_points_.has_default_values( |
96 | DNNL_ARG_DST)); |
97 | if (bias_via_binary_) { |
98 | auto status = post_ops_.prepend_binary(binary_add, &d->bias_desc); |
99 | if (status != status::success) return status; |
100 | binary_srcs_.insert( |
101 | binary_srcs_.begin(), binary_src_t {binary_src_t::bias, 0}); |
102 | } |
103 | |
104 | if (!wei_scales->has_default_values()) { |
105 | const auto &mask = wei_scales->mask_; |
106 | ok = ok && (mask == 0 || mask == (1 << (d->c_desc.ndims - 1))); |
107 | |
108 | dim_t dims = {(mask > 0) ? d->m() : 1}; |
109 | memory_desc_init_by_tag( |
110 | wei_scales_md, 1, &dims, f32, format_tag::a); |
111 | |
112 | auto status = post_ops_.prepend_binary(binary_mul, &wei_scales_md); |
113 | if (status != status::success) return status; |
114 | |
115 | binary_srcs_.insert(binary_srcs_.begin(), |
116 | binary_src_t {binary_src_t::scales, DNNL_ARG_WEIGHTS}); |
117 | } |
118 | if (!src_scales->has_default_values()) { |
119 | ok = ok && (src_scales->mask_ == 0); |
120 | |
121 | dim_t dims = {1}; |
122 | memory_desc_init_by_tag( |
123 | src_scales_md, 1, &dims, f32, format_tag::a); |
124 | |
125 | auto status = post_ops_.prepend_binary(binary_mul, &src_scales_md); |
126 | if (status != status::success) return status; |
127 | |
128 | binary_srcs_.insert(binary_srcs_.begin(), |
129 | binary_src_t {binary_src_t::scales, DNNL_ARG_SRC}); |
130 | } |
131 | if (!c_scales->has_default_values()) { |
132 | ok = ok && (c_scales->mask_ == 0); |
133 | |
134 | dim_t dims = {1}; |
135 | memory_desc_init_by_tag(c_scales_md, 1, &dims, f32, format_tag::a); |
136 | |
137 | auto status = post_ops_.append_binary(binary_div, &c_scales_md); |
138 | if (status != status::success) return status; |
139 | |
140 | binary_srcs_.push_back( |
141 | binary_src_t {binary_src_t::scales, DNNL_ARG_DST}); |
142 | } |
143 | |
144 | return status::success; |
145 | } |
146 | |
147 | dim_t ld_binary(int idx) const { |
148 | switch (binary_srcs_[idx].type) { |
149 | case binary_src_t::binary: { |
150 | const auto &entry = post_ops_.entry_[idx]; |
151 | assert(entry.kind == primitive_kind::binary); |
152 | return gemm_desc_t::get_ld(entry.binary.src1_desc); |
153 | } |
154 | case binary_src_t::bias: return desc()->ld_bias(); |
155 | default: return 1; |
156 | } |
157 | } |
158 | |
159 | dim_t stride_binary(int idx, int stride = 0) const { |
160 | switch (binary_srcs_[idx].type) { |
161 | case binary_src_t::binary: { |
162 | const auto &entry = post_ops_.entry_[idx]; |
163 | assert(entry.kind == primitive_kind::binary); |
164 | return gemm_desc_t::get_stride(entry.binary.src1_desc, stride); |
165 | } |
166 | default: return 0; |
167 | } |
168 | } |
169 | |
170 | const post_ops_t *post_ops() const { return &post_ops_; } |
171 | const std::vector<binary_src_t> &binary_srcs() const { |
172 | return binary_srcs_; |
173 | } |
174 | |
175 | float beta_ = 0.0f; |
176 | |
177 | bool with_sum_ = false; |
178 | bool sum_at_begin_ = false; |
179 | |
180 | bool bias_via_binary_ = false; |
181 | |
182 | post_ops_t post_ops_; |
183 | std::vector<binary_src_t> binary_srcs_; |
184 | |
185 | memory_desc_t wei_scales_md, src_scales_md, c_scales_md; |
186 | }; |
187 | |
188 | } // namespace jit |
189 | } // namespace gpu |
190 | } // namespace impl |
191 | } // namespace dnnl |
192 | |
193 | #endif |