1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24using tag = memory::format_tag;
25using dt = memory::data_type;
26
27struct batch_normalization_test_params_t {
28 dt src_dt;
29 dt dst_dt; // diff_dst_dt
30 dt diff_src_dt;
31 tag src_tag;
32 tag dst_tag; // diff_dst_tag
33 tag diff_src_tag;
34 memory::dims dims;
35 bool expect_to_fail;
36 dnnl_status_t expected_status;
37};
38
39bool cuda_check_format_tag(tag atag) {
40 return impl::utils::one_of(atag, tag::ncdhw, tag::ndhwc, tag::nchw,
41 tag::nhwc, tag::ncw, tag::nwc, tag::any);
42}
43
44template <typename... Rest>
45bool cuda_check_format_tag(tag first_tag, Rest... rest_tags) {
46 const bool ok = cuda_check_format_tag(first_tag);
47 if (!ok) return ok;
48 return cuda_check_format_tag(rest_tags...);
49}
50
51class batch_normalization_test_t
52 : public ::testing::TestWithParam<batch_normalization_test_params_t> {
53private:
54 batch_normalization_test_params_t p;
55 memory src, workspace, mean, variance, scale;
56 std::shared_ptr<batch_normalization_forward::primitive_desc> pd_fwd_hint;
57
58protected:
59 void SetUp() override {
60 p = ::testing::TestWithParam<
61 batch_normalization_test_params_t>::GetParam();
62
63 SKIP_IF(unsupported_data_type(p.src_dt, p.dst_dt),
64 "Engine does not support this data type.");
65
66 SKIP_IF_CUDA(!cuda_check_format_tag(p.src_tag, p.dst_tag),
67 "Unsupported format tag");
68
69 SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef
70 && p.dst_dt != dt::undef,
71 "Unsupported different data types for source and "
72 "destination");
73
74 SKIP_IF_CUDA(p.src_tag != p.dst_tag && p.src_tag != tag::any
75 && p.dst_tag != tag::any,
76 "Unsupported different memory formats for source and "
77 "destination");
78
79 catch_expected_failures(
80 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
81 }
82
83 void Forward(prop_kind pk, normalization_flags flags) {
84 // batch_normalization specific types and values
85 using pd_t = batch_normalization_forward::primitive_desc;
86
87 auto eng = get_test_engine();
88 auto strm = make_stream(eng);
89
90 auto aa = allows_attr_t {false};
91 aa.po_eltwise = !is_amd_gpu(eng);
92
93 auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
94 auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
95
96 // default pd ctor
97 auto pd = pd_t();
98 // regular pd ctor
99 pd = pd_t(eng, pk, src_md, dst_md, /* epsilon = */ 1e-4f, flags);
100 // test all pd ctors
101 test_fwd_pd_constructors<pd_t>(pd, aa, pk, src_md, dst_md,
102 /* epsilon = */ 1e-4f, flags);
103 pd_fwd_hint = std::make_shared<pd_t>(pd);
104
105 EXPECT_ANY_THROW(batch_normalization_forward(pd, {}));
106 // default primitive ctor
107 auto batch_normalization = batch_normalization_forward();
108 // regular primitive ctor
109 batch_normalization = batch_normalization_forward(pd);
110
111 // check primitive kind is batch_normalization
112 ASSERT_TRUE(batch_normalization.get_kind()
113 == primitive::kind::batch_normalization);
114 // query for descs from pd
115 const auto src_desc = pd.src_desc();
116 const auto dst_desc = pd.dst_desc();
117 // query for src_desc via exec arg
118 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
119 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
120 // query for dst_desc via exec arg
121 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
122 if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); }
123 // query for stats and scales via exec arg
124 const auto mean_desc = pd.mean_desc();
125 const auto variance_desc = pd.variance_desc();
126 const auto scale_desc = pd.weights_desc();
127 const auto shift_desc = pd.weights_desc();
128 ASSERT_TRUE(
129 pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN) == mean_desc);
130 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE)
131 == variance_desc);
132
133 if (has_scale(flags)) {
134 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SCALE)
135 == scale_desc);
136 }
137 if (has_shift(flags)) {
138 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SHIFT)
139 == shift_desc);
140 }
141
142 // query for workspace
143 const auto workspace_desc = pd.workspace_desc();
144 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WORKSPACE)
145 == workspace_desc);
146
147 // query primitive parameters
148 ASSERT_EQ(pd.get_prop_kind(), pk);
149 ASSERT_EQ(pd.get_flags(), flags);
150 ASSERT_EQ(pd.get_epsilon(), 1e-4f);
151
152 // check primitive returns zero_md for all rest md
153 if (!has_scale(flags) && !has_shift(flags)) {
154 ASSERT_TRUE(pd.weights_desc().is_zero());
155 }
156 ASSERT_TRUE(pd.diff_src_desc().is_zero());
157 ASSERT_TRUE(pd.diff_dst_desc().is_zero());
158 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
159
160 src = test::make_memory(src_desc, eng);
161 auto dst = test::make_memory(dst_desc, eng);
162 workspace = test::make_memory(workspace_desc, eng);
163 mean = test::make_memory(mean_desc, eng);
164 variance = test::make_memory(variance_desc, eng);
165 scale = test::make_memory(scale_desc, eng);
166 auto shift = test::make_memory(shift_desc, eng);
167
168 fill_data(p.src_dt, src, 1, 1);
169 if (has_scale(flags)) fill_data(dt::f32, scale, 1, 1);
170 if (has_shift(flags)) fill_data(dt::f32, shift, 1, 1);
171 if (use_global_stats(flags)) {
172 fill_data(dt::f32, mean, 1, 1);
173 fill_data(dt::f32, variance, 1, 1);
174 }
175 // test out-place mode
176 batch_normalization.execute(strm,
177 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst},
178 {DNNL_ARG_MEAN, mean}, {DNNL_ARG_VARIANCE, variance},
179 {DNNL_ARG_SCALE, scale}, {DNNL_ARG_SHIFT, shift},
180 {DNNL_ARG_WORKSPACE, workspace}});
181 strm.wait();
182
183 // test in-place mode on forward
184 if (p.src_tag == p.dst_tag && p.src_dt == p.dst_dt) {
185 // TODO: add a copy of memory and result comparison with previous
186 // dst output.
187 batch_normalization.execute(strm,
188 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, src},
189 {DNNL_ARG_MEAN, mean},
190 {DNNL_ARG_VARIANCE, variance},
191 {DNNL_ARG_SCALE, scale}, {DNNL_ARG_SHIFT, shift},
192 {DNNL_ARG_WORKSPACE, workspace}});
193 strm.wait();
194 }
195 }
196
197 void Backward(prop_kind pk, normalization_flags flags) {
198 // batch_normalization specific types and values
199 using pd_t = batch_normalization_backward::primitive_desc;
200 using hint_pd_t = batch_normalization_forward::primitive_desc;
201 allows_attr_t aa {false}; // doesn't support anything
202
203 auto eng = get_test_engine();
204 auto strm = make_stream(eng);
205
206 auto diff_src_md = memory::desc(p.dims, p.diff_src_dt, p.diff_src_tag);
207 auto diff_dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
208 auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
209
210 // default pd ctor
211 auto pd = pd_t();
212 // regular pd ctor
213 pd = pd_t(eng, pk, diff_src_md, diff_dst_md, src_md,
214 /* epsilon = */ 1e-4f, flags, *pd_fwd_hint);
215 // test all pd ctors
216 test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa, pk,
217 diff_src_md, diff_dst_md, src_md,
218 /* epsilon = */ 1e-4f, flags);
219
220 EXPECT_ANY_THROW(batch_normalization_backward(pd, {}));
221 // default primitive ctor
222 auto batch_normalization = batch_normalization_backward();
223 // regular primitive ctor
224 batch_normalization = batch_normalization_backward(pd);
225
226 // check primitive kind is batch_normalization
227 ASSERT_TRUE(batch_normalization.get_kind()
228 == primitive::kind::batch_normalization);
229
230 // query for descs from pd
231 const auto diff_src_desc = pd.diff_src_desc();
232 const auto diff_dst_desc = pd.diff_dst_desc();
233 const auto src_desc = pd.src_desc();
234 // query for diff_src_desc via exec arg
235 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
236 == diff_src_desc);
237 if (p.diff_src_tag != tag::any) {
238 ASSERT_TRUE(diff_src_md == diff_src_desc);
239 }
240 // query for diff_dst_desc via exec arg
241 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
242 == diff_dst_desc);
243 if (p.dst_tag != tag::any) {
244 ASSERT_TRUE(diff_dst_md == diff_dst_desc);
245 }
246 // query for dst_desc via exec arg
247 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
248 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
249
250 // query for stats and scales via exec arg
251 const auto mean_desc = pd.mean_desc();
252 const auto variance_desc = pd.variance_desc();
253 const auto scale_desc = pd.weights_desc();
254 const auto diff_scale_desc = pd.diff_weights_desc();
255 const auto diff_shift_desc = pd.diff_weights_desc();
256 ASSERT_TRUE(
257 pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN) == mean_desc);
258 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE)
259 == variance_desc);
260
261 if (has_scale(flags)) {
262 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SCALE)
263 == scale_desc);
264 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SCALE)
265 == diff_scale_desc);
266 }
267 if (has_shift(flags)) {
268 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SHIFT)
269 == diff_shift_desc);
270 }
271
272 // query for workspace
273 const auto workspace_desc = pd.workspace_desc();
274 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_WORKSPACE)
275 == workspace_desc);
276
277 // query primitive parameters
278 ASSERT_EQ(pd.get_prop_kind(), pk);
279 ASSERT_EQ(pd.get_flags(), flags);
280 ASSERT_EQ(pd.get_epsilon(), 1e-4f);
281
282 // check primitive returns zero_md for all rest md
283 ASSERT_TRUE(pd.dst_desc().is_zero());
284 if (!has_scale(flags) && !has_shift(flags)) {
285 ASSERT_TRUE(pd.weights_desc().is_zero());
286 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
287 }
288
289 auto diff_src = test::make_memory(diff_src_desc, eng);
290 auto diff_dst = test::make_memory(diff_dst_desc, eng);
291 auto diff_scale = test::make_memory(diff_scale_desc, eng);
292 auto diff_shift = test::make_memory(diff_shift_desc, eng);
293
294 fill_data(p.dst_dt, diff_dst, 2, 2);
295
296 // test out-place mode
297 batch_normalization.execute(strm,
298 {{DNNL_ARG_DIFF_SRC, diff_src}, {DNNL_ARG_DIFF_DST, diff_dst},
299 {DNNL_ARG_SRC, src}, {DNNL_ARG_MEAN, mean},
300 {DNNL_ARG_VARIANCE, variance},
301 {DNNL_ARG_DIFF_SCALE, diff_scale},
302 {DNNL_ARG_DIFF_SHIFT, diff_shift},
303 {DNNL_ARG_SCALE, scale},
304 {DNNL_ARG_WORKSPACE, workspace}});
305 strm.wait();
306
307 // test in-place mode
308 if (p.dst_tag == p.diff_src_tag && p.dst_dt == p.diff_src_dt) {
309 batch_normalization.execute(strm,
310 {{DNNL_ARG_DIFF_SRC, diff_dst},
311 {DNNL_ARG_DIFF_DST, diff_dst}, {DNNL_ARG_SRC, src},
312 {DNNL_ARG_MEAN, mean},
313 {DNNL_ARG_VARIANCE, variance},
314 {DNNL_ARG_DIFF_SCALE, diff_scale},
315 {DNNL_ARG_DIFF_SHIFT, diff_shift},
316 {DNNL_ARG_SCALE, scale},
317 {DNNL_ARG_WORKSPACE, workspace}});
318 strm.wait();
319 }
320 }
321
322 void Test() {
323 using nf = normalization_flags;
324 std::vector<normalization_flags> inference_flags {nf::none,
325 nf::use_global_stats, nf::use_scale, nf::use_shift,
326 nf::use_global_stats | nf::use_scale,
327 nf::use_global_stats | nf::use_shift,
328 nf::use_scale | nf::use_shift,
329 nf::use_global_stats | nf::use_scale | nf::use_shift};
330
331 for (auto flags : inference_flags) {
332 SKIP_FOR_LOOP(p.src_dt == dt::s8 && !use_global_stats(flags),
333 "s8 doesn't support anything but use_global_stats flag");
334 SKIP_FOR_LOOP_CUDA(p.src_dt == dt::f16 && !use_global_stats(flags),
335 "s8 doesn't support anything but use_global_stats flag");
336 Forward(prop_kind::forward_inference, flags);
337 }
338
339 // No training for int8.
340 if (p.src_dt == dt::s8) return;
341 // No training for cuda dor f16.
342 if (is_nvidia_gpu(get_test_engine()) && p.src_dt == dt::f16) return;
343
344 // TODO: add fuse_norm_add_relu
345 std::vector<normalization_flags> training_flags {nf::none,
346 nf::fuse_norm_relu, nf::use_scale, nf::use_shift,
347 nf::fuse_norm_relu | nf::use_scale,
348 nf::fuse_norm_relu | nf::use_shift,
349 nf::use_scale | nf::use_shift,
350 nf::fuse_norm_relu | nf::use_scale | nf::use_shift};
351
352 for (auto flags : training_flags) {
353 Forward(prop_kind::forward_training, flags);
354
355 if (p.diff_src_dt != dt::undef) {
356 SKIP_IF(unsupported_data_type(p.diff_src_dt),
357 "Engine does not support this data type.");
358 SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_src_tag),
359 "Unsupported format tag");
360
361 SKIP_IF_CUDA(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef
362 && p.diff_src_dt != dt::undef,
363 "Unsupported different data types for diff_source and "
364 "diff_destination");
365
366 SKIP_IF_CUDA(p.src_tag != p.diff_src_tag
367 && p.src_tag != tag::any
368 && p.diff_src_tag != tag::any,
369 "Unsupported different memory formats for diff_source "
370 "and diff_destination");
371
372 const prop_kind bwd_pk = (has_scale(flags) || has_shift(flags))
373 ? prop_kind::backward
374 : prop_kind::backward_data;
375 Backward(bwd_pk, flags);
376 }
377 }
378 }
379
380 bool use_global_stats(normalization_flags flags) const {
381 return static_cast<bool>(flags & normalization_flags::use_global_stats);
382 }
383
384 bool has_scale(normalization_flags flags) const {
385 return static_cast<bool>(flags & normalization_flags::use_scale);
386 }
387
388 bool has_shift(normalization_flags flags) const {
389 return static_cast<bool>(flags & normalization_flags::use_shift);
390 }
391
392 bool is_training(prop_kind pk) const {
393 return pk == prop_kind::forward_training;
394 }
395};
396
397using tp = batch_normalization_test_params_t;
398
399TEST_P(batch_normalization_test_t, TestsBatchNormalization) {}
400
401INSTANTIATE_TEST_SUITE_P(Test_BatchNormalization_EF, batch_normalization_test_t,
402 ::testing::Values(
403 // Negative dims
404 tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nchw,
405 tag::undef, {2, -2, 128, 256}, true,
406 dnnl_invalid_arguments},
407 // Tag for src on forward is not specified
408 tp {dt::f32, dt::f32, dt::undef, tag::any, tag::nchw,
409 tag::undef, {2, 2, 128, 256}, true,
410 dnnl_invalid_arguments},
411 // Tag for src on backward is not specified
412 tp {dt::f32, dt::f32, dt::f32, tag::any, tag::nchw, tag::nchw,
413 {2, 2, 128, 256}, true, dnnl_invalid_arguments},
414 // Data type for src is not specified
415 tp {dt::undef, dt::f32, dt::undef, tag::nchw, tag::nchw,
416 tag::undef, {2, 2, 128, 256}, true,
417 dnnl_invalid_arguments},
418 // Different data types are not supported
419 tp {dt::f32, dt::bf16, dt::f32, tag::nchw, tag::nchw, tag::nchw,
420 {2, 2, 128, 256}, true, dnnl_unimplemented},
421 // Different memory formats are not supported
422 tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nhwc, tag::nchw,
423 {2, 2, 128, 256}, true, dnnl_unimplemented}));
424
425static auto all_cases = [](memory::data_type src_dt, memory::data_type dst_dt,
426 memory::data_type diff_src_dt) {
427 return ::testing::Values(
428 tp {src_dt, dst_dt, diff_src_dt, tag::nCdhw16c, tag::nCdhw16c,
429 tag::nCdhw16c, {2, 17, 5, 4, 4}},
430 tp {src_dt, dst_dt, diff_src_dt, tag::ncdhw, tag::ncdhw, tag::ncdhw,
431 {2, 7, 3, 4, 4}},
432 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
433 tag::nChw16c, {2, 17, 4, 4}},
434 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
435 tag::nChw8c, {2, 7, 4, 4}},
436 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
437 {2, 10, 4, 4}},
438 tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc,
439 {2, 10, 4, 4}},
440 tp {src_dt, dst_dt, diff_src_dt, tag::nCw8c, tag::nCw8c, tag::nCw8c,
441 {2, 7, 4}},
442 tp {src_dt, dst_dt, diff_src_dt, tag::nwc, tag::nwc, tag::nwc,
443 {2, 10, 4}});
444}; // namespace dnnl
445
446#define EXPAND_DTS(src, dst, diff_src) \
447 memory::data_type::src, memory::data_type::dst, memory::data_type::diff_src
448
449#define INST_TEST_CASE(name, suite, ...) \
450 INSTANTIATE_TEST_SUITE_P( \
451 name, batch_normalization_test_t, suite(__VA_ARGS__));
452
453#define CPU_INST_TEST_CASE(name, suite, ...) \
454 CPU_INSTANTIATE_TEST_SUITE_P( \
455 name, batch_normalization_test_t, suite(__VA_ARGS__));
456
457#define GPU_INST_TEST_CASE(name, suite, ...) \
458 GPU_INSTANTIATE_TEST_SUITE_P( \
459 name, batch_normalization_test_t, suite(__VA_ARGS__));
460
461INST_TEST_CASE(
462 BatchNormalizationSimpleF32, all_cases, EXPAND_DTS(f32, f32, f32));
463INST_TEST_CASE(
464 BatchNormalizationSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16, bf16));
465INST_TEST_CASE(
466 BatchNormalizationSimpleF16, all_cases, EXPAND_DTS(f16, f16, undef));
467INST_TEST_CASE(
468 BatchNormalizationSimpleS8, all_cases, EXPAND_DTS(s8, s8, undef));
469} // namespace dnnl
470