1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_CONVOLUTION_HPP
18#define GPU_OCL_REF_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/gpu_primitive.hpp"
23
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_convolution_pd.hpp"
26#include "gpu/gpu_resource.hpp"
27#include "gpu/ocl/ocl_stream.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct ref_convolution_fwd_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_convolution_fwd_pd_t {
38 using gpu_convolution_fwd_pd_t::gpu_convolution_fwd_pd_t;
39
40 DECLARE_COMMON_PD_T("ocl:ref:any", ref_convolution_fwd_t);
41
42 status_t init(engine_t *engine) {
43 using namespace data_type;
44
45 const auto *compute_engine
46 = utils::downcast<compute::compute_engine_t *>(engine);
47
48 const auto attr_skip_mask
49 = primitive_attr_t::skip_mask_t::scales_runtime
50 | primitive_attr_t::skip_mask_t::zero_points_runtime
51 | primitive_attr_t::skip_mask_t::post_ops
52 | primitive_attr_t::skip_mask_t::sum_dt;
53
54 bool ok = set_default_alg_kind(alg_kind::convolution_direct)
55 && utils::one_of(desc()->prop_kind,
56 prop_kind::forward_training,
57 prop_kind::forward_inference)
58 && desc()->alg_kind == alg_kind::convolution_direct
59 && IMPLICATION(
60 utils::one_of(f16, src_md_.data_type,
61 weights_md_.data_type, dst_md_.data_type),
62 compute_engine->mayiuse(
63 compute::device_ext_t::khr_fp16))
64 && IMPLICATION(
65 utils::one_of(f64, src_md_.data_type,
66 weights_md_.data_type, dst_md_.data_type),
67 compute_engine->mayiuse(
68 compute::device_ext_t::khr_fp64)
69 && attr()->post_ops_.has_default_values())
70 && !memory_desc_ndims_ok(src_md(), weights_md(), dst_md())
71 && this->set_default_formats()
72 && attr()->has_default_values(
73 attr_skip_mask, dst_md_.data_type)
74 && attr()->post_ops_.check_sum_consistent_dt(
75 dst_md_.data_type, true)
76 && attr_.set_default_formats(dst_md(0)) == status::success
77 && post_ops_with_binary_ok(
78 attr(), dst_md()->data_type, 5, 0xffff)
79 && zero_points_ok(attr()) && arg_scales_ok()
80 && IMPLICATION(!attr()->scales_.has_default_values(),
81 utils::one_of(src_md_.data_type, s8, u8));
82 if (!ok) return status::unimplemented;
83
84 return init_conf(engine);
85 }
86
87 status_t init_conf(engine_t *engine);
88 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
89
90 conv_conf_t conf;
91
92 private:
93 bool set_default_formats() {
94 using namespace format_tag;
95 auto dat_tag = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
96 auto wei_tag = with_groups()
97 ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
98 : utils::pick(ndims() - 3, oiw, oihw, oidhw);
99 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
100 }
101 };
102
103 status_t init(engine_t *engine) override {
104 compute::kernel_ctx_t kernel_ctx;
105
106 auto status = pd()->init_kernel_ctx(kernel_ctx);
107 if (status != status::success) return status;
108
109 create_kernel(engine, &kernel_, "ref_convolution_fwd", kernel_ctx);
110 if (!kernel_) return status::runtime_error;
111
112 return status::success;
113 }
114
115 status_t execute(const exec_ctx_t &ctx) const override {
116 return execute_forward(ctx);
117 }
118
119private:
120 status_t execute_forward(const exec_ctx_t &ctx) const;
121 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
122 compute::kernel_t kernel_;
123};
124
125struct ref_convolution_bwd_data_t : public gpu_primitive_t {
126 using gpu_primitive_t::gpu_primitive_t;
127 struct pd_t : public gpu_convolution_bwd_data_pd_t {
128 using gpu_convolution_bwd_data_pd_t::gpu_convolution_bwd_data_pd_t;
129
130 DECLARE_COMMON_PD_T("ocl:ref:any", ref_convolution_bwd_data_t);
131
132 status_t init(engine_t *engine) {
133 using sm = primitive_attr_t::skip_mask_t;
134 const auto attr_skip_mask = sm::post_ops | sm::scales_runtime
135 | sm::zero_points_runtime;
136 using namespace data_type;
137 const auto *compute_engine
138 = utils::downcast<compute::compute_engine_t *>(engine);
139 bool ok = set_default_alg_kind(alg_kind::convolution_direct)
140 && desc()->prop_kind == prop_kind::backward_data
141 && desc()->alg_kind == alg_kind::convolution_direct
142 && !memory_desc_ndims_ok(diff_src_md(), diff_dst_md())
143 && this->set_default_formats()
144 && attr()->has_default_values(attr_skip_mask)
145 && post_ops_with_binary_ok(
146 attr(), dst_md()->data_type, ndims())
147 && zero_points_ok(attr()) && arg_scales_ok()
148 && IMPLICATION(utils::one_of(f64, diff_src_md()->data_type,
149 dst_md()->data_type),
150 compute_engine->mayiuse(
151 compute::device_ext_t::khr_fp64)
152 && attr()->post_ops_.has_default_values())
153 && attr_.set_default_formats(diff_src_md(0))
154 == status::success;
155 if (!ok) return status::unimplemented;
156
157 return init_conf(engine);
158 }
159
160 status_t init_conf(engine_t *engine);
161 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
162
163 conv_conf_t conf;
164
165 private:
166 bool set_default_formats() {
167 using namespace format_tag;
168 auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
169 auto wei_tag = with_groups()
170 ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
171 : utils::pick(ndims() - 3, oiw, oihw, oidhw);
172 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
173 }
174 };
175
176 status_t init(engine_t *engine) override {
177 compute::kernel_ctx_t kernel_ctx;
178
179 auto status = pd()->init_kernel_ctx(kernel_ctx);
180 if (status != status::success) return status;
181
182 create_kernel(engine, &kernel_, "ref_convolution_bwd_data", kernel_ctx);
183 if (!kernel_) return status::runtime_error;
184
185 return status::success;
186 }
187
188 status_t execute(const exec_ctx_t &ctx) const override {
189 return execute_backward_data(ctx);
190 }
191
192private:
193 status_t execute_backward_data(const exec_ctx_t &ctx) const;
194 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
195 compute::kernel_t kernel_;
196};
197
198struct ref_convolution_bwd_weights_t : public gpu_primitive_t {
199 using gpu_primitive_t::gpu_primitive_t;
200 struct pd_t : public gpu_convolution_bwd_weights_pd_t {
201 using gpu_convolution_bwd_weights_pd_t::
202 gpu_convolution_bwd_weights_pd_t;
203
204 DECLARE_COMMON_PD_T("ocl:ref:any", ref_convolution_bwd_weights_t);
205
206 status_t init(engine_t *engine) {
207 using namespace data_type;
208 const auto *compute_engine
209 = utils::downcast<compute::compute_engine_t *>(engine);
210
211 bool ok = set_default_alg_kind(alg_kind::convolution_direct)
212 && desc()->prop_kind == prop_kind::backward_weights
213 && desc()->alg_kind == alg_kind::convolution_direct
214 && !memory_desc_ndims_ok(src_md(), diff_dst_md())
215 && utils::one_of(
216 desc()->diff_weights_desc.data_type, f32, bf16, f64)
217 && utils::one_of(desc()->src_desc.data_type, f32, bf16, f64)
218 && utils::one_of(
219 desc()->diff_dst_desc.data_type, f32, bf16, f64)
220 && this->set_default_formats()
221 && attr()->has_default_values()
222 && IMPLICATION(
223 utils::one_of(f64, desc()->src_desc.data_type,
224 desc()->diff_dst_desc.data_type),
225 compute_engine->mayiuse(
226 compute::device_ext_t::khr_fp64)
227 && attr()->post_ops_.has_default_values());
228 if (!ok) return status::unimplemented;
229
230 return init_conf(engine);
231 }
232
233 status_t init_conf(engine_t *engine);
234 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
235
236 conv_conf_t conf;
237
238 private:
239 bool set_default_formats() {
240 using namespace format_tag;
241 auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
242 auto wei_tag = with_groups()
243 ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
244 : utils::pick(ndims() - 3, oiw, oihw, oidhw);
245 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
246 }
247 };
248
249 status_t init(engine_t *engine) override {
250 compute::kernel_ctx_t kernel_ctx;
251
252 auto status = pd()->init_kernel_ctx(kernel_ctx);
253 if (status != status::success) return status;
254
255 create_kernel(
256 engine, &kernel_, "ref_convolution_bwd_weights", kernel_ctx);
257 if (!kernel_) return status::runtime_error;
258
259 return status::success;
260 }
261
262 status_t execute(const exec_ctx_t &ctx) const override {
263 return execute_backward_weights(ctx);
264 }
265
266private:
267 status_t execute_backward_weights(const exec_ctx_t &ctx) const;
268 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
269 compute::kernel_t kernel_;
270};
271
272} // namespace ocl
273} // namespace gpu
274} // namespace impl
275} // namespace dnnl
276#endif
277