1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_CONVOLUTION_HPP
18#define GPU_OCL_GEN9_CONVOLUTION_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_convolution_pd.hpp"
26#include "gpu/gpu_eltwise_pd.hpp"
27#include "gpu/gpu_primitive.hpp"
28#include "gpu/gpu_resource.hpp"
29#include "gpu/ocl/ocl_stream.hpp"
30#include "gpu/ocl/ocl_utils.hpp"
31#include "gpu/primitive_conf.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace gpu {
36namespace ocl {
37
38struct gen9_convolution_fwd_t : public gpu_primitive_t {
39 using gpu_primitive_t::gpu_primitive_t;
40 struct pd_t : public gpu_convolution_fwd_pd_t {
41 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
42 const convolution_fwd_pd_t *hint_fwd_pd)
43 : gpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
44
45 DECLARE_COMMON_PD_T("ocl:gen9:blocked", gen9_convolution_fwd_t);
46
47 status_t init(engine_t *engine) {
48 using namespace prop_kind;
49 using namespace data_type;
50 assert(engine->kind() == engine_kind::gpu);
51 auto *compute_engine
52 = utils::downcast<compute::compute_engine_t *>(engine);
53
54 auto src_data_t = this->desc()->src_desc.data_type;
55 auto dst_data_t = this->desc()->dst_desc.data_type;
56
57 const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops;
58
59 bool ok = set_default_alg_kind(alg_kind::convolution_direct)
60 && utils::one_of(this->desc()->prop_kind, forward_training,
61 forward_inference)
62 && this->desc()->alg_kind == alg_kind::convolution_direct
63 && utils::one_of(true,
64 expect_data_types(f32, f32, f32, f32, f32),
65 expect_data_types(f32, f32, f32, s8, f32),
66 expect_data_types(f16, f16, f16, s8, f32),
67 expect_data_types(f16, f16, f16, f16, f32))
68 && compute_engine->mayiuse(
69 compute::device_ext_t::intel_subgroups)
70 && IMPLICATION(src_data_t == f16,
71 true
72 && compute_engine->mayiuse(
73 compute::device_ext_t::khr_fp16)
74 && compute_engine->mayiuse(
75 compute::device_ext_t::
76 intel_subgroups_short))
77 && !has_zero_dim_memory()
78 && attr()->has_default_values(attr_skip_mask, dst_data_t)
79 && post_ops_with_binary_ok(attr(), dst_md()->data_type);
80 if (!ok) return status::unimplemented;
81
82 CHECK(init_conf(engine));
83
84 if (!compute_engine->mayiuse_sub_group(conf.sub_group_size))
85 return status::unimplemented;
86
87 ok = set_default_formats_common(
88 conf.src_tag, conf.wei_tag, conf.dst_tag);
89 if (!ok) return status::unimplemented;
90
91 CHECK(attr_.set_default_formats(dst_md(0)));
92
93 return status::success;
94 }
95
96 status_t init_conf(engine_t *engine);
97 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
98
99 conv_conf_t conf;
100 };
101
102 status_t init(engine_t *engine) override {
103 const char *kernel_name = nullptr;
104
105 if (pd()->conf.is_nhwc
106 && utils::one_of(pd()->conf.src_data_type, data_type::f32,
107 data_type::f16)) {
108 kernel_name = "gen9_conv_nhwc_fwd";
109
110 } else if (pd()->conf.is_depthwise) {
111 kernel_name = "gen9_conv_dw_fwd";
112 } else if (utils::one_of(pd()->desc()->src_desc.data_type,
113 data_type::f16, data_type::f32)) {
114 kernel_name = "gen9_conv_fwd";
115 } else {
116 assert(!"not expected");
117 }
118
119 compute::kernel_ctx_t kernel_ctx;
120 status_t status = pd()->init_kernel_ctx(kernel_ctx);
121 if (status != status::success) return status;
122
123 create_kernel(engine, &kernel_, kernel_name, kernel_ctx);
124 if (!kernel_) return status::runtime_error;
125
126 return status::success;
127 }
128
129 status_t execute(const exec_ctx_t &ctx) const override {
130 return execute_forward(ctx);
131 }
132
133private:
134 status_t execute_forward(const exec_ctx_t &ctx) const;
135 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
136 compute::kernel_t kernel_;
137};
138
139struct gen9_convolution_bwd_data_t : public gpu_primitive_t {
140 using gpu_primitive_t::gpu_primitive_t;
141 struct pd_t : public gpu_convolution_bwd_data_pd_t {
142 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
143 const convolution_fwd_pd_t *hint_fwd_pd)
144 : gpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
145
146 DECLARE_COMMON_PD_T("ocl:ncsp:any", gen9_convolution_bwd_data_t);
147
148 status_t init(engine_t *engine) {
149 using namespace data_type;
150 using namespace prop_kind;
151 assert(engine->kind() == engine_kind::gpu);
152 auto *compute_engine
153 = utils::downcast<compute::compute_engine_t *>(engine);
154
155 bool ok = set_default_alg_kind(alg_kind::convolution_direct)
156 && this->desc()->prop_kind == backward_data
157 && this->desc()->alg_kind == alg_kind::convolution_direct
158 && utils::one_of(true,
159 expect_data_types(
160 f32, f32, data_type::undef, f32, f32),
161 expect_data_types(f16, f16, data_type::undef, f16,
162 data_type::undef))
163 && IMPLICATION(this->with_bias()
164 && this->desc()->diff_dst_desc.data_type
165 != f16,
166 this->desc()->bias_desc.data_type == f32)
167 && IMPLICATION(this->with_bias()
168 && this->desc()->diff_dst_desc.data_type
169 == f16,
170 this->desc()->bias_desc.data_type == f16)
171 && compute_engine->mayiuse(
172 compute::device_ext_t::intel_subgroups)
173 && !has_zero_dim_memory() && attr()->has_default_values();
174 if (!ok) return status::unimplemented;
175
176 CHECK(init_conf(engine));
177
178 if (!compute_engine->mayiuse_sub_group(conf.sub_group_size))
179 return status::unimplemented;
180
181 ok = set_default_formats_common(
182 conf.src_tag, conf.wei_tag, conf.dst_tag);
183 return ok ? status::success : status::unimplemented;
184 }
185
186 status_t init_conf(engine_t *engine);
187 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
188
189 conv_conf_t conf;
190 };
191
192 status_t init(engine_t *engine) override {
193 const char *kernel_name = nullptr;
194 if (pd()->conf.is_depthwise) {
195 kernel_name = "gen9_conv_dw_bwd_data";
196 } else {
197 if (pd()->conf.is_nhwc)
198 kernel_name = "gen9_conv_nhwc_bwd_data";
199 else
200 kernel_name = "gen9_conv_bwd_data";
201 }
202
203 compute::kernel_ctx_t kernel_ctx;
204 status_t status = pd()->init_kernel_ctx(kernel_ctx);
205 if (status != status::success) return status;
206
207 create_kernel(engine, &kernel_, kernel_name, kernel_ctx);
208 if (!kernel_) return status::runtime_error;
209
210 return status::success;
211 }
212
213 status_t execute(const exec_ctx_t &ctx) const override {
214 return execute_backward_data(ctx);
215 }
216
217private:
218 status_t execute_backward_data(const exec_ctx_t &ctx) const;
219 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
220 compute::kernel_t kernel_;
221};
222
223struct gen9_convolution_bwd_weights_t : public gpu_primitive_t {
224 using gpu_primitive_t::gpu_primitive_t;
225 struct pd_t : public gpu_convolution_bwd_weights_pd_t {
226 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
227 const convolution_fwd_pd_t *hint_fwd_pd)
228 : gpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
229
230 pd_t(const pd_t &rhs) = default;
231
232 DECLARE_COMMON_PD_T("ocl:ncsp:any", gen9_convolution_bwd_weights_t);
233
234 status_t init(engine_t *engine) {
235 using namespace data_type;
236 using namespace prop_kind;
237 assert(engine->kind() == engine_kind::gpu);
238 auto *compute_engine
239 = utils::downcast<compute::compute_engine_t *>(engine);
240
241 bool ok = set_default_alg_kind(alg_kind::convolution_direct)
242 && this->desc()->prop_kind == backward_weights
243 && this->desc()->alg_kind == alg_kind::convolution_direct
244 && utils::one_of(this->desc()->diff_weights_desc.data_type,
245 f32, bf16)
246 && utils::one_of(
247 this->desc()->src_desc.data_type, f32, bf16)
248 && utils::one_of(
249 this->desc()->diff_dst_desc.data_type, f32, bf16)
250 && compute_engine->mayiuse(
251 compute::device_ext_t::intel_subgroups)
252 && compute_engine->mayiuse(
253 compute::device_ext_t::khr_int64_base_atomics)
254 && !has_zero_dim_memory() && attr()->has_default_values();
255 if (!ok) return status::unimplemented;
256
257 CHECK(init_conf(engine));
258 if (!compute_engine->mayiuse_sub_group(conf.sub_group_size))
259 return status::unimplemented;
260
261 if (!IMPLICATION(utils::one_of(bf16,
262 this->desc()->diff_weights_desc.data_type,
263 this->desc()->src_desc.data_type,
264 this->desc()->diff_dst_desc.data_type),
265 conf.ver == ver_1stconv))
266 return status::unimplemented;
267
268 init_scratchpad();
269 return status::success;
270 }
271
272 status_t init_conf(engine_t *engine);
273 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
274
275 conv_conf_t conf;
276 std::shared_ptr<primitive_desc_t> rpd_wei_;
277 std::shared_ptr<primitive_desc_t> rpd_bia_;
278
279 private:
280 status_t init_scratchpad();
281 };
282
283 status_t init(engine_t *engine) override {
284 const char *kernel_name;
285 if (pd()->conf.is_nhwc) {
286 kernel_name = "gen9_conv_nhwc_bwd_weights";
287 } else {
288 kernel_name = "gen9_conv_bwd_weights";
289 }
290 if (pd()->conf.reorder_wei) {
291 CHECK(create_nested_primitive(
292 wei_reorder_, pd()->rpd_wei_, engine));
293 }
294 if (pd()->conf.reorder_bias) {
295 CHECK(create_nested_primitive(
296 bia_reorder_, pd()->rpd_bia_, engine));
297 }
298 compute::kernel_ctx_t kernel_ctx;
299 status_t status = pd()->init_kernel_ctx(kernel_ctx);
300 if (status != status::success) return status;
301
302 create_kernel(engine, &kernel_, kernel_name, kernel_ctx);
303 if (!kernel_) return status::runtime_error;
304 return status::success;
305 }
306
307 status_t execute(const exec_ctx_t &ctx) const override {
308 return execute_backward_weights(ctx);
309 }
310
311private:
312 status_t execute_backward_weights(const exec_ctx_t &ctx) const;
313 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
314 compute::kernel_t kernel_;
315 std::shared_ptr<primitive_t> wei_reorder_;
316 std::shared_ptr<primitive_t> bia_reorder_;
317};
318
319} // namespace ocl
320} // namespace gpu
321} // namespace impl
322} // namespace dnnl
323
324#endif
325
326// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
327