1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_POOLING_HPP
18#define GPU_OCL_GEN9_POOLING_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "common/utils.hpp"
23#include "gpu/compute/compute.hpp"
24#include "gpu/gpu_pooling_pd.hpp"
25#include "gpu/gpu_primitive.hpp"
26#include "gpu/gpu_resource.hpp"
27#include "gpu/primitive_conf.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace gpu {
32namespace ocl {
33
34struct gen9_pooling_fwd_t : public gpu_primitive_t {
35 using gpu_primitive_t::gpu_primitive_t;
36 struct pd_t : public gpu_pooling_fwd_pd_t {
37 pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr,
38 const pooling_fwd_pd_t *hint_fwd_pd)
39 : gpu_pooling_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
40
41 DECLARE_COMMON_PD_T("ocl:gen9", gen9_pooling_fwd_t);
42
43 status_t init(engine_t *engine) {
44 using namespace data_type;
45 using namespace prop_kind;
46 using namespace alg_kind;
47 auto *compute_engine
48 = utils::downcast<compute::compute_engine_t *>(engine);
49 auto src_data_t = src_md()->data_type;
50 auto dst_data_t = dst_md()->data_type;
51 auto acc_data_t = desc()->accum_data_type;
52
53 bool ok = set_default_params() == status::success
54 && utils::one_of(desc()->prop_kind, forward_training,
55 forward_inference)
56 && utils::one_of(desc()->alg_kind, pooling_max,
57 pooling_avg_include_padding,
58 pooling_avg_exclude_padding)
59 && (utils::everyone_is(
60 f32, src_data_t, dst_data_t, acc_data_t)
61 || utils::everyone_is(f16, src_data_t, dst_data_t)
62 || utils::everyone_is(bf16, src_data_t, dst_data_t)
63 || utils::everyone_is(u8, src_data_t, dst_data_t)
64 || utils::everyone_is(s8, src_data_t, dst_data_t))
65 && IMPLICATION(utils::one_of(src_data_t, f16, s8, u8),
66 desc()->prop_kind == forward_inference)
67 && post_ops_with_binary_ok(attr(), dst_md()->data_type)
68 && attr_.set_default_formats(dst_md(0)) == status::success
69 && !is_dilated()
70 && compute_engine->mayiuse(
71 compute::device_ext_t::intel_subgroups)
72 && IMPLICATION(src_data_t == f16,
73 compute_engine->mayiuse(
74 compute::device_ext_t::khr_fp16)
75 && compute_engine->mayiuse(
76 compute::device_ext_t::
77 intel_subgroups_short));
78 if (!ok) return status::unimplemented;
79
80 bool is_training = desc()->prop_kind == forward_training;
81 if (desc()->alg_kind == pooling_max && is_training)
82 init_default_ws(s32);
83
84 return init_conf(engine);
85 }
86
87 status_t init_conf(engine_t *engine);
88 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
89
90 pool_conf_t conf;
91 offsets_t off;
92 };
93
94 status_t init(engine_t *engine) override {
95 compute::kernel_ctx_t kernel_ctx;
96 status_t status = pd()->init_kernel_ctx(kernel_ctx);
97 CHECK(status);
98
99 create_kernel(engine, &kernel_, "gen9_pooling_fwd", kernel_ctx);
100 if (!kernel_) return status::runtime_error;
101
102 return status::success;
103 }
104
105 status_t execute(const exec_ctx_t &ctx) const override {
106 return execute_forward(ctx);
107 }
108
109private:
110 status_t execute_forward(const exec_ctx_t &ctx) const;
111 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
112 compute::kernel_t kernel_;
113};
114
115struct gen9_pooling_bwd_t : public gpu_primitive_t {
116 using gpu_primitive_t::gpu_primitive_t;
117 struct pd_t : public gpu_pooling_bwd_pd_t {
118 pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr,
119 const pooling_fwd_pd_t *hint_fwd_pd)
120 : gpu_pooling_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
121
122 DECLARE_COMMON_PD_T("ocl:gen9:any", gen9_pooling_bwd_t);
123
124 status_t init(engine_t *engine) {
125 using namespace prop_kind;
126 using namespace alg_kind;
127 auto *compute_engine
128 = utils::downcast<compute::compute_engine_t *>(engine);
129
130 bool ok = set_default_params() == status::success
131 && utils::one_of(desc()->prop_kind, backward_data)
132 && utils::one_of(desc()->alg_kind, pooling_max,
133 pooling_avg_include_padding,
134 pooling_avg_exclude_padding)
135 && (utils::everyone_is(data_type::f32,
136 diff_dst_md()->data_type,
137 diff_src_md()->data_type)
138 || utils::everyone_is(data_type::bf16,
139 diff_dst_md()->data_type,
140 diff_src_md()->data_type))
141 && attr()->has_default_values() && !is_dilated()
142 && compute_engine->mayiuse(
143 compute::device_ext_t::intel_subgroups);
144 if (!ok) return status::unimplemented;
145
146 if (desc()->alg_kind == pooling_max) {
147 init_default_ws(data_type::s32);
148 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
149 }
150
151 return init_conf(engine);
152 }
153
154 status_t init_conf(engine_t *engine);
155 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
156
157 pool_conf_t conf;
158 offsets_t off;
159 };
160
161 status_t init(engine_t *engine) override {
162 compute::kernel_ctx_t kernel_ctx;
163 status_t status = pd()->init_kernel_ctx(kernel_ctx);
164 CHECK(status);
165
166 create_kernel(engine, &kernel_, "gen9_pooling_bwd", kernel_ctx);
167 if (!kernel_) return status::runtime_error;
168
169 return status::success;
170 }
171
172 status_t execute(const exec_ctx_t &ctx) const override {
173 return execute_backward(ctx);
174 }
175
176private:
177 status_t execute_backward(const exec_ctx_t &ctx) const;
178 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
179
180 compute::kernel_t kernel_;
181};
182
183} // namespace ocl
184} // namespace gpu
185} // namespace impl
186} // namespace dnnl
187
188#endif
189