1/*******************************************************************************
2* Copyright 2018-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <vector>
18
19#include "dnnl_test_common.hpp"
20#include "gtest/gtest.h"
21
22namespace dnnl {
23
24TEST(test_parallel, Test) {
25 impl::parallel(0, [&](int ithr, int nthr) {
26 ASSERT_LE(0, ithr);
27 ASSERT_LT(ithr, nthr);
28 ASSERT_LE(nthr, dnnl_get_max_threads());
29 });
30}
31
32using data_t = ptrdiff_t;
33
34struct nd_params_t {
35 std::vector<ptrdiff_t> dims;
36};
37using np_t = nd_params_t;
38
39class test_nd_t : public ::testing::TestWithParam<nd_params_t> {
40protected:
41 void SetUp() override {
42 p = ::testing::TestWithParam<decltype(p)>::GetParam();
43 size = 1;
44 for (auto &d : p.dims)
45 size *= d;
46 data.resize((size_t)size);
47 }
48
49 void CheckID() {
50 for (ptrdiff_t i = 0; i < size; ++i)
51 ASSERT_EQ(data[i], i);
52 }
53
54 nd_params_t p;
55 ptrdiff_t size;
56 std::vector<data_t> data;
57};
58
59class test_parallel_nd_t : public test_nd_t {
60protected:
61 void emit_parallel_nd() {
62 switch ((int)p.dims.size()) {
63 case 1:
64 impl::parallel_nd(p.dims[0], [&](ptrdiff_t d0) {
65 ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]);
66 data[d0] = d0;
67 });
68 break;
69 case 2:
70 impl::parallel_nd(
71 p.dims[0], p.dims[1], [&](ptrdiff_t d0, ptrdiff_t d1) {
72 ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]);
73 ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]);
74 const ptrdiff_t idx = d0 * p.dims[1] + d1;
75 data[idx] = idx;
76 });
77 break;
78 case 3:
79 impl::parallel_nd(p.dims[0], p.dims[1], p.dims[2],
80 [&](ptrdiff_t d0, ptrdiff_t d1, ptrdiff_t d2) {
81 ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]);
82 ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]);
83 ASSERT_TRUE(0 <= d2 && d2 < p.dims[2]);
84 const ptrdiff_t idx
85 = (d0 * p.dims[1] + d1) * p.dims[2] + d2;
86 data[idx] = idx;
87 });
88 break;
89 case 4:
90 impl::parallel_nd(p.dims[0], p.dims[1], p.dims[2], p.dims[3],
91 [&](ptrdiff_t d0, ptrdiff_t d1, ptrdiff_t d2,
92 ptrdiff_t d3) {
93 ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]);
94 ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]);
95 ASSERT_TRUE(0 <= d2 && d2 < p.dims[2]);
96 ASSERT_TRUE(0 <= d3 && d3 < p.dims[3]);
97 const ptrdiff_t idx
98 = ((d0 * p.dims[1] + d1) * p.dims[2] + d2)
99 * p.dims[3]
100 + d3;
101 data[idx] = idx;
102 });
103 break;
104 case 5:
105 impl::parallel_nd(p.dims[0], p.dims[1], p.dims[2], p.dims[3],
106 p.dims[4],
107 [&](ptrdiff_t d0, ptrdiff_t d1, ptrdiff_t d2,
108 ptrdiff_t d3, ptrdiff_t d4) {
109 ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]);
110 ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]);
111 ASSERT_TRUE(0 <= d2 && d2 < p.dims[2]);
112 ASSERT_TRUE(0 <= d3 && d3 < p.dims[3]);
113 ASSERT_TRUE(0 <= d4 && d4 < p.dims[4]);
114 const ptrdiff_t idx
115 = (((d0 * p.dims[1] + d1) * p.dims[2] + d2)
116 * p.dims[3]
117 + d3)
118 * p.dims[4]
119 + d4;
120 data[idx] = idx;
121 });
122 break;
123 case 6:
124 impl::parallel_nd(p.dims[0], p.dims[1], p.dims[2], p.dims[3],
125 p.dims[4], p.dims[5],
126 [&](ptrdiff_t d0, ptrdiff_t d1, ptrdiff_t d2,
127 ptrdiff_t d3, ptrdiff_t d4, ptrdiff_t d5) {
128 ASSERT_TRUE(0 <= d0 && d0 < p.dims[0]);
129 ASSERT_TRUE(0 <= d1 && d1 < p.dims[1]);
130 ASSERT_TRUE(0 <= d2 && d2 < p.dims[2]);
131 ASSERT_TRUE(0 <= d3 && d3 < p.dims[3]);
132 ASSERT_TRUE(0 <= d4 && d4 < p.dims[4]);
133 ASSERT_TRUE(0 <= d5 && d5 < p.dims[5]);
134 const ptrdiff_t idx
135 = ((((d0 * p.dims[1] + d1) * p.dims[2] + d2)
136 * p.dims[3]
137 + d3) * p.dims[4]
138 + d4)
139 * p.dims[5]
140 + d5;
141 data[idx] = idx;
142 });
143 break;
144 default: ASSERT_TRUE(false);
145 }
146 }
147};
148
149TEST_P(test_parallel_nd_t, Test) {
150 emit_parallel_nd();
151 CheckID();
152}
153
154CPU_INSTANTIATE_TEST_SUITE_P(Case, test_parallel_nd_t,
155 ::testing::Values(np_t {{0}}, np_t {{1}}, np_t {{100}}, np_t {{0, 0}},
156 np_t {{1, 2}}, np_t {{10, 10}}, np_t {{0, 1, 0}},
157 np_t {{1, 2, 1}}, np_t {{4, 4, 10}}, np_t {{0, 3, 0, 1}},
158 np_t {{1, 1, 2, 1}}, np_t {{4, 4, 5, 2}},
159 np_t {{3, 0, 3, 0, 1}}, np_t {{2, 1, 1, 2, 1}},
160 np_t {{4, 1, 4, 5, 2}}, np_t {{4, 3, 0, 3, 0, 1}},
161 np_t {{2, 1, 3, 1, 2, 1}}, np_t {{4, 1, 4, 3, 2, 2}}));
162
163} // namespace dnnl
164