1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gen9_global_pooling.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace ocl {
23
24int calculate_spatial_chunk(const pool_conf_t &conf, engine_t *engine) {
25 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
26 const int hw_threads = compute_engine->device_info()->hw_threads();
27 const bool is_xe_hp_plus = compute_engine->is_xe_hp()
28 || compute_engine->is_xe_hpg() || compute_engine->is_xe_hpc();
29
30 const int spatial_dim = conf.id * conf.ih * conf.iw;
31 int chunk_size = spatial_dim;
32
33 // Experimentally selected values for XeHP family
34 const int desired_wi_per_thread = is_xe_hp_plus && conf.is_plain ? 1024 : 4;
35
36 const auto get_work_items_num = [&]() {
37 return conf.c * conf.mb * utils::div_up(spatial_dim, chunk_size);
38 };
39 while (get_work_items_num() < hw_threads * desired_wi_per_thread
40 && chunk_size > 1) {
41 chunk_size = utils::div_up(chunk_size, 2);
42 }
43 return chunk_size;
44}
45
46static status_t init_conf_common(pool_conf_t &conf, offsets_t &off,
47 const pooling_pd_t *pd, engine_t *engine) {
48 using namespace dnnl::impl::format_tag;
49
50 set_default_pool_conf(conf, *pd->desc(), *pd->invariant_src_md(),
51 *pd->invariant_dst_md(), *pd->attr());
52
53 if (conf.id != conf.kd || conf.iw != conf.kw || conf.ih != conf.kh
54 || conf.od * conf.ow * conf.oh != 1)
55 return status::unimplemented;
56
57 const memory_desc_wrapper src_mdw(pd->invariant_src_md());
58 const memory_desc_wrapper dst_mdw(pd->invariant_dst_md());
59 const auto &padded_src_dims = src_mdw.padded_dims();
60 const auto &padded_dst_dims = dst_mdw.padded_dims();
61 if (utils::array_product(padded_src_dims + 2, conf.ndims - 2)
62 != conf.id * conf.ih * conf.iw
63 || utils::array_product(padded_dst_dims + 2, conf.ndims - 2)
64 != conf.od * conf.oh * conf.ow)
65 return status::unimplemented;
66 if (!conf.is_backward) {
67 // gen9_global_pooling_fwd doesn't support zero padding.
68 if (conf.mb != conf.mb_padded || conf.c != conf.c_padded)
69 return status::unimplemented;
70 // heuristics: for small shapes, gen9_pooling_fwd provides better perf.
71 if (conf.kd * conf.kh * conf.kw < 128) return status::unimplemented;
72 }
73
74 set_offsets(src_mdw, off.src_off);
75 set_offsets(dst_mdw, off.dst_off);
76
77 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
78
79 conf.is_plain = src_mdw.is_plain();
80 conf.global_pool_spatial_chunk = calculate_spatial_chunk(conf, engine);
81
82 const int spatial_dim_padded = utils::rnd_up(
83 conf.id * conf.ih * conf.iw, conf.global_pool_spatial_chunk);
84 conf.dispatch = compute_engine->create_dispatch(src_mdw.md_);
85 conf.dispatch.define_dim("MB", 0, conf.mb_padded);
86 conf.dispatch.define_dim("C", 1, conf.c_padded);
87 if (conf.is_backward) {
88 conf.dispatch.define_dim("SPATIAL", 2, spatial_dim_padded,
89 conf.global_pool_spatial_chunk);
90 }
91 conf.dispatch.generate();
92
93 conf.attr_info = attr_info_t::create(pd->attr());
94
95 return status::success;
96};
97
98static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
99 const pool_conf_t &conf, const offsets_t &off,
100 const post_ops_t &post_ops) {
101 using namespace dnnl::impl::alg_kind;
102 kernel_ctx.set_data_type(conf.src_dt);
103
104 kernel_ctx.define_int("NDIMS", conf.ndims);
105 kernel_ctx.define_int("MB", conf.mb);
106 kernel_ctx.define_int("C", conf.c);
107 kernel_ctx.define_int("ID", conf.id);
108 kernel_ctx.define_int("IH", conf.ih);
109 kernel_ctx.define_int("IW", conf.iw);
110 kernel_ctx.define_int("SPATIAL_DIM", conf.id * conf.ih * conf.iw);
111 kernel_ctx.define_int("SPATIAL_CHUNK", conf.global_pool_spatial_chunk);
112 kernel_ctx.define_int("IS_TRAINING", conf.is_training);
113 kernel_ctx.define_int("IS_BWD", conf.is_backward);
114 kernel_ctx.define_int("IS_FWD", !conf.is_backward);
115
116 kernel_ctx.define_int("ALG_MAX", (conf.alg == pooling_max));
117 kernel_ctx.define_int(
118 "ALG_AVG_NP", (conf.alg == pooling_avg_exclude_padding));
119 kernel_ctx.define_int(
120 "ALG_AVG_P", (conf.alg == pooling_avg_include_padding));
121 kernel_ctx.define_int("NEED_ZERO_PADDING",
122 (conf.mb != conf.mb_padded || conf.c != conf.c_padded));
123
124 def_attr_info(kernel_ctx, conf.attr_info, post_ops);
125
126 def_offsets(off.src_off, kernel_ctx, "SRC", conf.ndims);
127 def_offsets(off.dst_off, kernel_ctx, "DST", conf.ndims);
128
129 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
130 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
131
132 def_dispatch(kernel_ctx, conf.dispatch);
133
134 return status::success;
135}
136
137status_t gen9_global_pooling_fwd_t::pd_t::init_conf(engine_t *engine) {
138 return init_conf_common(conf, off, this, engine);
139}
140
141status_t gen9_global_pooling_fwd_t::pd_t::init_kernel_ctx(
142 compute::kernel_ctx_t &kernel_ctx) const {
143 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
144}
145
146status_t gen9_global_pooling_fwd_t::execute_forward(
147 const exec_ctx_t &ctx) const {
148 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
149 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
150 auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE);
151
152 compute::kernel_arg_list_t arg_list;
153 arg_list.set(0, src);
154 arg_list.set(1, ws);
155 arg_list.set(2, dst);
156
157 auto nd_range = pd()->conf.dispatch.nd_range();
158
159 return parallel_for(ctx, nd_range, kernel_, arg_list);
160}
161
162status_t gen9_global_pooling_bwd_t::pd_t::init_conf(engine_t *engine) {
163 return init_conf_common(conf, off, this, engine);
164}
165
166status_t gen9_global_pooling_bwd_t::pd_t::init_kernel_ctx(
167 compute::kernel_ctx_t &kernel_ctx) const {
168 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
169}
170
171status_t gen9_global_pooling_bwd_t::execute_backward(
172 const exec_ctx_t &ctx) const {
173 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
174 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
175 auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE);
176
177 compute::kernel_arg_list_t arg_list;
178 arg_list.set(0, diff_src);
179 arg_list.set(1, ws);
180 arg_list.set(2, diff_dst);
181
182 auto nd_range = pd()->conf.dispatch.nd_range();
183
184 return parallel_for(ctx, nd_range, kernel_, arg_list);
185}
186
187} // namespace ocl
188} // namespace gpu
189} // namespace impl
190} // namespace dnnl
191