1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_eltwise.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | static status_t init_conf_common(eltwise_conf_t &conf, offsets_t &off, |
25 | const eltwise_pd_t *pd, engine_t *engine) { |
26 | alg_kind_t alg = pd->desc()->alg_kind; |
27 | const bool is_forward = pd->is_fwd(); |
28 | const auto &src_md = pd->use_dst() ? pd->dst_md() : pd->src_md(); |
29 | const memory_desc_wrapper src_d(src_md); |
30 | const memory_desc_wrapper diff_data_d( |
31 | is_forward ? &glob_zero_md : pd->diff_src_md()); |
32 | |
33 | conf.data_md_info = memory_desc_info_t::create(src_d); |
34 | if (!is_forward) |
35 | conf.data_diff_md_info = memory_desc_info_t::create(diff_data_d); |
36 | |
37 | const int ndims = src_d.ndims(); |
38 | conf.ndims = ndims; |
39 | |
40 | conf.data_type = src_d.data_type(); |
41 | conf.alg = alg; |
42 | conf.is_forward = is_forward; |
43 | conf.attr_info = attr_info_t::create(pd->attr()); |
44 | conf.sub_group_size = 32; |
45 | |
46 | set_offsets(src_d, off.src_off); |
47 | set_offsets(diff_data_d, off.dst_off); |
48 | |
49 | const auto &dims = src_d.padded_dims(); |
50 | |
51 | conf.with_zero_padding = src_d.nelems(false) != src_d.nelems(true); |
52 | |
53 | int max_ndims = 6; |
54 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
55 | conf.dispatch = compute_engine->create_dispatch( |
56 | is_forward ? src_d.md_ : diff_data_d.md_); |
57 | for (int i = 0; i < max_ndims; ++i) { |
58 | if (i < ndims) |
59 | conf.dispatch.define_dim(utils::format("D%d" , i), i, dims[i]); |
60 | else |
61 | conf.dispatch.define_dim(utils::format("D%d" , i), 1); |
62 | } |
63 | conf.dispatch.generate(/*generate_lws=*/false); |
64 | |
65 | return status::success; |
66 | } |
67 | |
68 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
69 | const eltwise_conf_t &conf, const offsets_t &off, |
70 | const post_ops_t &post_ops) { |
71 | kernel_ctx.set_data_type(conf.data_type); |
72 | |
73 | def_eltwise_alg_kinds(kernel_ctx); |
74 | |
75 | kernel_ctx.define_int("WITH_ELTWISE" , 1); |
76 | kernel_ctx.define_int("ELTWISE_ALG" , conf.alg); |
77 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
78 | kernel_ctx.define_int("GWS0" , conf.dispatch.nd_range().global_range()[0]); |
79 | kernel_ctx.define_int("GWS1" , conf.dispatch.nd_range().global_range()[1]); |
80 | kernel_ctx.define_int("GWS2" , conf.dispatch.nd_range().global_range()[2]); |
81 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
82 | |
83 | bool with_binary_post_ops |
84 | = post_ops.find(primitive_kind_t::dnnl_binary) != -1; |
85 | kernel_ctx.define_int( |
86 | "USE_GWS_GET" , conf.with_zero_padding || with_binary_post_ops); |
87 | |
88 | def_memory_desc_info(kernel_ctx, conf.data_md_info, "DATA" ); |
89 | |
90 | if (!conf.is_forward) { |
91 | def_memory_desc_info(kernel_ctx, conf.data_diff_md_info, "DIFF_DATA" ); |
92 | } else { |
93 | kernel_ctx.define_int("IS_FWD" , 1); |
94 | } |
95 | |
96 | def_attr_info(kernel_ctx, conf.attr_info, post_ops); |
97 | def_dispatch(kernel_ctx, conf.dispatch); |
98 | |
99 | return status::success; |
100 | } |
101 | |
102 | status_t ref_eltwise_fwd_t::pd_t::init_conf(engine_t *engine) { |
103 | return init_conf_common(conf, off, this, engine); |
104 | } |
105 | |
106 | status_t ref_eltwise_fwd_t::pd_t::init_kernel_ctx( |
107 | compute::kernel_ctx_t &kernel_ctx) const { |
108 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
109 | } |
110 | |
111 | status_t ref_eltwise_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const { |
112 | |
113 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
114 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
115 | |
116 | const float alpha = pd()->desc()->alpha; |
117 | const float beta = pd()->desc()->beta; |
118 | |
119 | const auto &conf = pd()->conf; |
120 | |
121 | compute::kernel_arg_list_t arg_list; |
122 | arg_list.set(0, src); |
123 | arg_list.set(1, dst); |
124 | arg_list.set(2, alpha); |
125 | arg_list.set(3, beta); |
126 | |
127 | append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_); |
128 | |
129 | auto nd_range = conf.dispatch.nd_range(); |
130 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
131 | } |
132 | |
133 | status_t ref_eltwise_bwd_t::pd_t::init_conf(engine_t *engine) { |
134 | return init_conf_common(conf, off, this, engine); |
135 | } |
136 | |
137 | status_t ref_eltwise_bwd_t::pd_t::init_kernel_ctx( |
138 | compute::kernel_ctx_t &kernel_ctx) const { |
139 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
140 | } |
141 | |
142 | status_t ref_eltwise_bwd_t::execute_backward_dense( |
143 | const exec_ctx_t &ctx) const { |
144 | |
145 | auto &src = pd()->use_dst() ? CTX_IN_STORAGE(DNNL_ARG_DST) |
146 | : CTX_IN_STORAGE(DNNL_ARG_SRC); |
147 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
148 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
149 | |
150 | const float alpha = pd()->desc()->alpha; |
151 | const float beta = pd()->desc()->beta; |
152 | |
153 | const auto &conf = pd()->conf; |
154 | |
155 | compute::kernel_arg_list_t arg_list; |
156 | arg_list.set(0, src); |
157 | arg_list.set(1, diff_src); |
158 | arg_list.set(2, diff_dst); |
159 | arg_list.set(3, alpha); |
160 | arg_list.set(4, beta); |
161 | |
162 | auto nd_range = conf.dispatch.nd_range(); |
163 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
164 | } |
165 | |
166 | } // namespace ocl |
167 | } // namespace gpu |
168 | } // namespace impl |
169 | } // namespace dnnl |
170 | |