1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gen9_global_pooling.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | int calculate_spatial_chunk(const pool_conf_t &conf, engine_t *engine) { |
25 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
26 | const int hw_threads = compute_engine->device_info()->hw_threads(); |
27 | const bool is_xe_hp_plus = compute_engine->is_xe_hp() |
28 | || compute_engine->is_xe_hpg() || compute_engine->is_xe_hpc(); |
29 | |
30 | const int spatial_dim = conf.id * conf.ih * conf.iw; |
31 | int chunk_size = spatial_dim; |
32 | |
33 | // Experimentally selected values for XeHP family |
34 | const int desired_wi_per_thread = is_xe_hp_plus && conf.is_plain ? 1024 : 4; |
35 | |
36 | const auto get_work_items_num = [&]() { |
37 | return conf.c * conf.mb * utils::div_up(spatial_dim, chunk_size); |
38 | }; |
39 | while (get_work_items_num() < hw_threads * desired_wi_per_thread |
40 | && chunk_size > 1) { |
41 | chunk_size = utils::div_up(chunk_size, 2); |
42 | } |
43 | return chunk_size; |
44 | } |
45 | |
46 | static status_t init_conf_common(pool_conf_t &conf, offsets_t &off, |
47 | const pooling_pd_t *pd, engine_t *engine) { |
48 | using namespace dnnl::impl::format_tag; |
49 | |
50 | set_default_pool_conf(conf, *pd->desc(), *pd->invariant_src_md(), |
51 | *pd->invariant_dst_md(), *pd->attr()); |
52 | |
53 | if (conf.id != conf.kd || conf.iw != conf.kw || conf.ih != conf.kh |
54 | || conf.od * conf.ow * conf.oh != 1) |
55 | return status::unimplemented; |
56 | |
57 | const memory_desc_wrapper src_mdw(pd->invariant_src_md()); |
58 | const memory_desc_wrapper dst_mdw(pd->invariant_dst_md()); |
59 | const auto &padded_src_dims = src_mdw.padded_dims(); |
60 | const auto &padded_dst_dims = dst_mdw.padded_dims(); |
61 | if (utils::array_product(padded_src_dims + 2, conf.ndims - 2) |
62 | != conf.id * conf.ih * conf.iw |
63 | || utils::array_product(padded_dst_dims + 2, conf.ndims - 2) |
64 | != conf.od * conf.oh * conf.ow) |
65 | return status::unimplemented; |
66 | if (!conf.is_backward) { |
67 | // gen9_global_pooling_fwd doesn't support zero padding. |
68 | if (conf.mb != conf.mb_padded || conf.c != conf.c_padded) |
69 | return status::unimplemented; |
70 | // heuristics: for small shapes, gen9_pooling_fwd provides better perf. |
71 | if (conf.kd * conf.kh * conf.kw < 128) return status::unimplemented; |
72 | } |
73 | |
74 | set_offsets(src_mdw, off.src_off); |
75 | set_offsets(dst_mdw, off.dst_off); |
76 | |
77 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
78 | |
79 | conf.is_plain = src_mdw.is_plain(); |
80 | conf.global_pool_spatial_chunk = calculate_spatial_chunk(conf, engine); |
81 | |
82 | const int spatial_dim_padded = utils::rnd_up( |
83 | conf.id * conf.ih * conf.iw, conf.global_pool_spatial_chunk); |
84 | conf.dispatch = compute_engine->create_dispatch(src_mdw.md_); |
85 | conf.dispatch.define_dim("MB" , 0, conf.mb_padded); |
86 | conf.dispatch.define_dim("C" , 1, conf.c_padded); |
87 | if (conf.is_backward) { |
88 | conf.dispatch.define_dim("SPATIAL" , 2, spatial_dim_padded, |
89 | conf.global_pool_spatial_chunk); |
90 | } |
91 | conf.dispatch.generate(); |
92 | |
93 | conf.attr_info = attr_info_t::create(pd->attr()); |
94 | |
95 | return status::success; |
96 | }; |
97 | |
98 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
99 | const pool_conf_t &conf, const offsets_t &off, |
100 | const post_ops_t &post_ops) { |
101 | using namespace dnnl::impl::alg_kind; |
102 | kernel_ctx.set_data_type(conf.src_dt); |
103 | |
104 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
105 | kernel_ctx.define_int("MB" , conf.mb); |
106 | kernel_ctx.define_int("C" , conf.c); |
107 | kernel_ctx.define_int("ID" , conf.id); |
108 | kernel_ctx.define_int("IH" , conf.ih); |
109 | kernel_ctx.define_int("IW" , conf.iw); |
110 | kernel_ctx.define_int("SPATIAL_DIM" , conf.id * conf.ih * conf.iw); |
111 | kernel_ctx.define_int("SPATIAL_CHUNK" , conf.global_pool_spatial_chunk); |
112 | kernel_ctx.define_int("IS_TRAINING" , conf.is_training); |
113 | kernel_ctx.define_int("IS_BWD" , conf.is_backward); |
114 | kernel_ctx.define_int("IS_FWD" , !conf.is_backward); |
115 | |
116 | kernel_ctx.define_int("ALG_MAX" , (conf.alg == pooling_max)); |
117 | kernel_ctx.define_int( |
118 | "ALG_AVG_NP" , (conf.alg == pooling_avg_exclude_padding)); |
119 | kernel_ctx.define_int( |
120 | "ALG_AVG_P" , (conf.alg == pooling_avg_include_padding)); |
121 | kernel_ctx.define_int("NEED_ZERO_PADDING" , |
122 | (conf.mb != conf.mb_padded || conf.c != conf.c_padded)); |
123 | |
124 | def_attr_info(kernel_ctx, conf.attr_info, post_ops); |
125 | |
126 | def_offsets(off.src_off, kernel_ctx, "SRC" , conf.ndims); |
127 | def_offsets(off.dst_off, kernel_ctx, "DST" , conf.ndims); |
128 | |
129 | def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC" ); |
130 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
131 | |
132 | def_dispatch(kernel_ctx, conf.dispatch); |
133 | |
134 | return status::success; |
135 | } |
136 | |
137 | status_t gen9_global_pooling_fwd_t::pd_t::init_conf(engine_t *engine) { |
138 | return init_conf_common(conf, off, this, engine); |
139 | } |
140 | |
141 | status_t gen9_global_pooling_fwd_t::pd_t::init_kernel_ctx( |
142 | compute::kernel_ctx_t &kernel_ctx) const { |
143 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
144 | } |
145 | |
146 | status_t gen9_global_pooling_fwd_t::execute_forward( |
147 | const exec_ctx_t &ctx) const { |
148 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
149 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
150 | auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE); |
151 | |
152 | compute::kernel_arg_list_t arg_list; |
153 | arg_list.set(0, src); |
154 | arg_list.set(1, ws); |
155 | arg_list.set(2, dst); |
156 | |
157 | auto nd_range = pd()->conf.dispatch.nd_range(); |
158 | |
159 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
160 | } |
161 | |
162 | status_t gen9_global_pooling_bwd_t::pd_t::init_conf(engine_t *engine) { |
163 | return init_conf_common(conf, off, this, engine); |
164 | } |
165 | |
166 | status_t gen9_global_pooling_bwd_t::pd_t::init_kernel_ctx( |
167 | compute::kernel_ctx_t &kernel_ctx) const { |
168 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
169 | } |
170 | |
171 | status_t gen9_global_pooling_bwd_t::execute_backward( |
172 | const exec_ctx_t &ctx) const { |
173 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
174 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
175 | auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); |
176 | |
177 | compute::kernel_arg_list_t arg_list; |
178 | arg_list.set(0, diff_src); |
179 | arg_list.set(1, ws); |
180 | arg_list.set(2, diff_dst); |
181 | |
182 | auto nd_range = pd()->conf.dispatch.nd_range(); |
183 | |
184 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
185 | } |
186 | |
187 | } // namespace ocl |
188 | } // namespace gpu |
189 | } // namespace impl |
190 | } // namespace dnnl |
191 | |