1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | |
19 | #include "gpu/ocl/gen9_eltwise.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace ocl { |
25 | |
26 | static status_t init_conf_common( |
27 | eltwise_conf_t &conf, engine_t *engine, const eltwise_pd_t *pd) { |
28 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
29 | auto arch = compute_engine->device_info()->gpu_arch(); |
30 | bool is_pre_xe_hp = arch < compute::gpu_arch_t::xe_hp; |
31 | |
32 | const auto &data_md = pd->use_dst() ? pd->dst_md() : pd->src_md(); |
33 | const memory_desc_wrapper data_d(*data_md); |
34 | // Important hw features for code generation |
35 | const int dt_size = (int)data_d.data_type_size(); |
36 | const int max_load_size = is_pre_xe_hp ? 128 : 256; |
37 | |
38 | // Heuristics chosen by experimentation |
39 | // load_unroll hides computation overhead associated with kernel start |
40 | // local_threads hides workgroup scheduling overhead |
41 | const int load_unroll = is_pre_xe_hp ? 4 : 1; |
42 | const int local_threads = is_pre_xe_hp ? 1 : 16; |
43 | |
44 | // Prefer loading multiple of max load size to reduce messages |
45 | const int load_size = load_unroll * max_load_size; |
46 | |
47 | conf.alg = pd->desc()->alg_kind; |
48 | conf.with_zero_padding = data_d.nelems(false) != data_d.nelems(true); |
49 | |
50 | // Set simd size |
51 | conf.sub_group_size = compute_engine->device_info()->max_subgroup_size(); |
52 | |
53 | // VECT_DATA_T only supports vector sizes up to 8 |
54 | conf.vector_size = std::min(load_size / (dt_size * conf.sub_group_size), 8); |
55 | conf.work_group_size = local_threads * conf.sub_group_size; |
56 | |
57 | return status::success; |
58 | } |
59 | |
60 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
61 | const eltwise_conf_t &conf, const offsets_t &off, |
62 | const memory_desc_wrapper &data_d) { |
63 | kernel_ctx.set_data_type(data_d.data_type()); |
64 | def_eltwise_alg_kinds(kernel_ctx); |
65 | |
66 | kernel_ctx.define_int("WITH_ELTWISE" , 1); |
67 | kernel_ctx.define_int("ELTWISE_ALG" , conf.alg); |
68 | |
69 | kernel_ctx.define_int("VECT_DT_N" , conf.vector_size); |
70 | |
71 | const int local_block_size = conf.work_group_size * conf.vector_size; |
72 | kernel_ctx.define_int("NELEMS_OVERFLOW" , |
73 | (data_d.nelems(conf.with_zero_padding) % local_block_size) != 0); |
74 | |
75 | // attribute for wg-size and subgroup-size |
76 | kernel_ctx.define_int("GWS_WITH_SG_DEFAULT" , 1); |
77 | // wg-size |
78 | kernel_ctx.define_int("GWS_LWS0_DEFAULT" , conf.work_group_size); |
79 | kernel_ctx.define_int("GWS_LWS1_DEFAULT" , 1); |
80 | kernel_ctx.define_int("GWS_LWS2_DEFAULT" , 1); |
81 | // subgroup-size |
82 | kernel_ctx.define_int("GWS_SGS_DEFAULT" , conf.sub_group_size); |
83 | |
84 | return status::success; |
85 | } |
86 | |
87 | status_t gen9_eltwise_fwd_t::pd_t::init_conf(engine_t *engine) { |
88 | const memory_desc_wrapper src_d(src_md()); |
89 | return init_conf_common(conf, engine, this); |
90 | } |
91 | |
92 | status_t gen9_eltwise_fwd_t::pd_t::init_kernel_ctx( |
93 | compute::kernel_ctx_t &kernel_ctx) const { |
94 | const memory_desc_wrapper src_d(src_md()); |
95 | return init_kernel_ctx_common(kernel_ctx, conf, off, src_d); |
96 | } |
97 | |
98 | status_t gen9_eltwise_fwd_t::execute_forward_dense( |
99 | const exec_ctx_t &ctx) const { |
100 | status_t status = status::success; |
101 | |
102 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
103 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
104 | |
105 | const memory_desc_wrapper src_d(pd()->src_md()); |
106 | const int nelems = src_d.nelems(pd()->conf.with_zero_padding); |
107 | const float alpha = pd()->desc()->alpha; |
108 | const float beta = pd()->desc()->beta; |
109 | |
110 | compute::kernel_arg_list_t arg_list; |
111 | arg_list.set(0, src); |
112 | arg_list.set(1, dst); |
113 | arg_list.set(2, nelems); |
114 | arg_list.set(3, alpha); |
115 | arg_list.set(4, beta); |
116 | |
117 | size_t lws = pd()->conf.work_group_size; |
118 | size_t total_wi = utils::div_up(nelems, pd()->conf.vector_size); |
119 | compute::nd_range_t nd_range({utils::rnd_up(total_wi, lws)}, {lws}); |
120 | |
121 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
122 | |
123 | if (!gpu_eltwise_fwd_pd_t::eltwise_preserves_zero( |
124 | pd()->desc()->alg_kind, alpha, beta)) { |
125 | ctx.zero_pad_output(DNNL_ARG_DST); |
126 | } |
127 | |
128 | return status; |
129 | } |
130 | |
131 | status_t gen9_eltwise_bwd_t::pd_t::init_conf(engine_t *engine) { |
132 | using namespace dnnl::impl::format_tag; |
133 | |
134 | const memory_desc_wrapper data_d(data_md()); |
135 | const memory_desc_wrapper diff_data_d(diff_src_md()); |
136 | |
137 | // This kernel supports only matching data and diff formats |
138 | if (data_d != diff_data_d) return status::unimplemented; |
139 | |
140 | return init_conf_common(conf, engine, this); |
141 | } |
142 | |
143 | status_t gen9_eltwise_bwd_t::pd_t::init_kernel_ctx( |
144 | compute::kernel_ctx_t &kernel_ctx) const { |
145 | const memory_desc_wrapper data_d(data_md()); |
146 | return init_kernel_ctx_common(kernel_ctx, conf, off, data_d); |
147 | } |
148 | |
149 | status_t gen9_eltwise_bwd_t::execute_backward_dense( |
150 | const exec_ctx_t &ctx) const { |
151 | status_t status = status::success; |
152 | |
153 | auto &src = pd()->use_dst() ? CTX_IN_STORAGE(DNNL_ARG_DST) |
154 | : CTX_IN_STORAGE(DNNL_ARG_SRC); |
155 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
156 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
157 | |
158 | const memory_desc_wrapper data_d(pd()->data_md()); |
159 | const int nelems = data_d.nelems(pd()->conf.with_zero_padding); |
160 | const float alpha = pd()->desc()->alpha; |
161 | const float beta = pd()->desc()->beta; |
162 | |
163 | compute::kernel_arg_list_t arg_list; |
164 | arg_list.set(0, src); |
165 | arg_list.set(1, diff_src); |
166 | arg_list.set(2, diff_dst); |
167 | arg_list.set(3, nelems); |
168 | arg_list.set(4, alpha); |
169 | arg_list.set(5, beta); |
170 | |
171 | size_t lws = pd()->conf.work_group_size; |
172 | size_t total_wi = utils::div_up(nelems, pd()->conf.vector_size); |
173 | compute::nd_range_t nd_range({utils::rnd_up(total_wi, lws)}, {lws}); |
174 | |
175 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
176 | |
177 | if (!gpu_eltwise_bwd_pd_t::eltwise_preserves_zero( |
178 | pd()->desc()->alg_kind, alpha, beta)) { |
179 | ctx.zero_pad_output(DNNL_ARG_DIFF_SRC); |
180 | } |
181 | |
182 | return status; |
183 | } |
184 | |
185 | } // namespace ocl |
186 | } // namespace gpu |
187 | } // namespace impl |
188 | } // namespace dnnl |
189 | |