1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stddef.h> |
18 | #include <stdio.h> |
19 | #include <stdlib.h> |
20 | |
21 | #include <sstream> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "utils/parallel.hpp" |
26 | |
27 | #include "dnnl_common.hpp" |
28 | #include "dnnl_memory.hpp" |
29 | |
30 | #include "lrn/lrn.hpp" |
31 | |
32 | namespace lrn { |
33 | |
34 | int fill_dat(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt, |
35 | dnn_mem_t &mem_fp) { |
36 | const auto nelems = mem_fp.nelems(); |
37 | const int range = 16; |
38 | // LRN in MIOpen 2.17 and older doesn't support negative input. The support |
39 | // was added in https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1562. |
40 | // The plan is to use only positive input at this point but bump the |
41 | // minimum required MIOpen version to 2.18 once it's released and enable |
42 | // negative input back. |
43 | const int f_min |
44 | = prb->dt == dnnl_u8 ? 0 : (is_amd_gpu() ? range : -range) / 2; |
45 | |
46 | benchdnn_parallel_nd(nelems, [&](int64_t i) { |
47 | const int64_t gen = kind == SRC ? 1091 * i + 1637 : 1279 * i + 1009; |
48 | const float value = f_min + gen % range; |
49 | mem_fp.set_elem(i, value); |
50 | }); |
51 | |
52 | SAFE(mem_dt.reorder(mem_fp), WARN); |
53 | |
54 | return OK; |
55 | } |
56 | |
57 | int fill_src(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
58 | return fill_dat(prb, SRC, mem_dt, mem_fp); |
59 | } |
60 | |
61 | int fill_dst(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
62 | return fill_dat(prb, DST, mem_dt, mem_fp); |
63 | } |
64 | |
65 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
66 | const prb_t *prb = init_pd_args.prb; |
67 | const dir_t dir = init_pd_args.dir; |
68 | |
69 | dnnl_dims_t data_dims_0d = {prb->mb, prb->ic}; |
70 | dnnl_dims_t data_dims_1d = {prb->mb, prb->ic, prb->iw}; |
71 | dnnl_dims_t data_dims_2d = {prb->mb, prb->ic, prb->ih, prb->iw}; |
72 | dnnl_dims_t data_dims_3d = {prb->mb, prb->ic, prb->id, prb->ih, prb->iw}; |
73 | |
74 | dnnl_dim_t *data_dims = prb->ndims == 5 |
75 | ? data_dims_3d |
76 | : prb->ndims == 4 ? data_dims_2d |
77 | : prb->ndims == 3 ? data_dims_1d : data_dims_0d; |
78 | |
79 | auto src_d = dnn_mem_t::init_md(prb->ndims, data_dims, prb->dt, prb->tag); |
80 | auto dst_d = dnn_mem_t::init_md(prb->ndims, data_dims, prb->dt, tag::any); |
81 | |
82 | dnnl_alg_kind_t alg = alg2alg_kind(prb->alg); |
83 | |
84 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
85 | create_dnnl_attr(prb->attr, attr_args_t())); |
86 | |
87 | if (dir & FLAG_FWD) { |
88 | auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference |
89 | : dnnl_forward_training; |
90 | DNN_SAFE_STATUS(dnnl_lrn_forward_primitive_desc_create(&init_pd_args.pd, |
91 | init_pd_args.engine, prop, alg, src_d, dst_d, prb->ls, |
92 | prb->alpha, prb->beta, prb->k, dnnl_attr)); |
93 | } else { |
94 | auto diff_src_d |
95 | = dnn_mem_t::init_md(prb->ndims, data_dims, prb->dt, tag::any); |
96 | auto diff_dst_d |
97 | = dnn_mem_t::init_md(prb->ndims, data_dims, prb->dt, tag::any); |
98 | DNN_SAFE_STATUS(dnnl_lrn_backward_primitive_desc_create( |
99 | &init_pd_args.pd, init_pd_args.engine, alg, diff_src_d, |
100 | diff_dst_d, src_d, prb->ls, prb->alpha, prb->beta, prb->k, |
101 | init_pd_args.hint, dnnl_attr)); |
102 | } |
103 | |
104 | return dnnl_success; |
105 | } |
106 | |
107 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
108 | skip_unimplemented_data_type({prb->dt}, prb->dir, res); |
109 | skip_unimplemented_sum_po(prb->attr, res); |
110 | } |
111 | |
112 | void skip_invalid_prb(const prb_t *prb, res_t *res) {} |
113 | |
114 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
115 | const args_t &ref_args) { |
116 | // `3` is a const needed to adjust division error |
117 | cmp.set_threshold(compute_n_summands(prb) * 3 * epsilon_dt(prb->dt)); |
118 | } |
119 | |
120 | int doit(const prb_t *prb, res_t *res) { |
121 | if (bench_mode == LIST) return res->state = LISTED, OK; |
122 | |
123 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
124 | bool is_service_prim = prb->dir & FLAG_BWD; |
125 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res, FLAG_FWD, nullptr, |
126 | is_service_prim), |
127 | WARN); |
128 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
129 | |
130 | auto const_fpd = query_pd(prim); |
131 | |
132 | const auto &data_md = query_md(const_fpd, DNNL_ARG_SRC); |
133 | const auto &ws_md = query_md(const_fpd, DNNL_ARG_WORKSPACE); |
134 | const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD); |
135 | |
136 | const auto fp = dnnl_f32; |
137 | const auto tag = tag::abx; |
138 | |
139 | const auto &test_engine = get_test_engine(); |
140 | const auto &ref_engine = get_cpu_engine(); |
141 | |
142 | dnn_mem_t src_fp(data_md, fp, tag, ref_engine); |
143 | dnn_mem_t src_dt(data_md, test_engine); |
144 | |
145 | dnn_mem_t dst_fp(data_md, fp, tag, ref_engine); |
146 | dnn_mem_t dst_dt(data_md, test_engine); |
147 | |
148 | dnn_mem_t ws_fp(ws_md, ref_engine); |
149 | dnn_mem_t ws_dt(ws_md, test_engine); |
150 | if (prb->dir & FLAG_INF) SAFE(ws_dt.ndims() == 0 ? OK : FAIL, WARN); |
151 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
152 | |
153 | dnn_mem_t d_dst_dt, d_src_dt; |
154 | |
155 | SAFE(fill_src(prb, src_dt, src_fp), WARN); |
156 | |
157 | args_t args, ref_args; |
158 | |
159 | args.set(DNNL_ARG_SRC, src_dt); |
160 | args.set(DNNL_ARG_DST, dst_dt); |
161 | args.set(DNNL_ARG_WORKSPACE, ws_dt); |
162 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
163 | |
164 | SAFE(execute_and_wait(prim, args, res), WARN); |
165 | |
166 | if (prb->dir & FLAG_FWD) { |
167 | if (is_bench_mode(CORR)) { |
168 | ref_args.set(DNNL_ARG_SRC, src_fp); |
169 | ref_args.set(DNNL_ARG_DST, dst_fp); |
170 | |
171 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
172 | } |
173 | } |
174 | |
175 | if (prb->dir & FLAG_BWD) { |
176 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim; |
177 | SAFE(init_prim(tmp_prim, init_pd, prb, res, FLAG_BWD, const_fpd), WARN); |
178 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
179 | prim.reset(tmp_prim.release()); |
180 | |
181 | auto const_bpd = query_pd(prim); |
182 | |
183 | const auto &d_data_md = query_md(const_bpd, DNNL_ARG_DIFF_DST); |
184 | const auto &d_scratchpad_md = query_md(const_bpd, DNNL_ARG_SCRATCHPAD); |
185 | |
186 | dnn_mem_t d_dst_fp(d_data_md, fp, tag, ref_engine); |
187 | d_dst_dt = dnn_mem_t(d_data_md, test_engine); |
188 | |
189 | dnn_mem_t d_src_fp(d_data_md, fp, tag, ref_engine); |
190 | d_src_dt = dnn_mem_t(d_data_md, test_engine); |
191 | |
192 | scratchpad_dt = dnn_mem_t(d_scratchpad_md, test_engine); |
193 | |
194 | SAFE(fill_dst(prb, d_dst_dt, d_dst_fp), WARN); |
195 | |
196 | args.clear(); |
197 | args.set(DNNL_ARG_SRC, src_dt); |
198 | args.set(DNNL_ARG_DIFF_DST, d_dst_dt); |
199 | args.set(DNNL_ARG_DIFF_SRC, d_src_dt); |
200 | args.set(DNNL_ARG_WORKSPACE, ws_dt); |
201 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
202 | |
203 | SAFE(execute_and_wait(prim, args, res), WARN); |
204 | |
205 | if (is_bench_mode(CORR)) { |
206 | ref_args.set(DNNL_ARG_SRC, src_fp); |
207 | ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp); |
208 | ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp); |
209 | |
210 | check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res); |
211 | } |
212 | } |
213 | |
214 | return measure_perf(prb->ctx_exe, res, prim, args); |
215 | } |
216 | |
217 | } // namespace lrn |
218 | |