1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_pooling.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace ocl {
23
24static status_t init_conf_common(pool_conf_t &conf, offsets_t &off,
25 const pooling_pd_t *pd, engine_t *engine, const bool is_bwd) {
26 using namespace dnnl::impl::format_tag;
27
28 const memory_desc_wrapper src_mdw(pd->invariant_src_md());
29 const memory_desc_wrapper dst_mdw(pd->invariant_dst_md());
30
31 set_default_pool_conf(conf, *pd->desc(), *pd->invariant_src_md(),
32 *pd->invariant_dst_md(), *pd->attr());
33
34 set_offsets(src_mdw, off.src_off);
35 set_offsets(dst_mdw, off.dst_off);
36
37 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
38 conf.dispatch = compute_engine->create_dispatch(
39 conf.is_backward ? src_mdw.md_ : dst_mdw.md_);
40 conf.dispatch.define_dim("MB", 0, conf.mb_padded);
41 conf.dispatch.define_dim("OC", 1, conf.c_padded);
42 int ndims = conf.ndims;
43 if (is_bwd) {
44 conf.dispatch.define_dim("ID", nstl::max(1, ndims - 3), conf.id);
45 conf.dispatch.define_dim("IH", nstl::max(1, ndims - 2), conf.ih);
46 conf.dispatch.define_dim("IW", nstl::max(1, ndims - 1), conf.iw);
47 } else {
48 conf.dispatch.define_dim("OD", nstl::max(1, ndims - 3), conf.od);
49 conf.dispatch.define_dim("OH", nstl::max(1, ndims - 2), conf.oh);
50 conf.dispatch.define_dim("OW", nstl::max(1, ndims - 1), conf.ow);
51 }
52 conf.dispatch.generate();
53
54 conf.attr_info = attr_info_t::create(pd->attr());
55
56 return status::success;
57};
58
59static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
60 const pool_conf_t &conf, const offsets_t &off,
61 const post_ops_t &post_ops) {
62 using namespace dnnl::impl::alg_kind;
63 kernel_ctx.set_data_type(conf.src_dt);
64
65 kernel_ctx.define_int("SUB_GROUP_SIZE", 1);
66 kernel_ctx.define_int("NDIMS", conf.ndims);
67 kernel_ctx.define_int("OC_WO_PADDING", conf.c);
68 kernel_ctx.define_int("ID", conf.id);
69 kernel_ctx.define_int("IH", conf.ih);
70 kernel_ctx.define_int("IW", conf.iw);
71 kernel_ctx.define_int("OD", conf.od);
72 kernel_ctx.define_int("OH", conf.oh);
73 kernel_ctx.define_int("OW", conf.ow);
74 kernel_ctx.define_int("KD", conf.kd);
75 kernel_ctx.define_int("KH", conf.kh);
76 kernel_ctx.define_int("KW", conf.kw);
77 kernel_ctx.define_int("DD", conf.dd);
78 kernel_ctx.define_int("DH", conf.dh);
79 kernel_ctx.define_int("DW", conf.dw);
80 kernel_ctx.define_int("SD", conf.stride_d);
81 kernel_ctx.define_int("SH", conf.stride_h);
82 kernel_ctx.define_int("SW", conf.stride_w);
83 kernel_ctx.define_int("PD", conf.f_pad);
84 kernel_ctx.define_int("PH", conf.t_pad);
85 kernel_ctx.define_int("PW", conf.l_pad);
86 kernel_ctx.define_int("IS_TRAINING", conf.is_training);
87 kernel_ctx.define_int("IS_BWD", conf.is_backward);
88 kernel_ctx.define_int("IS_FWD", !conf.is_backward);
89
90 kernel_ctx.define_int("ALG_MAX", (conf.alg == pooling_max));
91 kernel_ctx.define_int(
92 "ALG_AVG_NP", (conf.alg == pooling_avg_exclude_padding));
93 kernel_ctx.define_int(
94 "ALG_AVG_P", (conf.alg == pooling_avg_include_padding));
95
96 def_attr_info(kernel_ctx, conf.attr_info, post_ops);
97
98 def_offsets(off.src_off, kernel_ctx, "SRC", conf.ndims);
99 def_offsets(off.dst_off, kernel_ctx, "DST", conf.ndims);
100
101 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
102 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
103
104 def_dispatch(kernel_ctx, conf.dispatch);
105
106 return status::success;
107}
108
109status_t ref_pooling_fwd_t::pd_t::init_conf(engine_t *engine) {
110 return init_conf_common(conf, off, this, engine, false);
111}
112
113status_t ref_pooling_fwd_t::pd_t::init_kernel_ctx(
114 compute::kernel_ctx_t &kernel_ctx) const {
115 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
116}
117
118status_t ref_pooling_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
119 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
120 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
121 auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE);
122
123 compute::kernel_arg_list_t arg_list;
124 arg_list.set(0, src);
125 arg_list.set(1, ws);
126 arg_list.set(2, dst);
127 append_post_ops_to_arg_list(ctx, arg_list, 3, pd()->attr()->post_ops_);
128
129 auto nd_range = pd()->conf.dispatch.nd_range();
130
131 return parallel_for(ctx, nd_range, kernel_, arg_list);
132}
133
134status_t ref_pooling_bwd_t::pd_t::init_conf(engine_t *engine) {
135 return init_conf_common(conf, off, this, engine, true);
136}
137
138status_t ref_pooling_bwd_t::pd_t::init_kernel_ctx(
139 compute::kernel_ctx_t &kernel_ctx) const {
140 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
141}
142
143status_t ref_pooling_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
144 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
145 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
146 auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE);
147
148 compute::kernel_arg_list_t arg_list;
149 arg_list.set(0, diff_src);
150 arg_list.set(1, ws);
151 arg_list.set(2, diff_dst);
152
153 auto nd_range = pd()->conf.dispatch.nd_range();
154
155 return parallel_for(ctx, nd_range, kernel_, arg_list);
156}
157
158} // namespace ocl
159} // namespace gpu
160} // namespace impl
161} // namespace dnnl
162