1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stddef.h>
18#include <stdio.h>
19#include <stdlib.h>
20
21#include <sstream>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "utils/parallel.hpp"
26
27#include "dnnl_common.hpp"
28#include "dnnl_memory.hpp"
29
30#include "lrn/lrn.hpp"
31
32namespace lrn {
33
34int fill_dat(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt,
35 dnn_mem_t &mem_fp) {
36 const auto nelems = mem_fp.nelems();
37 const int range = 16;
38 // LRN in MIOpen 2.17 and older doesn't support negative input. The support
39 // was added in https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1562.
40 // The plan is to use only positive input at this point but bump the
41 // minimum required MIOpen version to 2.18 once it's released and enable
42 // negative input back.
43 const int f_min
44 = prb->dt == dnnl_u8 ? 0 : (is_amd_gpu() ? range : -range) / 2;
45
46 benchdnn_parallel_nd(nelems, [&](int64_t i) {
47 const int64_t gen = kind == SRC ? 1091 * i + 1637 : 1279 * i + 1009;
48 const float value = f_min + gen % range;
49 mem_fp.set_elem(i, value);
50 });
51
52 SAFE(mem_dt.reorder(mem_fp), WARN);
53
54 return OK;
55}
56
57int fill_src(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
58 return fill_dat(prb, SRC, mem_dt, mem_fp);
59}
60
61int fill_dst(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
62 return fill_dat(prb, DST, mem_dt, mem_fp);
63}
64
65dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
66 const prb_t *prb = init_pd_args.prb;
67 const dir_t dir = init_pd_args.dir;
68
69 dnnl_dims_t data_dims_0d = {prb->mb, prb->ic};
70 dnnl_dims_t data_dims_1d = {prb->mb, prb->ic, prb->iw};
71 dnnl_dims_t data_dims_2d = {prb->mb, prb->ic, prb->ih, prb->iw};
72 dnnl_dims_t data_dims_3d = {prb->mb, prb->ic, prb->id, prb->ih, prb->iw};
73
74 dnnl_dim_t *data_dims = prb->ndims == 5
75 ? data_dims_3d
76 : prb->ndims == 4 ? data_dims_2d
77 : prb->ndims == 3 ? data_dims_1d : data_dims_0d;
78
79 auto src_d = dnn_mem_t::init_md(prb->ndims, data_dims, prb->dt, prb->tag);
80 auto dst_d = dnn_mem_t::init_md(prb->ndims, data_dims, prb->dt, tag::any);
81
82 dnnl_alg_kind_t alg = alg2alg_kind(prb->alg);
83
84 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
85 create_dnnl_attr(prb->attr, attr_args_t()));
86
87 if (dir & FLAG_FWD) {
88 auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
89 : dnnl_forward_training;
90 DNN_SAFE_STATUS(dnnl_lrn_forward_primitive_desc_create(&init_pd_args.pd,
91 init_pd_args.engine, prop, alg, src_d, dst_d, prb->ls,
92 prb->alpha, prb->beta, prb->k, dnnl_attr));
93 } else {
94 auto diff_src_d
95 = dnn_mem_t::init_md(prb->ndims, data_dims, prb->dt, tag::any);
96 auto diff_dst_d
97 = dnn_mem_t::init_md(prb->ndims, data_dims, prb->dt, tag::any);
98 DNN_SAFE_STATUS(dnnl_lrn_backward_primitive_desc_create(
99 &init_pd_args.pd, init_pd_args.engine, alg, diff_src_d,
100 diff_dst_d, src_d, prb->ls, prb->alpha, prb->beta, prb->k,
101 init_pd_args.hint, dnnl_attr));
102 }
103
104 return dnnl_success;
105}
106
107void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
108 skip_unimplemented_data_type({prb->dt}, prb->dir, res);
109 skip_unimplemented_sum_po(prb->attr, res);
110}
111
112void skip_invalid_prb(const prb_t *prb, res_t *res) {}
113
114void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
115 const args_t &ref_args) {
116 // `3` is a const needed to adjust division error
117 cmp.set_threshold(compute_n_summands(prb) * 3 * epsilon_dt(prb->dt));
118}
119
120int doit(const prb_t *prb, res_t *res) {
121 if (bench_mode == LIST) return res->state = LISTED, OK;
122
123 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
124 bool is_service_prim = prb->dir & FLAG_BWD;
125 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res, FLAG_FWD, nullptr,
126 is_service_prim),
127 WARN);
128 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
129
130 auto const_fpd = query_pd(prim);
131
132 const auto &data_md = query_md(const_fpd, DNNL_ARG_SRC);
133 const auto &ws_md = query_md(const_fpd, DNNL_ARG_WORKSPACE);
134 const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD);
135
136 const auto fp = dnnl_f32;
137 const auto tag = tag::abx;
138
139 const auto &test_engine = get_test_engine();
140 const auto &ref_engine = get_cpu_engine();
141
142 dnn_mem_t src_fp(data_md, fp, tag, ref_engine);
143 dnn_mem_t src_dt(data_md, test_engine);
144
145 dnn_mem_t dst_fp(data_md, fp, tag, ref_engine);
146 dnn_mem_t dst_dt(data_md, test_engine);
147
148 dnn_mem_t ws_fp(ws_md, ref_engine);
149 dnn_mem_t ws_dt(ws_md, test_engine);
150 if (prb->dir & FLAG_INF) SAFE(ws_dt.ndims() == 0 ? OK : FAIL, WARN);
151 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
152
153 dnn_mem_t d_dst_dt, d_src_dt;
154
155 SAFE(fill_src(prb, src_dt, src_fp), WARN);
156
157 args_t args, ref_args;
158
159 args.set(DNNL_ARG_SRC, src_dt);
160 args.set(DNNL_ARG_DST, dst_dt);
161 args.set(DNNL_ARG_WORKSPACE, ws_dt);
162 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
163
164 SAFE(execute_and_wait(prim, args, res), WARN);
165
166 if (prb->dir & FLAG_FWD) {
167 if (is_bench_mode(CORR)) {
168 ref_args.set(DNNL_ARG_SRC, src_fp);
169 ref_args.set(DNNL_ARG_DST, dst_fp);
170
171 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
172 }
173 }
174
175 if (prb->dir & FLAG_BWD) {
176 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim;
177 SAFE(init_prim(tmp_prim, init_pd, prb, res, FLAG_BWD, const_fpd), WARN);
178 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
179 prim.reset(tmp_prim.release());
180
181 auto const_bpd = query_pd(prim);
182
183 const auto &d_data_md = query_md(const_bpd, DNNL_ARG_DIFF_DST);
184 const auto &d_scratchpad_md = query_md(const_bpd, DNNL_ARG_SCRATCHPAD);
185
186 dnn_mem_t d_dst_fp(d_data_md, fp, tag, ref_engine);
187 d_dst_dt = dnn_mem_t(d_data_md, test_engine);
188
189 dnn_mem_t d_src_fp(d_data_md, fp, tag, ref_engine);
190 d_src_dt = dnn_mem_t(d_data_md, test_engine);
191
192 scratchpad_dt = dnn_mem_t(d_scratchpad_md, test_engine);
193
194 SAFE(fill_dst(prb, d_dst_dt, d_dst_fp), WARN);
195
196 args.clear();
197 args.set(DNNL_ARG_SRC, src_dt);
198 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
199 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
200 args.set(DNNL_ARG_WORKSPACE, ws_dt);
201 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
202
203 SAFE(execute_and_wait(prim, args, res), WARN);
204
205 if (is_bench_mode(CORR)) {
206 ref_args.set(DNNL_ARG_SRC, src_fp);
207 ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp);
208 ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp);
209
210 check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res);
211 }
212 }
213
214 return measure_perf(prb->ctx_exe, res, prim, args);
215}
216
217} // namespace lrn
218