1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_SOFTMAX_HPP
18#define GPU_OCL_REF_SOFTMAX_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/nstl.hpp"
22#include "common/primitive.hpp"
23#include "gpu/compute/compute.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/gpu_softmax_pd.hpp"
27#include "gpu/ocl/ocl_stream.hpp"
28#include "gpu/ocl/ocl_utils.hpp"
29#include "gpu/primitive_conf.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace ocl {
35
36struct ref_softmax_fwd_t : public gpu_primitive_t {
37 using gpu_primitive_t::gpu_primitive_t;
38 struct pd_t : public gpu_softmax_fwd_pd_t {
39 using gpu_softmax_fwd_pd_t::gpu_softmax_fwd_pd_t;
40
41 DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t);
42
43 status_t init(engine_t *engine) {
44 auto *compute_engine
45 = utils::downcast<compute::compute_engine_t *>(engine);
46
47 const memory_desc_wrapper src_d(src_md());
48 const memory_desc_wrapper dst_d(dst_md());
49 const auto src_dt = src_d.data_type();
50 const auto dst_dt = dst_d.data_type();
51
52 using namespace data_type;
53 using skip_mask_t = primitive_attr_t::skip_mask_t;
54 bool ok = is_fwd() && utils::one_of(src_dt, f32, f16, bf16, u8, s8)
55 && utils::one_of(dst_dt, f32, f16, bf16, u8, s8)
56 && IMPLICATION(utils::one_of(f16, src_dt, dst_dt),
57 compute_engine->mayiuse(
58 compute::device_ext_t::khr_fp16))
59 && compute_engine->mayiuse_sub_group(subgroup_size)
60 && !memory_desc_ndims_ok(src_md(), dst_md())
61 && attr()->has_default_values(skip_mask_t::scales_runtime)
62 && attr_scales_ok()
63 && set_default_formats() == status::success;
64 if (!ok) return status::unimplemented;
65
66 gws[0] = 1;
67 gws[1] = 1;
68 gws[2] = 1;
69
70 lws[0] = 1;
71 lws[1] = 1;
72 lws[2] = 1;
73
74 block[0] = 1;
75 block[1] = 1;
76 block[2] = 1;
77
78 int nelems = axis_size(true);
79
80 if (nelems < subgroup_size) {
81 group_size = subgroup_size = 1;
82 } else if (nelems <= 100) {
83 group_size = subgroup_size * 1;
84 } else if (nelems <= 1000) {
85 group_size = subgroup_size * 2;
86 } else if (nelems <= 2000) {
87 group_size = subgroup_size * 4;
88 } else if (nelems <= 5000) {
89 group_size = subgroup_size * 8;
90 } else {
91 group_size = subgroup_size * 16;
92 }
93
94 for (int i = 0, j = 0; i < src_md()->ndims; ++i) {
95 if (i != desc()->softmax_axis) {
96 auto dim = src_md()->padded_dims[i];
97 gws[j % 3] *= dim;
98 if (j < 3) block[j % 3] = dim;
99 j++;
100 }
101 }
102
103 if (group_size != 1) {
104 lws[0] = group_size;
105 gws[0] *= group_size;
106 }
107
108 return status::success;
109 }
110
111 size_t gws[3] = {};
112 size_t lws[3] = {};
113 size_t block[3] = {};
114 size_t group_size = 0;
115 int subgroup_size = 16;
116 };
117
118 status_t init(engine_t *engine) override {
119 if (pd()->has_zero_dim_memory()) return status::success;
120
121 compute::kernel_ctx_t kernel_ctx;
122
123 const auto *desc = pd()->desc();
124 kernel_ctx.define_int("SOFTMAX_AXIS_IDX", desc->softmax_axis);
125 kernel_ctx.define_int("SOFTMAX_AXIS", pd()->axis_size(true));
126 kernel_ctx.define_int("GROUP_SIZE", pd()->group_size);
127 kernel_ctx.define_int("SUB_GROUP_SIZE", pd()->subgroup_size);
128 kernel_ctx.define_int("IS_FWD", 1);
129 kernel_ctx.add_option("-cl-std=CL2.0");
130 kernel_ctx.define_int("LOGSOFTMAX", pd()->is_logsoftmax());
131 kernel_ctx.define_int("WITH_SRC_SCALES",
132 !pd()->attr()->scales_.get(DNNL_ARG_SRC).has_default_values());
133 kernel_ctx.define_int("WITH_DST_SCALES",
134 !pd()->attr()->scales_.get(DNNL_ARG_DST).has_default_values());
135
136 const memory_desc_wrapper dst_mdw(pd()->dst_md());
137 const memory_desc_wrapper src_mdw(pd()->src_md());
138 const auto dst_md_info = memory_desc_info_t::create(dst_mdw);
139 const auto src_md_info = memory_desc_info_t::create(src_mdw);
140 def_memory_desc_info(kernel_ctx, dst_md_info, "DST");
141 def_memory_desc_info(kernel_ctx, src_md_info, "SRC");
142 kernel_ctx.set_data_type(dst_mdw.data_type());
143 set_offsets(kernel_ctx, pd()->dst_md(), "DATA");
144
145 for (int i = 0; i < 3; i++)
146 kernel_ctx.define_int(utils::format("BLOCK_%d", i), pd()->block[i]);
147
148 create_kernel(engine, &kernel_, "ref_softmax_fwd_generic", kernel_ctx);
149 if (!kernel_) return status::runtime_error;
150
151 return status::success;
152 }
153
154 status_t execute(const exec_ctx_t &ctx) const override {
155 return execute_generic(ctx);
156 }
157
158protected:
159 status_t execute_generic(const exec_ctx_t &ctx) const;
160 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
161 compute::kernel_t kernel_;
162};
163
164struct ref_softmax_bwd_t : public gpu_primitive_t {
165 using gpu_primitive_t::gpu_primitive_t;
166 struct pd_t : public gpu_softmax_bwd_pd_t {
167 using gpu_softmax_bwd_pd_t::gpu_softmax_bwd_pd_t;
168
169 DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t);
170
171 status_t init(engine_t *engine) {
172 auto *compute_engine
173 = utils::downcast<compute::compute_engine_t *>(engine);
174
175 const memory_desc_wrapper diff_dst_d(diff_dst_md());
176 const memory_desc_wrapper diff_src_d(diff_src_md());
177 const memory_desc_wrapper dst_d(dst_md());
178
179 using namespace data_type;
180 bool ok = !is_fwd()
181 && utils::one_of(diff_src_d.data_type(), f32, bf16)
182 && utils::one_of(diff_dst_d.data_type(), f32, bf16)
183 && compute_engine->mayiuse_sub_group(16)
184 && !memory_desc_ndims_ok(
185 dst_md(), diff_src_md(), diff_dst_md())
186 && attr()->has_default_values()
187 && set_default_formats() == status::success
188 && diff_dst_d.data_type() == dst_d.data_type();
189 if (!ok) return status::unimplemented;
190
191 gws[0] = 1;
192 gws[1] = 1;
193 gws[2] = 1;
194
195 lws[0] = 1;
196 lws[1] = 1;
197 lws[2] = 1;
198
199 block[0] = 1;
200 block[1] = 1;
201 block[2] = 1;
202
203 for (int i = 0, j = 0; i < dst_d.ndims(); ++i) {
204 if (i != axis()) {
205 auto dim = dst_d.padded_dims()[i];
206 gws[j % 3] *= dim;
207 if (j < 3) block[j % 3] = dim;
208 j++;
209 }
210 }
211
212 int nelems = axis_size(true);
213 if (nelems <= 100) {
214 group_size = 16;
215 } else if (nelems <= 1000) {
216 group_size = 32;
217 } else if (nelems <= 2000) {
218 group_size = 64;
219 } else if (nelems <= 5000) {
220 group_size = 128;
221 } else {
222 group_size = 256;
223 }
224
225 lws[0] = group_size;
226 gws[0] *= group_size;
227
228 return status::success;
229 }
230
231 size_t lws[3] = {};
232 size_t gws[3] = {};
233 size_t block[3] = {};
234 size_t group_size = 0;
235 };
236
237 status_t init(engine_t *engine) override {
238 if (pd()->has_zero_dim_memory()) return status::success;
239
240 compute::kernel_ctx_t kernel_ctx;
241
242 kernel_ctx.define_int("SOFTMAX_AXIS_IDX", pd()->axis());
243 kernel_ctx.define_int("SOFTMAX_AXIS", pd()->axis_size(true));
244 kernel_ctx.define_int("GROUP_SIZE", pd()->group_size);
245 kernel_ctx.define_int("SUB_GROUP_SIZE", 16);
246 kernel_ctx.define_int("IS_BWD", 1);
247 kernel_ctx.add_option("-cl-std=CL2.0");
248 kernel_ctx.define_int("LOGSOFTMAX", pd()->is_logsoftmax());
249
250 const memory_desc_wrapper diff_src_mdw(pd()->diff_src_md());
251 const memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md());
252 const auto diff_src_md_info = memory_desc_info_t::create(diff_src_mdw);
253 const auto diff_dst_md_info = memory_desc_info_t::create(diff_dst_mdw);
254 def_memory_desc_info(kernel_ctx, diff_src_md_info, "SRC");
255 def_memory_desc_info(kernel_ctx, diff_dst_md_info, "DST");
256 kernel_ctx.set_data_type(diff_src_mdw.data_type());
257 set_offsets(kernel_ctx, *pd()->diff_src_md(), "DATA");
258
259 for (int i = 0; i < 3; i++)
260 kernel_ctx.define_int(utils::format("BLOCK_%d", i), pd()->block[i]);
261
262 create_kernel(engine, &kernel_, "ref_softmax_bwd_generic", kernel_ctx);
263 if (!kernel_) return status::runtime_error;
264
265 return status::success;
266 }
267
268 status_t execute(const exec_ctx_t &ctx) const override {
269 return execute_generic(ctx);
270 }
271
272protected:
273 status_t execute_generic(const exec_ctx_t &ctx) const;
274 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
275 compute::kernel_t kernel_;
276};
277
278} // namespace ocl
279} // namespace gpu
280} // namespace impl
281} // namespace dnnl
282
283#endif
284