1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_GEMM_PD_HPP
18#define GPU_JIT_GEMM_PD_HPP
19
20#include <vector>
21
22#include "common/c_types_map.hpp"
23#include "gpu/gpu_gemm_pd.hpp"
24#include "gpu/jit/gemm/gen_gemm_kernel_generator.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31#define GEMM_MAX_PO 36
32
33struct jit_gemm_pd_t : public gpu_gemm_pd_t {
34 using gpu_gemm_pd_t::gpu_gemm_pd_t;
35
36 struct binary_src_t {
37 enum type_t { none, scales, bias, binary } type;
38 int index;
39
40 binary_src_t(type_t type_, int index_) : type(type_), index(index_) {}
41 };
42
43 status_t init_post_ops() {
44 using namespace primitive_kind;
45 using namespace alg_kind;
46 using namespace data_type;
47
48 const auto d = desc();
49
50 // Examine post-ops and remember binary srcs.
51 post_ops_ = attr()->post_ops_;
52 binary_srcs_.reserve(post_ops_.len() + 4);
53
54 bool ok = true;
55
56 for (int i = 0; i < post_ops_.len(); i++) {
57 const auto &e = post_ops_.entry_[i];
58 switch (e.kind) {
59 case binary:
60 ok &= gemm_kernel_generator_t<ngen::HW::Unknown>::
61 supportedBinaryOp(e.binary.alg)
62 && is_md_gemm_compatible_plain_format(
63 &e.binary.src1_desc);
64 binary_srcs_.push_back(
65 binary_src_t {binary_src_t::binary, int(i)});
66 break;
67 case sum:
68 ok &= !with_sum_;
69 with_sum_ = true;
70 sum_at_begin_ = (i == 0);
71 binary_srcs_.push_back(
72 binary_src_t {binary_src_t::none, 0});
73 beta_ = e.sum.scale;
74 break;
75 case eltwise:
76 ok &= jit_eltwise_injector_f32_is_supported(e.eltwise.alg);
77 binary_srcs_.push_back(
78 binary_src_t {binary_src_t::none, 0});
79 break;
80 default: return status::unimplemented;
81 }
82 }
83
84 if (!ok) return status::unimplemented;
85
86 // If scales are present, convert them and any bias to binary post-ops.
87 // Also convert bias to binary post-op if dst zp are present.
88 const auto *wei_scales = &attr()->scales_.get(DNNL_ARG_WEIGHTS);
89 const auto *src_scales = &attr()->scales_.get(DNNL_ARG_SRC);
90 const auto *c_scales = &attr()->scales_.get(DNNL_ARG_DST);
91
92 bias_via_binary_ = (desc()->bias_type() != data_type::undef)
93 && (!wei_scales->has_default_values()
94 || !src_scales->has_default_values()
95 || !attr()->zero_points_.has_default_values(
96 DNNL_ARG_DST));
97 if (bias_via_binary_) {
98 auto status = post_ops_.prepend_binary(binary_add, &d->bias_desc);
99 if (status != status::success) return status;
100 binary_srcs_.insert(
101 binary_srcs_.begin(), binary_src_t {binary_src_t::bias, 0});
102 }
103
104 if (!wei_scales->has_default_values()) {
105 const auto &mask = wei_scales->mask_;
106 ok = ok && (mask == 0 || mask == (1 << (d->c_desc.ndims - 1)));
107
108 dim_t dims = {(mask > 0) ? d->m() : 1};
109 memory_desc_init_by_tag(
110 wei_scales_md, 1, &dims, f32, format_tag::a);
111
112 auto status = post_ops_.prepend_binary(binary_mul, &wei_scales_md);
113 if (status != status::success) return status;
114
115 binary_srcs_.insert(binary_srcs_.begin(),
116 binary_src_t {binary_src_t::scales, DNNL_ARG_WEIGHTS});
117 }
118 if (!src_scales->has_default_values()) {
119 ok = ok && (src_scales->mask_ == 0);
120
121 dim_t dims = {1};
122 memory_desc_init_by_tag(
123 src_scales_md, 1, &dims, f32, format_tag::a);
124
125 auto status = post_ops_.prepend_binary(binary_mul, &src_scales_md);
126 if (status != status::success) return status;
127
128 binary_srcs_.insert(binary_srcs_.begin(),
129 binary_src_t {binary_src_t::scales, DNNL_ARG_SRC});
130 }
131 if (!c_scales->has_default_values()) {
132 ok = ok && (c_scales->mask_ == 0);
133
134 dim_t dims = {1};
135 memory_desc_init_by_tag(c_scales_md, 1, &dims, f32, format_tag::a);
136
137 auto status = post_ops_.append_binary(binary_div, &c_scales_md);
138 if (status != status::success) return status;
139
140 binary_srcs_.push_back(
141 binary_src_t {binary_src_t::scales, DNNL_ARG_DST});
142 }
143
144 return status::success;
145 }
146
147 dim_t ld_binary(int idx) const {
148 switch (binary_srcs_[idx].type) {
149 case binary_src_t::binary: {
150 const auto &entry = post_ops_.entry_[idx];
151 assert(entry.kind == primitive_kind::binary);
152 return gemm_desc_t::get_ld(entry.binary.src1_desc);
153 }
154 case binary_src_t::bias: return desc()->ld_bias();
155 default: return 1;
156 }
157 }
158
159 dim_t stride_binary(int idx, int stride = 0) const {
160 switch (binary_srcs_[idx].type) {
161 case binary_src_t::binary: {
162 const auto &entry = post_ops_.entry_[idx];
163 assert(entry.kind == primitive_kind::binary);
164 return gemm_desc_t::get_stride(entry.binary.src1_desc, stride);
165 }
166 default: return 0;
167 }
168 }
169
170 const post_ops_t *post_ops() const { return &post_ops_; }
171 const std::vector<binary_src_t> &binary_srcs() const {
172 return binary_srcs_;
173 }
174
175 float beta_ = 0.0f;
176
177 bool with_sum_ = false;
178 bool sum_at_begin_ = false;
179
180 bool bias_via_binary_ = false;
181
182 post_ops_t post_ops_;
183 std::vector<binary_src_t> binary_srcs_;
184
185 memory_desc_t wei_scales_md, src_scales_md, c_scales_md;
186};
187
188} // namespace jit
189} // namespace gpu
190} // namespace impl
191} // namespace dnnl
192
193#endif