1 | /******************************************************************************* |
2 | * Copyright 2018-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <vector> |
18 | |
19 | #include "dnnl_test_common.hpp" |
20 | #include "gtest/gtest.h" |
21 | |
22 | namespace dnnl { |
23 | |
24 | TEST(test_parallel, Test) { |
25 | impl::parallel(0, [&](int ithr, int nthr) { |
26 | ASSERT_LE(0, ithr); |
27 | ASSERT_LT(ithr, nthr); |
28 | ASSERT_LE(nthr, dnnl_get_max_threads()); |
29 | }); |
30 | } |
31 | |
32 | using data_t = ptrdiff_t; |
33 | |
34 | struct nd_params_t { |
35 | std::vector<ptrdiff_t> dims; |
36 | }; |
37 | using np_t = nd_params_t; |
38 | |
39 | class test_nd_t : public ::testing::TestWithParam<nd_params_t> { |
40 | protected: |
41 | void SetUp() override { |
42 | p = ::testing::TestWithParam<decltype(p)>::GetParam(); |
43 | size = 1; |
44 | for (auto &d : p.dims) |
45 | size *= d; |
46 | data.resize((size_t)size); |
47 | } |
48 | |
49 | void CheckID() { |
50 | for (ptrdiff_t i = 0; i < size; ++i) |
51 | ASSERT_EQ(data[i], i); |
52 | } |
53 | |
54 | nd_params_t p; |
55 | ptrdiff_t size; |
56 | std::vector<data_t> data; |
57 | }; |
58 | |
59 | class test_parallel_nd_t : public test_nd_t { |
60 | protected: |
61 | void emit_parallel_nd() { |
62 | switch ((int)p.dims.size()) { |
63 | case 1: |
64 | impl::parallel_nd(p.dims[0], [&](ptrdiff_t d0) { |
65 | ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]); |
66 | data[d0] = d0; |
67 | }); |
68 | break; |
69 | case 2: |
70 | impl::parallel_nd( |
71 | p.dims[0], p.dims[1], [&](ptrdiff_t d0, ptrdiff_t d1) { |
72 | ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]); |
73 | ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]); |
74 | const ptrdiff_t idx = d0 * p.dims[1] + d1; |
75 | data[idx] = idx; |
76 | }); |
77 | break; |
78 | case 3: |
79 | impl::parallel_nd(p.dims[0], p.dims[1], p.dims[2], |
80 | [&](ptrdiff_t d0, ptrdiff_t d1, ptrdiff_t d2) { |
81 | ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]); |
82 | ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]); |
83 | ASSERT_TRUE(0 <= d2 && d2 < p.dims[2]); |
84 | const ptrdiff_t idx |
85 | = (d0 * p.dims[1] + d1) * p.dims[2] + d2; |
86 | data[idx] = idx; |
87 | }); |
88 | break; |
89 | case 4: |
90 | impl::parallel_nd(p.dims[0], p.dims[1], p.dims[2], p.dims[3], |
91 | [&](ptrdiff_t d0, ptrdiff_t d1, ptrdiff_t d2, |
92 | ptrdiff_t d3) { |
93 | ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]); |
94 | ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]); |
95 | ASSERT_TRUE(0 <= d2 && d2 < p.dims[2]); |
96 | ASSERT_TRUE(0 <= d3 && d3 < p.dims[3]); |
97 | const ptrdiff_t idx |
98 | = ((d0 * p.dims[1] + d1) * p.dims[2] + d2) |
99 | * p.dims[3] |
100 | + d3; |
101 | data[idx] = idx; |
102 | }); |
103 | break; |
104 | case 5: |
105 | impl::parallel_nd(p.dims[0], p.dims[1], p.dims[2], p.dims[3], |
106 | p.dims[4], |
107 | [&](ptrdiff_t d0, ptrdiff_t d1, ptrdiff_t d2, |
108 | ptrdiff_t d3, ptrdiff_t d4) { |
109 | ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]); |
110 | ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]); |
111 | ASSERT_TRUE(0 <= d2 && d2 < p.dims[2]); |
112 | ASSERT_TRUE(0 <= d3 && d3 < p.dims[3]); |
113 | ASSERT_TRUE(0 <= d4 && d4 < p.dims[4]); |
114 | const ptrdiff_t idx |
115 | = (((d0 * p.dims[1] + d1) * p.dims[2] + d2) |
116 | * p.dims[3] |
117 | + d3) |
118 | * p.dims[4] |
119 | + d4; |
120 | data[idx] = idx; |
121 | }); |
122 | break; |
123 | case 6: |
124 | impl::parallel_nd(p.dims[0], p.dims[1], p.dims[2], p.dims[3], |
125 | p.dims[4], p.dims[5], |
126 | [&](ptrdiff_t d0, ptrdiff_t d1, ptrdiff_t d2, |
127 | ptrdiff_t d3, ptrdiff_t d4, ptrdiff_t d5) { |
128 | ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]); |
129 | ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]); |
130 | ASSERT_TRUE(0 <= d2 && d2 < p.dims[2]); |
131 | ASSERT_TRUE(0 <= d3 && d3 < p.dims[3]); |
132 | ASSERT_TRUE(0 <= d4 && d4 < p.dims[4]); |
133 | ASSERT_TRUE(0 <= d5 && d5 < p.dims[5]); |
134 | const ptrdiff_t idx |
135 | = ((((d0 * p.dims[1] + d1) * p.dims[2] + d2) |
136 | * p.dims[3] |
137 | + d3) * p.dims[4] |
138 | + d4) |
139 | * p.dims[5] |
140 | + d5; |
141 | data[idx] = idx; |
142 | }); |
143 | break; |
144 | default: ASSERT_TRUE(false); |
145 | } |
146 | } |
147 | }; |
148 | |
149 | TEST_P(test_parallel_nd_t, Test) { |
150 | emit_parallel_nd(); |
151 | CheckID(); |
152 | } |
153 | |
154 | CPU_INSTANTIATE_TEST_SUITE_P(Case, test_parallel_nd_t, |
155 | ::testing::Values(np_t {{0}}, np_t {{1}}, np_t {{100}}, np_t {{0, 0}}, |
156 | np_t {{1, 2}}, np_t {{10, 10}}, np_t {{0, 1, 0}}, |
157 | np_t {{1, 2, 1}}, np_t {{4, 4, 10}}, np_t {{0, 3, 0, 1}}, |
158 | np_t {{1, 1, 2, 1}}, np_t {{4, 4, 5, 2}}, |
159 | np_t {{3, 0, 3, 0, 1}}, np_t {{2, 1, 1, 2, 1}}, |
160 | np_t {{4, 1, 4, 5, 2}}, np_t {{4, 3, 0, 3, 0, 1}}, |
161 | np_t {{2, 1, 3, 1, 2, 1}}, np_t {{4, 1, 4, 3, 2, 2}})); |
162 | |
163 | } // namespace dnnl |
164 | |