1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "math_utils.hpp"
19#include "oneapi/dnnl/dnnl.hpp"
20#include "test_convolution_eltwise_forward_common.hpp"
21#include "gtest/gtest.h"
22
23namespace dnnl {
24
25using convolution_test = convolution_eltwise_test<float, float, float, float>;
26
27TEST_P(convolution_test, TestConvolutionEltwise) {}
28
29#define EXPAND_FORMATS(src, weights, bias, dst) \
30 { \
31 dnnl::memory::format_tag::src, dnnl::memory::format_tag::weights, \
32 dnnl::memory::format_tag::bias, dnnl::memory::format_tag::dst \
33 }
34
35#define CONCAT_WITH_UNDERSCORE_(a, b) a##_##b
36#define CONCAT_WITH_UNDERSCORE(a, b) CONCAT_WITH_UNDERSCORE_(a, b)
37
38#define CPU_INST_TEST_CASE_(str, ...) \
39 CPU_INSTANTIATE_TEST_SUITE_P( \
40 str, convolution_test, ::testing::Values(__VA_ARGS__))
41
42#define CPU_INST_TEST_CASE(str, ...) \
43 CPU_INST_TEST_CASE_( \
44 CONCAT_WITH_UNDERSCORE( \
45 CONCAT_WITH_UNDERSCORE(Convolution, str), eltwise), \
46 __VA_ARGS__)
47
48#define INST_TEST_CASE_(str, ...) \
49 INSTANTIATE_TEST_SUITE_P( \
50 str, convolution_test, ::testing::Values(__VA_ARGS__))
51
52#define INST_TEST_CASE(str, ...) \
53 INST_TEST_CASE_( \
54 CONCAT_WITH_UNDERSCORE( \
55 CONCAT_WITH_UNDERSCORE(Convolution, str), eltwise), \
56 __VA_ARGS__)
57
58#define EXPAND_ARGS(args) args
59
60#define PARAMS(...) \
61 EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_relu, __VA_ARGS__)), \
62 EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_elu, __VA_ARGS__)), \
63 EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_tanh, __VA_ARGS__)), \
64 EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_square, __VA_ARGS__)), \
65 EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_abs, __VA_ARGS__)), \
66 EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_linear, __VA_ARGS__)), \
67 EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_clip, __VA_ARGS__)), \
68 EXPAND_ARGS( \
69 PARAMS_CONV(algorithm::eltwise_soft_relu, __VA_ARGS__)), \
70 EXPAND_ARGS( \
71 PARAMS_CONV(algorithm::eltwise_logistic, __VA_ARGS__)), \
72 EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_swish, __VA_ARGS__))
73// Not testing due to not scaled output
74// EXPAND_ARGS(PARAMS_CONV(algorithm::eltwise_exp, __VA_ARGS__))
75#define ELTWISE_ALPHA 0.5f
76#define ELTWISE_BETA 1.5f
77
78#define PARAMS_CONV(alg, src, weights, bias, dst, ...) \
79 test_convolution_eltwise_params_t { \
80 alg, dnnl::algorithm::convolution_direct, ELTWISE_ALPHA, ELTWISE_BETA, \
81 EXPAND_FORMATS(src, weights, bias, dst), \
82 /* empty attributes */ {}, { \
83 __VA_ARGS__ \
84 } \
85 }
86
87INST_TEST_CASE(SimpleSmall,
88 PARAMS(nchw, oihw, x, nchw, 2, 1, 32, 13, 13, 48, 11, 11, 3, 3, 0, 0, 1,
89 1),
90 PARAMS(nchw, oihw, x, nchw, 2, 1, 16, 13, 13, 48, 13, 13, 1, 1, 0, 0, 1,
91 1),
92 PARAMS(nchw, goihw, x, nchw, 2, 64, 64, 16, 16, 64, 16, 16, 3, 3, 0, 0,
93 1, 1),
94 PARAMS(nchw, goihw, x, nchw, 2, 32, 32, 9, 9, 32, 9, 9, 1, 1, 0, 0, 1,
95 1));
96
97CPU_INST_TEST_CASE(SimpleSmall_Blocked,
98 PARAMS(nChw8c, Goihw8g, x, nChw8c, 1, 48, 48, 20, 20, 48, 20, 20, 3, 3,
99 1, 1, 1, 1),
100 PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 48, 20, 20, 48, 20, 20, 1, 1,
101 0, 0, 1, 1),
102 PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 48, 20, 20, 48, 20, 20, 3, 3,
103 0, 0, 1, 1));
104
105CPU_INST_TEST_CASE(SimpleSmall_Blocked_Tail,
106 PARAMS(nChw8c, Goihw8g, x, nChw8c, 1, 47, 47, 20, 20, 47, 20, 20, 3, 3,
107 1, 1, 1, 1),
108 PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 47, 20, 20, 47, 20, 20, 1, 1,
109 0, 0, 1, 1),
110 PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 47, 20, 20, 47, 20, 20, 3, 3,
111 0, 0, 1, 1));
112
113INST_TEST_CASE(SimpleSmall_Blocked16,
114 PARAMS(nChw16c, Goihw16g, x, nChw16c, 1, 48, 48, 20, 20, 48, 20, 20, 3,
115 3, 1, 1, 1, 1),
116 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 48, 20, 20, 48, 20, 20, 1,
117 1, 0, 0, 1, 1),
118 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 48, 20, 20, 48, 20, 20, 3,
119 3, 0, 0, 1, 1),
120 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 2, 1, 32, 32, 32, 32, 32, 32, 3,
121 3, 0, 0, 1, 1));
122
123CPU_INST_TEST_CASE(SimpleSmall_Blocked16_Tail,
124 PARAMS(nChw16c, Goihw16g, x, nChw16c, 1, 47, 47, 20, 20, 47, 20, 20, 3,
125 3, 1, 1, 1, 1),
126 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 47, 20, 20, 47, 20, 20, 1,
127 1, 0, 0, 1, 1),
128 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 47, 20, 20, 47, 20, 20, 3,
129 3, 0, 0, 1, 1),
130 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 2, 1, 32, 32, 32, 32, 32, 32, 3,
131 3, 0, 0, 1, 1));
132
133} // namespace dnnl
134