1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cmath>
18#include <memory>
19
20#include "dnnl_test_common.hpp"
21#include "gtest/gtest.h"
22
23#include "oneapi/dnnl/dnnl.hpp"
24
25namespace dnnl {
26
27static constexpr float epsilon = 1e-5f;
28
29struct test_lnorm_params_t {
30 memory::format_tag src_tag;
31 memory::format_tag stat_tag;
32 memory::format_tag diff_src_tag;
33 memory::data_type src_dt;
34 memory::data_type dst_dt;
35 memory::data_type diff_src_dt;
36 memory::dims dims;
37 bool expect_to_fail;
38 dnnl_status_t expected_status;
39};
40
41template <typename T>
42void fill(const memory &m) {
43 auto numElements = m.get_desc().get_size() / sizeof(T);
44 fill_data<T>(numElements, m);
45}
46
47class lnorm_test_t : public ::testing::TestWithParam<test_lnorm_params_t> {
48private:
49 std::shared_ptr<test_memory> src, dst, diff_src, diff_dst;
50 memory weights, bias, diff_weights, diff_bias, mean, variance;
51
52 std::shared_ptr<memory::desc> src_md;
53 std::shared_ptr<memory::desc> dst_md;
54 std::shared_ptr<memory::desc> stat_d;
55 std::shared_ptr<memory::desc> diff_src_md;
56
57 layer_normalization_forward::primitive_desc lnorm_fwd_pd;
58 layer_normalization_backward::primitive_desc lnorm_bwd_pd;
59
60 test_lnorm_params_t p;
61 engine eng;
62 stream strm;
63
64protected:
65 void SetUp() override {
66 SKIP_IF_CUDA(true, "Layer normalization not supported by CUDA.");
67 p = ::testing::TestWithParam<decltype(p)>::GetParam();
68
69 SKIP_IF(unsupported_data_type(p.src_dt)
70 || unsupported_data_type(p.dst_dt),
71 "Engine does not support this data type.");
72 if (p.diff_src_dt != memory::data_type::undef) {
73 SKIP_IF(unsupported_data_type(p.diff_src_dt),
74 "Engine does not support this data type.");
75 }
76
77 catch_expected_failures(
78 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
79 }
80
81 void Test() {
82 eng = get_test_engine();
83 strm = make_stream(eng);
84
85 src_md = std::make_shared<memory::desc>(p.dims, p.src_dt, p.src_tag);
86 dst_md = std::make_shared<memory::desc>(p.dims, p.dst_dt, p.src_tag);
87 memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
88 stat_d = std::make_shared<memory::desc>(
89 stat_dims, memory::data_type::f32, p.stat_tag);
90
91 auto training = prop_kind::forward_training;
92 auto inference = prop_kind::forward_inference;
93
94 using flags = normalization_flags;
95 Forward(training);
96 Forward(training, flags::use_global_stats);
97 Forward(training, flags::use_scale);
98 Forward(training, flags::use_shift);
99 Forward(training, flags::use_scale | flags::use_shift);
100 Forward(training,
101 flags::use_scale | flags::use_shift | flags::use_global_stats);
102 Forward(inference);
103 Forward(inference, flags::use_global_stats);
104 Forward(inference, flags::use_scale | flags::use_shift);
105
106 if (!impl::utils::one_of(p.dst_dt, memory::data_type::f16,
107 memory::data_type::s8, memory::data_type::u8)) {
108 diff_src_md = std::make_shared<memory::desc>(
109 p.dims, p.diff_src_dt, p.diff_src_tag);
110
111 Backward(prop_kind::backward_data);
112 Backward(prop_kind::backward_data, flags::use_global_stats);
113 Backward(prop_kind::backward, flags::use_scale);
114 Backward(prop_kind::backward, flags::use_shift);
115 Backward(prop_kind::backward, flags::use_scale | flags::use_shift);
116 Backward(prop_kind::backward,
117 flags::use_scale | flags::use_shift
118 | flags::use_global_stats);
119 }
120 }
121
122 void Forward(prop_kind pk,
123 normalization_flags flags = normalization_flags::none) {
124 fwd_iface_test_stat_any(pk, flags);
125
126 bool useScale = (bool)(flags & normalization_flags::use_scale);
127 bool useShift = (bool)(flags & normalization_flags::use_shift);
128 bool useGlobalStats
129 = (bool)(flags & normalization_flags::use_global_stats);
130 bool isTraining = pk == prop_kind::forward_training;
131
132 lnorm_fwd_pd = layer_normalization_forward::primitive_desc(
133 eng, pk, *src_md, *dst_md, *stat_d, epsilon, flags);
134 lnorm_fwd_pd = layer_normalization_forward::primitive_desc(
135 lnorm_fwd_pd.get()); // test construction from a C pd
136
137 ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
138 == lnorm_fwd_pd.src_desc());
139 ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
140 == lnorm_fwd_pd.dst_desc());
141 ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN)
142 == lnorm_fwd_pd.mean_desc());
143 ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE)
144 == lnorm_fwd_pd.variance_desc());
145 if (p.src_tag != memory::format_tag::any) {
146 ASSERT_TRUE(*src_md == lnorm_fwd_pd.src_desc());
147 }
148
149 ASSERT_EQ(lnorm_fwd_pd.get_prop_kind(), pk);
150 ASSERT_EQ(lnorm_fwd_pd.get_epsilon(), epsilon);
151 ASSERT_EQ(lnorm_fwd_pd.get_flags(), flags);
152
153 src = std::make_shared<test_memory>(lnorm_fwd_pd.src_desc(), eng);
154 dst = std::make_shared<test_memory>(lnorm_fwd_pd.dst_desc(), eng);
155
156 if (useScale)
157 weights = test::make_memory(lnorm_fwd_pd.weights_desc(), eng);
158 if (useShift)
159 bias = test::make_memory(lnorm_fwd_pd.weights_desc(), eng);
160 if (isTraining || useGlobalStats) {
161 mean = test::make_memory(*stat_d, eng);
162 variance = test::make_memory(*stat_d, eng);
163 }
164
165 fill<float>(src->get());
166 fill<float>(dst->get());
167 if (useScale) fill<float>(weights);
168 if (useShift) fill<float>(bias);
169 if (useGlobalStats) {
170 fill<float>(mean);
171 fill<float>(variance);
172 }
173
174 execlnormFwd(isTraining, useGlobalStats, useScale, useShift);
175 }
176
177 void Backward(prop_kind pk,
178 normalization_flags flags = normalization_flags::none) {
179 bwd_iface_test_stat_any(pk, flags);
180
181 bool useScale = (bool)(flags & normalization_flags::use_scale);
182 bool useShift = (bool)(flags & normalization_flags::use_shift);
183
184 lnorm_fwd_pd = layer_normalization_forward::primitive_desc(eng,
185 prop_kind::forward_training, *src_md, *dst_md, *stat_d, epsilon,
186 flags);
187
188 lnorm_bwd_pd = layer_normalization_backward::primitive_desc(eng, pk,
189 *diff_src_md, *dst_md, *src_md, *stat_d, epsilon, flags,
190 lnorm_fwd_pd);
191 lnorm_bwd_pd = layer_normalization_backward::primitive_desc(
192 lnorm_bwd_pd.get()); // test construction from a C pd
193
194 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
195 == lnorm_bwd_pd.src_desc());
196 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
197 == lnorm_bwd_pd.diff_src_desc());
198 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
199 == lnorm_bwd_pd.diff_dst_desc());
200 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN)
201 == lnorm_bwd_pd.mean_desc());
202 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE)
203 == lnorm_bwd_pd.variance_desc());
204 if (p.diff_src_tag != memory::format_tag::any) {
205 ASSERT_TRUE(*diff_src_md == lnorm_bwd_pd.diff_src_desc());
206 }
207
208 ASSERT_EQ(lnorm_bwd_pd.get_prop_kind(), pk);
209 ASSERT_EQ(lnorm_bwd_pd.get_epsilon(), epsilon);
210 ASSERT_EQ(lnorm_bwd_pd.get_flags(), flags);
211
212 diff_src = std::make_shared<test_memory>(
213 lnorm_bwd_pd.diff_src_desc(), eng);
214 diff_dst = std::make_shared<test_memory>(
215 lnorm_bwd_pd.diff_dst_desc(), eng);
216
217 if (useScale)
218 weights = test::make_memory(lnorm_bwd_pd.weights_desc(), eng);
219 if (useShift)
220 bias = test::make_memory(lnorm_bwd_pd.weights_desc(), eng);
221 if (useScale)
222 diff_weights
223 = test::make_memory(lnorm_bwd_pd.diff_weights_desc(), eng);
224 if (useShift)
225 diff_bias
226 = test::make_memory(lnorm_bwd_pd.diff_weights_desc(), eng);
227 mean = test::make_memory(*stat_d, eng);
228 variance = test::make_memory(*stat_d, eng);
229
230 if (useScale) fill<float>(weights);
231 if (useShift) fill<float>(bias);
232 fill<float>(diff_src->get());
233 fill<float>(diff_dst->get());
234 fill<float>(mean);
235 fill<float>(variance);
236
237 execlnormBwd(useScale, useShift, pk);
238 }
239
240 void execlnormFwd(bool isTraining, bool useGlobalStats, bool useScale,
241 bool useShift) {
242 std::unordered_map<int, memory> args = {
243 {DNNL_ARG_SRC, src->get()},
244 {DNNL_ARG_DST, dst->get()},
245 };
246
247 if (useScale) args.insert({DNNL_ARG_SCALE, weights});
248 if (useShift) args.insert({DNNL_ARG_SHIFT, bias});
249
250 if (isTraining || useGlobalStats) {
251 args.insert({DNNL_ARG_MEAN, mean});
252 args.insert({DNNL_ARG_VARIANCE, variance});
253 }
254
255 EXPECT_ANY_THROW(layer_normalization_forward(lnorm_fwd_pd, {}));
256 layer_normalization_forward(lnorm_fwd_pd).execute(strm, args);
257 strm.wait();
258 }
259
260 void execlnormBwd(bool useScale, bool useShift, prop_kind pk) {
261 std::unordered_map<int, memory> args = {
262 {DNNL_ARG_SRC, src->get()},
263 {DNNL_ARG_DIFF_DST, dst->get()},
264 {DNNL_ARG_MEAN, mean},
265 {DNNL_ARG_VARIANCE, variance},
266 {DNNL_ARG_DIFF_SRC, diff_src->get()},
267 };
268
269 if (useScale) {
270 args.insert({DNNL_ARG_SCALE, weights});
271 if (pk == prop_kind::backward)
272 args.insert({DNNL_ARG_DIFF_SCALE, diff_weights});
273 }
274
275 if (useShift) {
276 args.insert({DNNL_ARG_SHIFT, bias});
277 if (pk == prop_kind::backward)
278 args.insert({DNNL_ARG_DIFF_SHIFT, diff_bias});
279 }
280
281 EXPECT_ANY_THROW(layer_normalization_backward(lnorm_bwd_pd, {}));
282 layer_normalization_backward(lnorm_bwd_pd).execute(strm, args);
283 strm.wait();
284 }
285
286 void fwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) {
287 // non stats if inference w/o use global stats
288 if (pk == prop_kind::forward_inference
289 && !(bool)(flags & normalization_flags::use_global_stats))
290 return;
291
292 using tag = memory::format_tag;
293
294 tag expect_stat_tag = derive_stat_tag();
295 if (expect_stat_tag == tag::undef) return; // optimism
296
297 memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
298 memory::desc expect_stat_md(
299 stat_dims, memory::data_type::f32, expect_stat_tag);
300
301 // no stat_md provided at all
302 {
303 layer_normalization_forward::primitive_desc fwd_pd(
304 eng, pk, *src_md, *dst_md, epsilon, flags);
305
306 EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md);
307 EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md);
308 }
309
310 // stat_md with format_tag::any
311 {
312 memory::desc any_stat_md(
313 stat_dims, memory::data_type::f32, tag::any);
314 layer_normalization_forward::primitive_desc fwd_pd(
315 eng, pk, *src_md, *dst_md, any_stat_md, epsilon, flags);
316
317 EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md);
318 EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md);
319 }
320 }
321
322 void bwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) {
323 using tag = memory::format_tag;
324
325 tag expect_stat_tag = derive_stat_tag();
326 if (expect_stat_tag == tag::undef) return; // optimism
327
328 memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
329 memory::desc expect_stat_md(
330 stat_dims, memory::data_type::f32, expect_stat_tag);
331
332 layer_normalization_forward::primitive_desc fwd_pd(eng,
333 prop_kind::forward_training, *src_md, *dst_md, epsilon, flags);
334
335 // stat_md with format_tag::any
336 {
337 memory::desc any_stat_md(
338 stat_dims, memory::data_type::f32, tag::any);
339 layer_normalization_backward::primitive_desc bwd_pd(eng, pk,
340 *diff_src_md, *dst_md, *src_md, any_stat_md, epsilon, flags,
341 fwd_pd);
342
343 EXPECT_EQ(bwd_pd.mean_desc(), expect_stat_md);
344 EXPECT_EQ(bwd_pd.variance_desc(), expect_stat_md);
345 }
346 }
347
348private:
349 memory::format_tag derive_stat_tag() const {
350 using tag = memory::format_tag;
351 tag expect_stat_tag = tag::undef;
352
353 // TODO: add more cases and test cases
354 // XXX: currently test only simple cases like `abc`, `acb`. Extend,
355 // if possible, to blocked formats too.
356 switch (p.src_tag) {
357 case tag::abc: expect_stat_tag = tag::ab; break;
358 case tag::bac: expect_stat_tag = tag::ba; break;
359 default: break;
360 }
361
362 return expect_stat_tag;
363 }
364};
365
366#define EXPAND_FORMATS(src, stat, diff_src) \
367 memory::format_tag::src, memory::format_tag::stat, \
368 memory::format_tag::diff_src
369
370#define EXPAND_DTS(src, dst, diff_src) \
371 memory::data_type::src, memory::data_type::dst, memory::data_type::diff_src
372
373#define TAGS_NC EXPAND_FORMATS(ab, a, ab)
374#define TAGS_TNC EXPAND_FORMATS(abc, ab, abc)
375#define TAGS_cTNC EXPAND_FORMATS(abc, ba, abc)
376#define TAGS_NTC EXPAND_FORMATS(bac, ba, bac)
377#define TAGS_LDSNC EXPAND_FORMATS(abcde, abcd, abcde)
378#define TAGS_cLDSNC EXPAND_FORMATS(abcde, acdb, abcde)
379
380#define LNORM_TEST_CASE(...) \
381 test_lnorm_params_t { __VA_ARGS__, false, dnnl_success }
382
383static auto expected_failure_cases = []() {
384 // clang-format off
385 return ::testing::Values(
386 // Negative dimension
387 test_lnorm_params_t {TAGS_NC, EXPAND_DTS(f32, f32, f32), {-1, 10}, true, dnnl_invalid_arguments},
388 // Undef data type
389 test_lnorm_params_t {TAGS_NC, EXPAND_DTS(undef, f32, f32), {1, 10}, true, dnnl_invalid_arguments},
390 // Only `any` tags
391 test_lnorm_params_t {EXPAND_FORMATS(any, any, any), EXPAND_DTS(f32, f32, f32), {1, 10}, true, dnnl_invalid_arguments}
392 );
393 // clang-format on
394};
395
396static auto zero_dim_cases = [](memory::data_type src_dt,
397 memory::data_type dst_dt,
398 memory::data_type diff_src_dt) {
399 // clang-format off
400 return ::testing::Values(
401 LNORM_TEST_CASE(TAGS_NC, src_dt, dst_dt, diff_src_dt, {0, 100}),
402 LNORM_TEST_CASE(TAGS_TNC, src_dt, dst_dt, diff_src_dt, {6, 0, 8}),
403 LNORM_TEST_CASE(TAGS_NTC, src_dt, dst_dt, diff_src_dt, {6, 32, 0}),
404 LNORM_TEST_CASE(TAGS_LDSNC, src_dt, dst_dt, diff_src_dt, {6, 2, 2, 32, 0})
405 );
406 // clang-format on
407};
408
409static auto simple_cases = [](memory::data_type src_dt,
410 memory::data_type dst_dt,
411 memory::data_type diff_src_dt) {
412 // clang-format off
413 return ::testing::Values(
414 LNORM_TEST_CASE(TAGS_NC, src_dt, dst_dt, diff_src_dt, {1, 100}),
415 LNORM_TEST_CASE(TAGS_NC, src_dt, dst_dt, diff_src_dt, {20, 8}),
416 LNORM_TEST_CASE(TAGS_NC, src_dt, dst_dt, diff_src_dt, {2, 10}),
417 LNORM_TEST_CASE(TAGS_TNC, src_dt, dst_dt, diff_src_dt, {6, 32, 8}),
418 LNORM_TEST_CASE(TAGS_TNC, src_dt, dst_dt, diff_src_dt, {2, 8, 16}),
419 LNORM_TEST_CASE(TAGS_TNC, src_dt, dst_dt, diff_src_dt, {2, 10, 4}),
420 LNORM_TEST_CASE(TAGS_cTNC, src_dt, dst_dt, diff_src_dt, {6, 32, 8}),
421 LNORM_TEST_CASE(TAGS_cTNC, src_dt, dst_dt, diff_src_dt, {2, 8, 16}),
422 LNORM_TEST_CASE(TAGS_cTNC, src_dt, dst_dt, diff_src_dt, {2, 10, 4}),
423 LNORM_TEST_CASE(TAGS_NTC, src_dt, dst_dt, diff_src_dt, {64, 32, 8}),
424 LNORM_TEST_CASE(TAGS_NTC, src_dt, dst_dt, diff_src_dt, {12, 8, 16}),
425 LNORM_TEST_CASE(TAGS_NTC, src_dt, dst_dt, diff_src_dt, {32, 10, 4}),
426 LNORM_TEST_CASE(TAGS_LDSNC, src_dt, dst_dt, diff_src_dt, {6, 2, 2, 32, 8}),
427 LNORM_TEST_CASE(TAGS_LDSNC, src_dt, dst_dt, diff_src_dt, {2, 2, 2, 8, 16}),
428 LNORM_TEST_CASE(TAGS_LDSNC, src_dt, dst_dt, diff_src_dt, {2, 2, 2, 10, 4}),
429 LNORM_TEST_CASE(TAGS_cLDSNC, src_dt, dst_dt, diff_src_dt, {6, 2, 2, 32, 8}),
430 LNORM_TEST_CASE(TAGS_cLDSNC, src_dt, dst_dt, diff_src_dt, {2, 2, 2, 8, 16}),
431 LNORM_TEST_CASE(TAGS_cLDSNC, src_dt, dst_dt, diff_src_dt, {2, 2, 2, 10, 4})
432 );
433 // clang-format on
434};
435
436TEST_P(lnorm_test_t, TestsLnormV2) {}
437
438#define INST_TEST_CASE(name, ...) \
439 INSTANTIATE_TEST_SUITE_P(name, lnorm_test_t, simple_cases(__VA_ARGS__));
440
441#define CPU_INST_TEST_CASE(name, ...) \
442 CPU_INSTANTIATE_TEST_SUITE_P(name, lnorm_test_t, simple_cases(__VA_ARGS__));
443
444#define GPU_INST_TEST_CASE(name, ...) \
445 GPU_INSTANTIATE_TEST_SUITE_P(name, lnorm_test_t, simple_cases(__VA_ARGS__));
446
447INSTANTIATE_TEST_SUITE_P(LnormEF, lnorm_test_t, expected_failure_cases());
448INSTANTIATE_TEST_SUITE_P(
449 LnormZeroDim, lnorm_test_t, zero_dim_cases(EXPAND_DTS(f32, f32, f32)));
450
451INST_TEST_CASE(LnormSimpleF32, EXPAND_DTS(f32, f32, f32))
452INST_TEST_CASE(LnormSimpleBF16, EXPAND_DTS(bf16, bf16, bf16))
453INST_TEST_CASE(LnormSimpleF16, EXPAND_DTS(f16, f16, undef))
454CPU_INST_TEST_CASE(LnormSimpleF32BF16, EXPAND_DTS(f32, bf16, f32))
455CPU_INST_TEST_CASE(LnormSimpleBF16F32, EXPAND_DTS(bf16, f32, bf16))
456CPU_INST_TEST_CASE(LnormSimpleF32S8, EXPAND_DTS(f32, s8, undef))
457CPU_INST_TEST_CASE(LnormSimpleBF16U8, EXPAND_DTS(bf16, u8, undef))
458
459} // namespace dnnl
460