1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_GLOBAL_POOLING_HPP
18#define GPU_OCL_GEN9_GLOBAL_POOLING_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/gpu_pooling_pd.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/ocl/ocl_stream.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct gen9_global_pooling_fwd_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_pooling_fwd_pd_t {
38 pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr,
39 const pooling_fwd_pd_t *hint_fwd_pd)
40 : gpu_pooling_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
41
42 DECLARE_COMMON_PD_T("ocl:gen9_global:any", gen9_global_pooling_fwd_t);
43
44 status_t init(engine_t *engine) {
45 using namespace data_type;
46 using namespace prop_kind;
47 using namespace alg_kind;
48
49 bool ok = set_default_params() == status::success
50 && utils::one_of(desc()->prop_kind, forward_training,
51 forward_inference)
52 && utils::one_of(desc()->alg_kind, pooling_max,
53 pooling_avg_include_padding,
54 pooling_avg_exclude_padding)
55 && (utils::everyone_is(data_type::f32, src_md()->data_type,
56 dst_md()->data_type)
57 || utils::everyone_is(data_type::bf16,
58 src_md()->data_type, dst_md()->data_type))
59 && attr()->has_default_values();
60 if (!ok) return status::unimplemented;
61
62 bool is_training = desc_.prop_kind == forward_training;
63 if (desc()->alg_kind == pooling_max && is_training)
64 init_default_ws(s32);
65
66 return init_conf(engine);
67 }
68
69 status_t init_conf(engine_t *engine);
70 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
71
72 pool_conf_t conf;
73 offsets_t off;
74 };
75
76 status_t init(engine_t *engine) override {
77 compute::kernel_ctx_t kernel_ctx;
78 status_t status = pd()->init_kernel_ctx(kernel_ctx);
79 CHECK(status);
80
81 create_kernel(engine, &kernel_, "gen9_global_pooling_fwd", kernel_ctx);
82 if (!kernel_) return status::runtime_error;
83
84 return status::success;
85 }
86
87 status_t execute(const exec_ctx_t &ctx) const override {
88 return execute_forward(ctx);
89 }
90
91private:
92 status_t execute_forward(const exec_ctx_t &ctx) const;
93 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
94 compute::kernel_t kernel_;
95};
96
97struct gen9_global_pooling_bwd_t : public gpu_primitive_t {
98 using gpu_primitive_t::gpu_primitive_t;
99 struct pd_t : public gpu_pooling_bwd_pd_t {
100 pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr,
101 const pooling_fwd_pd_t *hint_fwd_pd)
102 : gpu_pooling_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
103
104 DECLARE_COMMON_PD_T("ocl:gen9_global:any", gen9_global_pooling_bwd_t);
105
106 status_t init(engine_t *engine) {
107 using namespace prop_kind;
108 using namespace alg_kind;
109
110 bool ok = set_default_params() == status::success
111 && utils::one_of(desc()->prop_kind, backward_data)
112 && utils::one_of(desc()->alg_kind, pooling_max,
113 pooling_avg_include_padding,
114 pooling_avg_exclude_padding)
115 && (utils::everyone_is(data_type::f32,
116 diff_dst_md()->data_type,
117 diff_src_md()->data_type)
118 || utils::everyone_is(data_type::bf16,
119 diff_dst_md()->data_type,
120 diff_src_md()->data_type))
121 && attr()->has_default_values();
122 if (!ok) return status::unimplemented;
123
124 if (desc()->alg_kind == pooling_max) {
125 init_default_ws(data_type::s32);
126 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
127 }
128
129 return init_conf(engine);
130 }
131
132 status_t init_conf(engine_t *engine);
133 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
134
135 pool_conf_t conf;
136 offsets_t off;
137 };
138
139 status_t init(engine_t *engine) override {
140 compute::kernel_ctx_t kernel_ctx;
141 status_t status = pd()->init_kernel_ctx(kernel_ctx);
142 CHECK(status);
143
144 create_kernel(engine, &kernel_, "gen9_global_pooling_bwd", kernel_ctx);
145 if (!kernel_) return status::runtime_error;
146
147 return status::success;
148 }
149
150 status_t execute(const exec_ctx_t &ctx) const override {
151 return execute_backward(ctx);
152 }
153
154private:
155 status_t execute_backward(const exec_ctx_t &ctx) const;
156 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
157 compute::kernel_t kernel_;
158};
159
160} // namespace ocl
161} // namespace gpu
162} // namespace impl
163} // namespace dnnl
164
165#endif
166