1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <math.h> |
18 | |
19 | #include "common/primitive_exec_types.hpp" |
20 | |
21 | #include "gpu/ocl/ocl_utils.hpp" |
22 | #include "gpu/ocl/ref_reduction.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace ocl { |
28 | |
29 | status_t ref_reduction_t::pd_t::init_conf(engine_t *engine) { |
30 | const reduction_pd_t *pd = this; |
31 | |
32 | const memory_desc_wrapper src_mdw(pd->src_md()); |
33 | const memory_desc_wrapper dst_mdw(pd->dst_md()); |
34 | |
35 | const int ndims = src_mdw.ndims(); |
36 | const auto src_dims = src_mdw.md_->dims; |
37 | const auto dst_dims = dst_mdw.md_->dims; |
38 | const auto *compute_engine |
39 | = utils::downcast<compute::compute_engine_t *>(engine); |
40 | |
41 | conf.alg = pd->desc()->alg_kind; |
42 | conf.src_md_info = memory_desc_info_t::create(src_mdw); |
43 | conf.dst_md_info = memory_desc_info_t::create(dst_mdw); |
44 | conf.dst_type = dst_mdw.data_type(); |
45 | conf.src_type = src_mdw.data_type(); |
46 | conf.ndims = ndims; |
47 | conf.power = pd->desc()->p; |
48 | conf.eps = pd->desc()->eps; |
49 | conf.dispatch = compute_engine->create_dispatch(src_mdw.md_); |
50 | conf.div = 1; |
51 | |
52 | for (int d = 0; d < ndims; d++) { |
53 | conf.reduce_dims[d] = conf.dst_dims[d] = dim_t {1}; |
54 | const bool is_reduction_dim = src_dims[d] != dst_dims[d]; |
55 | conf.is_reduction_dim[d] = is_reduction_dim; |
56 | |
57 | if (is_reduction_dim) { |
58 | conf.reduce_dims[d] = src_dims[d]; |
59 | conf.div *= conf.reduce_dims[d]; |
60 | } |
61 | conf.dst_dims[d] = dst_mdw.md_->padded_dims[d]; |
62 | } |
63 | |
64 | conf.dispatch.define_dim("D0" , 0, conf.dst_dims[0]); |
65 | conf.dispatch.define_dim("D1" , 0, ndims >= 2 ? conf.dst_dims[1] : 1); |
66 | conf.dispatch.define_dim( |
67 | "D2" , 0, ndims >= 6 ? conf.dst_dims[ndims - 4] : 1); |
68 | conf.dispatch.define_dim( |
69 | "D3" , 0, ndims >= 5 ? conf.dst_dims[ndims - 3] : 1); |
70 | conf.dispatch.define_dim( |
71 | "D4" , 0, ndims >= 4 ? conf.dst_dims[ndims - 2] : 1); |
72 | conf.dispatch.define_dim( |
73 | "D5" , 0, ndims >= 3 ? conf.dst_dims[ndims - 1] : 1); |
74 | conf.dispatch.generate(false); |
75 | |
76 | conf.attr_info = attr_info_t::create(pd->attr()); |
77 | set_offsets(src_mdw, conf.off.src_off); |
78 | set_offsets(dst_mdw, conf.off.dst_off); |
79 | |
80 | return status::success; |
81 | } |
82 | |
83 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
84 | const reduction_conf_t &conf, const post_ops_t &post_ops) { |
85 | using namespace alg_kind; |
86 | |
87 | kernel_ctx.set_data_type(conf.src_type); |
88 | |
89 | kernel_ctx.define_int("D0" , conf.dst_dims[0]); |
90 | kernel_ctx.define_int("D1" , conf.ndims >= 2 ? conf.dst_dims[1] : 1); |
91 | kernel_ctx.define_int( |
92 | "D2" , conf.ndims >= 6 ? conf.dst_dims[conf.ndims - 4] : 1); |
93 | kernel_ctx.define_int( |
94 | "D3" , conf.ndims >= 5 ? conf.dst_dims[conf.ndims - 3] : 1); |
95 | kernel_ctx.define_int( |
96 | "D4" , conf.ndims >= 4 ? conf.dst_dims[conf.ndims - 2] : 1); |
97 | kernel_ctx.define_int( |
98 | "D5" , conf.ndims >= 3 ? conf.dst_dims[conf.ndims - 1] : 1); |
99 | |
100 | kernel_ctx.define_int("REDUCTION_D0" , conf.reduce_dims[0]); |
101 | kernel_ctx.define_int( |
102 | "REDUCTION_D1" , conf.ndims >= 2 ? conf.reduce_dims[1] : 1); |
103 | kernel_ctx.define_int("REDUCTION_D2" , |
104 | conf.ndims >= 6 ? conf.reduce_dims[conf.ndims - 4] : 1); |
105 | kernel_ctx.define_int("REDUCTION_D3" , |
106 | conf.ndims >= 5 ? conf.reduce_dims[conf.ndims - 3] : 1); |
107 | kernel_ctx.define_int("REDUCTION_D4" , |
108 | conf.ndims >= 4 ? conf.reduce_dims[conf.ndims - 2] : 1); |
109 | kernel_ctx.define_int("REDUCTION_D5" , |
110 | conf.ndims >= 3 ? conf.reduce_dims[conf.ndims - 1] : 1); |
111 | |
112 | switch (conf.alg) { |
113 | case reduction_max: kernel_ctx.define_int("IS_MAX" , 1); break; |
114 | case reduction_min: kernel_ctx.define_int("IS_MIN" , 1); break; |
115 | case reduction_mean: kernel_ctx.define_int("IS_MEAN" , 1); break; |
116 | case reduction_sum: kernel_ctx.define_int("IS_SUM" , 1); break; |
117 | case reduction_mul: kernel_ctx.define_int("IS_MUL" , 1); break; |
118 | case reduction_norm_lp_max: |
119 | kernel_ctx.define_int("IS_LP_MAX" , 1); |
120 | break; |
121 | case reduction_norm_lp_sum: |
122 | kernel_ctx.define_int("IS_LP_SUM" , 1); |
123 | break; |
124 | case reduction_norm_lp_power_p_max: |
125 | kernel_ctx.define_int("IS_P_MAX" , 1); |
126 | break; |
127 | case reduction_norm_lp_power_p_sum: |
128 | kernel_ctx.define_int("IS_P_SUM" , 1); |
129 | break; |
130 | default: return status::invalid_arguments; |
131 | } |
132 | |
133 | def_offsets(conf.off.src_off, kernel_ctx, "SRC" , conf.ndims); |
134 | def_offsets(conf.off.dst_off, kernel_ctx, "DST" , conf.ndims); |
135 | |
136 | kernel_ctx.define_int("DIV" , conf.div); |
137 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
138 | kernel_ctx.define_int("POWER" , conf.power); |
139 | kernel_ctx.define_float("EPS" , conf.eps); |
140 | |
141 | def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC" ); |
142 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
143 | |
144 | def_attr_info(kernel_ctx, conf.attr_info, post_ops); |
145 | |
146 | def_dispatch(kernel_ctx, conf.dispatch); |
147 | |
148 | return status::success; |
149 | } |
150 | |
151 | status_t ref_reduction_t::pd_t::init_kernel_ctx( |
152 | compute::kernel_ctx_t &kernel_ctx) const { |
153 | return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_); |
154 | } |
155 | |
156 | status_t ref_reduction_t::execute_ref(const exec_ctx_t &ctx) const { |
157 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
158 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
159 | |
160 | const auto &conf = pd()->conf; |
161 | |
162 | compute::kernel_arg_list_t reduction_arg_list; |
163 | |
164 | reduction_arg_list.set(0, src); |
165 | reduction_arg_list.set(1, dst); |
166 | append_post_ops_to_arg_list( |
167 | ctx, reduction_arg_list, 2, pd()->attr()->post_ops_); |
168 | |
169 | auto nd_range = conf.dispatch.nd_range(); |
170 | |
171 | return parallel_for(ctx, nd_range, kernel, reduction_arg_list); |
172 | } |
173 | |
174 | } // namespace ocl |
175 | } // namespace gpu |
176 | } // namespace impl |
177 | } // namespace dnnl |
178 | |