1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_POOLING_HPP
18#define GPU_OCL_REF_POOLING_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/gpu_pooling_pd.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/ocl/ocl_stream.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct ref_pooling_fwd_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_pooling_fwd_pd_t {
38 pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr,
39 const pooling_fwd_pd_t *hint_fwd_pd)
40 : gpu_pooling_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
41
42 DECLARE_COMMON_PD_T("ocl:ref", ref_pooling_fwd_t);
43
44 status_t init(engine_t *engine) {
45 using namespace data_type;
46 using namespace prop_kind;
47 using namespace alg_kind;
48 auto src_data_t = src_md()->data_type;
49 auto dst_data_t = dst_md()->data_type;
50 auto acc_data_t = desc()->accum_data_type;
51
52 const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops;
53
54 bool ok = set_default_params() == status::success
55 && utils::one_of(desc()->prop_kind, forward_training,
56 forward_inference)
57 && utils::one_of(desc()->alg_kind, pooling_max,
58 pooling_avg_include_padding,
59 pooling_avg_exclude_padding)
60 && IMPLICATION(utils::one_of(src_data_t, f16, s8, u8, s32),
61 desc()->prop_kind == forward_inference)
62 && IMPLICATION(src_data_t != dst_data_t,
63 desc()->prop_kind == forward_inference)
64 && IMPLICATION(src_data_t == bf16, src_data_t == dst_data_t)
65 && IMPLICATION(utils::one_of(src_data_t, s8, u8),
66 utils::one_of(dst_data_t, s8, u8, f16, f32))
67 && IMPLICATION(src_data_t == f16,
68 utils::one_of(dst_data_t, s8, u8, f16))
69 && IMPLICATION(src_data_t == f32,
70 utils::one_of(dst_data_t, s8, u8, f32))
71 && IMPLICATION(utils::one_of(f32, src_data_t, dst_data_t),
72 acc_data_t == f32)
73 && IMPLICATION(utils::one_of(src_data_t, s8, u8)
74 && dst_data_t != f32,
75 acc_data_t == s32)
76 && attr()->has_default_values(attr_skip_mask)
77 && post_ops_with_binary_ok(attr(), dst_md()->data_type, 5)
78 && attr_.set_default_formats(dst_md(0)) == status::success;
79 if (!ok) return status::unimplemented;
80
81 bool is_training = desc_.prop_kind == forward_training;
82 if (desc()->alg_kind == pooling_max && is_training)
83 init_default_ws(s32);
84
85 return init_conf(engine);
86 }
87
88 status_t init_conf(engine_t *engine);
89 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
90
91 pool_conf_t conf;
92 offsets_t off;
93 };
94
95 status_t init(engine_t *engine) override {
96 compute::kernel_ctx_t kernel_ctx;
97 status_t status = pd()->init_kernel_ctx(kernel_ctx);
98 CHECK(status);
99
100 create_kernel(engine, &kernel_, "ref_pooling_fwd", kernel_ctx);
101 if (!kernel_) return status::runtime_error;
102
103 return status::success;
104 }
105
106 status_t execute(const exec_ctx_t &ctx) const override {
107 return execute_forward(ctx);
108 }
109
110private:
111 status_t execute_forward(const exec_ctx_t &ctx) const;
112 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
113 compute::kernel_t kernel_;
114};
115
116struct ref_pooling_bwd_t : public gpu_primitive_t {
117 using gpu_primitive_t::gpu_primitive_t;
118 struct pd_t : public gpu_pooling_bwd_pd_t {
119 pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr,
120 const pooling_fwd_pd_t *hint_fwd_pd)
121 : gpu_pooling_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
122
123 DECLARE_COMMON_PD_T("ocl:ref:any", ref_pooling_bwd_t);
124
125 status_t init(engine_t *engine) {
126 using namespace prop_kind;
127 using namespace alg_kind;
128
129 bool ok = set_default_params() == status::success
130 && utils::one_of(desc()->prop_kind, backward_data)
131 && utils::one_of(desc()->alg_kind, pooling_max,
132 pooling_avg_include_padding,
133 pooling_avg_exclude_padding)
134 && (utils::everyone_is(data_type::f32,
135 diff_dst_md()->data_type,
136 diff_src_md()->data_type)
137 || utils::everyone_is(data_type::bf16,
138 diff_dst_md()->data_type,
139 diff_src_md()->data_type))
140 && attr()->has_default_values();
141 if (!ok) return status::unimplemented;
142
143 if (desc()->alg_kind == pooling_max) {
144 init_default_ws(data_type::s32);
145 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
146 }
147
148 return init_conf(engine);
149 }
150
151 status_t init_conf(engine_t *engine);
152 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
153
154 pool_conf_t conf;
155 offsets_t off;
156 };
157
158 status_t init(engine_t *engine) override {
159 compute::kernel_ctx_t kernel_ctx;
160 status_t status = pd()->init_kernel_ctx(kernel_ctx);
161 CHECK(status);
162
163 create_kernel(engine, &kernel_, "ref_pooling_bwd", kernel_ctx);
164 if (!kernel_) return status::runtime_error;
165
166 return status::success;
167 }
168
169 status_t execute(const exec_ctx_t &ctx) const override {
170 return execute_backward(ctx);
171 }
172
173private:
174 status_t execute_backward(const exec_ctx_t &ctx) const;
175 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
176 compute::kernel_t kernel_;
177};
178
179} // namespace ocl
180} // namespace gpu
181} // namespace impl
182} // namespace dnnl
183
184#endif
185
186// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
187