1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <math.h>
18
19#include "common/primitive_exec_types.hpp"
20
21#include "gpu/ocl/ocl_utils.hpp"
22#include "gpu/ocl/ref_reduction.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace ocl {
28
29status_t ref_reduction_t::pd_t::init_conf(engine_t *engine) {
30 const reduction_pd_t *pd = this;
31
32 const memory_desc_wrapper src_mdw(pd->src_md());
33 const memory_desc_wrapper dst_mdw(pd->dst_md());
34
35 const int ndims = src_mdw.ndims();
36 const auto src_dims = src_mdw.md_->dims;
37 const auto dst_dims = dst_mdw.md_->dims;
38 const auto *compute_engine
39 = utils::downcast<compute::compute_engine_t *>(engine);
40
41 conf.alg = pd->desc()->alg_kind;
42 conf.src_md_info = memory_desc_info_t::create(src_mdw);
43 conf.dst_md_info = memory_desc_info_t::create(dst_mdw);
44 conf.dst_type = dst_mdw.data_type();
45 conf.src_type = src_mdw.data_type();
46 conf.ndims = ndims;
47 conf.power = pd->desc()->p;
48 conf.eps = pd->desc()->eps;
49 conf.dispatch = compute_engine->create_dispatch(src_mdw.md_);
50 conf.div = 1;
51
52 for (int d = 0; d < ndims; d++) {
53 conf.reduce_dims[d] = conf.dst_dims[d] = dim_t {1};
54 const bool is_reduction_dim = src_dims[d] != dst_dims[d];
55 conf.is_reduction_dim[d] = is_reduction_dim;
56
57 if (is_reduction_dim) {
58 conf.reduce_dims[d] = src_dims[d];
59 conf.div *= conf.reduce_dims[d];
60 }
61 conf.dst_dims[d] = dst_mdw.md_->padded_dims[d];
62 }
63
64 conf.dispatch.define_dim("D0", 0, conf.dst_dims[0]);
65 conf.dispatch.define_dim("D1", 0, ndims >= 2 ? conf.dst_dims[1] : 1);
66 conf.dispatch.define_dim(
67 "D2", 0, ndims >= 6 ? conf.dst_dims[ndims - 4] : 1);
68 conf.dispatch.define_dim(
69 "D3", 0, ndims >= 5 ? conf.dst_dims[ndims - 3] : 1);
70 conf.dispatch.define_dim(
71 "D4", 0, ndims >= 4 ? conf.dst_dims[ndims - 2] : 1);
72 conf.dispatch.define_dim(
73 "D5", 0, ndims >= 3 ? conf.dst_dims[ndims - 1] : 1);
74 conf.dispatch.generate(false);
75
76 conf.attr_info = attr_info_t::create(pd->attr());
77 set_offsets(src_mdw, conf.off.src_off);
78 set_offsets(dst_mdw, conf.off.dst_off);
79
80 return status::success;
81}
82
83static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
84 const reduction_conf_t &conf, const post_ops_t &post_ops) {
85 using namespace alg_kind;
86
87 kernel_ctx.set_data_type(conf.src_type);
88
89 kernel_ctx.define_int("D0", conf.dst_dims[0]);
90 kernel_ctx.define_int("D1", conf.ndims >= 2 ? conf.dst_dims[1] : 1);
91 kernel_ctx.define_int(
92 "D2", conf.ndims >= 6 ? conf.dst_dims[conf.ndims - 4] : 1);
93 kernel_ctx.define_int(
94 "D3", conf.ndims >= 5 ? conf.dst_dims[conf.ndims - 3] : 1);
95 kernel_ctx.define_int(
96 "D4", conf.ndims >= 4 ? conf.dst_dims[conf.ndims - 2] : 1);
97 kernel_ctx.define_int(
98 "D5", conf.ndims >= 3 ? conf.dst_dims[conf.ndims - 1] : 1);
99
100 kernel_ctx.define_int("REDUCTION_D0", conf.reduce_dims[0]);
101 kernel_ctx.define_int(
102 "REDUCTION_D1", conf.ndims >= 2 ? conf.reduce_dims[1] : 1);
103 kernel_ctx.define_int("REDUCTION_D2",
104 conf.ndims >= 6 ? conf.reduce_dims[conf.ndims - 4] : 1);
105 kernel_ctx.define_int("REDUCTION_D3",
106 conf.ndims >= 5 ? conf.reduce_dims[conf.ndims - 3] : 1);
107 kernel_ctx.define_int("REDUCTION_D4",
108 conf.ndims >= 4 ? conf.reduce_dims[conf.ndims - 2] : 1);
109 kernel_ctx.define_int("REDUCTION_D5",
110 conf.ndims >= 3 ? conf.reduce_dims[conf.ndims - 1] : 1);
111
112 switch (conf.alg) {
113 case reduction_max: kernel_ctx.define_int("IS_MAX", 1); break;
114 case reduction_min: kernel_ctx.define_int("IS_MIN", 1); break;
115 case reduction_mean: kernel_ctx.define_int("IS_MEAN", 1); break;
116 case reduction_sum: kernel_ctx.define_int("IS_SUM", 1); break;
117 case reduction_mul: kernel_ctx.define_int("IS_MUL", 1); break;
118 case reduction_norm_lp_max:
119 kernel_ctx.define_int("IS_LP_MAX", 1);
120 break;
121 case reduction_norm_lp_sum:
122 kernel_ctx.define_int("IS_LP_SUM", 1);
123 break;
124 case reduction_norm_lp_power_p_max:
125 kernel_ctx.define_int("IS_P_MAX", 1);
126 break;
127 case reduction_norm_lp_power_p_sum:
128 kernel_ctx.define_int("IS_P_SUM", 1);
129 break;
130 default: return status::invalid_arguments;
131 }
132
133 def_offsets(conf.off.src_off, kernel_ctx, "SRC", conf.ndims);
134 def_offsets(conf.off.dst_off, kernel_ctx, "DST", conf.ndims);
135
136 kernel_ctx.define_int("DIV", conf.div);
137 kernel_ctx.define_int("NDIMS", conf.ndims);
138 kernel_ctx.define_int("POWER", conf.power);
139 kernel_ctx.define_float("EPS", conf.eps);
140
141 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
142 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
143
144 def_attr_info(kernel_ctx, conf.attr_info, post_ops);
145
146 def_dispatch(kernel_ctx, conf.dispatch);
147
148 return status::success;
149}
150
151status_t ref_reduction_t::pd_t::init_kernel_ctx(
152 compute::kernel_ctx_t &kernel_ctx) const {
153 return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_);
154}
155
156status_t ref_reduction_t::execute_ref(const exec_ctx_t &ctx) const {
157 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
158 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
159
160 const auto &conf = pd()->conf;
161
162 compute::kernel_arg_list_t reduction_arg_list;
163
164 reduction_arg_list.set(0, src);
165 reduction_arg_list.set(1, dst);
166 append_post_ops_to_arg_list(
167 ctx, reduction_arg_list, 2, pd()->attr()->post_ops_);
168
169 auto nd_range = conf.dispatch.nd_range();
170
171 return parallel_for(ctx, nd_range, kernel, reduction_arg_list);
172}
173
174} // namespace ocl
175} // namespace gpu
176} // namespace impl
177} // namespace dnnl
178