1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_eltwise.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace ocl {
23
24static status_t init_conf_common(eltwise_conf_t &conf, offsets_t &off,
25 const eltwise_pd_t *pd, engine_t *engine) {
26 alg_kind_t alg = pd->desc()->alg_kind;
27 const bool is_forward = pd->is_fwd();
28 const auto &src_md = pd->use_dst() ? pd->dst_md() : pd->src_md();
29 const memory_desc_wrapper src_d(src_md);
30 const memory_desc_wrapper diff_data_d(
31 is_forward ? &glob_zero_md : pd->diff_src_md());
32
33 conf.data_md_info = memory_desc_info_t::create(src_d);
34 if (!is_forward)
35 conf.data_diff_md_info = memory_desc_info_t::create(diff_data_d);
36
37 const int ndims = src_d.ndims();
38 conf.ndims = ndims;
39
40 conf.data_type = src_d.data_type();
41 conf.alg = alg;
42 conf.is_forward = is_forward;
43 conf.attr_info = attr_info_t::create(pd->attr());
44 conf.sub_group_size = 32;
45
46 set_offsets(src_d, off.src_off);
47 set_offsets(diff_data_d, off.dst_off);
48
49 const auto &dims = src_d.padded_dims();
50
51 conf.with_zero_padding = src_d.nelems(false) != src_d.nelems(true);
52
53 int max_ndims = 6;
54 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
55 conf.dispatch = compute_engine->create_dispatch(
56 is_forward ? src_d.md_ : diff_data_d.md_);
57 for (int i = 0; i < max_ndims; ++i) {
58 if (i < ndims)
59 conf.dispatch.define_dim(utils::format("D%d", i), i, dims[i]);
60 else
61 conf.dispatch.define_dim(utils::format("D%d", i), 1);
62 }
63 conf.dispatch.generate(/*generate_lws=*/false);
64
65 return status::success;
66}
67
68static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
69 const eltwise_conf_t &conf, const offsets_t &off,
70 const post_ops_t &post_ops) {
71 kernel_ctx.set_data_type(conf.data_type);
72
73 def_eltwise_alg_kinds(kernel_ctx);
74
75 kernel_ctx.define_int("WITH_ELTWISE", 1);
76 kernel_ctx.define_int("ELTWISE_ALG", conf.alg);
77 kernel_ctx.define_int("NDIMS", conf.ndims);
78 kernel_ctx.define_int("GWS0", conf.dispatch.nd_range().global_range()[0]);
79 kernel_ctx.define_int("GWS1", conf.dispatch.nd_range().global_range()[1]);
80 kernel_ctx.define_int("GWS2", conf.dispatch.nd_range().global_range()[2]);
81 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
82
83 bool with_binary_post_ops
84 = post_ops.find(primitive_kind_t::dnnl_binary) != -1;
85 kernel_ctx.define_int(
86 "USE_GWS_GET", conf.with_zero_padding || with_binary_post_ops);
87
88 def_memory_desc_info(kernel_ctx, conf.data_md_info, "DATA");
89
90 if (!conf.is_forward) {
91 def_memory_desc_info(kernel_ctx, conf.data_diff_md_info, "DIFF_DATA");
92 } else {
93 kernel_ctx.define_int("IS_FWD", 1);
94 }
95
96 def_attr_info(kernel_ctx, conf.attr_info, post_ops);
97 def_dispatch(kernel_ctx, conf.dispatch);
98
99 return status::success;
100}
101
102status_t ref_eltwise_fwd_t::pd_t::init_conf(engine_t *engine) {
103 return init_conf_common(conf, off, this, engine);
104}
105
106status_t ref_eltwise_fwd_t::pd_t::init_kernel_ctx(
107 compute::kernel_ctx_t &kernel_ctx) const {
108 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
109}
110
111status_t ref_eltwise_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
112
113 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
114 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
115
116 const float alpha = pd()->desc()->alpha;
117 const float beta = pd()->desc()->beta;
118
119 const auto &conf = pd()->conf;
120
121 compute::kernel_arg_list_t arg_list;
122 arg_list.set(0, src);
123 arg_list.set(1, dst);
124 arg_list.set(2, alpha);
125 arg_list.set(3, beta);
126
127 append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_);
128
129 auto nd_range = conf.dispatch.nd_range();
130 return parallel_for(ctx, nd_range, kernel_, arg_list);
131}
132
133status_t ref_eltwise_bwd_t::pd_t::init_conf(engine_t *engine) {
134 return init_conf_common(conf, off, this, engine);
135}
136
137status_t ref_eltwise_bwd_t::pd_t::init_kernel_ctx(
138 compute::kernel_ctx_t &kernel_ctx) const {
139 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
140}
141
142status_t ref_eltwise_bwd_t::execute_backward_dense(
143 const exec_ctx_t &ctx) const {
144
145 auto &src = pd()->use_dst() ? CTX_IN_STORAGE(DNNL_ARG_DST)
146 : CTX_IN_STORAGE(DNNL_ARG_SRC);
147 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
148 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
149
150 const float alpha = pd()->desc()->alpha;
151 const float beta = pd()->desc()->beta;
152
153 const auto &conf = pd()->conf;
154
155 compute::kernel_arg_list_t arg_list;
156 arg_list.set(0, src);
157 arg_list.set(1, diff_src);
158 arg_list.set(2, diff_dst);
159 arg_list.set(3, alpha);
160 arg_list.set(4, beta);
161
162 auto nd_range = conf.dispatch.nd_range();
163 return parallel_for(ctx, nd_range, kernel_, arg_list);
164}
165
166} // namespace ocl
167} // namespace gpu
168} // namespace impl
169} // namespace dnnl
170