1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "math_utils.hpp" |
19 | #include "oneapi/dnnl/dnnl.hpp" |
20 | #include "test_convolution_eltwise_forward_common.hpp" |
21 | #include "gtest/gtest.h" |
22 | |
23 | namespace dnnl { |
24 | |
25 | using convolution_test = convolution_eltwise_test<float, float, float, float>; |
26 | |
27 | TEST_P(convolution_test, TestConvolutionEltwise) {} |
28 | |
29 | #define EXPAND_FORMATS(src, weights, bias, dst) \ |
30 | { \ |
31 | dnnl::memory::format_tag::src, dnnl::memory::format_tag::weights, \ |
32 | dnnl::memory::format_tag::bias, dnnl::memory::format_tag::dst \ |
33 | } |
34 | |
35 | #define CONCAT_WITH_UNDERSCORE_(a, b) a##_##b |
36 | #define CONCAT_WITH_UNDERSCORE(a, b) CONCAT_WITH_UNDERSCORE_(a, b) |
37 | |
38 | #define CPU_INST_TEST_CASE_(str, ...) \ |
39 | CPU_INSTANTIATE_TEST_SUITE_P( \ |
40 | str, convolution_test, ::testing::Values(__VA_ARGS__)) |
41 | |
42 | #define CPU_INST_TEST_CASE(str, ...) \ |
43 | CPU_INST_TEST_CASE_( \ |
44 | CONCAT_WITH_UNDERSCORE( \ |
45 | CONCAT_WITH_UNDERSCORE(Convolution, str), eltwise), \ |
46 | __VA_ARGS__) |
47 | |
48 | #define INST_TEST_CASE_(str, ...) \ |
49 | INSTANTIATE_TEST_SUITE_P( \ |
50 | str, convolution_test, ::testing::Values(__VA_ARGS__)) |
51 | |
52 | #define INST_TEST_CASE(str, ...) \ |
53 | INST_TEST_CASE_( \ |
54 | CONCAT_WITH_UNDERSCORE( \ |
55 | CONCAT_WITH_UNDERSCORE(Convolution, str), eltwise), \ |
56 | __VA_ARGS__) |
57 | |
58 | #define EXPAND_ARGS(args) args |
59 | |
60 | #define PARAMS(...) \ |
61 | EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_relu, __VA_ARGS__)), \ |
62 | EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_elu, __VA_ARGS__)), \ |
63 | EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_tanh, __VA_ARGS__)), \ |
64 | EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_square, __VA_ARGS__)), \ |
65 | EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_abs, __VA_ARGS__)), \ |
66 | EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_linear, __VA_ARGS__)), \ |
67 | EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_clip, __VA_ARGS__)), \ |
68 | EXPAND_ARGS( \ |
69 | PARAMS_CONV(algorithm::eltwise_soft_relu, __VA_ARGS__)), \ |
70 | EXPAND_ARGS( \ |
71 | PARAMS_CONV(algorithm::eltwise_logistic, __VA_ARGS__)), \ |
72 | EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_swish, __VA_ARGS__)) |
73 | // Not testing due to not scaled output |
74 | // EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_exp, __VA_ARGS__)) |
75 | #define ELTWISE_ALPHA 0.5f |
76 | #define ELTWISE_BETA 1.5f |
77 | |
78 | #define PARAMS_CONV(alg, src, weights, bias, dst, ...) \ |
79 | test_convolution_eltwise_params_t { \ |
80 | alg, dnnl::algorithm::convolution_direct, ELTWISE_ALPHA, ELTWISE_BETA, \ |
81 | EXPAND_FORMATS(src, weights, bias, dst), \ |
82 | /* empty attributes */ {}, { \ |
83 | __VA_ARGS__ \ |
84 | } \ |
85 | } |
86 | |
87 | INST_TEST_CASE(SimpleSmall, |
88 | PARAMS(nchw, oihw, x, nchw, 2, 1, 32, 13, 13, 48, 11, 11, 3, 3, 0, 0, 1, |
89 | 1), |
90 | PARAMS(nchw, oihw, x, nchw, 2, 1, 16, 13, 13, 48, 13, 13, 1, 1, 0, 0, 1, |
91 | 1), |
92 | PARAMS(nchw, goihw, x, nchw, 2, 64, 64, 16, 16, 64, 16, 16, 3, 3, 0, 0, |
93 | 1, 1), |
94 | PARAMS(nchw, goihw, x, nchw, 2, 32, 32, 9, 9, 32, 9, 9, 1, 1, 0, 0, 1, |
95 | 1)); |
96 | |
97 | CPU_INST_TEST_CASE(SimpleSmall_Blocked, |
98 | PARAMS(nChw8c, Goihw8g, x, nChw8c, 1, 48, 48, 20, 20, 48, 20, 20, 3, 3, |
99 | 1, 1, 1, 1), |
100 | PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 48, 20, 20, 48, 20, 20, 1, 1, |
101 | 0, 0, 1, 1), |
102 | PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 48, 20, 20, 48, 20, 20, 3, 3, |
103 | 0, 0, 1, 1)); |
104 | |
105 | CPU_INST_TEST_CASE(SimpleSmall_Blocked_Tail, |
106 | PARAMS(nChw8c, Goihw8g, x, nChw8c, 1, 47, 47, 20, 20, 47, 20, 20, 3, 3, |
107 | 1, 1, 1, 1), |
108 | PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 47, 20, 20, 47, 20, 20, 1, 1, |
109 | 0, 0, 1, 1), |
110 | PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 47, 20, 20, 47, 20, 20, 3, 3, |
111 | 0, 0, 1, 1)); |
112 | |
113 | INST_TEST_CASE(SimpleSmall_Blocked16, |
114 | PARAMS(nChw16c, Goihw16g, x, nChw16c, 1, 48, 48, 20, 20, 48, 20, 20, 3, |
115 | 3, 1, 1, 1, 1), |
116 | PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 48, 20, 20, 48, 20, 20, 1, |
117 | 1, 0, 0, 1, 1), |
118 | PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 48, 20, 20, 48, 20, 20, 3, |
119 | 3, 0, 0, 1, 1), |
120 | PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 2, 1, 32, 32, 32, 32, 32, 32, 3, |
121 | 3, 0, 0, 1, 1)); |
122 | |
123 | CPU_INST_TEST_CASE(SimpleSmall_Blocked16_Tail, |
124 | PARAMS(nChw16c, Goihw16g, x, nChw16c, 1, 47, 47, 20, 20, 47, 20, 20, 3, |
125 | 3, 1, 1, 1, 1), |
126 | PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 47, 20, 20, 47, 20, 20, 1, |
127 | 1, 0, 0, 1, 1), |
128 | PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 47, 20, 20, 47, 20, 20, 3, |
129 | 3, 0, 0, 1, 1), |
130 | PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 2, 1, 32, 32, 32, 32, 32, 32, 3, |
131 | 3, 0, 0, 1, 1)); |
132 | |
133 | } // namespace dnnl |
134 | |