1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain src copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | #define BCAST 1 |
23 | #define NO_BCAST 8 |
24 | |
25 | #define CASE(ndims, tag) \ |
26 | case ndims: return memory::format_tag::tag; |
27 | |
28 | namespace dnnl { |
29 | |
30 | memory::format_tag plain_format_tag(size_t ndims) { |
31 | assert(ndims <= 12); |
32 | switch (ndims) { |
33 | CASE(1, a) |
34 | CASE(2, ab) |
35 | CASE(3, abc) |
36 | CASE(4, abcd) |
37 | CASE(5, abcde) |
38 | CASE(6, abcdef) |
39 | CASE(7, abcdefg) |
40 | CASE(8, abcdefgh) |
41 | CASE(9, abcdefghi) |
42 | CASE(10, abcdefghij) |
43 | CASE(11, abcdefghijk) |
44 | CASE(12, abcdefghijkl) |
45 | default: return memory::format_tag::any; |
46 | } |
47 | } |
48 | |
49 | struct binary_bcast_test_t |
50 | : public ::testing::TestWithParam< |
51 | std::tuple<engine::kind, memory::dims, bool>> {}; |
52 | |
53 | HANDLE_EXCEPTIONS_FOR_TEST_P( |
54 | binary_bcast_test_t, TestBinaryOptimizedBroadcast) { |
55 | auto engine_kind = std::get<0>(GetParam()); |
56 | SKIP_IF(engine_kind != get_test_engine_kind(), |
57 | "Test prepared for a different engine kind" ); |
58 | SKIP_IF(!IMPLICATION(engine_kind == engine::kind::cpu, DNNL_X64), |
59 | "Binary impl_info_str should be same only on x64 CPU" ); |
60 | engine e {engine_kind, 0}; |
61 | |
62 | const auto &src1_bcast_dims = std::get<1>(GetParam()); |
63 | const size_t ndims = src1_bcast_dims.size(); |
64 | |
65 | constexpr auto defualt_dt = memory::data_type::f32; |
66 | const auto default_format = plain_format_tag(ndims); |
67 | memory::dims default_dims; |
68 | for (size_t d = 0; d < ndims; d++) |
69 | default_dims.push_back(NO_BCAST); |
70 | |
71 | std::string impl_info_no_bcast, impl_info_bcast; |
72 | |
73 | auto src0_md = memory::desc(default_dims, defualt_dt, default_format, true); |
74 | auto src1_md = memory::desc(default_dims, defualt_dt, default_format, true); |
75 | auto dst_md = memory::desc(default_dims, defualt_dt, default_format, true); |
76 | |
77 | auto binary_pd = binary::primitive_desc( |
78 | e, algorithm::binary_add, src0_md, src1_md, dst_md); |
79 | ASSERT_NO_THROW(impl_info_no_bcast = binary_pd.impl_info_str();); |
80 | |
81 | memory::desc src1_bcast_md( |
82 | src1_bcast_dims, defualt_dt, default_format, true); |
83 | |
84 | binary_pd = binary::primitive_desc( |
85 | e, algorithm::binary_add, src0_md, src1_bcast_md, dst_md); |
86 | |
87 | ASSERT_NO_THROW(impl_info_bcast = binary_pd.impl_info_str();); |
88 | |
89 | const auto expect_jit = std::get<2>(GetParam()); |
90 | if (expect_jit) |
91 | ASSERT_EQ(impl_info_no_bcast, impl_info_bcast); |
92 | else |
93 | ASSERT_NE(impl_info_no_bcast, impl_info_bcast); |
94 | } |
95 | |
96 | INSTANTIATE_TEST_SUITE_P(CPUOptimizedDims, binary_bcast_test_t, |
97 | ::testing::Values( |
98 | // 5d cases |
99 | std::make_tuple(engine::kind::cpu, |
100 | memory::dims {NO_BCAST, NO_BCAST, BCAST, BCAST, BCAST}, |
101 | true), |
102 | std::make_tuple(engine::kind::cpu, |
103 | memory::dims {BCAST, NO_BCAST, BCAST, BCAST, BCAST}, |
104 | true), |
105 | std::make_tuple(engine::kind::cpu, |
106 | memory::dims { |
107 | NO_BCAST, BCAST, NO_BCAST, NO_BCAST, NO_BCAST}, |
108 | true), |
109 | std::make_tuple(engine::kind::cpu, |
110 | memory::dims { |
111 | NO_BCAST, BCAST, BCAST, NO_BCAST, NO_BCAST}, |
112 | true), |
113 | std::make_tuple(engine::kind::cpu, |
114 | memory::dims {NO_BCAST, BCAST, BCAST, BCAST, NO_BCAST}, |
115 | true), |
116 | std::make_tuple(engine::kind::cpu, |
117 | memory::dims { |
118 | BCAST, NO_BCAST, NO_BCAST, NO_BCAST, NO_BCAST}, |
119 | true), |
120 | std::make_tuple(engine::kind::cpu, |
121 | memory::dims { |
122 | BCAST, BCAST, NO_BCAST, NO_BCAST, NO_BCAST}, |
123 | true), |
124 | std::make_tuple(engine::kind::cpu, |
125 | memory::dims {BCAST, BCAST, BCAST, NO_BCAST, NO_BCAST}, |
126 | true), |
127 | std::make_tuple(engine::kind::cpu, |
128 | memory::dims {BCAST, BCAST, BCAST, BCAST, NO_BCAST}, |
129 | true), |
130 | std::make_tuple(engine::kind::cpu, |
131 | memory::dims {BCAST, BCAST, BCAST, BCAST, BCAST}, true), |
132 | // 4d cases |
133 | std::make_tuple(engine::kind::cpu, |
134 | memory::dims {NO_BCAST, NO_BCAST, BCAST, BCAST}, true), |
135 | std::make_tuple(engine::kind::cpu, |
136 | memory::dims {BCAST, NO_BCAST, BCAST, BCAST}, true), |
137 | std::make_tuple(engine::kind::cpu, |
138 | memory::dims {NO_BCAST, BCAST, NO_BCAST, NO_BCAST}, |
139 | true), |
140 | std::make_tuple(engine::kind::cpu, |
141 | memory::dims {NO_BCAST, BCAST, BCAST, NO_BCAST}, true), |
142 | std::make_tuple(engine::kind::cpu, |
143 | memory::dims {BCAST, NO_BCAST, NO_BCAST, NO_BCAST}, |
144 | true), |
145 | std::make_tuple(engine::kind::cpu, |
146 | memory::dims {BCAST, BCAST, NO_BCAST, NO_BCAST}, true), |
147 | std::make_tuple(engine::kind::cpu, |
148 | memory::dims {BCAST, BCAST, BCAST, NO_BCAST}, true), |
149 | std::make_tuple(engine::kind::cpu, |
150 | memory::dims {BCAST, BCAST, BCAST, BCAST}, true), |
151 | // 3d cases |
152 | std::make_tuple(engine::kind::cpu, |
153 | memory::dims {NO_BCAST, NO_BCAST, BCAST}, true), |
154 | std::make_tuple(engine::kind::cpu, |
155 | memory::dims {BCAST, NO_BCAST, BCAST}, true), |
156 | std::make_tuple(engine::kind::cpu, |
157 | memory::dims {NO_BCAST, BCAST, NO_BCAST}, true), |
158 | std::make_tuple(engine::kind::cpu, |
159 | memory::dims {BCAST, NO_BCAST, NO_BCAST}, true), |
160 | std::make_tuple(engine::kind::cpu, |
161 | memory::dims {BCAST, BCAST, NO_BCAST}, true), |
162 | std::make_tuple(engine::kind::cpu, |
163 | memory::dims {BCAST, BCAST, BCAST}, true), |
164 | // 2d cases |
165 | std::make_tuple( |
166 | engine::kind::cpu, memory::dims {BCAST, BCAST}, true), |
167 | std::make_tuple(engine::kind::cpu, |
168 | memory::dims {NO_BCAST, BCAST}, true), |
169 | std::make_tuple(engine::kind::cpu, |
170 | memory::dims {BCAST, NO_BCAST}, true), |
171 | // 1d case |
172 | std::make_tuple( |
173 | engine::kind::cpu, memory::dims {BCAST}, true))); |
174 | |
175 | INSTANTIATE_TEST_SUITE_P(CPUNotOptimizedDims, binary_bcast_test_t, |
176 | ::testing::Values( |
177 | // selected unoptimized cases |
178 | std::make_tuple(engine::kind::cpu, |
179 | memory::dims {BCAST, BCAST, NO_BCAST, BCAST, NO_BCAST}, |
180 | false), |
181 | std::make_tuple(engine::kind::cpu, |
182 | memory::dims {BCAST, NO_BCAST, BCAST, NO_BCAST}, false), |
183 | std::make_tuple(engine::kind::cpu, |
184 | memory::dims {BCAST, NO_BCAST, BCAST, BCAST, NO_BCAST}, |
185 | false), |
186 | std::make_tuple(engine::kind::cpu, |
187 | memory::dims {NO_BCAST, BCAST, BCAST, BCAST, BCAST}, |
188 | false), |
189 | std::make_tuple(engine::kind::cpu, |
190 | memory::dims {BCAST, BCAST, NO_BCAST, BCAST}, false), |
191 | std::make_tuple(engine::kind::cpu, |
192 | memory::dims {NO_BCAST, BCAST, BCAST}, false))); |
193 | |
194 | } // namespace dnnl |
195 | |