1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_pooling.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | static status_t init_conf_common(pool_conf_t &conf, offsets_t &off, |
25 | const pooling_pd_t *pd, engine_t *engine, const bool is_bwd) { |
26 | using namespace dnnl::impl::format_tag; |
27 | |
28 | const memory_desc_wrapper src_mdw(pd->invariant_src_md()); |
29 | const memory_desc_wrapper dst_mdw(pd->invariant_dst_md()); |
30 | |
31 | set_default_pool_conf(conf, *pd->desc(), *pd->invariant_src_md(), |
32 | *pd->invariant_dst_md(), *pd->attr()); |
33 | |
34 | set_offsets(src_mdw, off.src_off); |
35 | set_offsets(dst_mdw, off.dst_off); |
36 | |
37 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
38 | conf.dispatch = compute_engine->create_dispatch( |
39 | conf.is_backward ? src_mdw.md_ : dst_mdw.md_); |
40 | conf.dispatch.define_dim("MB" , 0, conf.mb_padded); |
41 | conf.dispatch.define_dim("OC" , 1, conf.c_padded); |
42 | int ndims = conf.ndims; |
43 | if (is_bwd) { |
44 | conf.dispatch.define_dim("ID" , nstl::max(1, ndims - 3), conf.id); |
45 | conf.dispatch.define_dim("IH" , nstl::max(1, ndims - 2), conf.ih); |
46 | conf.dispatch.define_dim("IW" , nstl::max(1, ndims - 1), conf.iw); |
47 | } else { |
48 | conf.dispatch.define_dim("OD" , nstl::max(1, ndims - 3), conf.od); |
49 | conf.dispatch.define_dim("OH" , nstl::max(1, ndims - 2), conf.oh); |
50 | conf.dispatch.define_dim("OW" , nstl::max(1, ndims - 1), conf.ow); |
51 | } |
52 | conf.dispatch.generate(); |
53 | |
54 | conf.attr_info = attr_info_t::create(pd->attr()); |
55 | |
56 | return status::success; |
57 | }; |
58 | |
59 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
60 | const pool_conf_t &conf, const offsets_t &off, |
61 | const post_ops_t &post_ops) { |
62 | using namespace dnnl::impl::alg_kind; |
63 | kernel_ctx.set_data_type(conf.src_dt); |
64 | |
65 | kernel_ctx.define_int("SUB_GROUP_SIZE" , 1); |
66 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
67 | kernel_ctx.define_int("OC_WO_PADDING" , conf.c); |
68 | kernel_ctx.define_int("ID" , conf.id); |
69 | kernel_ctx.define_int("IH" , conf.ih); |
70 | kernel_ctx.define_int("IW" , conf.iw); |
71 | kernel_ctx.define_int("OD" , conf.od); |
72 | kernel_ctx.define_int("OH" , conf.oh); |
73 | kernel_ctx.define_int("OW" , conf.ow); |
74 | kernel_ctx.define_int("KD" , conf.kd); |
75 | kernel_ctx.define_int("KH" , conf.kh); |
76 | kernel_ctx.define_int("KW" , conf.kw); |
77 | kernel_ctx.define_int("DD" , conf.dd); |
78 | kernel_ctx.define_int("DH" , conf.dh); |
79 | kernel_ctx.define_int("DW" , conf.dw); |
80 | kernel_ctx.define_int("SD" , conf.stride_d); |
81 | kernel_ctx.define_int("SH" , conf.stride_h); |
82 | kernel_ctx.define_int("SW" , conf.stride_w); |
83 | kernel_ctx.define_int("PD" , conf.f_pad); |
84 | kernel_ctx.define_int("PH" , conf.t_pad); |
85 | kernel_ctx.define_int("PW" , conf.l_pad); |
86 | kernel_ctx.define_int("IS_TRAINING" , conf.is_training); |
87 | kernel_ctx.define_int("IS_BWD" , conf.is_backward); |
88 | kernel_ctx.define_int("IS_FWD" , !conf.is_backward); |
89 | |
90 | kernel_ctx.define_int("ALG_MAX" , (conf.alg == pooling_max)); |
91 | kernel_ctx.define_int( |
92 | "ALG_AVG_NP" , (conf.alg == pooling_avg_exclude_padding)); |
93 | kernel_ctx.define_int( |
94 | "ALG_AVG_P" , (conf.alg == pooling_avg_include_padding)); |
95 | |
96 | def_attr_info(kernel_ctx, conf.attr_info, post_ops); |
97 | |
98 | def_offsets(off.src_off, kernel_ctx, "SRC" , conf.ndims); |
99 | def_offsets(off.dst_off, kernel_ctx, "DST" , conf.ndims); |
100 | |
101 | def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC" ); |
102 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
103 | |
104 | def_dispatch(kernel_ctx, conf.dispatch); |
105 | |
106 | return status::success; |
107 | } |
108 | |
109 | status_t ref_pooling_fwd_t::pd_t::init_conf(engine_t *engine) { |
110 | return init_conf_common(conf, off, this, engine, false); |
111 | } |
112 | |
113 | status_t ref_pooling_fwd_t::pd_t::init_kernel_ctx( |
114 | compute::kernel_ctx_t &kernel_ctx) const { |
115 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
116 | } |
117 | |
118 | status_t ref_pooling_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
119 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
120 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
121 | auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE); |
122 | |
123 | compute::kernel_arg_list_t arg_list; |
124 | arg_list.set(0, src); |
125 | arg_list.set(1, ws); |
126 | arg_list.set(2, dst); |
127 | append_post_ops_to_arg_list(ctx, arg_list, 3, pd()->attr()->post_ops_); |
128 | |
129 | auto nd_range = pd()->conf.dispatch.nd_range(); |
130 | |
131 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
132 | } |
133 | |
134 | status_t ref_pooling_bwd_t::pd_t::init_conf(engine_t *engine) { |
135 | return init_conf_common(conf, off, this, engine, true); |
136 | } |
137 | |
138 | status_t ref_pooling_bwd_t::pd_t::init_kernel_ctx( |
139 | compute::kernel_ctx_t &kernel_ctx) const { |
140 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
141 | } |
142 | |
143 | status_t ref_pooling_bwd_t::execute_backward(const exec_ctx_t &ctx) const { |
144 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
145 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
146 | auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); |
147 | |
148 | compute::kernel_arg_list_t arg_list; |
149 | arg_list.set(0, diff_src); |
150 | arg_list.set(1, ws); |
151 | arg_list.set(2, diff_dst); |
152 | |
153 | auto nd_range = pd()->conf.dispatch.nd_range(); |
154 | |
155 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
156 | } |
157 | |
158 | } // namespace ocl |
159 | } // namespace gpu |
160 | } // namespace impl |
161 | } // namespace dnnl |
162 | |