1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_SOFTMAX_HPP
18#define GPU_OCL_GEN9_SOFTMAX_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/nstl.hpp"
22#include "common/primitive.hpp"
23#include "gpu/compute/compute.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/gpu_softmax_pd.hpp"
27#include "gpu/ocl/ocl_stream.hpp"
28#include "gpu/ocl/ocl_utils.hpp"
29#include "gpu/primitive_conf.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace ocl {
35
36struct gen9_softmax_fwd_t : public gpu_primitive_t {
37 using gpu_primitive_t::gpu_primitive_t;
38 struct pd_t : public gpu_softmax_fwd_pd_t {
39 using gpu_softmax_fwd_pd_t::gpu_softmax_fwd_pd_t;
40
41 DECLARE_COMMON_PD_T("ocl:gen9", gen9_softmax_fwd_t);
42
43 status_t init(engine_t *engine) {
44 using namespace dnnl::impl::format_tag;
45 auto *compute_engine
46 = utils::downcast<compute::compute_engine_t *>(engine);
47
48 const memory_desc_wrapper src_d(src_md());
49 const memory_desc_wrapper dst_d(dst_md());
50 const auto src_dt = src_d.data_type();
51 const auto dst_dt = dst_d.data_type();
52
53 using namespace data_type;
54 using skip_mask_t = primitive_attr_t::skip_mask_t;
55 is_nhwc = (src_d.matches_one_of_tag(nwc, nhwc, ndhwc)
56 != format_tag::undef);
57 is_blocked = (src_d.matches_one_of_tag(nCw16c, nChw16c, nCdhw16c)
58 != format_tag::undef);
59
60 bool ok = is_fwd() && axis_size() % buffer_size == 0
61 && !memory_desc_ndims_ok(src_md(), dst_md())
62 && axis() == src_d.ndims() - 1
63 && (src_d.is_plain() || is_blocked || is_nhwc)
64 && utils::one_of(src_dt, f32, f16, bf16, u8, s8)
65 && utils::one_of(dst_dt, f32, f16, bf16, u8, s8)
66 && IMPLICATION(utils::one_of(f16, src_dt, dst_dt),
67 compute_engine->mayiuse(
68 compute::device_ext_t::khr_fp16))
69 && attr()->has_default_values(skip_mask_t::scales_runtime)
70 && attr_scales_ok()
71 && set_default_formats() == status::success
72 && compute_engine->mayiuse_sub_group(subgroup_size);
73 if (!ok) return status::unimplemented;
74
75 if (is_blocked && src_md()->dims[1] % subgroup_size != 0) {
76 return status::unimplemented;
77 }
78
79 if (is_nhwc || is_blocked) {
80 group_size = subgroup_size * (axis_size() / buffer_size);
81 } else {
82 group_size = subgroup_size;
83 }
84
85 lws[0] = group_size;
86 lws[1] = lws[2] = 1;
87 gws[0] = utils::array_product(&src_md()->dims[0], ndims() - 1)
88 * group_size;
89 gws[1] = gws[2] = 1;
90
91 auto src_padded_dims = src_d.padded_dims();
92 mb = src_padded_dims[0] * src_padded_dims[2] * src_padded_dims[3]
93 & src_padded_dims[4];
94 return status::success;
95 }
96
97 bool is_nhwc = false;
98 bool is_blocked = false;
99 size_t gws[3] = {};
100 size_t lws[3] = {};
101 size_t block[3] = {};
102 size_t group_size = 0;
103 size_t mb = 0;
104 const int subgroup_size = 16;
105 // 8x16 load and store commands (Vector_Size x Sub_Group_Size)
106 const int buffer_size = 128;
107 };
108
109 status_t init(engine_t *engine) override {
110 if (pd()->has_zero_dim_memory()) return status::success;
111
112 compute::kernel_ctx_t kernel_ctx;
113
114 kernel_ctx.define_int("SOFTMAX_AXIS_IDX", pd()->axis());
115 kernel_ctx.define_int("SOFTMAX_AXIS_SIZE", pd()->axis_size());
116 kernel_ctx.define_int("SOFTMAX_BUF", pd()->buffer_size);
117 kernel_ctx.define_int("GROUP_SIZE", pd()->group_size);
118 kernel_ctx.define_int("SUB_GROUP_SIZE", pd()->subgroup_size);
119 kernel_ctx.define_int("MB", pd()->mb);
120 kernel_ctx.define_int("OC_PADDED", pd()->src_md()->padded_dims[1]);
121 kernel_ctx.define_int("OC",
122 pd()->is_blocked ? pd()->subgroup_size
123 : pd()->src_md(0)->padded_dims[1]);
124 kernel_ctx.define_int("IS_NHWC", pd()->is_nhwc);
125 kernel_ctx.define_int("IS_BLOCKED", pd()->is_blocked);
126 kernel_ctx.define_int("IS_FWD", 1);
127 kernel_ctx.add_option("-cl-std=CL2.0");
128 kernel_ctx.define_int("LOGSOFTMAX", pd()->is_logsoftmax());
129 kernel_ctx.define_int("WITH_SRC_SCALES",
130 !pd()->attr()->scales_.get(DNNL_ARG_SRC).has_default_values());
131 kernel_ctx.define_int("WITH_DST_SCALES",
132 !pd()->attr()->scales_.get(DNNL_ARG_DST).has_default_values());
133
134 const memory_desc_wrapper dst_mdw(pd()->dst_md());
135 const memory_desc_wrapper src_mdw(pd()->src_md());
136 const auto dst_md_info = memory_desc_info_t::create(dst_mdw);
137 const auto src_md_info = memory_desc_info_t::create(src_mdw);
138 def_memory_desc_info(kernel_ctx, dst_md_info, "DST");
139 def_memory_desc_info(kernel_ctx, src_md_info, "SRC");
140 kernel_ctx.set_data_type(dst_mdw.data_type());
141 set_offsets(kernel_ctx, pd()->dst_md(), "DATA");
142
143 for (int i = 0; i < 3; ++i)
144 kernel_ctx.define_int(utils::format("BLOCK_%d", i), pd()->block[i]);
145
146 create_kernel(engine, &kernel_, "gen9_softmax_fwd", kernel_ctx);
147 if (!kernel_) return status::runtime_error;
148
149 return status::success;
150 }
151
152 status_t execute(const exec_ctx_t &ctx) const override {
153 return execute_generic(ctx);
154 }
155
156protected:
157 status_t execute_generic(const exec_ctx_t &ctx) const;
158 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
159 compute::kernel_t kernel_;
160};
161
162struct gen9_softmax_bwd_t : public gpu_primitive_t {
163 using gpu_primitive_t::gpu_primitive_t;
164 struct pd_t : public gpu_softmax_bwd_pd_t {
165 using gpu_softmax_bwd_pd_t::gpu_softmax_bwd_pd_t;
166
167 DECLARE_COMMON_PD_T("ocl:gen9", gen9_softmax_bwd_t);
168
169 status_t init(engine_t *engine) {
170 using namespace dnnl::impl::format_tag;
171
172 auto *compute_engine
173 = utils::downcast<compute::compute_engine_t *>(engine);
174
175 const memory_desc_wrapper diff_src_d(diff_src_md());
176 const memory_desc_wrapper diff_dst_d(diff_dst_md());
177 const memory_desc_wrapper dst_d(dst_md());
178
179 using namespace data_type;
180 bool ok = !is_fwd() && axis_size() % buffer_size == 0
181 && !memory_desc_ndims_ok(
182 dst_md(), diff_src_md(), diff_dst_md())
183 && axis() == diff_src_d.ndims() - 1
184 && utils::one_of(diff_src_d.data_type(), f32, bf16)
185 && utils::one_of(diff_dst_d.data_type(), f32, bf16)
186 && compute_engine->mayiuse_sub_group(subgroup_size)
187 && attr()->has_default_values()
188 && set_default_formats() == status::success
189 && diff_dst_d.data_type() == dst_d.data_type();
190 if (!ok) return status::unimplemented;
191
192 is_nhwc = (diff_src_d.matches_one_of_tag(nwc, nhwc, ndhwc)
193 != format_tag::undef);
194 is_blk = (diff_src_d.matches_one_of_tag(nCw16c, nChw16c, nCdhw16c)
195 != format_tag::undef);
196 if (is_nhwc || is_blk) {
197 group_size = subgroup_size * (axis_size() / buffer_size);
198 } else {
199 group_size = subgroup_size;
200 }
201 lws[0] = group_size;
202 lws[1] = lws[2] = 1;
203 gws[0] = utils::array_product(
204 &diff_src_md(0)->padded_dims[0], ndims() - 1)
205 * group_size;
206 gws[1] = gws[2] = 1;
207 batches = diff_src_md(0)->padded_dims[0]
208 * diff_src_md(0)->padded_dims[2];
209 return status::success;
210 }
211
212 size_t gws[3] = {};
213 size_t lws[3] = {};
214 size_t block[3] = {};
215 size_t group_size = 0;
216 size_t batches = 0;
217 bool is_nhwc = false;
218 bool is_blk = false;
219 const int subgroup_size = 16;
220 // 8x16 load and store commands (Vector_Size x Sub_Group_Size)
221 const int buffer_size = 128;
222 };
223
224 status_t init(engine_t *engine) override {
225 if (pd()->has_zero_dim_memory()) return status::success;
226
227 compute::kernel_ctx_t kernel_ctx;
228
229 kernel_ctx.define_int("SOFTMAX_AXIS_IDX", pd()->axis());
230 kernel_ctx.define_int("SOFTMAX_AXIS_SIZE", pd()->axis_size());
231 kernel_ctx.define_int("SOFTMAX_BUF", pd()->buffer_size);
232 kernel_ctx.define_int("SUB_GROUP_SIZE", pd()->subgroup_size);
233 kernel_ctx.define_int("GROUP_SIZE", pd()->group_size);
234 kernel_ctx.define_int("IS_BWD", 1);
235 kernel_ctx.define_int("IS_16C", pd()->is_blk);
236 kernel_ctx.define_int("BATCH", pd()->batches);
237 kernel_ctx.define_int("IC_WO_PADDING", pd()->diff_src_md(0)->dims[1]);
238 kernel_ctx.define_int(
239 "IC_PADDED", pd()->diff_src_md(0)->padded_dims[1]);
240 kernel_ctx.define_int("IC",
241 pd()->is_blk ? pd()->subgroup_size
242 : pd()->diff_src_md(0)->padded_dims[1]);
243 kernel_ctx.define_int("IS_NHWC", pd()->is_nhwc);
244 kernel_ctx.add_option("-cl-std=CL2.0");
245 kernel_ctx.define_int("LOGSOFTMAX", pd()->is_logsoftmax());
246
247 const memory_desc_wrapper diff_src_mdw(pd()->diff_src_md());
248 const memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md());
249 const auto diff_src_md_info = memory_desc_info_t::create(diff_src_mdw);
250 const auto diff_dst_md_info = memory_desc_info_t::create(diff_dst_mdw);
251 def_memory_desc_info(kernel_ctx, diff_src_md_info, "SRC");
252 def_memory_desc_info(kernel_ctx, diff_dst_md_info, "DST");
253 kernel_ctx.set_data_type(pd()->diff_src_md()->data_type);
254 set_offsets(kernel_ctx, *pd()->diff_src_md(), "DATA");
255
256 for (int i = 0; i < 3; ++i)
257 kernel_ctx.define_int(utils::format("BLOCK_%d", i), pd()->block[i]);
258
259 create_kernel(engine, &kernel_, "gen9_softmax_bwd", kernel_ctx);
260 if (!kernel_) return status::runtime_error;
261
262 return status::success;
263 }
264
265 status_t execute(const exec_ctx_t &ctx) const override {
266 return execute_generic(ctx);
267 }
268
269protected:
270 status_t execute_generic(const exec_ctx_t &ctx) const;
271 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
272 compute::kernel_t kernel_;
273};
274} // namespace ocl
275} // namespace gpu
276} // namespace impl
277} // namespace dnnl
278
279#endif
280