1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24using tag = memory::format_tag;
25using dt = memory::data_type;
26
27struct lrn_test_params_t {
28 dt src_dt;
29 dt dst_dt; // diff_dst_dt
30 dt diff_src_dt;
31 tag src_tag;
32 tag dst_tag; // diff_dst_tag
33 tag diff_src_tag;
34 memory::dims dims;
35 memory::dim local_size;
36 float alpha, beta, k;
37 bool expect_to_fail;
38 dnnl_status_t expected_status;
39};
40
41bool cuda_check_format_tag(tag atag) {
42 return impl::utils::one_of(atag, tag::ncdhw, tag::nchw, tag::nhwc, tag::ncw,
43 tag::nwc, tag::any);
44}
45
46template <typename... Rest>
47bool cuda_check_format_tag(tag first_tag, Rest... rest_tags) {
48 const bool ok = cuda_check_format_tag(first_tag);
49 if (!ok) return ok;
50 return cuda_check_format_tag(rest_tags...);
51}
52
53bool hip_check_format_tag(tag atag) {
54 return impl::utils::one_of(atag, tag::nchw, tag::any);
55}
56
57template <typename... Rest>
58bool hip_check_format_tag(tag first_tag, Rest... rest_tags) {
59 const bool ok = hip_check_format_tag(first_tag);
60 if (!ok) return ok;
61 return hip_check_format_tag(rest_tags...);
62}
63
64class lrn_test_t : public ::testing::TestWithParam<lrn_test_params_t> {
65private:
66 lrn_test_params_t p;
67 memory src, workspace;
68 std::shared_ptr<lrn_forward::primitive_desc> pd_fwd_hint;
69
70protected:
71 void SetUp() override {
72 p = ::testing::TestWithParam<lrn_test_params_t>::GetParam();
73
74 SKIP_IF(unsupported_data_type(p.src_dt, p.dst_dt),
75 "Engine does not support this data type.");
76
77 SKIP_IF_CUDA(
78 p.dst_dt == dt::s8, "Unsupported int8 destination data type");
79 SKIP_IF_HIP(
80 p.dst_dt == dt::s8, "Unsupported int8 destination data type");
81
82 SKIP_IF_CUDA(!cuda_check_format_tag(p.src_tag, p.dst_tag),
83 "Unsupported format tag");
84 SKIP_IF_HIP(!hip_check_format_tag(p.src_tag, p.dst_tag),
85 "Unsupported format tag");
86
87 SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef
88 && p.dst_dt != dt::undef,
89 "Unsupported different data types for source and "
90 "destination");
91 SKIP_IF_HIP(p.src_dt != p.dst_dt && p.src_dt != dt::undef
92 && p.dst_dt != dt::undef,
93 "Unsupported different data types for source and "
94 "destination");
95
96 SKIP_IF_CUDA(p.src_tag != p.dst_tag && p.src_tag != tag::any
97 && p.dst_tag != tag::any,
98 "Unsupported different memory formats for source and "
99 "destination");
100 SKIP_IF_HIP(p.src_tag != p.dst_tag && p.src_tag != tag::any
101 && p.dst_tag != tag::any,
102 "Unsupported different memory formats for source and "
103 "destination");
104
105 catch_expected_failures(
106 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
107 }
108
109 void Forward(prop_kind pk, algorithm aalgorithm) {
110 // lrn specific types and values
111 using pd_t = lrn_forward::primitive_desc;
112
113 auto eng = get_test_engine();
114 auto strm = make_stream(eng);
115
116 auto aa = allows_attr_t {false};
117
118 auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
119 auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
120
121 // default pd ctor
122 auto pd = pd_t();
123 // regular pd ctor
124 pd = pd_t(eng, pk, aalgorithm, src_md, dst_md, p.local_size, p.alpha,
125 p.beta, p.k);
126 // test all pd ctors
127 test_fwd_pd_constructors<pd_t>(pd, aa, pk, aalgorithm, src_md, dst_md,
128 p.local_size, p.alpha, p.beta, p.k);
129 pd_fwd_hint = std::make_shared<pd_t>(pd);
130
131 EXPECT_ANY_THROW(lrn_forward(pd, {}));
132 // default primitive ctor
133 auto lrn = lrn_forward();
134 // regular primitive ctor
135 lrn = lrn_forward(pd);
136
137 // check primitive kind is lrn
138 ASSERT_TRUE(lrn.get_kind() == primitive::kind::lrn);
139 // query for descs from pd
140 const auto src_desc = pd.src_desc();
141 const auto dst_desc = pd.dst_desc();
142 // query for src_desc via exec arg
143 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
144 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
145 // query for dst_desc via exec arg
146 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
147 if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); }
148
149 // query primitive parameters
150 ASSERT_EQ(pd.get_prop_kind(), pk);
151 ASSERT_EQ(pd.get_algorithm(), aalgorithm);
152 ASSERT_EQ(pd.get_local_size(), p.local_size);
153 ASSERT_EQ(pd.get_alpha(), p.alpha);
154 ASSERT_EQ(pd.get_beta(), p.beta);
155 ASSERT_EQ(pd.get_k(), p.k);
156
157 // query for workspace
158 const auto workspace_desc = pd.workspace_desc();
159
160 // check primitive returns zero_md for all rest md
161 ASSERT_TRUE(pd.weights_desc().is_zero());
162 ASSERT_TRUE(pd.diff_src_desc().is_zero());
163 ASSERT_TRUE(pd.diff_dst_desc().is_zero());
164 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
165
166 src = test::make_memory(src_desc, eng);
167 auto dst = test::make_memory(dst_desc, eng);
168 workspace = test::make_memory(workspace_desc, eng);
169
170 fill_data(p.src_dt, src, 1, 1);
171 // test out-place mode
172 lrn.execute(strm,
173 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst},
174 {DNNL_ARG_WORKSPACE, workspace}});
175 strm.wait();
176 }
177
178 void Backward(algorithm aalgorithm) {
179 // lrn specific types and values
180 using pd_t = lrn_backward::primitive_desc;
181 using hint_pd_t = lrn_forward::primitive_desc;
182 allows_attr_t aa {false}; // doesn't support anything
183
184 auto eng = get_test_engine();
185 auto strm = make_stream(eng);
186 auto diff_src_md = memory::desc(p.dims, p.diff_src_dt, p.diff_src_tag);
187 auto diff_dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
188 auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
189
190 // default pd ctor
191 auto pd = pd_t();
192 // regular pd ctor
193 pd = pd_t(eng, aalgorithm, diff_src_md, diff_dst_md, src_md,
194 p.local_size, p.alpha, p.beta, p.k, *pd_fwd_hint);
195 // test all pd ctors
196 test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa,
197 aalgorithm, diff_src_md, diff_dst_md, src_md, p.local_size,
198 p.alpha, p.beta, p.k);
199
200 EXPECT_ANY_THROW(lrn_backward(pd, {}));
201 // default primitive ctor
202 auto lrn = lrn_backward();
203 // regular primitive ctor
204 lrn = lrn_backward(pd);
205
206 // check primitive kind is lrn
207 ASSERT_TRUE(lrn.get_kind() == primitive::kind::lrn);
208
209 // query for descs from pd
210 const auto diff_src_desc = pd.diff_src_desc();
211 const auto diff_dst_desc = pd.diff_dst_desc();
212 const auto src_desc = pd.src_desc();
213 // query for diff_src_desc via exec arg
214 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
215 == diff_src_desc);
216 if (p.diff_src_tag != tag::any) {
217 ASSERT_TRUE(diff_src_md == diff_src_desc);
218 }
219 // query for diff_dst_desc via exec arg
220 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
221 == diff_dst_desc);
222 if (p.dst_tag != tag::any) {
223 ASSERT_TRUE(diff_dst_md == diff_dst_desc);
224 }
225 // query for dst_desc via exec arg
226 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
227 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
228
229 // query primitive parameters
230 ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward_data);
231 ASSERT_EQ(pd.get_algorithm(), aalgorithm);
232 ASSERT_EQ(pd.get_local_size(), p.local_size);
233 ASSERT_EQ(pd.get_alpha(), p.alpha);
234 ASSERT_EQ(pd.get_beta(), p.beta);
235 ASSERT_EQ(pd.get_k(), p.k);
236
237 // check primitive returns zero_md for all rest md
238 ASSERT_TRUE(pd.weights_desc().is_zero());
239 ASSERT_TRUE(pd.dst_desc().is_zero());
240 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
241
242 auto diff_src = test::make_memory(diff_src_desc, eng);
243 auto diff_dst = test::make_memory(diff_dst_desc, eng);
244
245 fill_data(p.diff_src_dt, diff_dst, 2, 2);
246
247 // test out-place mode
248 lrn.execute(strm,
249 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DIFF_DST, diff_dst},
250 {DNNL_ARG_DIFF_SRC, diff_src},
251 {DNNL_ARG_WORKSPACE, workspace}});
252 strm.wait();
253 }
254
255 void Test() {
256 const std::vector<prop_kind> pks
257 = {prop_kind::forward_training, prop_kind::forward_inference};
258 const std::vector<algorithm> algs = {
259 algorithm::lrn_across_channels, algorithm::lrn_within_channel};
260
261 for_(auto pk : pks)
262 for (auto alg : algs) {
263 SKIP_FOR_LOOP_CUDA(alg != algorithm::lrn_across_channels,
264 "Unsupported algorithm");
265
266 Forward(pk, alg);
267 if (pk == prop_kind::forward_training
268 && p.diff_src_dt != memory::data_type::undef) {
269 SKIP_IF(unsupported_data_type(p.diff_src_dt),
270 "Engine does not support this data type.");
271 SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_src_tag),
272 "Unsupported format tag");
273 SKIP_IF_HIP(!hip_check_format_tag(p.diff_src_tag),
274 "Unsupported format tag");
275
276 SKIP_IF_CUDA(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef
277 && p.diff_src_dt != dt::undef,
278 "Unsupported different data types for diff_source and "
279 "diff_destination");
280 SKIP_IF_HIP(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef
281 && p.diff_src_dt != dt::undef,
282 "Unsupported different data types for diff_source and "
283 "diff_destination");
284
285 SKIP_IF_CUDA(p.src_tag != p.diff_src_tag
286 && p.src_tag != tag::any
287 && p.diff_src_tag != tag::any,
288 "Unsupported different memory formats for diff_source "
289 "and diff_destination");
290 SKIP_IF_HIP(p.src_tag != p.diff_src_tag && p.src_tag != tag::any
291 && p.diff_src_tag != tag::any,
292 "Unsupported different memory formats for diff_source "
293 "and diff_destination");
294
295 Backward(alg);
296 }
297 }
298 }
299
300 bool is_fwd(prop_kind pk) const {
301 return pk == prop_kind::forward_training
302 || pk == prop_kind::forward_inference;
303 }
304};
305
306using tp = lrn_test_params_t;
307
308TEST_P(lrn_test_t, TestsLRN) {}
309
310INSTANTIATE_TEST_SUITE_P(Test_LRN_EF, lrn_test_t,
311 ::testing::Values(
312 // Negative dims
313 tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nchw,
314 tag::undef, {2, -2, 128, 256}, 5, 1e-4f, 0.75f, 1.f,
315 true, dnnl_invalid_arguments},
316 // Negative local size
317 tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nchw,
318 tag::undef, {2, 2, 128, 256}, -1, 1e-4f, 0.75f, 1.f,
319 true, dnnl_invalid_arguments},
320 // Tag for src on forward is not specified
321 tp {dt::f32, dt::f32, dt::undef, tag::any, tag::nchw,
322 tag::undef, {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f,
323 true, dnnl_invalid_arguments},
324 // Tag for src on backward is not specified
325 tp {dt::f32, dt::f32, dt::f32, tag::any, tag::nchw, tag::nchw,
326 {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, true,
327 dnnl_invalid_arguments},
328 // Data type for src is not specified
329 tp {dt::undef, dt::f32, dt::undef, tag::nchw, tag::nchw,
330 tag::undef, {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f,
331 true, dnnl_invalid_arguments},
332 // Different data types are not supported
333 tp {dt::f32, dt::bf16, dt::undef, tag::nchw, tag::nchw,
334 tag::undef, {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f,
335 true, dnnl_unimplemented},
336 // Different data types are not supported
337 tp {dt::f32, dt::bf16, dt::f32, tag::nchw, tag::nchw, tag::nchw,
338 {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, true,
339 dnnl_unimplemented},
340 // Different memory formats are not supported
341 tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nhwc,
342 tag::undef, {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f,
343 true, dnnl_unimplemented},
344 // Different memory formats are not supported
345 tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nhwc, tag::nchw,
346 {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, true,
347 dnnl_unimplemented}));
348
349static auto all_cases = [](memory::data_type src_dt, memory::data_type dst_dt,
350 memory::data_type diff_src_dt) {
351 return ::testing::Values(
352 // Forward_nChw16c_padded_cases
353 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
354 tag::nChw16c, {2, 17, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f},
355 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
356 tag::nChw16c, {2, 26, 4, 4}, 5, 1.0e-4f, 0.75f, 5.7f},
357 // Forward_nChw8c_padded_cases
358 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
359 tag::nChw8c, {2, 7, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f},
360 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
361 tag::nChw8c, {2, 26, 4, 4}, 5, 1.0e-4f, 0.75f, 5.7f},
362 // Forward_cases
363 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
364 {2, 10, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f},
365 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
366 {2, 10, 4, 4}, 5, 1.0e-4f, 0.75f, 3.0f},
367 // ForwardNHWC_cases
368 tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc,
369 {2, 10, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f},
370 tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc,
371 {2, 10, 4, 4}, 5, 1.0e-4f, 0.75f, 4.85f},
372 // Forward_nChw8c_cases
373 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
374 tag::nChw8c, {2, 16, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f},
375 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
376 tag::nChw8c, {2, 16, 4, 4}, 5, 1.0e-4f, 0.75f, 5.7f},
377 // Forward_nChw16c_cases
378 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
379 tag::nChw16c, {2, 16, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f},
380 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
381 tag::nChw16c, {2, 16, 4, 4}, 5, 1.0e-4f, 0.75f, 5.7f},
382 // AlexnetForwardNCHW_cases
383 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
384 {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f},
385 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
386 {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f},
387 // AlexnetForwardNHWC_cases
388 tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc,
389 {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f},
390 tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc,
391 {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f},
392 // AlexnetForward_nChw8c_cases
393 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
394 tag::nChw8c, {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f},
395 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
396 tag::nChw8c, {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f},
397 // AlexnetForward_nChw16c_cases
398 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
399 tag::nChw16c, {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f},
400 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
401 tag::nChw16c, {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f},
402 // GoogleNetV1ForwardNCHW_cases
403 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
404 {2, 64, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f},
405 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
406 {2, 192, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f},
407 // GoogleNetV1Forward_nChw8c_cases
408 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
409 tag::nChw8c, {2, 64, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f},
410 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
411 tag::nChw8c, {2, 192, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f},
412 // GoogleNetV1Forward_nChw16c_cases
413 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
414 tag::nChw16c, {2, 64, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f},
415 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
416 tag::nChw16c, {2, 192, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f},
417 // RCNNForwardBlocked_cases
418 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
419 tag::nChw8c, {2, 96, 55, 55}, 3, 1.0e-4f, 0.75f, 1.0f},
420 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
421 tag::nChw8c, {2, 256, 27, 27}, 3, 1.0e-4f, 0.75f, 1.0f},
422 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
423 tag::nChw8c, {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f},
424 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
425 tag::nChw8c, {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f},
426 // ForwardNCHWTail_cases
427 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
428 {1, 64, 1, 9}, 5, 1.0e-4f, 0.75f, 1.0f},
429 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
430 {1, 64, 2, 9}, 5, 1.0e-4f, 0.75f, 1.0f},
431 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
432 {1, 64, 3, 9}, 5, 1.0e-4f, 0.75f, 1.0f},
433 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
434 {1, 64, 4, 9}, 5, 1.0e-4f, 0.75f, 1.0f},
435 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
436 {1, 64, 5, 9}, 5, 1.0e-4f, 0.75f, 1.0f},
437 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
438 {1, 64, 9, 6}, 5, 1.0e-4f, 0.75f, 1.0f},
439 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
440 {1, 64, 7, 9}, 5, 1.0e-4f, 0.75f, 1.0f},
441 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
442 {1, 64, 8, 9}, 5, 1.0e-4f, 0.75f, 1.0f});
443};
444
445#define EXPAND_DTS(src, dst, diff_src) \
446 memory::data_type::src, memory::data_type::dst, memory::data_type::diff_src
447
448#define INST_TEST_CASE(name, suite, ...) \
449 INSTANTIATE_TEST_SUITE_P(name, lrn_test_t, suite(__VA_ARGS__));
450
451#define CPU_INST_TEST_CASE(name, suite, ...) \
452 CPU_INSTANTIATE_TEST_SUITE_P(name, lrn_test_t, suite(__VA_ARGS__));
453
454#define GPU_INST_TEST_CASE(name, suite, ...) \
455 GPU_INSTANTIATE_TEST_SUITE_P(name, lrn_test_t, suite(__VA_ARGS__));
456
457INST_TEST_CASE(LRNSimpleF32, all_cases, EXPAND_DTS(f32, f32, f32));
458INST_TEST_CASE(LRNSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16, bf16));
459INST_TEST_CASE(LRNSimpleF16, all_cases, EXPAND_DTS(f16, f16, undef));
460} // namespace dnnl
461