1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | using tag = memory::format_tag; |
25 | using dt = memory::data_type; |
26 | |
27 | struct batch_normalization_test_params_t { |
28 | dt src_dt; |
29 | dt dst_dt; // diff_dst_dt |
30 | dt diff_src_dt; |
31 | tag src_tag; |
32 | tag dst_tag; // diff_dst_tag |
33 | tag diff_src_tag; |
34 | memory::dims dims; |
35 | bool expect_to_fail; |
36 | dnnl_status_t expected_status; |
37 | }; |
38 | |
39 | bool cuda_check_format_tag(tag atag) { |
40 | return impl::utils::one_of(atag, tag::ncdhw, tag::ndhwc, tag::nchw, |
41 | tag::nhwc, tag::ncw, tag::nwc, tag::any); |
42 | } |
43 | |
44 | template <typename... Rest> |
45 | bool cuda_check_format_tag(tag first_tag, Rest... rest_tags) { |
46 | const bool ok = cuda_check_format_tag(first_tag); |
47 | if (!ok) return ok; |
48 | return cuda_check_format_tag(rest_tags...); |
49 | } |
50 | |
51 | class batch_normalization_test_t |
52 | : public ::testing::TestWithParam<batch_normalization_test_params_t> { |
53 | private: |
54 | batch_normalization_test_params_t p; |
55 | memory src, workspace, mean, variance, scale; |
56 | std::shared_ptr<batch_normalization_forward::primitive_desc> pd_fwd_hint; |
57 | |
58 | protected: |
59 | void SetUp() override { |
60 | p = ::testing::TestWithParam< |
61 | batch_normalization_test_params_t>::GetParam(); |
62 | |
63 | SKIP_IF(unsupported_data_type(p.src_dt, p.dst_dt), |
64 | "Engine does not support this data type." ); |
65 | |
66 | SKIP_IF_CUDA(!cuda_check_format_tag(p.src_tag, p.dst_tag), |
67 | "Unsupported format tag" ); |
68 | |
69 | SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef |
70 | && p.dst_dt != dt::undef, |
71 | "Unsupported different data types for source and " |
72 | "destination" ); |
73 | |
74 | SKIP_IF_CUDA(p.src_tag != p.dst_tag && p.src_tag != tag::any |
75 | && p.dst_tag != tag::any, |
76 | "Unsupported different memory formats for source and " |
77 | "destination" ); |
78 | |
79 | catch_expected_failures( |
80 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
81 | } |
82 | |
83 | void Forward(prop_kind pk, normalization_flags flags) { |
84 | // batch_normalization specific types and values |
85 | using pd_t = batch_normalization_forward::primitive_desc; |
86 | |
87 | auto eng = get_test_engine(); |
88 | auto strm = make_stream(eng); |
89 | |
90 | auto aa = allows_attr_t {false}; |
91 | aa.po_eltwise = !is_amd_gpu(eng); |
92 | |
93 | auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
94 | auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
95 | |
96 | // default pd ctor |
97 | auto pd = pd_t(); |
98 | // regular pd ctor |
99 | pd = pd_t(eng, pk, src_md, dst_md, /* epsilon = */ 1e-4f, flags); |
100 | // test all pd ctors |
101 | test_fwd_pd_constructors<pd_t>(pd, aa, pk, src_md, dst_md, |
102 | /* epsilon = */ 1e-4f, flags); |
103 | pd_fwd_hint = std::make_shared<pd_t>(pd); |
104 | |
105 | EXPECT_ANY_THROW(batch_normalization_forward(pd, {})); |
106 | // default primitive ctor |
107 | auto batch_normalization = batch_normalization_forward(); |
108 | // regular primitive ctor |
109 | batch_normalization = batch_normalization_forward(pd); |
110 | |
111 | // check primitive kind is batch_normalization |
112 | ASSERT_TRUE(batch_normalization.get_kind() |
113 | == primitive::kind::batch_normalization); |
114 | // query for descs from pd |
115 | const auto src_desc = pd.src_desc(); |
116 | const auto dst_desc = pd.dst_desc(); |
117 | // query for src_desc via exec arg |
118 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
119 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
120 | // query for dst_desc via exec arg |
121 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
122 | if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); } |
123 | // query for stats and scales via exec arg |
124 | const auto mean_desc = pd.mean_desc(); |
125 | const auto variance_desc = pd.variance_desc(); |
126 | const auto scale_desc = pd.weights_desc(); |
127 | const auto shift_desc = pd.weights_desc(); |
128 | ASSERT_TRUE( |
129 | pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN) == mean_desc); |
130 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE) |
131 | == variance_desc); |
132 | |
133 | if (has_scale(flags)) { |
134 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SCALE) |
135 | == scale_desc); |
136 | } |
137 | if (has_shift(flags)) { |
138 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SHIFT) |
139 | == shift_desc); |
140 | } |
141 | |
142 | // query for workspace |
143 | const auto workspace_desc = pd.workspace_desc(); |
144 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WORKSPACE) |
145 | == workspace_desc); |
146 | |
147 | // query primitive parameters |
148 | ASSERT_EQ(pd.get_prop_kind(), pk); |
149 | ASSERT_EQ(pd.get_flags(), flags); |
150 | ASSERT_EQ(pd.get_epsilon(), 1e-4f); |
151 | |
152 | // check primitive returns zero_md for all rest md |
153 | if (!has_scale(flags) && !has_shift(flags)) { |
154 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
155 | } |
156 | ASSERT_TRUE(pd.diff_src_desc().is_zero()); |
157 | ASSERT_TRUE(pd.diff_dst_desc().is_zero()); |
158 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
159 | |
160 | src = test::make_memory(src_desc, eng); |
161 | auto dst = test::make_memory(dst_desc, eng); |
162 | workspace = test::make_memory(workspace_desc, eng); |
163 | mean = test::make_memory(mean_desc, eng); |
164 | variance = test::make_memory(variance_desc, eng); |
165 | scale = test::make_memory(scale_desc, eng); |
166 | auto shift = test::make_memory(shift_desc, eng); |
167 | |
168 | fill_data(p.src_dt, src, 1, 1); |
169 | if (has_scale(flags)) fill_data(dt::f32, scale, 1, 1); |
170 | if (has_shift(flags)) fill_data(dt::f32, shift, 1, 1); |
171 | if (use_global_stats(flags)) { |
172 | fill_data(dt::f32, mean, 1, 1); |
173 | fill_data(dt::f32, variance, 1, 1); |
174 | } |
175 | // test out-place mode |
176 | batch_normalization.execute(strm, |
177 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, |
178 | {DNNL_ARG_MEAN, mean}, {DNNL_ARG_VARIANCE, variance}, |
179 | {DNNL_ARG_SCALE, scale}, {DNNL_ARG_SHIFT, shift}, |
180 | {DNNL_ARG_WORKSPACE, workspace}}); |
181 | strm.wait(); |
182 | |
183 | // test in-place mode on forward |
184 | if (p.src_tag == p.dst_tag && p.src_dt == p.dst_dt) { |
185 | // TODO: add a copy of memory and result comparison with previous |
186 | // dst output. |
187 | batch_normalization.execute(strm, |
188 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, src}, |
189 | {DNNL_ARG_MEAN, mean}, |
190 | {DNNL_ARG_VARIANCE, variance}, |
191 | {DNNL_ARG_SCALE, scale}, {DNNL_ARG_SHIFT, shift}, |
192 | {DNNL_ARG_WORKSPACE, workspace}}); |
193 | strm.wait(); |
194 | } |
195 | } |
196 | |
197 | void Backward(prop_kind pk, normalization_flags flags) { |
198 | // batch_normalization specific types and values |
199 | using pd_t = batch_normalization_backward::primitive_desc; |
200 | using hint_pd_t = batch_normalization_forward::primitive_desc; |
201 | allows_attr_t aa {false}; // doesn't support anything |
202 | |
203 | auto eng = get_test_engine(); |
204 | auto strm = make_stream(eng); |
205 | |
206 | auto diff_src_md = memory::desc(p.dims, p.diff_src_dt, p.diff_src_tag); |
207 | auto diff_dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
208 | auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
209 | |
210 | // default pd ctor |
211 | auto pd = pd_t(); |
212 | // regular pd ctor |
213 | pd = pd_t(eng, pk, diff_src_md, diff_dst_md, src_md, |
214 | /* epsilon = */ 1e-4f, flags, *pd_fwd_hint); |
215 | // test all pd ctors |
216 | test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa, pk, |
217 | diff_src_md, diff_dst_md, src_md, |
218 | /* epsilon = */ 1e-4f, flags); |
219 | |
220 | EXPECT_ANY_THROW(batch_normalization_backward(pd, {})); |
221 | // default primitive ctor |
222 | auto batch_normalization = batch_normalization_backward(); |
223 | // regular primitive ctor |
224 | batch_normalization = batch_normalization_backward(pd); |
225 | |
226 | // check primitive kind is batch_normalization |
227 | ASSERT_TRUE(batch_normalization.get_kind() |
228 | == primitive::kind::batch_normalization); |
229 | |
230 | // query for descs from pd |
231 | const auto diff_src_desc = pd.diff_src_desc(); |
232 | const auto diff_dst_desc = pd.diff_dst_desc(); |
233 | const auto src_desc = pd.src_desc(); |
234 | // query for diff_src_desc via exec arg |
235 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
236 | == diff_src_desc); |
237 | if (p.diff_src_tag != tag::any) { |
238 | ASSERT_TRUE(diff_src_md == diff_src_desc); |
239 | } |
240 | // query for diff_dst_desc via exec arg |
241 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST) |
242 | == diff_dst_desc); |
243 | if (p.dst_tag != tag::any) { |
244 | ASSERT_TRUE(diff_dst_md == diff_dst_desc); |
245 | } |
246 | // query for dst_desc via exec arg |
247 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
248 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
249 | |
250 | // query for stats and scales via exec arg |
251 | const auto mean_desc = pd.mean_desc(); |
252 | const auto variance_desc = pd.variance_desc(); |
253 | const auto scale_desc = pd.weights_desc(); |
254 | const auto diff_scale_desc = pd.diff_weights_desc(); |
255 | const auto diff_shift_desc = pd.diff_weights_desc(); |
256 | ASSERT_TRUE( |
257 | pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN) == mean_desc); |
258 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE) |
259 | == variance_desc); |
260 | |
261 | if (has_scale(flags)) { |
262 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SCALE) |
263 | == scale_desc); |
264 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SCALE) |
265 | == diff_scale_desc); |
266 | } |
267 | if (has_shift(flags)) { |
268 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SHIFT) |
269 | == diff_shift_desc); |
270 | } |
271 | |
272 | // query for workspace |
273 | const auto workspace_desc = pd.workspace_desc(); |
274 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WORKSPACE) |
275 | == workspace_desc); |
276 | |
277 | // query primitive parameters |
278 | ASSERT_EQ(pd.get_prop_kind(), pk); |
279 | ASSERT_EQ(pd.get_flags(), flags); |
280 | ASSERT_EQ(pd.get_epsilon(), 1e-4f); |
281 | |
282 | // check primitive returns zero_md for all rest md |
283 | ASSERT_TRUE(pd.dst_desc().is_zero()); |
284 | if (!has_scale(flags) && !has_shift(flags)) { |
285 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
286 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
287 | } |
288 | |
289 | auto diff_src = test::make_memory(diff_src_desc, eng); |
290 | auto diff_dst = test::make_memory(diff_dst_desc, eng); |
291 | auto diff_scale = test::make_memory(diff_scale_desc, eng); |
292 | auto diff_shift = test::make_memory(diff_shift_desc, eng); |
293 | |
294 | fill_data(p.dst_dt, diff_dst, 2, 2); |
295 | |
296 | // test out-place mode |
297 | batch_normalization.execute(strm, |
298 | {{DNNL_ARG_DIFF_SRC, diff_src}, {DNNL_ARG_DIFF_DST, diff_dst}, |
299 | {DNNL_ARG_SRC, src}, {DNNL_ARG_MEAN, mean}, |
300 | {DNNL_ARG_VARIANCE, variance}, |
301 | {DNNL_ARG_DIFF_SCALE, diff_scale}, |
302 | {DNNL_ARG_DIFF_SHIFT, diff_shift}, |
303 | {DNNL_ARG_SCALE, scale}, |
304 | {DNNL_ARG_WORKSPACE, workspace}}); |
305 | strm.wait(); |
306 | |
307 | // test in-place mode |
308 | if (p.dst_tag == p.diff_src_tag && p.dst_dt == p.diff_src_dt) { |
309 | batch_normalization.execute(strm, |
310 | {{DNNL_ARG_DIFF_SRC, diff_dst}, |
311 | {DNNL_ARG_DIFF_DST, diff_dst}, {DNNL_ARG_SRC, src}, |
312 | {DNNL_ARG_MEAN, mean}, |
313 | {DNNL_ARG_VARIANCE, variance}, |
314 | {DNNL_ARG_DIFF_SCALE, diff_scale}, |
315 | {DNNL_ARG_DIFF_SHIFT, diff_shift}, |
316 | {DNNL_ARG_SCALE, scale}, |
317 | {DNNL_ARG_WORKSPACE, workspace}}); |
318 | strm.wait(); |
319 | } |
320 | } |
321 | |
322 | void Test() { |
323 | using nf = normalization_flags; |
324 | std::vector<normalization_flags> inference_flags {nf::none, |
325 | nf::use_global_stats, nf::use_scale, nf::use_shift, |
326 | nf::use_global_stats | nf::use_scale, |
327 | nf::use_global_stats | nf::use_shift, |
328 | nf::use_scale | nf::use_shift, |
329 | nf::use_global_stats | nf::use_scale | nf::use_shift}; |
330 | |
331 | for (auto flags : inference_flags) { |
332 | SKIP_FOR_LOOP(p.src_dt == dt::s8 && !use_global_stats(flags), |
333 | "s8 doesn't support anything but use_global_stats flag" ); |
334 | SKIP_FOR_LOOP_CUDA(p.src_dt == dt::f16 && !use_global_stats(flags), |
335 | "s8 doesn't support anything but use_global_stats flag" ); |
336 | Forward(prop_kind::forward_inference, flags); |
337 | } |
338 | |
339 | // No training for int8. |
340 | if (p.src_dt == dt::s8) return; |
341 | // No training for cuda dor f16. |
342 | if (is_nvidia_gpu(get_test_engine()) && p.src_dt == dt::f16) return; |
343 | |
344 | // TODO: add fuse_norm_add_relu |
345 | std::vector<normalization_flags> training_flags {nf::none, |
346 | nf::fuse_norm_relu, nf::use_scale, nf::use_shift, |
347 | nf::fuse_norm_relu | nf::use_scale, |
348 | nf::fuse_norm_relu | nf::use_shift, |
349 | nf::use_scale | nf::use_shift, |
350 | nf::fuse_norm_relu | nf::use_scale | nf::use_shift}; |
351 | |
352 | for (auto flags : training_flags) { |
353 | Forward(prop_kind::forward_training, flags); |
354 | |
355 | if (p.diff_src_dt != dt::undef) { |
356 | SKIP_IF(unsupported_data_type(p.diff_src_dt), |
357 | "Engine does not support this data type." ); |
358 | SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_src_tag), |
359 | "Unsupported format tag" ); |
360 | |
361 | SKIP_IF_CUDA(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef |
362 | && p.diff_src_dt != dt::undef, |
363 | "Unsupported different data types for diff_source and " |
364 | "diff_destination" ); |
365 | |
366 | SKIP_IF_CUDA(p.src_tag != p.diff_src_tag |
367 | && p.src_tag != tag::any |
368 | && p.diff_src_tag != tag::any, |
369 | "Unsupported different memory formats for diff_source " |
370 | "and diff_destination" ); |
371 | |
372 | const prop_kind bwd_pk = (has_scale(flags) || has_shift(flags)) |
373 | ? prop_kind::backward |
374 | : prop_kind::backward_data; |
375 | Backward(bwd_pk, flags); |
376 | } |
377 | } |
378 | } |
379 | |
380 | bool use_global_stats(normalization_flags flags) const { |
381 | return static_cast<bool>(flags & normalization_flags::use_global_stats); |
382 | } |
383 | |
384 | bool has_scale(normalization_flags flags) const { |
385 | return static_cast<bool>(flags & normalization_flags::use_scale); |
386 | } |
387 | |
388 | bool has_shift(normalization_flags flags) const { |
389 | return static_cast<bool>(flags & normalization_flags::use_shift); |
390 | } |
391 | |
392 | bool is_training(prop_kind pk) const { |
393 | return pk == prop_kind::forward_training; |
394 | } |
395 | }; |
396 | |
397 | using tp = batch_normalization_test_params_t; |
398 | |
399 | TEST_P(batch_normalization_test_t, TestsBatchNormalization) {} |
400 | |
401 | INSTANTIATE_TEST_SUITE_P(Test_BatchNormalization_EF, batch_normalization_test_t, |
402 | ::testing::Values( |
403 | // Negative dims |
404 | tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nchw, |
405 | tag::undef, {2, -2, 128, 256}, true, |
406 | dnnl_invalid_arguments}, |
407 | // Tag for src on forward is not specified |
408 | tp {dt::f32, dt::f32, dt::undef, tag::any, tag::nchw, |
409 | tag::undef, {2, 2, 128, 256}, true, |
410 | dnnl_invalid_arguments}, |
411 | // Tag for src on backward is not specified |
412 | tp {dt::f32, dt::f32, dt::f32, tag::any, tag::nchw, tag::nchw, |
413 | {2, 2, 128, 256}, true, dnnl_invalid_arguments}, |
414 | // Data type for src is not specified |
415 | tp {dt::undef, dt::f32, dt::undef, tag::nchw, tag::nchw, |
416 | tag::undef, {2, 2, 128, 256}, true, |
417 | dnnl_invalid_arguments}, |
418 | // Different data types are not supported |
419 | tp {dt::f32, dt::bf16, dt::f32, tag::nchw, tag::nchw, tag::nchw, |
420 | {2, 2, 128, 256}, true, dnnl_unimplemented}, |
421 | // Different memory formats are not supported |
422 | tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nhwc, tag::nchw, |
423 | {2, 2, 128, 256}, true, dnnl_unimplemented})); |
424 | |
425 | static auto all_cases = [](memory::data_type src_dt, memory::data_type dst_dt, |
426 | memory::data_type diff_src_dt) { |
427 | return ::testing::Values( |
428 | tp {src_dt, dst_dt, diff_src_dt, tag::nCdhw16c, tag::nCdhw16c, |
429 | tag::nCdhw16c, {2, 17, 5, 4, 4}}, |
430 | tp {src_dt, dst_dt, diff_src_dt, tag::ncdhw, tag::ncdhw, tag::ncdhw, |
431 | {2, 7, 3, 4, 4}}, |
432 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
433 | tag::nChw16c, {2, 17, 4, 4}}, |
434 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
435 | tag::nChw8c, {2, 7, 4, 4}}, |
436 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
437 | {2, 10, 4, 4}}, |
438 | tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc, |
439 | {2, 10, 4, 4}}, |
440 | tp {src_dt, dst_dt, diff_src_dt, tag::nCw8c, tag::nCw8c, tag::nCw8c, |
441 | {2, 7, 4}}, |
442 | tp {src_dt, dst_dt, diff_src_dt, tag::nwc, tag::nwc, tag::nwc, |
443 | {2, 10, 4}}); |
444 | }; // namespace dnnl |
445 | |
446 | #define EXPAND_DTS(src, dst, diff_src) \ |
447 | memory::data_type::src, memory::data_type::dst, memory::data_type::diff_src |
448 | |
449 | #define INST_TEST_CASE(name, suite, ...) \ |
450 | INSTANTIATE_TEST_SUITE_P( \ |
451 | name, batch_normalization_test_t, suite(__VA_ARGS__)); |
452 | |
453 | #define CPU_INST_TEST_CASE(name, suite, ...) \ |
454 | CPU_INSTANTIATE_TEST_SUITE_P( \ |
455 | name, batch_normalization_test_t, suite(__VA_ARGS__)); |
456 | |
457 | #define GPU_INST_TEST_CASE(name, suite, ...) \ |
458 | GPU_INSTANTIATE_TEST_SUITE_P( \ |
459 | name, batch_normalization_test_t, suite(__VA_ARGS__)); |
460 | |
461 | INST_TEST_CASE( |
462 | BatchNormalizationSimpleF32, all_cases, EXPAND_DTS(f32, f32, f32)); |
463 | INST_TEST_CASE( |
464 | BatchNormalizationSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16, bf16)); |
465 | INST_TEST_CASE( |
466 | BatchNormalizationSimpleF16, all_cases, EXPAND_DTS(f16, f16, undef)); |
467 | INST_TEST_CASE( |
468 | BatchNormalizationSimpleS8, all_cases, EXPAND_DTS(s8, s8, undef)); |
469 | } // namespace dnnl |
470 | |