1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cmath> |
18 | #include <memory> |
19 | |
20 | #include "dnnl_test_common.hpp" |
21 | #include "gtest/gtest.h" |
22 | |
23 | #include "oneapi/dnnl/dnnl.hpp" |
24 | |
25 | namespace dnnl { |
26 | |
27 | static constexpr float epsilon = 1e-5f; |
28 | |
29 | struct test_lnorm_params_t { |
30 | memory::format_tag src_tag; |
31 | memory::format_tag stat_tag; |
32 | memory::format_tag diff_src_tag; |
33 | memory::data_type src_dt; |
34 | memory::data_type dst_dt; |
35 | memory::data_type diff_src_dt; |
36 | memory::dims dims; |
37 | bool expect_to_fail; |
38 | dnnl_status_t expected_status; |
39 | }; |
40 | |
41 | template <typename T> |
42 | void fill(const memory &m) { |
43 | auto numElements = m.get_desc().get_size() / sizeof(T); |
44 | fill_data<T>(numElements, m); |
45 | } |
46 | |
47 | class lnorm_test_t : public ::testing::TestWithParam<test_lnorm_params_t> { |
48 | private: |
49 | std::shared_ptr<test_memory> src, dst, diff_src, diff_dst; |
50 | memory weights, bias, diff_weights, diff_bias, mean, variance; |
51 | |
52 | std::shared_ptr<memory::desc> src_md; |
53 | std::shared_ptr<memory::desc> dst_md; |
54 | std::shared_ptr<memory::desc> stat_d; |
55 | std::shared_ptr<memory::desc> diff_src_md; |
56 | |
57 | layer_normalization_forward::primitive_desc lnorm_fwd_pd; |
58 | layer_normalization_backward::primitive_desc lnorm_bwd_pd; |
59 | |
60 | test_lnorm_params_t p; |
61 | engine eng; |
62 | stream strm; |
63 | |
64 | protected: |
65 | void SetUp() override { |
66 | SKIP_IF_CUDA(true, "Layer normalization not supported by CUDA." ); |
67 | p = ::testing::TestWithParam<decltype(p)>::GetParam(); |
68 | |
69 | SKIP_IF(unsupported_data_type(p.src_dt) |
70 | || unsupported_data_type(p.dst_dt), |
71 | "Engine does not support this data type." ); |
72 | if (p.diff_src_dt != memory::data_type::undef) { |
73 | SKIP_IF(unsupported_data_type(p.diff_src_dt), |
74 | "Engine does not support this data type." ); |
75 | } |
76 | |
77 | catch_expected_failures( |
78 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
79 | } |
80 | |
81 | void Test() { |
82 | eng = get_test_engine(); |
83 | strm = make_stream(eng); |
84 | |
85 | src_md = std::make_shared<memory::desc>(p.dims, p.src_dt, p.src_tag); |
86 | dst_md = std::make_shared<memory::desc>(p.dims, p.dst_dt, p.src_tag); |
87 | memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1); |
88 | stat_d = std::make_shared<memory::desc>( |
89 | stat_dims, memory::data_type::f32, p.stat_tag); |
90 | |
91 | auto training = prop_kind::forward_training; |
92 | auto inference = prop_kind::forward_inference; |
93 | |
94 | using flags = normalization_flags; |
95 | Forward(training); |
96 | Forward(training, flags::use_global_stats); |
97 | Forward(training, flags::use_scale); |
98 | Forward(training, flags::use_shift); |
99 | Forward(training, flags::use_scale | flags::use_shift); |
100 | Forward(training, |
101 | flags::use_scale | flags::use_shift | flags::use_global_stats); |
102 | Forward(inference); |
103 | Forward(inference, flags::use_global_stats); |
104 | Forward(inference, flags::use_scale | flags::use_shift); |
105 | |
106 | if (!impl::utils::one_of(p.dst_dt, memory::data_type::f16, |
107 | memory::data_type::s8, memory::data_type::u8)) { |
108 | diff_src_md = std::make_shared<memory::desc>( |
109 | p.dims, p.diff_src_dt, p.diff_src_tag); |
110 | |
111 | Backward(prop_kind::backward_data); |
112 | Backward(prop_kind::backward_data, flags::use_global_stats); |
113 | Backward(prop_kind::backward, flags::use_scale); |
114 | Backward(prop_kind::backward, flags::use_shift); |
115 | Backward(prop_kind::backward, flags::use_scale | flags::use_shift); |
116 | Backward(prop_kind::backward, |
117 | flags::use_scale | flags::use_shift |
118 | | flags::use_global_stats); |
119 | } |
120 | } |
121 | |
122 | void Forward(prop_kind pk, |
123 | normalization_flags flags = normalization_flags::none) { |
124 | fwd_iface_test_stat_any(pk, flags); |
125 | |
126 | bool useScale = (bool)(flags & normalization_flags::use_scale); |
127 | bool useShift = (bool)(flags & normalization_flags::use_shift); |
128 | bool useGlobalStats |
129 | = (bool)(flags & normalization_flags::use_global_stats); |
130 | bool isTraining = pk == prop_kind::forward_training; |
131 | |
132 | lnorm_fwd_pd = layer_normalization_forward::primitive_desc( |
133 | eng, pk, *src_md, *dst_md, *stat_d, epsilon, flags); |
134 | lnorm_fwd_pd = layer_normalization_forward::primitive_desc( |
135 | lnorm_fwd_pd.get()); // test construction from a C pd |
136 | |
137 | ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) |
138 | == lnorm_fwd_pd.src_desc()); |
139 | ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DST) |
140 | == lnorm_fwd_pd.dst_desc()); |
141 | ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN) |
142 | == lnorm_fwd_pd.mean_desc()); |
143 | ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE) |
144 | == lnorm_fwd_pd.variance_desc()); |
145 | if (p.src_tag != memory::format_tag::any) { |
146 | ASSERT_TRUE(*src_md == lnorm_fwd_pd.src_desc()); |
147 | } |
148 | |
149 | ASSERT_EQ(lnorm_fwd_pd.get_prop_kind(), pk); |
150 | ASSERT_EQ(lnorm_fwd_pd.get_epsilon(), epsilon); |
151 | ASSERT_EQ(lnorm_fwd_pd.get_flags(), flags); |
152 | |
153 | src = std::make_shared<test_memory>(lnorm_fwd_pd.src_desc(), eng); |
154 | dst = std::make_shared<test_memory>(lnorm_fwd_pd.dst_desc(), eng); |
155 | |
156 | if (useScale) |
157 | weights = test::make_memory(lnorm_fwd_pd.weights_desc(), eng); |
158 | if (useShift) |
159 | bias = test::make_memory(lnorm_fwd_pd.weights_desc(), eng); |
160 | if (isTraining || useGlobalStats) { |
161 | mean = test::make_memory(*stat_d, eng); |
162 | variance = test::make_memory(*stat_d, eng); |
163 | } |
164 | |
165 | fill<float>(src->get()); |
166 | fill<float>(dst->get()); |
167 | if (useScale) fill<float>(weights); |
168 | if (useShift) fill<float>(bias); |
169 | if (useGlobalStats) { |
170 | fill<float>(mean); |
171 | fill<float>(variance); |
172 | } |
173 | |
174 | execlnormFwd(isTraining, useGlobalStats, useScale, useShift); |
175 | } |
176 | |
177 | void Backward(prop_kind pk, |
178 | normalization_flags flags = normalization_flags::none) { |
179 | bwd_iface_test_stat_any(pk, flags); |
180 | |
181 | bool useScale = (bool)(flags & normalization_flags::use_scale); |
182 | bool useShift = (bool)(flags & normalization_flags::use_shift); |
183 | |
184 | lnorm_fwd_pd = layer_normalization_forward::primitive_desc(eng, |
185 | prop_kind::forward_training, *src_md, *dst_md, *stat_d, epsilon, |
186 | flags); |
187 | |
188 | lnorm_bwd_pd = layer_normalization_backward::primitive_desc(eng, pk, |
189 | *diff_src_md, *dst_md, *src_md, *stat_d, epsilon, flags, |
190 | lnorm_fwd_pd); |
191 | lnorm_bwd_pd = layer_normalization_backward::primitive_desc( |
192 | lnorm_bwd_pd.get()); // test construction from a C pd |
193 | |
194 | ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) |
195 | == lnorm_bwd_pd.src_desc()); |
196 | ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
197 | == lnorm_bwd_pd.diff_src_desc()); |
198 | ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST) |
199 | == lnorm_bwd_pd.diff_dst_desc()); |
200 | ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN) |
201 | == lnorm_bwd_pd.mean_desc()); |
202 | ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE) |
203 | == lnorm_bwd_pd.variance_desc()); |
204 | if (p.diff_src_tag != memory::format_tag::any) { |
205 | ASSERT_TRUE(*diff_src_md == lnorm_bwd_pd.diff_src_desc()); |
206 | } |
207 | |
208 | ASSERT_EQ(lnorm_bwd_pd.get_prop_kind(), pk); |
209 | ASSERT_EQ(lnorm_bwd_pd.get_epsilon(), epsilon); |
210 | ASSERT_EQ(lnorm_bwd_pd.get_flags(), flags); |
211 | |
212 | diff_src = std::make_shared<test_memory>( |
213 | lnorm_bwd_pd.diff_src_desc(), eng); |
214 | diff_dst = std::make_shared<test_memory>( |
215 | lnorm_bwd_pd.diff_dst_desc(), eng); |
216 | |
217 | if (useScale) |
218 | weights = test::make_memory(lnorm_bwd_pd.weights_desc(), eng); |
219 | if (useShift) |
220 | bias = test::make_memory(lnorm_bwd_pd.weights_desc(), eng); |
221 | if (useScale) |
222 | diff_weights |
223 | = test::make_memory(lnorm_bwd_pd.diff_weights_desc(), eng); |
224 | if (useShift) |
225 | diff_bias |
226 | = test::make_memory(lnorm_bwd_pd.diff_weights_desc(), eng); |
227 | mean = test::make_memory(*stat_d, eng); |
228 | variance = test::make_memory(*stat_d, eng); |
229 | |
230 | if (useScale) fill<float>(weights); |
231 | if (useShift) fill<float>(bias); |
232 | fill<float>(diff_src->get()); |
233 | fill<float>(diff_dst->get()); |
234 | fill<float>(mean); |
235 | fill<float>(variance); |
236 | |
237 | execlnormBwd(useScale, useShift, pk); |
238 | } |
239 | |
240 | void execlnormFwd(bool isTraining, bool useGlobalStats, bool useScale, |
241 | bool useShift) { |
242 | std::unordered_map<int, memory> args = { |
243 | {DNNL_ARG_SRC, src->get()}, |
244 | {DNNL_ARG_DST, dst->get()}, |
245 | }; |
246 | |
247 | if (useScale) args.insert({DNNL_ARG_SCALE, weights}); |
248 | if (useShift) args.insert({DNNL_ARG_SHIFT, bias}); |
249 | |
250 | if (isTraining || useGlobalStats) { |
251 | args.insert({DNNL_ARG_MEAN, mean}); |
252 | args.insert({DNNL_ARG_VARIANCE, variance}); |
253 | } |
254 | |
255 | EXPECT_ANY_THROW(layer_normalization_forward(lnorm_fwd_pd, {})); |
256 | layer_normalization_forward(lnorm_fwd_pd).execute(strm, args); |
257 | strm.wait(); |
258 | } |
259 | |
260 | void execlnormBwd(bool useScale, bool useShift, prop_kind pk) { |
261 | std::unordered_map<int, memory> args = { |
262 | {DNNL_ARG_SRC, src->get()}, |
263 | {DNNL_ARG_DIFF_DST, dst->get()}, |
264 | {DNNL_ARG_MEAN, mean}, |
265 | {DNNL_ARG_VARIANCE, variance}, |
266 | {DNNL_ARG_DIFF_SRC, diff_src->get()}, |
267 | }; |
268 | |
269 | if (useScale) { |
270 | args.insert({DNNL_ARG_SCALE, weights}); |
271 | if (pk == prop_kind::backward) |
272 | args.insert({DNNL_ARG_DIFF_SCALE, diff_weights}); |
273 | } |
274 | |
275 | if (useShift) { |
276 | args.insert({DNNL_ARG_SHIFT, bias}); |
277 | if (pk == prop_kind::backward) |
278 | args.insert({DNNL_ARG_DIFF_SHIFT, diff_bias}); |
279 | } |
280 | |
281 | EXPECT_ANY_THROW(layer_normalization_backward(lnorm_bwd_pd, {})); |
282 | layer_normalization_backward(lnorm_bwd_pd).execute(strm, args); |
283 | strm.wait(); |
284 | } |
285 | |
286 | void fwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) { |
287 | // non stats if inference w/o use global stats |
288 | if (pk == prop_kind::forward_inference |
289 | && !(bool)(flags & normalization_flags::use_global_stats)) |
290 | return; |
291 | |
292 | using tag = memory::format_tag; |
293 | |
294 | tag expect_stat_tag = derive_stat_tag(); |
295 | if (expect_stat_tag == tag::undef) return; // optimism |
296 | |
297 | memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1); |
298 | memory::desc expect_stat_md( |
299 | stat_dims, memory::data_type::f32, expect_stat_tag); |
300 | |
301 | // no stat_md provided at all |
302 | { |
303 | layer_normalization_forward::primitive_desc fwd_pd( |
304 | eng, pk, *src_md, *dst_md, epsilon, flags); |
305 | |
306 | EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md); |
307 | EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md); |
308 | } |
309 | |
310 | // stat_md with format_tag::any |
311 | { |
312 | memory::desc any_stat_md( |
313 | stat_dims, memory::data_type::f32, tag::any); |
314 | layer_normalization_forward::primitive_desc fwd_pd( |
315 | eng, pk, *src_md, *dst_md, any_stat_md, epsilon, flags); |
316 | |
317 | EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md); |
318 | EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md); |
319 | } |
320 | } |
321 | |
322 | void bwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) { |
323 | using tag = memory::format_tag; |
324 | |
325 | tag expect_stat_tag = derive_stat_tag(); |
326 | if (expect_stat_tag == tag::undef) return; // optimism |
327 | |
328 | memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1); |
329 | memory::desc expect_stat_md( |
330 | stat_dims, memory::data_type::f32, expect_stat_tag); |
331 | |
332 | layer_normalization_forward::primitive_desc fwd_pd(eng, |
333 | prop_kind::forward_training, *src_md, *dst_md, epsilon, flags); |
334 | |
335 | // stat_md with format_tag::any |
336 | { |
337 | memory::desc any_stat_md( |
338 | stat_dims, memory::data_type::f32, tag::any); |
339 | layer_normalization_backward::primitive_desc bwd_pd(eng, pk, |
340 | *diff_src_md, *dst_md, *src_md, any_stat_md, epsilon, flags, |
341 | fwd_pd); |
342 | |
343 | EXPECT_EQ(bwd_pd.mean_desc(), expect_stat_md); |
344 | EXPECT_EQ(bwd_pd.variance_desc(), expect_stat_md); |
345 | } |
346 | } |
347 | |
348 | private: |
349 | memory::format_tag derive_stat_tag() const { |
350 | using tag = memory::format_tag; |
351 | tag expect_stat_tag = tag::undef; |
352 | |
353 | // TODO: add more cases and test cases |
354 | // XXX: currently test only simple cases like `abc`, `acb`. Extend, |
355 | // if possible, to blocked formats too. |
356 | switch (p.src_tag) { |
357 | case tag::abc: expect_stat_tag = tag::ab; break; |
358 | case tag::bac: expect_stat_tag = tag::ba; break; |
359 | default: break; |
360 | } |
361 | |
362 | return expect_stat_tag; |
363 | } |
364 | }; |
365 | |
366 | #define EXPAND_FORMATS(src, stat, diff_src) \ |
367 | memory::format_tag::src, memory::format_tag::stat, \ |
368 | memory::format_tag::diff_src |
369 | |
370 | #define EXPAND_DTS(src, dst, diff_src) \ |
371 | memory::data_type::src, memory::data_type::dst, memory::data_type::diff_src |
372 | |
373 | #define TAGS_NC EXPAND_FORMATS(ab, a, ab) |
374 | #define TAGS_TNC EXPAND_FORMATS(abc, ab, abc) |
375 | #define TAGS_cTNC EXPAND_FORMATS(abc, ba, abc) |
376 | #define TAGS_NTC EXPAND_FORMATS(bac, ba, bac) |
377 | #define TAGS_LDSNC EXPAND_FORMATS(abcde, abcd, abcde) |
378 | #define TAGS_cLDSNC EXPAND_FORMATS(abcde, acdb, abcde) |
379 | |
380 | #define LNORM_TEST_CASE(...) \ |
381 | test_lnorm_params_t { __VA_ARGS__, false, dnnl_success } |
382 | |
383 | static auto expected_failure_cases = []() { |
384 | // clang-format off |
385 | return ::testing::Values( |
386 | // Negative dimension |
387 | test_lnorm_params_t {TAGS_NC, EXPAND_DTS(f32, f32, f32), {-1, 10}, true, dnnl_invalid_arguments}, |
388 | // Undef data type |
389 | test_lnorm_params_t {TAGS_NC, EXPAND_DTS(undef, f32, f32), {1, 10}, true, dnnl_invalid_arguments}, |
390 | // Only `any` tags |
391 | test_lnorm_params_t {EXPAND_FORMATS(any, any, any), EXPAND_DTS(f32, f32, f32), {1, 10}, true, dnnl_invalid_arguments} |
392 | ); |
393 | // clang-format on |
394 | }; |
395 | |
396 | static auto zero_dim_cases = [](memory::data_type src_dt, |
397 | memory::data_type dst_dt, |
398 | memory::data_type diff_src_dt) { |
399 | // clang-format off |
400 | return ::testing::Values( |
401 | LNORM_TEST_CASE(TAGS_NC, src_dt, dst_dt, diff_src_dt, {0, 100}), |
402 | LNORM_TEST_CASE(TAGS_TNC, src_dt, dst_dt, diff_src_dt, {6, 0, 8}), |
403 | LNORM_TEST_CASE(TAGS_NTC, src_dt, dst_dt, diff_src_dt, {6, 32, 0}), |
404 | LNORM_TEST_CASE(TAGS_LDSNC, src_dt, dst_dt, diff_src_dt, {6, 2, 2, 32, 0}) |
405 | ); |
406 | // clang-format on |
407 | }; |
408 | |
409 | static auto simple_cases = [](memory::data_type src_dt, |
410 | memory::data_type dst_dt, |
411 | memory::data_type diff_src_dt) { |
412 | // clang-format off |
413 | return ::testing::Values( |
414 | LNORM_TEST_CASE(TAGS_NC, src_dt, dst_dt, diff_src_dt, {1, 100}), |
415 | LNORM_TEST_CASE(TAGS_NC, src_dt, dst_dt, diff_src_dt, {20, 8}), |
416 | LNORM_TEST_CASE(TAGS_NC, src_dt, dst_dt, diff_src_dt, {2, 10}), |
417 | LNORM_TEST_CASE(TAGS_TNC, src_dt, dst_dt, diff_src_dt, {6, 32, 8}), |
418 | LNORM_TEST_CASE(TAGS_TNC, src_dt, dst_dt, diff_src_dt, {2, 8, 16}), |
419 | LNORM_TEST_CASE(TAGS_TNC, src_dt, dst_dt, diff_src_dt, {2, 10, 4}), |
420 | LNORM_TEST_CASE(TAGS_cTNC, src_dt, dst_dt, diff_src_dt, {6, 32, 8}), |
421 | LNORM_TEST_CASE(TAGS_cTNC, src_dt, dst_dt, diff_src_dt, {2, 8, 16}), |
422 | LNORM_TEST_CASE(TAGS_cTNC, src_dt, dst_dt, diff_src_dt, {2, 10, 4}), |
423 | LNORM_TEST_CASE(TAGS_NTC, src_dt, dst_dt, diff_src_dt, {64, 32, 8}), |
424 | LNORM_TEST_CASE(TAGS_NTC, src_dt, dst_dt, diff_src_dt, {12, 8, 16}), |
425 | LNORM_TEST_CASE(TAGS_NTC, src_dt, dst_dt, diff_src_dt, {32, 10, 4}), |
426 | LNORM_TEST_CASE(TAGS_LDSNC, src_dt, dst_dt, diff_src_dt, {6, 2, 2, 32, 8}), |
427 | LNORM_TEST_CASE(TAGS_LDSNC, src_dt, dst_dt, diff_src_dt, {2, 2, 2, 8, 16}), |
428 | LNORM_TEST_CASE(TAGS_LDSNC, src_dt, dst_dt, diff_src_dt, {2, 2, 2, 10, 4}), |
429 | LNORM_TEST_CASE(TAGS_cLDSNC, src_dt, dst_dt, diff_src_dt, {6, 2, 2, 32, 8}), |
430 | LNORM_TEST_CASE(TAGS_cLDSNC, src_dt, dst_dt, diff_src_dt, {2, 2, 2, 8, 16}), |
431 | LNORM_TEST_CASE(TAGS_cLDSNC, src_dt, dst_dt, diff_src_dt, {2, 2, 2, 10, 4}) |
432 | ); |
433 | // clang-format on |
434 | }; |
435 | |
436 | TEST_P(lnorm_test_t, TestsLnormV2) {} |
437 | |
438 | #define INST_TEST_CASE(name, ...) \ |
439 | INSTANTIATE_TEST_SUITE_P(name, lnorm_test_t, simple_cases(__VA_ARGS__)); |
440 | |
441 | #define CPU_INST_TEST_CASE(name, ...) \ |
442 | CPU_INSTANTIATE_TEST_SUITE_P(name, lnorm_test_t, simple_cases(__VA_ARGS__)); |
443 | |
444 | #define GPU_INST_TEST_CASE(name, ...) \ |
445 | GPU_INSTANTIATE_TEST_SUITE_P(name, lnorm_test_t, simple_cases(__VA_ARGS__)); |
446 | |
447 | INSTANTIATE_TEST_SUITE_P(LnormEF, lnorm_test_t, expected_failure_cases()); |
448 | INSTANTIATE_TEST_SUITE_P( |
449 | LnormZeroDim, lnorm_test_t, zero_dim_cases(EXPAND_DTS(f32, f32, f32))); |
450 | |
451 | INST_TEST_CASE(LnormSimpleF32, EXPAND_DTS(f32, f32, f32)) |
452 | INST_TEST_CASE(LnormSimpleBF16, EXPAND_DTS(bf16, bf16, bf16)) |
453 | INST_TEST_CASE(LnormSimpleF16, EXPAND_DTS(f16, f16, undef)) |
454 | CPU_INST_TEST_CASE(LnormSimpleF32BF16, EXPAND_DTS(f32, bf16, f32)) |
455 | CPU_INST_TEST_CASE(LnormSimpleBF16F32, EXPAND_DTS(bf16, f32, bf16)) |
456 | CPU_INST_TEST_CASE(LnormSimpleF32S8, EXPAND_DTS(f32, s8, undef)) |
457 | CPU_INST_TEST_CASE(LnormSimpleBF16U8, EXPAND_DTS(bf16, u8, undef)) |
458 | |
459 | } // namespace dnnl |
460 | |