1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | using tag = memory::format_tag; |
25 | using dt = memory::data_type; |
26 | |
27 | struct softmax_test_params_t { |
28 | prop_kind aprop_kind; |
29 | algorithm aalgorithm; |
30 | dt src_dt; // diff_src_dt |
31 | dt dst_dt; |
32 | dt diff_dst_dt; |
33 | tag src_tag; // diff_src_tag |
34 | tag dst_tag; |
35 | tag diff_dst_tag; |
36 | memory::dims dims; |
37 | int axis; |
38 | bool expect_to_fail; |
39 | dnnl_status_t expected_status; |
40 | }; |
41 | |
42 | class softmax_test_t : public ::testing::TestWithParam<softmax_test_params_t> { |
43 | private: |
44 | softmax_test_params_t p; |
45 | memory dst, workspace; |
46 | std::shared_ptr<softmax_forward::primitive_desc> pd_fwd_hint; |
47 | |
48 | protected: |
49 | void SetUp() override { |
50 | p = ::testing::TestWithParam<softmax_test_params_t>::GetParam(); |
51 | |
52 | SKIP_IF_CUDA( |
53 | !cuda_check_format_tag(p.src_tag), "Unsupported format tag" ); |
54 | SKIP_IF_CUDA( |
55 | !cuda_check_format_tag(p.dst_tag), "Unsupported format tag" ); |
56 | |
57 | SKIP_IF_HIP(!hip_check_format_tag(p.src_tag), "Unsupported format tag" ); |
58 | SKIP_IF_HIP(!hip_check_format_tag(p.dst_tag), "Unsupported format tag" ); |
59 | |
60 | if (!is_fwd(p.aprop_kind)) { |
61 | SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_dst_tag), |
62 | "Unsupported format tag" ); |
63 | |
64 | SKIP_IF_HIP(!hip_check_format_tag(p.diff_dst_tag), |
65 | "Unsupported format tag" ); |
66 | } |
67 | SKIP_IF_CUDA((p.src_dt == dt::bf16 || p.dst_dt == dt::bf16), |
68 | "Unsupported datatype for CUDA" ); |
69 | SKIP_IF_HIP((p.src_dt == dt::bf16 || p.dst_dt == dt::bf16), |
70 | "Unsupported datatype for HIP" ); |
71 | if (!is_fwd(p.aprop_kind)) { |
72 | SKIP_IF_CUDA((p.diff_dst_dt == dt::bf16), |
73 | "Unsupported datatype for CUDA" ); |
74 | SKIP_IF_HIP((p.diff_dst_dt == dt::bf16), |
75 | "Unsupported datatype for HIP" ); |
76 | } |
77 | |
78 | SKIP_IF(unsupported_data_type(p.src_dt) |
79 | || unsupported_data_type(p.dst_dt), |
80 | "Engine does not support this data type." ); |
81 | if (!is_fwd(p.aprop_kind)) { |
82 | SKIP_IF(unsupported_data_type(p.diff_dst_dt), |
83 | "Engine does not support this data type." ); |
84 | } |
85 | |
86 | SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef |
87 | && p.dst_dt != dt::undef, |
88 | "Unsupported different data types for source and " |
89 | "destination" ); |
90 | SKIP_IF_HIP(p.src_dt != p.dst_dt && p.src_dt != dt::undef |
91 | && p.dst_dt != dt::undef, |
92 | "Unsupported different data types for source and " |
93 | "destination" ); |
94 | SKIP_IF_CUDA(!is_fwd(p.aprop_kind) && p.src_dt != p.diff_dst_dt |
95 | && p.src_dt != dt::undef && p.diff_dst_dt != dt::undef, |
96 | "Unsupported different data types for diff_source and " |
97 | "diff_destination" ); |
98 | SKIP_IF_HIP(!is_fwd(p.aprop_kind) && p.src_dt != p.diff_dst_dt |
99 | && p.src_dt != dt::undef && p.diff_dst_dt != dt::undef, |
100 | "Unsupported different data types for diff_source and " |
101 | "diff_destination" ); |
102 | |
103 | SKIP_IF_CUDA(p.src_tag != p.dst_tag && p.src_tag != tag::any |
104 | && p.dst_tag != tag::any, |
105 | "Unsupported different memory formats for source and " |
106 | "destination" ); |
107 | SKIP_IF_HIP(p.src_tag != p.dst_tag && p.src_tag != tag::any |
108 | && p.dst_tag != tag::any, |
109 | "Unsupported different memory formats for source and " |
110 | "destination" ); |
111 | SKIP_IF_CUDA(!is_fwd(p.aprop_kind) && p.src_tag != p.diff_dst_tag |
112 | && p.src_tag != tag::any && p.diff_dst_tag != tag::any, |
113 | "Unsupported different memory formats for diff_source and " |
114 | "diff_destination" ); |
115 | SKIP_IF_HIP(!is_fwd(p.aprop_kind) && p.src_tag != p.diff_dst_tag |
116 | && p.src_tag != tag::any && p.diff_dst_tag != tag::any, |
117 | "Unsupported different memory formats for diff_source and " |
118 | "diff_destination" ); |
119 | SKIP_IF_CUDA(p.dst_dt == dt::u8 || p.dst_dt == dt::s8, |
120 | "Unsupported int8 destination data type" ); |
121 | SKIP_IF_HIP(p.dst_dt == dt::u8 || p.dst_dt == dt::s8, |
122 | "Unsupported int8 destination data type" ); |
123 | |
124 | SKIP_IF_HIP(p.axis != 1, "Unsupported axis for HIP" ); |
125 | |
126 | catch_expected_failures( |
127 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
128 | } |
129 | bool cuda_check_format_tag(memory::format_tag tag) { |
130 | return (tag != memory::format_tag::aBcd8b |
131 | && tag != memory::format_tag::aBcd16b); |
132 | } |
133 | bool hip_check_format_tag(memory::format_tag tag) { |
134 | return (tag == memory::format_tag::a || tag == memory::format_tag::ab |
135 | || tag == memory::format_tag::abc |
136 | || tag == memory::format_tag::abcd |
137 | || tag == memory::format_tag::abcde); |
138 | } |
139 | void Forward() { |
140 | // softmax specific types and values |
141 | using pd_t = softmax_forward::primitive_desc; |
142 | |
143 | auto eng = get_test_engine(); |
144 | auto strm = make_stream(eng); |
145 | prop_kind pk = !is_fwd(p.aprop_kind) ? prop_kind::forward_training |
146 | : p.aprop_kind; |
147 | |
148 | allows_attr_t aa {false}; |
149 | if (!(is_nvidia_gpu(eng) || is_amd_gpu(eng))) { aa.scales = true; } |
150 | |
151 | // To validate backward on valid tag::any settings reuse dst tag. |
152 | const bool src_bwd_any = !is_fwd(p.aprop_kind) && p.src_tag == tag::any; |
153 | auto src_tag = src_bwd_any ? p.dst_tag : p.src_tag; |
154 | |
155 | auto src_md = memory::desc(p.dims, p.src_dt, src_tag); |
156 | auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
157 | |
158 | // default pd ctor |
159 | auto pd = pd_t(); |
160 | // regular pd ctor |
161 | pd = pd_t(eng, pk, p.aalgorithm, src_md, dst_md, p.axis); |
162 | // test all pd ctors |
163 | test_fwd_pd_constructors<pd_t>( |
164 | pd, aa, pk, p.aalgorithm, src_md, dst_md, p.axis); |
165 | pd_fwd_hint = std::make_shared<pd_t>(pd); |
166 | |
167 | EXPECT_ANY_THROW(softmax_forward(pd, {})); |
168 | // default primitive ctor |
169 | auto softmax = softmax_forward(); |
170 | // regular primitive ctor |
171 | softmax = softmax_forward(pd); |
172 | |
173 | // check primitive kind is softmax |
174 | ASSERT_TRUE(softmax.get_kind() == primitive::kind::softmax); |
175 | // query for descs from pd |
176 | const auto src_desc = pd.src_desc(); |
177 | const auto dst_desc = pd.dst_desc(); |
178 | // query for src_desc via exec arg |
179 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
180 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
181 | // query for dst_desc via exec arg |
182 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
183 | if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); } |
184 | |
185 | // query primitive parameters |
186 | ASSERT_EQ(pd.get_prop_kind(), pk); |
187 | ASSERT_EQ(pd.get_axis(), p.axis); |
188 | ASSERT_EQ(pd.get_algorithm(), p.aalgorithm); |
189 | |
190 | // query for workspace |
191 | const auto workspace_desc = pd.workspace_desc(); |
192 | |
193 | // check primitive returns zero_md for all rest md |
194 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
195 | ASSERT_TRUE(pd.diff_src_desc().is_zero()); |
196 | ASSERT_TRUE(pd.diff_dst_desc().is_zero()); |
197 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
198 | |
199 | auto src = test::make_memory(src_desc, eng); |
200 | dst = test::make_memory(dst_desc, eng); |
201 | workspace = test::make_memory(workspace_desc, eng); |
202 | |
203 | fill_data(p.src_dt, src, 1, 1); |
204 | // test out-place mode |
205 | softmax.execute(strm, |
206 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, |
207 | {DNNL_ARG_WORKSPACE, workspace}}); |
208 | strm.wait(); |
209 | |
210 | // test in-place mode on forward |
211 | if (p.aprop_kind != prop_kind::backward_data && p.src_tag == p.dst_tag |
212 | && p.src_dt == p.dst_dt) { |
213 | softmax.execute(strm, |
214 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, src}, |
215 | {DNNL_ARG_WORKSPACE, workspace}}); |
216 | strm.wait(); |
217 | } |
218 | } |
219 | |
220 | void Backward() { |
221 | // softmax specific types and values |
222 | using pd_t = softmax_backward::primitive_desc; |
223 | using hint_pd_t = softmax_forward::primitive_desc; |
224 | allows_attr_t aa {false}; // doesn't support anything |
225 | |
226 | auto eng = get_test_engine(); |
227 | auto strm = make_stream(eng); |
228 | auto diff_src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
229 | auto diff_dst_md = memory::desc(p.dims, p.diff_dst_dt, p.diff_dst_tag); |
230 | auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
231 | |
232 | // default pd ctor |
233 | auto pd = pd_t(); |
234 | // regular pd ctor |
235 | pd = pd_t(eng, p.aalgorithm, diff_src_md, diff_dst_md, dst_md, p.axis, |
236 | *pd_fwd_hint); |
237 | // test all pd ctors |
238 | test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa, |
239 | p.aalgorithm, diff_src_md, diff_dst_md, dst_md, p.axis); |
240 | |
241 | EXPECT_ANY_THROW(softmax_backward(pd, {})); |
242 | // default primitive ctor |
243 | auto softmax = softmax_backward(); |
244 | // regular primitive ctor |
245 | softmax = softmax_backward(pd); |
246 | |
247 | // check primitive kind is softmax |
248 | ASSERT_TRUE(softmax.get_kind() == primitive::kind::softmax); |
249 | |
250 | // query for descs from pd |
251 | const auto diff_src_desc = pd.diff_src_desc(); |
252 | const auto diff_dst_desc = pd.diff_dst_desc(); |
253 | const auto dst_desc = pd.dst_desc(); |
254 | // query for diff_src_desc via exec arg |
255 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
256 | == diff_src_desc); |
257 | if (p.src_tag != tag::any) { |
258 | ASSERT_TRUE(diff_src_md == diff_src_desc); |
259 | } |
260 | // query for diff_dst_desc via exec arg |
261 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST) |
262 | == diff_dst_desc); |
263 | if (p.diff_dst_tag != tag::any) { |
264 | ASSERT_TRUE(diff_dst_md == diff_dst_desc); |
265 | } |
266 | // query for dst_desc via exec arg |
267 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
268 | if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); } |
269 | |
270 | // query primitive parameters |
271 | ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward_data); |
272 | ASSERT_EQ(pd.get_axis(), p.axis); |
273 | ASSERT_EQ(pd.get_algorithm(), p.aalgorithm); |
274 | |
275 | // check primitive returns zero_md for all rest md |
276 | ASSERT_TRUE(pd.src_desc().is_zero()); |
277 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
278 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
279 | |
280 | auto diff_src = test::make_memory(diff_src_desc, eng); |
281 | auto diff_dst = test::make_memory(diff_dst_desc, eng); |
282 | |
283 | fill_data(p.diff_dst_dt, diff_dst, 2, 2); |
284 | |
285 | // test out-place mode |
286 | softmax.execute(strm, |
287 | {{DNNL_ARG_DST, dst}, {DNNL_ARG_DIFF_DST, diff_dst}, |
288 | {DNNL_ARG_DIFF_SRC, diff_src}, |
289 | {DNNL_ARG_WORKSPACE, workspace}}); |
290 | strm.wait(); |
291 | |
292 | // test in-place mode |
293 | if (p.src_tag == p.diff_dst_tag && p.src_dt == p.diff_dst_dt) { |
294 | softmax.execute(strm, |
295 | {{DNNL_ARG_DST, dst}, {DNNL_ARG_DIFF_DST, diff_dst}, |
296 | {DNNL_ARG_DIFF_SRC, diff_dst}, |
297 | {DNNL_ARG_WORKSPACE, workspace}}); |
298 | strm.wait(); |
299 | } |
300 | } |
301 | |
302 | void Test() { |
303 | Forward(); |
304 | if (!is_fwd(p.aprop_kind)) Backward(); |
305 | } |
306 | |
307 | bool is_fwd(prop_kind pk) const { |
308 | return pk == prop_kind::forward_training |
309 | || pk == prop_kind::forward_inference; |
310 | } |
311 | }; |
312 | |
313 | using tp = softmax_test_params_t; |
314 | |
315 | static const auto training = prop_kind::forward_training; |
316 | static const auto inference = prop_kind::forward_inference; |
317 | static const auto backward = prop_kind::backward_data; |
318 | static const auto alg_softmax = algorithm::softmax_accurate; |
319 | static const auto alg_logsoftmax = algorithm::softmax_log; |
320 | |
321 | TEST_P(softmax_test_t, TestsSoftmax) {} |
322 | |
323 | INSTANTIATE_TEST_SUITE_P(Test_Softmax_EF, softmax_test_t, |
324 | ::testing::Values( |
325 | // Negative dims |
326 | tp {training, alg_softmax, dt::f32, dt::f32, dt::undef, |
327 | tag::nchw, tag::nchw, tag::undef, {2, -2, 128, 256}, 0, |
328 | true, dnnl_invalid_arguments}, |
329 | // Axis exceeds ndims |
330 | tp {training, alg_softmax, dt::f32, dt::f32, dt::undef, |
331 | tag::nchw, tag::nchw, tag::undef, {2, 2, 128, 256}, 10, |
332 | true, dnnl_invalid_arguments}, |
333 | // Not supported algorithm |
334 | tp {training, algorithm::eltwise_relu, dt::f32, dt::f32, |
335 | dt::undef, tag::nchw, tag::nchw, tag::undef, |
336 | {2, 2, 128, 256}, 3, true, dnnl_invalid_arguments}, |
337 | // Tag for src on forward is not specified |
338 | tp {training, alg_softmax, dt::f32, dt::f32, dt::undef, |
339 | tag::any, tag::nchw, tag::undef, {2, 2, 128, 256}, 3, |
340 | true, dnnl_invalid_arguments}, |
341 | // Tag for dst on backward is not specified |
342 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nchw, |
343 | tag::any, tag::nchw, {2, 2, 128, 256}, 3, true, |
344 | dnnl_invalid_arguments}, |
345 | // Data type for src is not specified |
346 | tp {training, alg_softmax, dt::undef, dt::f32, dt::undef, |
347 | tag::nchw, tag::nchw, tag::undef, {2, 2, 128, 256}, 3, |
348 | true, dnnl_invalid_arguments})); |
349 | |
350 | INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_Float, softmax_test_t, |
351 | ::testing::Values( |
352 | tp {training, alg_softmax, dt::f32, dt::f32, dt::undef, |
353 | tag::nchw, tag::nchw, tag::undef, {2, 0, 5, 5}, 1}, |
354 | tp {training, alg_softmax, dt::f32, dt::f32, dt::undef, |
355 | tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1}, |
356 | tp {training, alg_softmax, dt::f32, dt::f32, dt::undef, |
357 | tag::nchw, tag::any, tag::undef, {1, 8, 128, 1024}, 3}, |
358 | tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef, |
359 | tag::nc, tag::nc, tag::undef, {2, 1000}, 0}, |
360 | tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef, |
361 | tag::nc, tag::cn, tag::undef, {2, 1000}, 1}, |
362 | tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef, |
363 | tag::nc, tag::any, tag::undef, {1, 13}, 1}, |
364 | tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef, |
365 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1}, |
366 | tp {inference, alg_logsoftmax, dt::f32, dt::f32, dt::undef, |
367 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2}, |
368 | tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef, |
369 | tag::nChw16c, tag::nChw16c, tag::undef, |
370 | {64, 1011, 1, 1}, 1}, |
371 | tp {inference, alg_softmax, dt::f32, dt::f32, dt::undef, |
372 | tag::nChw8c, tag::nChw8c, tag::undef, {3, 1011, 1, 1}, |
373 | 1}, |
374 | tp {inference, alg_logsoftmax, dt::f32, dt::f32, dt::undef, |
375 | tag::nChw8c, tag::nChw8c, tag::undef, {2, 1011, 32, 1}, |
376 | 2})); |
377 | |
378 | INSTANTIATE_TEST_SUITE_P(Test_Softmax_Backward_Float, softmax_test_t, |
379 | ::testing::Values( |
380 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nchw, |
381 | tag::nchw, tag::nchw, {2, 0, 5, 5}, 1}, |
382 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nhwc, |
383 | tag::nhwc, tag::nhwc, {2, 19, 16, 64}, 1}, |
384 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::any, |
385 | tag::nchw, tag::any, {1, 8, 128, 1024}, 3}, |
386 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nc, |
387 | tag::nc, tag::nc, {2, 1000}, 0}, |
388 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::nc, |
389 | tag::cn, tag::cn, {2, 1000}, 1}, |
390 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::any, |
391 | tag::nc, tag::nc, {1, 13}, 1}, |
392 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, tag::ncw, |
393 | tag::ncw, tag::ncw, {16, 257, 32}, 1}, |
394 | tp {backward, alg_logsoftmax, dt::f32, dt::f32, dt::f32, |
395 | tag::ncw, tag::ncw, tag::nwc, {16, 257, 32}, 2}, |
396 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, |
397 | tag::nChw16c, tag::nChw16c, tag::nChw16c, |
398 | {64, 1011, 1, 1}, 1}, |
399 | tp {backward, alg_softmax, dt::f32, dt::f32, dt::f32, |
400 | tag::nChw8c, tag::nhwc, tag::nchw, {3, 1011, 1, 1}, 1}, |
401 | tp {backward, alg_logsoftmax, dt::f32, dt::f32, dt::f32, |
402 | tag::nChw8c, tag::nChw8c, tag::nChw8c, {2, 1011, 32, 1}, |
403 | 2})); |
404 | |
405 | INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_Bfloat16, softmax_test_t, |
406 | ::testing::Values( |
407 | tp {training, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
408 | tag::nchw, tag::nchw, tag::undef, {2, 0, 5, 5}, 1}, |
409 | tp {training, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
410 | tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1}, |
411 | tp {training, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
412 | tag::nchw, tag::any, tag::undef, {1, 8, 128, 1024}, 3}, |
413 | tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
414 | tag::nc, tag::nc, tag::undef, {2, 1000}, 0}, |
415 | tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
416 | tag::nc, tag::cn, tag::undef, {2, 1000}, 1}, |
417 | tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
418 | tag::nc, tag::any, tag::undef, {1, 13}, 1}, |
419 | tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
420 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1}, |
421 | tp {inference, alg_logsoftmax, dt::bf16, dt::bf16, dt::undef, |
422 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2}, |
423 | tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
424 | tag::nChw16c, tag::nChw16c, tag::undef, |
425 | {64, 1011, 1, 1}, 1}, |
426 | tp {inference, alg_softmax, dt::bf16, dt::bf16, dt::undef, |
427 | tag::nChw8c, tag::nChw8c, tag::undef, {3, 1011, 1, 1}, |
428 | 1}, |
429 | tp {inference, alg_logsoftmax, dt::bf16, dt::bf16, dt::undef, |
430 | tag::nChw8c, tag::nChw8c, tag::undef, {2, 1011, 32, 1}, |
431 | 2})); |
432 | |
433 | INSTANTIATE_TEST_SUITE_P(Test_Softmax_Backward_Bfloat16, softmax_test_t, |
434 | ::testing::Values( |
435 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
436 | tag::nchw, tag::nchw, tag::nchw, {2, 0, 5, 5}, 1}, |
437 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
438 | tag::nhwc, tag::nhwc, tag::nhwc, {2, 19, 16, 64}, 1}, |
439 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
440 | tag::any, tag::nchw, tag::any, {1, 8, 128, 1024}, 3}, |
441 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
442 | tag::nc, tag::nc, tag::nc, {2, 1000}, 0}, |
443 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
444 | tag::nc, tag::cn, tag::cn, {2, 1000}, 1}, |
445 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
446 | tag::any, tag::nc, tag::nc, {1, 13}, 1}, |
447 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
448 | tag::ncw, tag::ncw, tag::ncw, {16, 257, 32}, 1}, |
449 | tp {backward, alg_logsoftmax, dt::bf16, dt::bf16, dt::bf16, |
450 | tag::ncw, tag::ncw, tag::nwc, {16, 257, 32}, 2}, |
451 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
452 | tag::nChw16c, tag::nChw16c, tag::nChw16c, |
453 | {64, 1011, 1, 1}, 1}, |
454 | tp {backward, alg_softmax, dt::bf16, dt::bf16, dt::bf16, |
455 | tag::nChw8c, tag::nhwc, tag::nchw, {3, 1011, 1, 1}, 1}, |
456 | tp {backward, alg_logsoftmax, dt::bf16, dt::bf16, dt::bf16, |
457 | tag::nChw8c, tag::nChw8c, tag::nChw8c, {2, 1011, 32, 1}, |
458 | 2})); |
459 | |
460 | GPU_INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_Half, softmax_test_t, |
461 | ::testing::Values( |
462 | tp {training, alg_softmax, dt::f16, dt::f16, dt::undef, |
463 | tag::nchw, tag::nchw, tag::undef, {2, 0, 5, 5}, 1}, |
464 | tp {training, alg_softmax, dt::f16, dt::f16, dt::undef, |
465 | tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1}, |
466 | tp {training, alg_softmax, dt::f16, dt::f16, dt::undef, |
467 | tag::nchw, tag::any, tag::undef, {1, 8, 128, 1024}, 3}, |
468 | tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef, |
469 | tag::nc, tag::nc, tag::undef, {2, 1000}, 0}, |
470 | tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef, |
471 | tag::nc, tag::cn, tag::undef, {2, 1000}, 1}, |
472 | tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef, |
473 | tag::nc, tag::any, tag::undef, {1, 13}, 1}, |
474 | tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef, |
475 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1}, |
476 | tp {inference, alg_logsoftmax, dt::f16, dt::f16, dt::undef, |
477 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2}, |
478 | tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef, |
479 | tag::nChw16c, tag::nChw16c, tag::undef, |
480 | {64, 1011, 1, 1}, 1}, |
481 | tp {inference, alg_softmax, dt::f16, dt::f16, dt::undef, |
482 | tag::nChw8c, tag::nChw8c, tag::undef, {3, 1011, 1, 1}, |
483 | 1}, |
484 | tp {inference, alg_logsoftmax, dt::f16, dt::f16, dt::undef, |
485 | tag::nChw8c, tag::nChw8c, tag::undef, {2, 1011, 32, 1}, |
486 | 2})); |
487 | |
488 | INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_U8, softmax_test_t, |
489 | ::testing::Values( |
490 | tp {training, alg_softmax, dt::f32, dt::u8, dt::undef, |
491 | tag::nhwc, tag::nhwc, tag::undef, {2, 0, 5, 5}, 1}, |
492 | tp {training, alg_softmax, dt::f32, dt::u8, dt::undef, |
493 | tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1}, |
494 | tp {training, alg_softmax, dt::f32, dt::u8, dt::undef, |
495 | tag::nhwc, tag::any, tag::undef, {1, 8, 128, 1024}, 3}, |
496 | tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, tag::nc, |
497 | tag::nc, tag::undef, {2, 1000}, 0}, |
498 | tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, tag::nc, |
499 | tag::cn, tag::undef, {2, 1000}, 1}, |
500 | tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, tag::nc, |
501 | tag::any, tag::undef, {1, 13}, 1}, |
502 | tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, |
503 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1}, |
504 | tp {inference, alg_logsoftmax, dt::f32, dt::u8, dt::undef, |
505 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2}, |
506 | tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, |
507 | tag::nhwc, tag::nhwc, tag::undef, {64, 1011, 1, 1}, 1}, |
508 | tp {inference, alg_softmax, dt::f32, dt::u8, dt::undef, |
509 | tag::nhwc, tag::nhwc, tag::undef, {3, 1011, 1, 1}, 1}, |
510 | tp {inference, alg_logsoftmax, dt::f32, dt::u8, dt::undef, |
511 | tag::nhwc, tag::nhwc, tag::undef, {2, 1011, 32, 1}, |
512 | 2})); |
513 | |
514 | INSTANTIATE_TEST_SUITE_P(Test_Softmax_Forward_S8, softmax_test_t, |
515 | ::testing::Values( |
516 | tp {training, alg_softmax, dt::f32, dt::s8, dt::undef, |
517 | tag::nhwc, tag::nhwc, tag::undef, {2, 0, 5, 5}, 1}, |
518 | tp {training, alg_softmax, dt::f32, dt::s8, dt::undef, |
519 | tag::nhwc, tag::nhwc, tag::undef, {2, 19, 16, 64}, 1}, |
520 | tp {training, alg_softmax, dt::f32, dt::s8, dt::undef, |
521 | tag::nhwc, tag::any, tag::undef, {1, 8, 128, 1024}, 3}, |
522 | tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, tag::nc, |
523 | tag::nc, tag::undef, {2, 1000}, 0}, |
524 | tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, tag::nc, |
525 | tag::cn, tag::undef, {2, 1000}, 1}, |
526 | tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, tag::nc, |
527 | tag::any, tag::undef, {1, 13}, 1}, |
528 | tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, |
529 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 1}, |
530 | tp {inference, alg_logsoftmax, dt::f32, dt::s8, dt::undef, |
531 | tag::ncw, tag::ncw, tag::undef, {16, 257, 32}, 2}, |
532 | tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, |
533 | tag::nhwc, tag::nhwc, tag::undef, {64, 1011, 1, 1}, 1}, |
534 | tp {inference, alg_softmax, dt::f32, dt::s8, dt::undef, |
535 | tag::nhwc, tag::nhwc, tag::undef, {3, 1011, 1, 1}, 1}, |
536 | tp {inference, alg_logsoftmax, dt::f32, dt::s8, dt::undef, |
537 | tag::nhwc, tag::nhwc, tag::undef, {2, 1011, 32, 1}, |
538 | 2})); |
539 | |
540 | } // namespace dnnl |
541 | |