1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18
19#include "gpu/ocl/gen9_eltwise.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace ocl {
25
26static status_t init_conf_common(
27 eltwise_conf_t &conf, engine_t *engine, const eltwise_pd_t *pd) {
28 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
29 auto arch = compute_engine->device_info()->gpu_arch();
30 bool is_pre_xe_hp = arch < compute::gpu_arch_t::xe_hp;
31
32 const auto &data_md = pd->use_dst() ? pd->dst_md() : pd->src_md();
33 const memory_desc_wrapper data_d(*data_md);
34 // Important hw features for code generation
35 const int dt_size = (int)data_d.data_type_size();
36 const int max_load_size = is_pre_xe_hp ? 128 : 256;
37
38 // Heuristics chosen by experimentation
39 // load_unroll hides computation overhead associated with kernel start
40 // local_threads hides workgroup scheduling overhead
41 const int load_unroll = is_pre_xe_hp ? 4 : 1;
42 const int local_threads = is_pre_xe_hp ? 1 : 16;
43
44 // Prefer loading multiple of max load size to reduce messages
45 const int load_size = load_unroll * max_load_size;
46
47 conf.alg = pd->desc()->alg_kind;
48 conf.with_zero_padding = data_d.nelems(false) != data_d.nelems(true);
49
50 // Set simd size
51 conf.sub_group_size = compute_engine->device_info()->max_subgroup_size();
52
53 // VECT_DATA_T only supports vector sizes up to 8
54 conf.vector_size = std::min(load_size / (dt_size * conf.sub_group_size), 8);
55 conf.work_group_size = local_threads * conf.sub_group_size;
56
57 return status::success;
58}
59
60static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
61 const eltwise_conf_t &conf, const offsets_t &off,
62 const memory_desc_wrapper &data_d) {
63 kernel_ctx.set_data_type(data_d.data_type());
64 def_eltwise_alg_kinds(kernel_ctx);
65
66 kernel_ctx.define_int("WITH_ELTWISE", 1);
67 kernel_ctx.define_int("ELTWISE_ALG", conf.alg);
68
69 kernel_ctx.define_int("VECT_DT_N", conf.vector_size);
70
71 const int local_block_size = conf.work_group_size * conf.vector_size;
72 kernel_ctx.define_int("NELEMS_OVERFLOW",
73 (data_d.nelems(conf.with_zero_padding) % local_block_size) != 0);
74
75 // attribute for wg-size and subgroup-size
76 kernel_ctx.define_int("GWS_WITH_SG_DEFAULT", 1);
77 // wg-size
78 kernel_ctx.define_int("GWS_LWS0_DEFAULT", conf.work_group_size);
79 kernel_ctx.define_int("GWS_LWS1_DEFAULT", 1);
80 kernel_ctx.define_int("GWS_LWS2_DEFAULT", 1);
81 // subgroup-size
82 kernel_ctx.define_int("GWS_SGS_DEFAULT", conf.sub_group_size);
83
84 return status::success;
85}
86
87status_t gen9_eltwise_fwd_t::pd_t::init_conf(engine_t *engine) {
88 const memory_desc_wrapper src_d(src_md());
89 return init_conf_common(conf, engine, this);
90}
91
92status_t gen9_eltwise_fwd_t::pd_t::init_kernel_ctx(
93 compute::kernel_ctx_t &kernel_ctx) const {
94 const memory_desc_wrapper src_d(src_md());
95 return init_kernel_ctx_common(kernel_ctx, conf, off, src_d);
96}
97
98status_t gen9_eltwise_fwd_t::execute_forward_dense(
99 const exec_ctx_t &ctx) const {
100 status_t status = status::success;
101
102 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
103 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
104
105 const memory_desc_wrapper src_d(pd()->src_md());
106 const int nelems = src_d.nelems(pd()->conf.with_zero_padding);
107 const float alpha = pd()->desc()->alpha;
108 const float beta = pd()->desc()->beta;
109
110 compute::kernel_arg_list_t arg_list;
111 arg_list.set(0, src);
112 arg_list.set(1, dst);
113 arg_list.set(2, nelems);
114 arg_list.set(3, alpha);
115 arg_list.set(4, beta);
116
117 size_t lws = pd()->conf.work_group_size;
118 size_t total_wi = utils::div_up(nelems, pd()->conf.vector_size);
119 compute::nd_range_t nd_range({utils::rnd_up(total_wi, lws)}, {lws});
120
121 status = parallel_for(ctx, nd_range, kernel_, arg_list);
122
123 if (!gpu_eltwise_fwd_pd_t::eltwise_preserves_zero(
124 pd()->desc()->alg_kind, alpha, beta)) {
125 ctx.zero_pad_output(DNNL_ARG_DST);
126 }
127
128 return status;
129}
130
131status_t gen9_eltwise_bwd_t::pd_t::init_conf(engine_t *engine) {
132 using namespace dnnl::impl::format_tag;
133
134 const memory_desc_wrapper data_d(data_md());
135 const memory_desc_wrapper diff_data_d(diff_src_md());
136
137 // This kernel supports only matching data and diff formats
138 if (data_d != diff_data_d) return status::unimplemented;
139
140 return init_conf_common(conf, engine, this);
141}
142
143status_t gen9_eltwise_bwd_t::pd_t::init_kernel_ctx(
144 compute::kernel_ctx_t &kernel_ctx) const {
145 const memory_desc_wrapper data_d(data_md());
146 return init_kernel_ctx_common(kernel_ctx, conf, off, data_d);
147}
148
149status_t gen9_eltwise_bwd_t::execute_backward_dense(
150 const exec_ctx_t &ctx) const {
151 status_t status = status::success;
152
153 auto &src = pd()->use_dst() ? CTX_IN_STORAGE(DNNL_ARG_DST)
154 : CTX_IN_STORAGE(DNNL_ARG_SRC);
155 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
156 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
157
158 const memory_desc_wrapper data_d(pd()->data_md());
159 const int nelems = data_d.nelems(pd()->conf.with_zero_padding);
160 const float alpha = pd()->desc()->alpha;
161 const float beta = pd()->desc()->beta;
162
163 compute::kernel_arg_list_t arg_list;
164 arg_list.set(0, src);
165 arg_list.set(1, diff_src);
166 arg_list.set(2, diff_dst);
167 arg_list.set(3, nelems);
168 arg_list.set(4, alpha);
169 arg_list.set(5, beta);
170
171 size_t lws = pd()->conf.work_group_size;
172 size_t total_wi = utils::div_up(nelems, pd()->conf.vector_size);
173 compute::nd_range_t nd_range({utils::rnd_up(total_wi, lws)}, {lws});
174
175 status = parallel_for(ctx, nd_range, kernel_, arg_list);
176
177 if (!gpu_eltwise_bwd_pd_t::eltwise_preserves_zero(
178 pd()->desc()->alg_kind, alpha, beta)) {
179 ctx.zero_pad_output(DNNL_ARG_DIFF_SRC);
180 }
181
182 return status;
183}
184
185} // namespace ocl
186} // namespace gpu
187} // namespace impl
188} // namespace dnnl
189