1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24using tag = memory::format_tag;
25using dt = memory::data_type;
26
27struct softmax_test_params_t {
28 prop_kind aprop_kind;
29 algorithm aalgorithm;
30 dt src_dt; // diff_src_dt
31 dt dst_dt;
32 dt diff_dst_dt;
33 tag src_tag; // diff_src_tag
34 tag dst_tag;
35 tag diff_dst_tag;
36 memory::dims dims;
37 int axis;
38 bool expect_to_fail;
39 dnnl_status_t expected_status;
40};
41
42class softmax_test_t : public ::testing::TestWithParam<softmax_test_params_t> {
43private:
44 softmax_test_params_t p;
45 memory dst, workspace;
46 std::shared_ptr<softmax_forward::primitive_desc> pd_fwd_hint;
47
48protected:
49 void SetUp() override {
50 p = ::testing::TestWithParam<softmax_test_params_t>::GetParam();
51
52 SKIP_IF_CUDA(
53 !cuda_check_format_tag(p.src_tag), "Unsupported format tag");
54 SKIP_IF_CUDA(
55 !cuda_check_format_tag(p.dst_tag), "Unsupported format tag");
56
57 SKIP_IF_HIP(!hip_check_format_tag(p.src_tag), "Unsupported format tag");
58 SKIP_IF_HIP(!hip_check_format_tag(p.dst_tag), "Unsupported format tag");
59
60 if (!is_fwd(p.aprop_kind)) {
61 SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_dst_tag),
62 "Unsupported format tag");
63
64 SKIP_IF_HIP(!hip_check_format_tag(p.diff_dst_tag),
65 "Unsupported format tag");
66 }
67 SKIP_IF_CUDA((p.src_dt == dt::bf16 || p.dst_dt == dt::bf16),
68 "Unsupported datatype for CUDA");
69 SKIP_IF_HIP((p.src_dt == dt::bf16 || p.dst_dt == dt::bf16),
70 "Unsupported datatype for HIP");
71 if (!is_fwd(p.aprop_kind)) {
72 SKIP_IF_CUDA((p.diff_dst_dt == dt::bf16),
73 "Unsupported datatype for CUDA");
74 SKIP_IF_HIP((p.diff_dst_dt == dt::bf16),
75 "Unsupported datatype for HIP");
76 }
77
78 SKIP_IF(unsupported_data_type(p.src_dt)
79 || unsupported_data_type(p.dst_dt),
80 "Engine does not support this data type.");
81 if (!is_fwd(p.aprop_kind)) {
82 SKIP_IF(unsupported_data_type(p.diff_dst_dt),
83 "Engine does not support this data type.");
84 }
85
86 SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef
87 && p.dst_dt != dt::undef,
88 "Unsupported different data types for source and "
89 "destination");
90 SKIP_IF_HIP(p.src_dt != p.dst_dt && p.src_dt != dt::undef
91 && p.dst_dt != dt::undef,
92 "Unsupported different data types for source and "
93 "destination");
94 SKIP_IF_CUDA(!is_fwd(p.aprop_kind) && p.src_dt != p.diff_dst_dt
95 && p.src_dt != dt::undef && p.diff_dst_dt != dt::undef,
96 "Unsupported different data types for diff_source and "
97 "diff_destination");
98 SKIP_IF_HIP(!is_fwd(p.aprop_kind) && p.src_dt != p.diff_dst_dt
99 && p.src_dt != dt::undef && p.diff_dst_dt != dt::undef,
100 "Unsupported different data types for diff_source and "
101 "diff_destination");
102
103 SKIP_IF_CUDA(p.src_tag != p.dst_tag && p.src_tag != tag::any
104 && p.dst_tag != tag::any,
105 "Unsupported different memory formats for source and "
106 "destination");
107 SKIP_IF_HIP(p.src_tag != p.dst_tag && p.src_tag != tag::any
108 && p.dst_tag != tag::any,
109 "Unsupported different memory formats for source and "
110 "destination");
111 SKIP_IF_CUDA(!is_fwd(p.aprop_kind) && p.src_tag != p.diff_dst_tag
112 && p.src_tag != tag::any && p.diff_dst_tag != tag::any,
113 "Unsupported different memory formats for diff_source and "
114 "diff_destination");
115 SKIP_IF_HIP(!is_fwd(p.aprop_kind) && p.src_tag != p.diff_dst_tag
116 && p.src_tag != tag::any && p.diff_dst_tag != tag::any,
117 "Unsupported different memory formats for diff_source and "
118 "diff_destination");
119 SKIP_IF_CUDA(p.dst_dt == dt::u8 || p.dst_dt == dt::s8,
120 "Unsupported int8 destination data type");
121 SKIP_IF_HIP(p.dst_dt == dt::u8 || p.dst_dt == dt::s8,
122 "Unsupported int8 destination data type");
123
124 SKIP_IF_HIP(p.axis != 1, "Unsupported axis for HIP");
125
126 catch_expected_failures(
127 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
128 }
129 bool cuda_check_format_tag(memory::format_tag tag) {
130 return (tag != memory::format_tag::aBcd8b
131 && tag != memory::format_tag::aBcd16b);
132 }
133 bool hip_check_format_tag(memory::format_tag tag) {
134 return (tag == memory::format_tag::a || tag == memory::format_tag::ab
135 || tag == memory::format_tag::abc
136 || tag == memory::format_tag::abcd
137 || tag == memory::format_tag::abcde);
138 }
139 void Forward() {
140 // softmax specific types and values
141 using pd_t = softmax_forward::primitive_desc;
142
143 auto eng = get_test_engine();
144 auto strm = make_stream(eng);
145 prop_kind pk = !is_fwd(p.aprop_kind) ? prop_kind::forward_training
146 : p.aprop_kind;
147
148 allows_attr_t aa {false};
149 if (!(is_nvidia_gpu(eng) || is_amd_gpu(eng))) { aa.scales = true; }
150
151 // To validate backward on valid tag::any settings reuse dst tag.
152 const bool src_bwd_any = !is_fwd(p.aprop_kind) && p.src_tag == tag::any;
153 auto src_tag = src_bwd_any ? p.dst_tag : p.src_tag;
154
155 auto src_md = memory::desc(p.dims, p.src_dt, src_tag);
156 auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
157
158 // default pd ctor
159 auto pd = pd_t();
160 // regular pd ctor
161 pd = pd_t(eng, pk, p.aalgorithm, src_md, dst_md, p.axis);
162 // test all pd ctors
163 test_fwd_pd_constructors<pd_t>(
164 pd, aa, pk, p.aalgorithm, src_md, dst_md, p.axis);
165 pd_fwd_hint = std::make_shared<pd_t>(pd);
166
167 EXPECT_ANY_THROW(softmax_forward(pd, {}));
168 // default primitive ctor
169 auto softmax = softmax_forward();
170 // regular primitive ctor
171 softmax = softmax_forward(pd);
172
173 // check primitive kind is softmax
174 ASSERT_TRUE(softmax.get_kind() == primitive::kind::softmax);
175 // query for descs from pd
176 const auto src_desc = pd.src_desc();
177 const auto dst_desc = pd.dst_desc();
178 // query for src_desc via exec arg
179 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
180 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
181 // query for dst_desc via exec arg
182 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
183 if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); }
184
185 // query primitive parameters
186 ASSERT_EQ(pd.get_prop_kind(), pk);
187 ASSERT_EQ(pd.get_axis(), p.axis);
188 ASSERT_EQ(pd.get_algorithm(), p.aalgorithm);
189
190 // query for workspace
191 const auto workspace_desc = pd.workspace_desc();
192
193 // check primitive returns zero_md for all rest md
194 ASSERT_TRUE(pd.weights_desc().is_zero());
195 ASSERT_TRUE(pd.diff_src_desc().is_zero());
196 ASSERT_TRUE(pd.diff_dst_desc().is_zero());
197 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
198
199 auto src = test::make_memory(src_desc, eng);
200 dst = test::make_memory(dst_desc, eng);
201 workspace = test::make_memory(workspace_desc, eng);
202
203 fill_data(p.src_dt, src, 1, 1);
204 // test out-place mode
205 softmax.execute(strm,
206 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst},
207 {DNNL_ARG_WORKSPACE, workspace}});
208 strm.wait();
209
210 // test in-place mode on forward
211 if (p.aprop_kind != prop_kind::backward_data && p.src_tag == p.dst_tag
212 && p.src_dt == p.dst_dt) {
213 softmax.execute(strm,
214 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, src},
215 {DNNL_ARG_WORKSPACE, workspace}});
216 strm.wait();
217 }
218 }
219
220 void Backward() {
221 // softmax specific types and values
222 using pd_t = softmax_backward::primitive_desc;
223 using hint_pd_t = softmax_forward::primitive_desc;
224 allows_attr_t aa {false}; // doesn't support anything
225
226 auto eng = get_test_engine();
227 auto strm = make_stream(eng);
228 auto diff_src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
229 auto diff_dst_md = memory::desc(p.dims, p.diff_dst_dt, p.diff_dst_tag);
230 auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
231
232 // default pd ctor
233 auto pd = pd_t();
234 // regular pd ctor
235 pd = pd_t(eng, p.aalgorithm, diff_src_md, diff_dst_md, dst_md, p.axis,
236 *pd_fwd_hint);
237 // test all pd ctors
238 test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa,
239 p.aalgorithm, diff_src_md, diff_dst_md, dst_md, p.axis);
240
241 EXPECT_ANY_THROW(softmax_backward(pd, {}));
242 // default primitive ctor
243 auto softmax = softmax_backward();
244 // regular primitive ctor
245 softmax = softmax_backward(pd);
246
247 // check primitive kind is softmax
248 ASSERT_TRUE(softmax.get_kind() == primitive::kind::softmax);
249
250 // query for descs from pd
251 const auto diff_src_desc = pd.diff_src_desc();
252 const auto diff_dst_desc = pd.diff_dst_desc();
253 const auto dst_desc = pd.dst_desc();
254 // query for diff_src_desc via exec arg
255 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
256 == diff_src_desc);
257 if (p.src_tag != tag::any) {
258 ASSERT_TRUE(diff_src_md == diff_src_desc);
259 }
260 // query for diff_dst_desc via exec arg
261 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
262 == diff_dst_desc);
263 if (p.diff_dst_tag != tag::any) {
264 ASSERT_TRUE(diff_dst_md == diff_dst_desc);
265 }
266 // query for dst_desc via exec arg
267 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
268 if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); }
269
270 // query primitive parameters
271 ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward_data);
272 ASSERT_EQ(pd.get_axis(), p.axis);
273 ASSERT_EQ(pd.get_algorithm(), p.aalgorithm);
274
275 // check primitive returns zero_md for all rest md
276 ASSERT_TRUE(pd.src_desc().is_zero());
277 ASSERT_TRUE(pd.weights_desc().is_zero());
278 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
279
280 auto diff_src = test::make_memory(diff_src_desc, eng);
281 auto diff_dst = test::make_memory(diff_dst_desc, eng);
282
283 fill_data(p.diff_dst_dt, diff_dst, 2, 2);
284
285 // test out-place mode
286 softmax.execute(strm,
287 {{DNNL_ARG_DST, dst}, {DNNL_ARG_DIFF_DST, diff_dst},
288 {DNNL_ARG_DIFF_SRC, diff_src},
289 {DNNL_ARG_WORKSPACE, workspace}});
290 strm.wait();
291
292 // test in-place mode
293 if (p.src_tag == p.diff_dst_tag && p.src_dt == p.diff_dst_dt) {
294 softmax.execute(strm,
295 {{DNNL_ARG_DST, dst}, {DNNL_ARG_DIFF_DST, diff_dst},
296 {DNNL_ARG_DIFF_SRC, diff_dst},
297 {DNNL_ARG_WORKSPACE, workspace}});
298 strm.wait();
299 }
300 }
301
302 void Test() {
303 Forward();
304 if (!is_fwd(p.aprop_kind)) Backward();
305 }
306
307 bool is_fwd(prop_kind pk) const {
308 return pk == prop_kind::forward_training
309 || pk == prop_kind::forward_inference;
310 }
311};
312
313using tp = softmax_test_params_t;
314
315static const auto training = prop_kind::forward_training;
316static const auto inference = prop_kind::forward_inference;
317static const auto backward = prop_kind::backward_data;
318static const auto alg_softmax = algorithm::softmax_accurate;
319static const auto alg_logsoftmax = algorithm::softmax_log;
320
321TEST_P(softmax_test_t, TestsSoftmax) {}
322
323INSTANTIATE_TEST_SUITE_P(Test_Softmax_EF, softmax_test_t,
324 ::testing::Values(
325 // Negative dims
326 tp {training, alg_softmax, dt::f32, dt::f32, dt::undef,
327 tag::nchw, tag::nchw, tag::undef, {2, -2, 128, 256}, 0,
328 true, dnnl_invalid_arguments},
329 // Axis exceeds ndims
330 tp {training, alg_softmax, dt::f32, dt::f32, dt::undef,
331 tag::nchw, tag::nchw, tag::undef, {2, 2, 128, 256}, 10,
332 true, dnnl_invalid_arguments},
333 // Not supported algorithm
334 tp {training, algorithm::eltwise_relu, dt::f32, dt::f32,
335 dt::undef, tag::nchw, tag::nchw, tag::undef,
336 {2, 2, 128, 256}, 3, true, dnnl_invalid_arguments},
337 // Tag for src on forward is not specified
338 tp {training, alg_softmax, dt::f32, dt::f32, dt::undef,
339 tag::any, tag::nchw, tag::undef, {2, 2, 128, 256}, 3,
340 true, dnnl_invalid_arguments},
341 // Tag for dst on backward is not specified
342 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nchw,
343 tag::any, tag::nchw, {2, 2, 128, 256}, 3, true,
344 dnnl_invalid_arguments},
345 // Data type for src is not specified
346 tp {training, alg_softmax, dt::undef, dt::f32, dt::undef,
347 tag::nchw, tag::nchw, tag::undef, {2, 2, 128, 256}, 3,
348 true, dnnl_invalid_arguments}));
349
350INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_Float, softmax_test_t,
351 ::testing::Values(
352 tp {training, alg_softmax, dt::f32, dt::f32, dt::undef,
353 tag::nchw, tag::nchw, tag::undef, {2, 0, 5, 5}, 1},
354 tp {training, alg_softmax, dt::f32, dt::f32, dt::undef,
355 tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1},
356 tp {training, alg_softmax, dt::f32, dt::f32, dt::undef,
357 tag::nchw, tag::any, tag::undef, {1, 8, 128, 1024}, 3},
358 tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef,
359 tag::nc, tag::nc, tag::undef, {2, 1000}, 0},
360 tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef,
361 tag::nc, tag::cn, tag::undef, {2, 1000}, 1},
362 tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef,
363 tag::nc, tag::any, tag::undef, {1, 13}, 1},
364 tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef,
365 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1},
366 tp {inference, alg_logsoftmax, dt::f32, dt::f32, dt::undef,
367 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2},
368 tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef,
369 tag::nChw16c, tag::nChw16c, tag::undef,
370 {64, 1011, 1, 1}, 1},
371 tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef,
372 tag::nChw8c, tag::nChw8c, tag::undef, {3, 1011, 1, 1},
373 1},
374 tp {inference, alg_logsoftmax, dt::f32, dt::f32, dt::undef,
375 tag::nChw8c, tag::nChw8c, tag::undef, {2, 1011, 32, 1},
376 2}));
377
378INSTANTIATE_TEST_SUITE_P(Test_Softmax_Backward_Float, softmax_test_t,
379 ::testing::Values(
380 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nchw,
381 tag::nchw, tag::nchw, {2, 0, 5, 5}, 1},
382 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nhwc,
383 tag::nhwc, tag::nhwc, {2, 19, 16, 64}, 1},
384 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::any,
385 tag::nchw, tag::any, {1, 8, 128, 1024}, 3},
386 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nc,
387 tag::nc, tag::nc, {2, 1000}, 0},
388 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nc,
389 tag::cn, tag::cn, {2, 1000}, 1},
390 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::any,
391 tag::nc, tag::nc, {1, 13}, 1},
392 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::ncw,
393 tag::ncw, tag::ncw, {16, 257, 32}, 1},
394 tp {backward, alg_logsoftmax, dt::f32, dt::f32, dt::f32,
395 tag::ncw, tag::ncw, tag::nwc, {16, 257, 32}, 2},
396 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32,
397 tag::nChw16c, tag::nChw16c, tag::nChw16c,
398 {64, 1011, 1, 1}, 1},
399 tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32,
400 tag::nChw8c, tag::nhwc, tag::nchw, {3, 1011, 1, 1}, 1},
401 tp {backward, alg_logsoftmax, dt::f32, dt::f32, dt::f32,
402 tag::nChw8c, tag::nChw8c, tag::nChw8c, {2, 1011, 32, 1},
403 2}));
404
405INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_Bfloat16, softmax_test_t,
406 ::testing::Values(
407 tp {training, alg_softmax, dt::bf16, dt::bf16, dt::undef,
408 tag::nchw, tag::nchw, tag::undef, {2, 0, 5, 5}, 1},
409 tp {training, alg_softmax, dt::bf16, dt::bf16, dt::undef,
410 tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1},
411 tp {training, alg_softmax, dt::bf16, dt::bf16, dt::undef,
412 tag::nchw, tag::any, tag::undef, {1, 8, 128, 1024}, 3},
413 tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef,
414 tag::nc, tag::nc, tag::undef, {2, 1000}, 0},
415 tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef,
416 tag::nc, tag::cn, tag::undef, {2, 1000}, 1},
417 tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef,
418 tag::nc, tag::any, tag::undef, {1, 13}, 1},
419 tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef,
420 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1},
421 tp {inference, alg_logsoftmax, dt::bf16, dt::bf16, dt::undef,
422 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2},
423 tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef,
424 tag::nChw16c, tag::nChw16c, tag::undef,
425 {64, 1011, 1, 1}, 1},
426 tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef,
427 tag::nChw8c, tag::nChw8c, tag::undef, {3, 1011, 1, 1},
428 1},
429 tp {inference, alg_logsoftmax, dt::bf16, dt::bf16, dt::undef,
430 tag::nChw8c, tag::nChw8c, tag::undef, {2, 1011, 32, 1},
431 2}));
432
433INSTANTIATE_TEST_SUITE_P(Test_Softmax_Backward_Bfloat16, softmax_test_t,
434 ::testing::Values(
435 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
436 tag::nchw, tag::nchw, tag::nchw, {2, 0, 5, 5}, 1},
437 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
438 tag::nhwc, tag::nhwc, tag::nhwc, {2, 19, 16, 64}, 1},
439 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
440 tag::any, tag::nchw, tag::any, {1, 8, 128, 1024}, 3},
441 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
442 tag::nc, tag::nc, tag::nc, {2, 1000}, 0},
443 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
444 tag::nc, tag::cn, tag::cn, {2, 1000}, 1},
445 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
446 tag::any, tag::nc, tag::nc, {1, 13}, 1},
447 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
448 tag::ncw, tag::ncw, tag::ncw, {16, 257, 32}, 1},
449 tp {backward, alg_logsoftmax, dt::bf16, dt::bf16, dt::bf16,
450 tag::ncw, tag::ncw, tag::nwc, {16, 257, 32}, 2},
451 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
452 tag::nChw16c, tag::nChw16c, tag::nChw16c,
453 {64, 1011, 1, 1}, 1},
454 tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16,
455 tag::nChw8c, tag::nhwc, tag::nchw, {3, 1011, 1, 1}, 1},
456 tp {backward, alg_logsoftmax, dt::bf16, dt::bf16, dt::bf16,
457 tag::nChw8c, tag::nChw8c, tag::nChw8c, {2, 1011, 32, 1},
458 2}));
459
460GPU_INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_Half, softmax_test_t,
461 ::testing::Values(
462 tp {training, alg_softmax, dt::f16, dt::f16, dt::undef,
463 tag::nchw, tag::nchw, tag::undef, {2, 0, 5, 5}, 1},
464 tp {training, alg_softmax, dt::f16, dt::f16, dt::undef,
465 tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1},
466 tp {training, alg_softmax, dt::f16, dt::f16, dt::undef,
467 tag::nchw, tag::any, tag::undef, {1, 8, 128, 1024}, 3},
468 tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef,
469 tag::nc, tag::nc, tag::undef, {2, 1000}, 0},
470 tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef,
471 tag::nc, tag::cn, tag::undef, {2, 1000}, 1},
472 tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef,
473 tag::nc, tag::any, tag::undef, {1, 13}, 1},
474 tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef,
475 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1},
476 tp {inference, alg_logsoftmax, dt::f16, dt::f16, dt::undef,
477 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2},
478 tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef,
479 tag::nChw16c, tag::nChw16c, tag::undef,
480 {64, 1011, 1, 1}, 1},
481 tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef,
482 tag::nChw8c, tag::nChw8c, tag::undef, {3, 1011, 1, 1},
483 1},
484 tp {inference, alg_logsoftmax, dt::f16, dt::f16, dt::undef,
485 tag::nChw8c, tag::nChw8c, tag::undef, {2, 1011, 32, 1},
486 2}));
487
488INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_U8, softmax_test_t,
489 ::testing::Values(
490 tp {training, alg_softmax, dt::f32, dt::u8, dt::undef,
491 tag::nhwc, tag::nhwc, tag::undef, {2, 0, 5, 5}, 1},
492 tp {training, alg_softmax, dt::f32, dt::u8, dt::undef,
493 tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1},
494 tp {training, alg_softmax, dt::f32, dt::u8, dt::undef,
495 tag::nhwc, tag::any, tag::undef, {1, 8, 128, 1024}, 3},
496 tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, tag::nc,
497 tag::nc, tag::undef, {2, 1000}, 0},
498 tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, tag::nc,
499 tag::cn, tag::undef, {2, 1000}, 1},
500 tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, tag::nc,
501 tag::any, tag::undef, {1, 13}, 1},
502 tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef,
503 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1},
504 tp {inference, alg_logsoftmax, dt::f32, dt::u8, dt::undef,
505 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2},
506 tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef,
507 tag::nhwc, tag::nhwc, tag::undef, {64, 1011, 1, 1}, 1},
508 tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef,
509 tag::nhwc, tag::nhwc, tag::undef, {3, 1011, 1, 1}, 1},
510 tp {inference, alg_logsoftmax, dt::f32, dt::u8, dt::undef,
511 tag::nhwc, tag::nhwc, tag::undef, {2, 1011, 32, 1},
512 2}));
513
514INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_S8, softmax_test_t,
515 ::testing::Values(
516 tp {training, alg_softmax, dt::f32, dt::s8, dt::undef,
517 tag::nhwc, tag::nhwc, tag::undef, {2, 0, 5, 5}, 1},
518 tp {training, alg_softmax, dt::f32, dt::s8, dt::undef,
519 tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1},
520 tp {training, alg_softmax, dt::f32, dt::s8, dt::undef,
521 tag::nhwc, tag::any, tag::undef, {1, 8, 128, 1024}, 3},
522 tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, tag::nc,
523 tag::nc, tag::undef, {2, 1000}, 0},
524 tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, tag::nc,
525 tag::cn, tag::undef, {2, 1000}, 1},
526 tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, tag::nc,
527 tag::any, tag::undef, {1, 13}, 1},
528 tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef,
529 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1},
530 tp {inference, alg_logsoftmax, dt::f32, dt::s8, dt::undef,
531 tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2},
532 tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef,
533 tag::nhwc, tag::nhwc, tag::undef, {64, 1011, 1, 1}, 1},
534 tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef,
535 tag::nhwc, tag::nhwc, tag::undef, {3, 1011, 1, 1}, 1},
536 tp {inference, alg_logsoftmax, dt::f32, dt::s8, dt::undef,
537 tag::nhwc, tag::nhwc, tag::undef, {2, 1011, 32, 1},
538 2}));
539
540} // namespace dnnl
541