1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_GEN9_SOFTMAX_HPP |
18 | #define GPU_OCL_GEN9_SOFTMAX_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/primitive.hpp" |
23 | #include "gpu/compute/compute.hpp" |
24 | #include "gpu/gpu_primitive.hpp" |
25 | #include "gpu/gpu_resource.hpp" |
26 | #include "gpu/gpu_softmax_pd.hpp" |
27 | #include "gpu/ocl/ocl_stream.hpp" |
28 | #include "gpu/ocl/ocl_utils.hpp" |
29 | #include "gpu/primitive_conf.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace ocl { |
35 | |
36 | struct gen9_softmax_fwd_t : public gpu_primitive_t { |
37 | using gpu_primitive_t::gpu_primitive_t; |
38 | struct pd_t : public gpu_softmax_fwd_pd_t { |
39 | using gpu_softmax_fwd_pd_t::gpu_softmax_fwd_pd_t; |
40 | |
41 | DECLARE_COMMON_PD_T("ocl:gen9" , gen9_softmax_fwd_t); |
42 | |
43 | status_t init(engine_t *engine) { |
44 | using namespace dnnl::impl::format_tag; |
45 | auto *compute_engine |
46 | = utils::downcast<compute::compute_engine_t *>(engine); |
47 | |
48 | const memory_desc_wrapper src_d(src_md()); |
49 | const memory_desc_wrapper dst_d(dst_md()); |
50 | const auto src_dt = src_d.data_type(); |
51 | const auto dst_dt = dst_d.data_type(); |
52 | |
53 | using namespace data_type; |
54 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
55 | is_nhwc = (src_d.matches_one_of_tag(nwc, nhwc, ndhwc) |
56 | != format_tag::undef); |
57 | is_blocked = (src_d.matches_one_of_tag(nCw16c, nChw16c, nCdhw16c) |
58 | != format_tag::undef); |
59 | |
60 | bool ok = is_fwd() && axis_size() % buffer_size == 0 |
61 | && !memory_desc_ndims_ok(src_md(), dst_md()) |
62 | && axis() == src_d.ndims() - 1 |
63 | && (src_d.is_plain() || is_blocked || is_nhwc) |
64 | && utils::one_of(src_dt, f32, f16, bf16, u8, s8) |
65 | && utils::one_of(dst_dt, f32, f16, bf16, u8, s8) |
66 | && IMPLICATION(utils::one_of(f16, src_dt, dst_dt), |
67 | compute_engine->mayiuse( |
68 | compute::device_ext_t::khr_fp16)) |
69 | && attr()->has_default_values(skip_mask_t::scales_runtime) |
70 | && attr_scales_ok() |
71 | && set_default_formats() == status::success |
72 | && compute_engine->mayiuse_sub_group(subgroup_size); |
73 | if (!ok) return status::unimplemented; |
74 | |
75 | if (is_blocked && src_md()->dims[1] % subgroup_size != 0) { |
76 | return status::unimplemented; |
77 | } |
78 | |
79 | if (is_nhwc || is_blocked) { |
80 | group_size = subgroup_size * (axis_size() / buffer_size); |
81 | } else { |
82 | group_size = subgroup_size; |
83 | } |
84 | |
85 | lws[0] = group_size; |
86 | lws[1] = lws[2] = 1; |
87 | gws[0] = utils::array_product(&src_md()->dims[0], ndims() - 1) |
88 | * group_size; |
89 | gws[1] = gws[2] = 1; |
90 | |
91 | auto src_padded_dims = src_d.padded_dims(); |
92 | mb = src_padded_dims[0] * src_padded_dims[2] * src_padded_dims[3] |
93 | & src_padded_dims[4]; |
94 | return status::success; |
95 | } |
96 | |
97 | bool is_nhwc = false; |
98 | bool is_blocked = false; |
99 | size_t gws[3] = {}; |
100 | size_t lws[3] = {}; |
101 | size_t block[3] = {}; |
102 | size_t group_size = 0; |
103 | size_t mb = 0; |
104 | const int subgroup_size = 16; |
105 | // 8x16 load and store commands (Vector_Size x Sub_Group_Size) |
106 | const int buffer_size = 128; |
107 | }; |
108 | |
109 | status_t init(engine_t *engine) override { |
110 | if (pd()->has_zero_dim_memory()) return status::success; |
111 | |
112 | compute::kernel_ctx_t kernel_ctx; |
113 | |
114 | kernel_ctx.define_int("SOFTMAX_AXIS_IDX" , pd()->axis()); |
115 | kernel_ctx.define_int("SOFTMAX_AXIS_SIZE" , pd()->axis_size()); |
116 | kernel_ctx.define_int("SOFTMAX_BUF" , pd()->buffer_size); |
117 | kernel_ctx.define_int("GROUP_SIZE" , pd()->group_size); |
118 | kernel_ctx.define_int("SUB_GROUP_SIZE" , pd()->subgroup_size); |
119 | kernel_ctx.define_int("MB" , pd()->mb); |
120 | kernel_ctx.define_int("OC_PADDED" , pd()->src_md()->padded_dims[1]); |
121 | kernel_ctx.define_int("OC" , |
122 | pd()->is_blocked ? pd()->subgroup_size |
123 | : pd()->src_md(0)->padded_dims[1]); |
124 | kernel_ctx.define_int("IS_NHWC" , pd()->is_nhwc); |
125 | kernel_ctx.define_int("IS_BLOCKED" , pd()->is_blocked); |
126 | kernel_ctx.define_int("IS_FWD" , 1); |
127 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
128 | kernel_ctx.define_int("LOGSOFTMAX" , pd()->is_logsoftmax()); |
129 | kernel_ctx.define_int("WITH_SRC_SCALES" , |
130 | !pd()->attr()->scales_.get(DNNL_ARG_SRC).has_default_values()); |
131 | kernel_ctx.define_int("WITH_DST_SCALES" , |
132 | !pd()->attr()->scales_.get(DNNL_ARG_DST).has_default_values()); |
133 | |
134 | const memory_desc_wrapper dst_mdw(pd()->dst_md()); |
135 | const memory_desc_wrapper src_mdw(pd()->src_md()); |
136 | const auto dst_md_info = memory_desc_info_t::create(dst_mdw); |
137 | const auto src_md_info = memory_desc_info_t::create(src_mdw); |
138 | def_memory_desc_info(kernel_ctx, dst_md_info, "DST" ); |
139 | def_memory_desc_info(kernel_ctx, src_md_info, "SRC" ); |
140 | kernel_ctx.set_data_type(dst_mdw.data_type()); |
141 | set_offsets(kernel_ctx, pd()->dst_md(), "DATA" ); |
142 | |
143 | for (int i = 0; i < 3; ++i) |
144 | kernel_ctx.define_int(utils::format("BLOCK_%d" , i), pd()->block[i]); |
145 | |
146 | create_kernel(engine, &kernel_, "gen9_softmax_fwd" , kernel_ctx); |
147 | if (!kernel_) return status::runtime_error; |
148 | |
149 | return status::success; |
150 | } |
151 | |
152 | status_t execute(const exec_ctx_t &ctx) const override { |
153 | return execute_generic(ctx); |
154 | } |
155 | |
156 | protected: |
157 | status_t execute_generic(const exec_ctx_t &ctx) const; |
158 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
159 | compute::kernel_t kernel_; |
160 | }; |
161 | |
162 | struct gen9_softmax_bwd_t : public gpu_primitive_t { |
163 | using gpu_primitive_t::gpu_primitive_t; |
164 | struct pd_t : public gpu_softmax_bwd_pd_t { |
165 | using gpu_softmax_bwd_pd_t::gpu_softmax_bwd_pd_t; |
166 | |
167 | DECLARE_COMMON_PD_T("ocl:gen9" , gen9_softmax_bwd_t); |
168 | |
169 | status_t init(engine_t *engine) { |
170 | using namespace dnnl::impl::format_tag; |
171 | |
172 | auto *compute_engine |
173 | = utils::downcast<compute::compute_engine_t *>(engine); |
174 | |
175 | const memory_desc_wrapper diff_src_d(diff_src_md()); |
176 | const memory_desc_wrapper diff_dst_d(diff_dst_md()); |
177 | const memory_desc_wrapper dst_d(dst_md()); |
178 | |
179 | using namespace data_type; |
180 | bool ok = !is_fwd() && axis_size() % buffer_size == 0 |
181 | && !memory_desc_ndims_ok( |
182 | dst_md(), diff_src_md(), diff_dst_md()) |
183 | && axis() == diff_src_d.ndims() - 1 |
184 | && utils::one_of(diff_src_d.data_type(), f32, bf16) |
185 | && utils::one_of(diff_dst_d.data_type(), f32, bf16) |
186 | && compute_engine->mayiuse_sub_group(subgroup_size) |
187 | && attr()->has_default_values() |
188 | && set_default_formats() == status::success |
189 | && diff_dst_d.data_type() == dst_d.data_type(); |
190 | if (!ok) return status::unimplemented; |
191 | |
192 | is_nhwc = (diff_src_d.matches_one_of_tag(nwc, nhwc, ndhwc) |
193 | != format_tag::undef); |
194 | is_blk = (diff_src_d.matches_one_of_tag(nCw16c, nChw16c, nCdhw16c) |
195 | != format_tag::undef); |
196 | if (is_nhwc || is_blk) { |
197 | group_size = subgroup_size * (axis_size() / buffer_size); |
198 | } else { |
199 | group_size = subgroup_size; |
200 | } |
201 | lws[0] = group_size; |
202 | lws[1] = lws[2] = 1; |
203 | gws[0] = utils::array_product( |
204 | &diff_src_md(0)->padded_dims[0], ndims() - 1) |
205 | * group_size; |
206 | gws[1] = gws[2] = 1; |
207 | batches = diff_src_md(0)->padded_dims[0] |
208 | * diff_src_md(0)->padded_dims[2]; |
209 | return status::success; |
210 | } |
211 | |
212 | size_t gws[3] = {}; |
213 | size_t lws[3] = {}; |
214 | size_t block[3] = {}; |
215 | size_t group_size = 0; |
216 | size_t batches = 0; |
217 | bool is_nhwc = false; |
218 | bool is_blk = false; |
219 | const int subgroup_size = 16; |
220 | // 8x16 load and store commands (Vector_Size x Sub_Group_Size) |
221 | const int buffer_size = 128; |
222 | }; |
223 | |
224 | status_t init(engine_t *engine) override { |
225 | if (pd()->has_zero_dim_memory()) return status::success; |
226 | |
227 | compute::kernel_ctx_t kernel_ctx; |
228 | |
229 | kernel_ctx.define_int("SOFTMAX_AXIS_IDX" , pd()->axis()); |
230 | kernel_ctx.define_int("SOFTMAX_AXIS_SIZE" , pd()->axis_size()); |
231 | kernel_ctx.define_int("SOFTMAX_BUF" , pd()->buffer_size); |
232 | kernel_ctx.define_int("SUB_GROUP_SIZE" , pd()->subgroup_size); |
233 | kernel_ctx.define_int("GROUP_SIZE" , pd()->group_size); |
234 | kernel_ctx.define_int("IS_BWD" , 1); |
235 | kernel_ctx.define_int("IS_16C" , pd()->is_blk); |
236 | kernel_ctx.define_int("BATCH" , pd()->batches); |
237 | kernel_ctx.define_int("IC_WO_PADDING" , pd()->diff_src_md(0)->dims[1]); |
238 | kernel_ctx.define_int( |
239 | "IC_PADDED" , pd()->diff_src_md(0)->padded_dims[1]); |
240 | kernel_ctx.define_int("IC" , |
241 | pd()->is_blk ? pd()->subgroup_size |
242 | : pd()->diff_src_md(0)->padded_dims[1]); |
243 | kernel_ctx.define_int("IS_NHWC" , pd()->is_nhwc); |
244 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
245 | kernel_ctx.define_int("LOGSOFTMAX" , pd()->is_logsoftmax()); |
246 | |
247 | const memory_desc_wrapper diff_src_mdw(pd()->diff_src_md()); |
248 | const memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md()); |
249 | const auto diff_src_md_info = memory_desc_info_t::create(diff_src_mdw); |
250 | const auto diff_dst_md_info = memory_desc_info_t::create(diff_dst_mdw); |
251 | def_memory_desc_info(kernel_ctx, diff_src_md_info, "SRC" ); |
252 | def_memory_desc_info(kernel_ctx, diff_dst_md_info, "DST" ); |
253 | kernel_ctx.set_data_type(pd()->diff_src_md()->data_type); |
254 | set_offsets(kernel_ctx, *pd()->diff_src_md(), "DATA" ); |
255 | |
256 | for (int i = 0; i < 3; ++i) |
257 | kernel_ctx.define_int(utils::format("BLOCK_%d" , i), pd()->block[i]); |
258 | |
259 | create_kernel(engine, &kernel_, "gen9_softmax_bwd" , kernel_ctx); |
260 | if (!kernel_) return status::runtime_error; |
261 | |
262 | return status::success; |
263 | } |
264 | |
265 | status_t execute(const exec_ctx_t &ctx) const override { |
266 | return execute_generic(ctx); |
267 | } |
268 | |
269 | protected: |
270 | status_t execute_generic(const exec_ctx_t &ctx) const; |
271 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
272 | compute::kernel_t kernel_; |
273 | }; |
274 | } // namespace ocl |
275 | } // namespace gpu |
276 | } // namespace impl |
277 | } // namespace dnnl |
278 | |
279 | #endif |
280 | |