1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_REF_SOFTMAX_HPP |
18 | #define GPU_OCL_REF_SOFTMAX_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/primitive.hpp" |
23 | #include "gpu/compute/compute.hpp" |
24 | #include "gpu/gpu_primitive.hpp" |
25 | #include "gpu/gpu_resource.hpp" |
26 | #include "gpu/gpu_softmax_pd.hpp" |
27 | #include "gpu/ocl/ocl_stream.hpp" |
28 | #include "gpu/ocl/ocl_utils.hpp" |
29 | #include "gpu/primitive_conf.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace ocl { |
35 | |
36 | struct ref_softmax_fwd_t : public gpu_primitive_t { |
37 | using gpu_primitive_t::gpu_primitive_t; |
38 | struct pd_t : public gpu_softmax_fwd_pd_t { |
39 | using gpu_softmax_fwd_pd_t::gpu_softmax_fwd_pd_t; |
40 | |
41 | DECLARE_COMMON_PD_T("ref:any" , ref_softmax_fwd_t); |
42 | |
43 | status_t init(engine_t *engine) { |
44 | auto *compute_engine |
45 | = utils::downcast<compute::compute_engine_t *>(engine); |
46 | |
47 | const memory_desc_wrapper src_d(src_md()); |
48 | const memory_desc_wrapper dst_d(dst_md()); |
49 | const auto src_dt = src_d.data_type(); |
50 | const auto dst_dt = dst_d.data_type(); |
51 | |
52 | using namespace data_type; |
53 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
54 | bool ok = is_fwd() && utils::one_of(src_dt, f32, f16, bf16, u8, s8) |
55 | && utils::one_of(dst_dt, f32, f16, bf16, u8, s8) |
56 | && IMPLICATION(utils::one_of(f16, src_dt, dst_dt), |
57 | compute_engine->mayiuse( |
58 | compute::device_ext_t::khr_fp16)) |
59 | && compute_engine->mayiuse_sub_group(subgroup_size) |
60 | && !memory_desc_ndims_ok(src_md(), dst_md()) |
61 | && attr()->has_default_values(skip_mask_t::scales_runtime) |
62 | && attr_scales_ok() |
63 | && set_default_formats() == status::success; |
64 | if (!ok) return status::unimplemented; |
65 | |
66 | gws[0] = 1; |
67 | gws[1] = 1; |
68 | gws[2] = 1; |
69 | |
70 | lws[0] = 1; |
71 | lws[1] = 1; |
72 | lws[2] = 1; |
73 | |
74 | block[0] = 1; |
75 | block[1] = 1; |
76 | block[2] = 1; |
77 | |
78 | int nelems = axis_size(true); |
79 | |
80 | if (nelems < subgroup_size) { |
81 | group_size = subgroup_size = 1; |
82 | } else if (nelems <= 100) { |
83 | group_size = subgroup_size * 1; |
84 | } else if (nelems <= 1000) { |
85 | group_size = subgroup_size * 2; |
86 | } else if (nelems <= 2000) { |
87 | group_size = subgroup_size * 4; |
88 | } else if (nelems <= 5000) { |
89 | group_size = subgroup_size * 8; |
90 | } else { |
91 | group_size = subgroup_size * 16; |
92 | } |
93 | |
94 | for (int i = 0, j = 0; i < src_md()->ndims; ++i) { |
95 | if (i != desc()->softmax_axis) { |
96 | auto dim = src_md()->padded_dims[i]; |
97 | gws[j % 3] *= dim; |
98 | if (j < 3) block[j % 3] = dim; |
99 | j++; |
100 | } |
101 | } |
102 | |
103 | if (group_size != 1) { |
104 | lws[0] = group_size; |
105 | gws[0] *= group_size; |
106 | } |
107 | |
108 | return status::success; |
109 | } |
110 | |
111 | size_t gws[3] = {}; |
112 | size_t lws[3] = {}; |
113 | size_t block[3] = {}; |
114 | size_t group_size = 0; |
115 | int subgroup_size = 16; |
116 | }; |
117 | |
118 | status_t init(engine_t *engine) override { |
119 | if (pd()->has_zero_dim_memory()) return status::success; |
120 | |
121 | compute::kernel_ctx_t kernel_ctx; |
122 | |
123 | const auto *desc = pd()->desc(); |
124 | kernel_ctx.define_int("SOFTMAX_AXIS_IDX" , desc->softmax_axis); |
125 | kernel_ctx.define_int("SOFTMAX_AXIS" , pd()->axis_size(true)); |
126 | kernel_ctx.define_int("GROUP_SIZE" , pd()->group_size); |
127 | kernel_ctx.define_int("SUB_GROUP_SIZE" , pd()->subgroup_size); |
128 | kernel_ctx.define_int("IS_FWD" , 1); |
129 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
130 | kernel_ctx.define_int("LOGSOFTMAX" , pd()->is_logsoftmax()); |
131 | kernel_ctx.define_int("WITH_SRC_SCALES" , |
132 | !pd()->attr()->scales_.get(DNNL_ARG_SRC).has_default_values()); |
133 | kernel_ctx.define_int("WITH_DST_SCALES" , |
134 | !pd()->attr()->scales_.get(DNNL_ARG_DST).has_default_values()); |
135 | |
136 | const memory_desc_wrapper dst_mdw(pd()->dst_md()); |
137 | const memory_desc_wrapper src_mdw(pd()->src_md()); |
138 | const auto dst_md_info = memory_desc_info_t::create(dst_mdw); |
139 | const auto src_md_info = memory_desc_info_t::create(src_mdw); |
140 | def_memory_desc_info(kernel_ctx, dst_md_info, "DST" ); |
141 | def_memory_desc_info(kernel_ctx, src_md_info, "SRC" ); |
142 | kernel_ctx.set_data_type(dst_mdw.data_type()); |
143 | set_offsets(kernel_ctx, pd()->dst_md(), "DATA" ); |
144 | |
145 | for (int i = 0; i < 3; i++) |
146 | kernel_ctx.define_int(utils::format("BLOCK_%d" , i), pd()->block[i]); |
147 | |
148 | create_kernel(engine, &kernel_, "ref_softmax_fwd_generic" , kernel_ctx); |
149 | if (!kernel_) return status::runtime_error; |
150 | |
151 | return status::success; |
152 | } |
153 | |
154 | status_t execute(const exec_ctx_t &ctx) const override { |
155 | return execute_generic(ctx); |
156 | } |
157 | |
158 | protected: |
159 | status_t execute_generic(const exec_ctx_t &ctx) const; |
160 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
161 | compute::kernel_t kernel_; |
162 | }; |
163 | |
164 | struct ref_softmax_bwd_t : public gpu_primitive_t { |
165 | using gpu_primitive_t::gpu_primitive_t; |
166 | struct pd_t : public gpu_softmax_bwd_pd_t { |
167 | using gpu_softmax_bwd_pd_t::gpu_softmax_bwd_pd_t; |
168 | |
169 | DECLARE_COMMON_PD_T("ref:any" , ref_softmax_bwd_t); |
170 | |
171 | status_t init(engine_t *engine) { |
172 | auto *compute_engine |
173 | = utils::downcast<compute::compute_engine_t *>(engine); |
174 | |
175 | const memory_desc_wrapper diff_dst_d(diff_dst_md()); |
176 | const memory_desc_wrapper diff_src_d(diff_src_md()); |
177 | const memory_desc_wrapper dst_d(dst_md()); |
178 | |
179 | using namespace data_type; |
180 | bool ok = !is_fwd() |
181 | && utils::one_of(diff_src_d.data_type(), f32, bf16) |
182 | && utils::one_of(diff_dst_d.data_type(), f32, bf16) |
183 | && compute_engine->mayiuse_sub_group(16) |
184 | && !memory_desc_ndims_ok( |
185 | dst_md(), diff_src_md(), diff_dst_md()) |
186 | && attr()->has_default_values() |
187 | && set_default_formats() == status::success |
188 | && diff_dst_d.data_type() == dst_d.data_type(); |
189 | if (!ok) return status::unimplemented; |
190 | |
191 | gws[0] = 1; |
192 | gws[1] = 1; |
193 | gws[2] = 1; |
194 | |
195 | lws[0] = 1; |
196 | lws[1] = 1; |
197 | lws[2] = 1; |
198 | |
199 | block[0] = 1; |
200 | block[1] = 1; |
201 | block[2] = 1; |
202 | |
203 | for (int i = 0, j = 0; i < dst_d.ndims(); ++i) { |
204 | if (i != axis()) { |
205 | auto dim = dst_d.padded_dims()[i]; |
206 | gws[j % 3] *= dim; |
207 | if (j < 3) block[j % 3] = dim; |
208 | j++; |
209 | } |
210 | } |
211 | |
212 | int nelems = axis_size(true); |
213 | if (nelems <= 100) { |
214 | group_size = 16; |
215 | } else if (nelems <= 1000) { |
216 | group_size = 32; |
217 | } else if (nelems <= 2000) { |
218 | group_size = 64; |
219 | } else if (nelems <= 5000) { |
220 | group_size = 128; |
221 | } else { |
222 | group_size = 256; |
223 | } |
224 | |
225 | lws[0] = group_size; |
226 | gws[0] *= group_size; |
227 | |
228 | return status::success; |
229 | } |
230 | |
231 | size_t lws[3] = {}; |
232 | size_t gws[3] = {}; |
233 | size_t block[3] = {}; |
234 | size_t group_size = 0; |
235 | }; |
236 | |
237 | status_t init(engine_t *engine) override { |
238 | if (pd()->has_zero_dim_memory()) return status::success; |
239 | |
240 | compute::kernel_ctx_t kernel_ctx; |
241 | |
242 | kernel_ctx.define_int("SOFTMAX_AXIS_IDX" , pd()->axis()); |
243 | kernel_ctx.define_int("SOFTMAX_AXIS" , pd()->axis_size(true)); |
244 | kernel_ctx.define_int("GROUP_SIZE" , pd()->group_size); |
245 | kernel_ctx.define_int("SUB_GROUP_SIZE" , 16); |
246 | kernel_ctx.define_int("IS_BWD" , 1); |
247 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
248 | kernel_ctx.define_int("LOGSOFTMAX" , pd()->is_logsoftmax()); |
249 | |
250 | const memory_desc_wrapper diff_src_mdw(pd()->diff_src_md()); |
251 | const memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md()); |
252 | const auto diff_src_md_info = memory_desc_info_t::create(diff_src_mdw); |
253 | const auto diff_dst_md_info = memory_desc_info_t::create(diff_dst_mdw); |
254 | def_memory_desc_info(kernel_ctx, diff_src_md_info, "SRC" ); |
255 | def_memory_desc_info(kernel_ctx, diff_dst_md_info, "DST" ); |
256 | kernel_ctx.set_data_type(diff_src_mdw.data_type()); |
257 | set_offsets(kernel_ctx, *pd()->diff_src_md(), "DATA" ); |
258 | |
259 | for (int i = 0; i < 3; i++) |
260 | kernel_ctx.define_int(utils::format("BLOCK_%d" , i), pd()->block[i]); |
261 | |
262 | create_kernel(engine, &kernel_, "ref_softmax_bwd_generic" , kernel_ctx); |
263 | if (!kernel_) return status::runtime_error; |
264 | |
265 | return status::success; |
266 | } |
267 | |
268 | status_t execute(const exec_ctx_t &ctx) const override { |
269 | return execute_generic(ctx); |
270 | } |
271 | |
272 | protected: |
273 | status_t execute_generic(const exec_ctx_t &ctx) const; |
274 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
275 | compute::kernel_t kernel_; |
276 | }; |
277 | |
278 | } // namespace ocl |
279 | } // namespace gpu |
280 | } // namespace impl |
281 | } // namespace dnnl |
282 | |
283 | #endif |
284 | |