1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | using tag = memory::format_tag; |
25 | using dt = memory::data_type; |
26 | |
27 | struct lrn_test_params_t { |
28 | dt src_dt; |
29 | dt dst_dt; // diff_dst_dt |
30 | dt diff_src_dt; |
31 | tag src_tag; |
32 | tag dst_tag; // diff_dst_tag |
33 | tag diff_src_tag; |
34 | memory::dims dims; |
35 | memory::dim local_size; |
36 | float alpha, beta, k; |
37 | bool expect_to_fail; |
38 | dnnl_status_t expected_status; |
39 | }; |
40 | |
41 | bool cuda_check_format_tag(tag atag) { |
42 | return impl::utils::one_of(atag, tag::ncdhw, tag::nchw, tag::nhwc, tag::ncw, |
43 | tag::nwc, tag::any); |
44 | } |
45 | |
46 | template <typename... Rest> |
47 | bool cuda_check_format_tag(tag first_tag, Rest... rest_tags) { |
48 | const bool ok = cuda_check_format_tag(first_tag); |
49 | if (!ok) return ok; |
50 | return cuda_check_format_tag(rest_tags...); |
51 | } |
52 | |
53 | bool hip_check_format_tag(tag atag) { |
54 | return impl::utils::one_of(atag, tag::nchw, tag::any); |
55 | } |
56 | |
57 | template <typename... Rest> |
58 | bool hip_check_format_tag(tag first_tag, Rest... rest_tags) { |
59 | const bool ok = hip_check_format_tag(first_tag); |
60 | if (!ok) return ok; |
61 | return hip_check_format_tag(rest_tags...); |
62 | } |
63 | |
64 | class lrn_test_t : public ::testing::TestWithParam<lrn_test_params_t> { |
65 | private: |
66 | lrn_test_params_t p; |
67 | memory src, workspace; |
68 | std::shared_ptr<lrn_forward::primitive_desc> pd_fwd_hint; |
69 | |
70 | protected: |
71 | void SetUp() override { |
72 | p = ::testing::TestWithParam<lrn_test_params_t>::GetParam(); |
73 | |
74 | SKIP_IF(unsupported_data_type(p.src_dt, p.dst_dt), |
75 | "Engine does not support this data type." ); |
76 | |
77 | SKIP_IF_CUDA( |
78 | p.dst_dt == dt::s8, "Unsupported int8 destination data type" ); |
79 | SKIP_IF_HIP( |
80 | p.dst_dt == dt::s8, "Unsupported int8 destination data type" ); |
81 | |
82 | SKIP_IF_CUDA(!cuda_check_format_tag(p.src_tag, p.dst_tag), |
83 | "Unsupported format tag" ); |
84 | SKIP_IF_HIP(!hip_check_format_tag(p.src_tag, p.dst_tag), |
85 | "Unsupported format tag" ); |
86 | |
87 | SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef |
88 | && p.dst_dt != dt::undef, |
89 | "Unsupported different data types for source and " |
90 | "destination" ); |
91 | SKIP_IF_HIP(p.src_dt != p.dst_dt && p.src_dt != dt::undef |
92 | && p.dst_dt != dt::undef, |
93 | "Unsupported different data types for source and " |
94 | "destination" ); |
95 | |
96 | SKIP_IF_CUDA(p.src_tag != p.dst_tag && p.src_tag != tag::any |
97 | && p.dst_tag != tag::any, |
98 | "Unsupported different memory formats for source and " |
99 | "destination" ); |
100 | SKIP_IF_HIP(p.src_tag != p.dst_tag && p.src_tag != tag::any |
101 | && p.dst_tag != tag::any, |
102 | "Unsupported different memory formats for source and " |
103 | "destination" ); |
104 | |
105 | catch_expected_failures( |
106 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
107 | } |
108 | |
109 | void Forward(prop_kind pk, algorithm aalgorithm) { |
110 | // lrn specific types and values |
111 | using pd_t = lrn_forward::primitive_desc; |
112 | |
113 | auto eng = get_test_engine(); |
114 | auto strm = make_stream(eng); |
115 | |
116 | auto aa = allows_attr_t {false}; |
117 | |
118 | auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
119 | auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
120 | |
121 | // default pd ctor |
122 | auto pd = pd_t(); |
123 | // regular pd ctor |
124 | pd = pd_t(eng, pk, aalgorithm, src_md, dst_md, p.local_size, p.alpha, |
125 | p.beta, p.k); |
126 | // test all pd ctors |
127 | test_fwd_pd_constructors<pd_t>(pd, aa, pk, aalgorithm, src_md, dst_md, |
128 | p.local_size, p.alpha, p.beta, p.k); |
129 | pd_fwd_hint = std::make_shared<pd_t>(pd); |
130 | |
131 | EXPECT_ANY_THROW(lrn_forward(pd, {})); |
132 | // default primitive ctor |
133 | auto lrn = lrn_forward(); |
134 | // regular primitive ctor |
135 | lrn = lrn_forward(pd); |
136 | |
137 | // check primitive kind is lrn |
138 | ASSERT_TRUE(lrn.get_kind() == primitive::kind::lrn); |
139 | // query for descs from pd |
140 | const auto src_desc = pd.src_desc(); |
141 | const auto dst_desc = pd.dst_desc(); |
142 | // query for src_desc via exec arg |
143 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
144 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
145 | // query for dst_desc via exec arg |
146 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
147 | if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); } |
148 | |
149 | // query primitive parameters |
150 | ASSERT_EQ(pd.get_prop_kind(), pk); |
151 | ASSERT_EQ(pd.get_algorithm(), aalgorithm); |
152 | ASSERT_EQ(pd.get_local_size(), p.local_size); |
153 | ASSERT_EQ(pd.get_alpha(), p.alpha); |
154 | ASSERT_EQ(pd.get_beta(), p.beta); |
155 | ASSERT_EQ(pd.get_k(), p.k); |
156 | |
157 | // query for workspace |
158 | const auto workspace_desc = pd.workspace_desc(); |
159 | |
160 | // check primitive returns zero_md for all rest md |
161 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
162 | ASSERT_TRUE(pd.diff_src_desc().is_zero()); |
163 | ASSERT_TRUE(pd.diff_dst_desc().is_zero()); |
164 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
165 | |
166 | src = test::make_memory(src_desc, eng); |
167 | auto dst = test::make_memory(dst_desc, eng); |
168 | workspace = test::make_memory(workspace_desc, eng); |
169 | |
170 | fill_data(p.src_dt, src, 1, 1); |
171 | // test out-place mode |
172 | lrn.execute(strm, |
173 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, |
174 | {DNNL_ARG_WORKSPACE, workspace}}); |
175 | strm.wait(); |
176 | } |
177 | |
178 | void Backward(algorithm aalgorithm) { |
179 | // lrn specific types and values |
180 | using pd_t = lrn_backward::primitive_desc; |
181 | using hint_pd_t = lrn_forward::primitive_desc; |
182 | allows_attr_t aa {false}; // doesn't support anything |
183 | |
184 | auto eng = get_test_engine(); |
185 | auto strm = make_stream(eng); |
186 | auto diff_src_md = memory::desc(p.dims, p.diff_src_dt, p.diff_src_tag); |
187 | auto diff_dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
188 | auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
189 | |
190 | // default pd ctor |
191 | auto pd = pd_t(); |
192 | // regular pd ctor |
193 | pd = pd_t(eng, aalgorithm, diff_src_md, diff_dst_md, src_md, |
194 | p.local_size, p.alpha, p.beta, p.k, *pd_fwd_hint); |
195 | // test all pd ctors |
196 | test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa, |
197 | aalgorithm, diff_src_md, diff_dst_md, src_md, p.local_size, |
198 | p.alpha, p.beta, p.k); |
199 | |
200 | EXPECT_ANY_THROW(lrn_backward(pd, {})); |
201 | // default primitive ctor |
202 | auto lrn = lrn_backward(); |
203 | // regular primitive ctor |
204 | lrn = lrn_backward(pd); |
205 | |
206 | // check primitive kind is lrn |
207 | ASSERT_TRUE(lrn.get_kind() == primitive::kind::lrn); |
208 | |
209 | // query for descs from pd |
210 | const auto diff_src_desc = pd.diff_src_desc(); |
211 | const auto diff_dst_desc = pd.diff_dst_desc(); |
212 | const auto src_desc = pd.src_desc(); |
213 | // query for diff_src_desc via exec arg |
214 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
215 | == diff_src_desc); |
216 | if (p.diff_src_tag != tag::any) { |
217 | ASSERT_TRUE(diff_src_md == diff_src_desc); |
218 | } |
219 | // query for diff_dst_desc via exec arg |
220 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST) |
221 | == diff_dst_desc); |
222 | if (p.dst_tag != tag::any) { |
223 | ASSERT_TRUE(diff_dst_md == diff_dst_desc); |
224 | } |
225 | // query for dst_desc via exec arg |
226 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
227 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
228 | |
229 | // query primitive parameters |
230 | ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward_data); |
231 | ASSERT_EQ(pd.get_algorithm(), aalgorithm); |
232 | ASSERT_EQ(pd.get_local_size(), p.local_size); |
233 | ASSERT_EQ(pd.get_alpha(), p.alpha); |
234 | ASSERT_EQ(pd.get_beta(), p.beta); |
235 | ASSERT_EQ(pd.get_k(), p.k); |
236 | |
237 | // check primitive returns zero_md for all rest md |
238 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
239 | ASSERT_TRUE(pd.dst_desc().is_zero()); |
240 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
241 | |
242 | auto diff_src = test::make_memory(diff_src_desc, eng); |
243 | auto diff_dst = test::make_memory(diff_dst_desc, eng); |
244 | |
245 | fill_data(p.diff_src_dt, diff_dst, 2, 2); |
246 | |
247 | // test out-place mode |
248 | lrn.execute(strm, |
249 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DIFF_DST, diff_dst}, |
250 | {DNNL_ARG_DIFF_SRC, diff_src}, |
251 | {DNNL_ARG_WORKSPACE, workspace}}); |
252 | strm.wait(); |
253 | } |
254 | |
255 | void Test() { |
256 | const std::vector<prop_kind> pks |
257 | = {prop_kind::forward_training, prop_kind::forward_inference}; |
258 | const std::vector<algorithm> algs = { |
259 | algorithm::lrn_across_channels, algorithm::lrn_within_channel}; |
260 | |
261 | for_(auto pk : pks) |
262 | for (auto alg : algs) { |
263 | SKIP_FOR_LOOP_CUDA(alg != algorithm::lrn_across_channels, |
264 | "Unsupported algorithm" ); |
265 | |
266 | Forward(pk, alg); |
267 | if (pk == prop_kind::forward_training |
268 | && p.diff_src_dt != memory::data_type::undef) { |
269 | SKIP_IF(unsupported_data_type(p.diff_src_dt), |
270 | "Engine does not support this data type." ); |
271 | SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_src_tag), |
272 | "Unsupported format tag" ); |
273 | SKIP_IF_HIP(!hip_check_format_tag(p.diff_src_tag), |
274 | "Unsupported format tag" ); |
275 | |
276 | SKIP_IF_CUDA(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef |
277 | && p.diff_src_dt != dt::undef, |
278 | "Unsupported different data types for diff_source and " |
279 | "diff_destination" ); |
280 | SKIP_IF_HIP(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef |
281 | && p.diff_src_dt != dt::undef, |
282 | "Unsupported different data types for diff_source and " |
283 | "diff_destination" ); |
284 | |
285 | SKIP_IF_CUDA(p.src_tag != p.diff_src_tag |
286 | && p.src_tag != tag::any |
287 | && p.diff_src_tag != tag::any, |
288 | "Unsupported different memory formats for diff_source " |
289 | "and diff_destination" ); |
290 | SKIP_IF_HIP(p.src_tag != p.diff_src_tag && p.src_tag != tag::any |
291 | && p.diff_src_tag != tag::any, |
292 | "Unsupported different memory formats for diff_source " |
293 | "and diff_destination" ); |
294 | |
295 | Backward(alg); |
296 | } |
297 | } |
298 | } |
299 | |
300 | bool is_fwd(prop_kind pk) const { |
301 | return pk == prop_kind::forward_training |
302 | || pk == prop_kind::forward_inference; |
303 | } |
304 | }; |
305 | |
306 | using tp = lrn_test_params_t; |
307 | |
308 | TEST_P(lrn_test_t, TestsLRN) {} |
309 | |
310 | INSTANTIATE_TEST_SUITE_P(Test_LRN_EF, lrn_test_t, |
311 | ::testing::Values( |
312 | // Negative dims |
313 | tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nchw, |
314 | tag::undef, {2, -2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, |
315 | true, dnnl_invalid_arguments}, |
316 | // Negative local size |
317 | tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nchw, |
318 | tag::undef, {2, 2, 128, 256}, -1, 1e-4f, 0.75f, 1.f, |
319 | true, dnnl_invalid_arguments}, |
320 | // Tag for src on forward is not specified |
321 | tp {dt::f32, dt::f32, dt::undef, tag::any, tag::nchw, |
322 | tag::undef, {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, |
323 | true, dnnl_invalid_arguments}, |
324 | // Tag for src on backward is not specified |
325 | tp {dt::f32, dt::f32, dt::f32, tag::any, tag::nchw, tag::nchw, |
326 | {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, true, |
327 | dnnl_invalid_arguments}, |
328 | // Data type for src is not specified |
329 | tp {dt::undef, dt::f32, dt::undef, tag::nchw, tag::nchw, |
330 | tag::undef, {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, |
331 | true, dnnl_invalid_arguments}, |
332 | // Different data types are not supported |
333 | tp {dt::f32, dt::bf16, dt::undef, tag::nchw, tag::nchw, |
334 | tag::undef, {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, |
335 | true, dnnl_unimplemented}, |
336 | // Different data types are not supported |
337 | tp {dt::f32, dt::bf16, dt::f32, tag::nchw, tag::nchw, tag::nchw, |
338 | {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, true, |
339 | dnnl_unimplemented}, |
340 | // Different memory formats are not supported |
341 | tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nhwc, |
342 | tag::undef, {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, |
343 | true, dnnl_unimplemented}, |
344 | // Different memory formats are not supported |
345 | tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nhwc, tag::nchw, |
346 | {2, 2, 128, 256}, 5, 1e-4f, 0.75f, 1.f, true, |
347 | dnnl_unimplemented})); |
348 | |
349 | static auto all_cases = [](memory::data_type src_dt, memory::data_type dst_dt, |
350 | memory::data_type diff_src_dt) { |
351 | return ::testing::Values( |
352 | // Forward_nChw16c_padded_cases |
353 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
354 | tag::nChw16c, {2, 17, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f}, |
355 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
356 | tag::nChw16c, {2, 26, 4, 4}, 5, 1.0e-4f, 0.75f, 5.7f}, |
357 | // Forward_nChw8c_padded_cases |
358 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
359 | tag::nChw8c, {2, 7, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f}, |
360 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
361 | tag::nChw8c, {2, 26, 4, 4}, 5, 1.0e-4f, 0.75f, 5.7f}, |
362 | // Forward_cases |
363 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
364 | {2, 10, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f}, |
365 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
366 | {2, 10, 4, 4}, 5, 1.0e-4f, 0.75f, 3.0f}, |
367 | // ForwardNHWC_cases |
368 | tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc, |
369 | {2, 10, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f}, |
370 | tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc, |
371 | {2, 10, 4, 4}, 5, 1.0e-4f, 0.75f, 4.85f}, |
372 | // Forward_nChw8c_cases |
373 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
374 | tag::nChw8c, {2, 16, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f}, |
375 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
376 | tag::nChw8c, {2, 16, 4, 4}, 5, 1.0e-4f, 0.75f, 5.7f}, |
377 | // Forward_nChw16c_cases |
378 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
379 | tag::nChw16c, {2, 16, 4, 4}, 5, 1.0e-4f, 0.75f, 1.0f}, |
380 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
381 | tag::nChw16c, {2, 16, 4, 4}, 5, 1.0e-4f, 0.75f, 5.7f}, |
382 | // AlexnetForwardNCHW_cases |
383 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
384 | {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f}, |
385 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
386 | {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f}, |
387 | // AlexnetForwardNHWC_cases |
388 | tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc, |
389 | {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f}, |
390 | tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc, |
391 | {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f}, |
392 | // AlexnetForward_nChw8c_cases |
393 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
394 | tag::nChw8c, {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f}, |
395 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
396 | tag::nChw8c, {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f}, |
397 | // AlexnetForward_nChw16c_cases |
398 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
399 | tag::nChw16c, {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f}, |
400 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
401 | tag::nChw16c, {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f}, |
402 | // GoogleNetV1ForwardNCHW_cases |
403 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
404 | {2, 64, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f}, |
405 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
406 | {2, 192, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f}, |
407 | // GoogleNetV1Forward_nChw8c_cases |
408 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
409 | tag::nChw8c, {2, 64, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f}, |
410 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
411 | tag::nChw8c, {2, 192, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f}, |
412 | // GoogleNetV1Forward_nChw16c_cases |
413 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
414 | tag::nChw16c, {2, 64, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f}, |
415 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
416 | tag::nChw16c, {2, 192, 56, 56}, 5, 1.0e-4f, 0.75f, 1.0f}, |
417 | // RCNNForwardBlocked_cases |
418 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
419 | tag::nChw8c, {2, 96, 55, 55}, 3, 1.0e-4f, 0.75f, 1.0f}, |
420 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
421 | tag::nChw8c, {2, 256, 27, 27}, 3, 1.0e-4f, 0.75f, 1.0f}, |
422 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
423 | tag::nChw8c, {2, 96, 55, 55}, 5, 1.0e-4f, 0.75f, 1.0f}, |
424 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
425 | tag::nChw8c, {2, 256, 27, 27}, 5, 1.0e-4f, 0.75f, 1.0f}, |
426 | // ForwardNCHWTail_cases |
427 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
428 | {1, 64, 1, 9}, 5, 1.0e-4f, 0.75f, 1.0f}, |
429 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
430 | {1, 64, 2, 9}, 5, 1.0e-4f, 0.75f, 1.0f}, |
431 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
432 | {1, 64, 3, 9}, 5, 1.0e-4f, 0.75f, 1.0f}, |
433 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
434 | {1, 64, 4, 9}, 5, 1.0e-4f, 0.75f, 1.0f}, |
435 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
436 | {1, 64, 5, 9}, 5, 1.0e-4f, 0.75f, 1.0f}, |
437 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
438 | {1, 64, 9, 6}, 5, 1.0e-4f, 0.75f, 1.0f}, |
439 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
440 | {1, 64, 7, 9}, 5, 1.0e-4f, 0.75f, 1.0f}, |
441 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
442 | {1, 64, 8, 9}, 5, 1.0e-4f, 0.75f, 1.0f}); |
443 | }; |
444 | |
445 | #define EXPAND_DTS(src, dst, diff_src) \ |
446 | memory::data_type::src, memory::data_type::dst, memory::data_type::diff_src |
447 | |
448 | #define INST_TEST_CASE(name, suite, ...) \ |
449 | INSTANTIATE_TEST_SUITE_P(name, lrn_test_t, suite(__VA_ARGS__)); |
450 | |
451 | #define CPU_INST_TEST_CASE(name, suite, ...) \ |
452 | CPU_INSTANTIATE_TEST_SUITE_P(name, lrn_test_t, suite(__VA_ARGS__)); |
453 | |
454 | #define GPU_INST_TEST_CASE(name, suite, ...) \ |
455 | GPU_INSTANTIATE_TEST_SUITE_P(name, lrn_test_t, suite(__VA_ARGS__)); |
456 | |
457 | INST_TEST_CASE(LRNSimpleF32, all_cases, EXPAND_DTS(f32, f32, f32)); |
458 | INST_TEST_CASE(LRNSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16, bf16)); |
459 | INST_TEST_CASE(LRNSimpleF16, all_cases, EXPAND_DTS(f16, f16, undef)); |
460 | } // namespace dnnl |
461 | |