1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_TESTS_IMPORTERTESTUTILS_H
17#define GLOW_TESTS_IMPORTERTESTUTILS_H
18
19#include "glow/Base/Tensor.h"
20#include "glow/Graph/Graph.h"
21
22namespace glow {
23
24/// Populate \p result with increasing number following the NCHW format.
25/// E.g.:
26/// N=1:
27/// W=3
28/// <---->
29/// ^ +---+---+---+
30/// C=1/ | 0 | 1 | 2 | ^
31/// / +---+---+---+ | H=3
32/// v | 3 | 4 | 5 | v
33/// +---+---+---+
34/// | 6 | 7 | 8 |
35/// +---+---+---+
36void getNCHWData(Tensor *result, dim_t n, dim_t c, dim_t h, dim_t w) {
37 result->reset(ElemKind::FloatTy, {n, c, h, w});
38 auto RH = result->getHandle<>();
39 for (size_t i = 0, e = n * c * h * w; i < e; i++)
40 RH.raw(i) = i;
41}
42
43/// Populate \p result with increasing number following the NCTHW format.
44/// E.g.:
45/// N=1, T=2:
46/// T=0 T=1
47/// W=3
48/// <---->
49/// ^ +---+---+---+ +----+----+----+
50/// C=1/ | 0 | 1 | 2 | | 9 | 10 | 11 | ^
51/// / +---+---+---+ +----+----+----+ | H=3
52/// v | 3 | 4 | 5 | | 12 | 13 | 14 | v
53/// +---+---+---+ +----+----+----+
54/// | 6 | 7 | 8 | | 15 | 16 | 17 |
55/// +---+---+---+ +----+----+----+
56void getNCTHWData(Tensor *result, dim_t n, dim_t c, dim_t t, dim_t h, dim_t w) {
57 result->reset(ElemKind::FloatTy, {n, c, t, h, w});
58 auto RH = result->getHandle<>();
59 for (size_t i = 0, e = n * c * t * h * w; i < e; i++)
60 RH.raw(i) = i;
61}
62
63/// \returns the number of nodes in \p F of kind \p kind.
64unsigned countNodeKind(Function *F, Kinded::Kind kind) {
65 unsigned count = 0;
66 for (auto &n : F->getNodes()) {
67 if (n.getKind() == kind) {
68 count++;
69 }
70 }
71 return count;
72}
73
74/// Helper function to get the save node from a Variable \p var.
75/// \pre (var->getUsers().size() == 1)
76SaveNode *getSaveNodeFromDest(Storage *var) {
77 auto &varUsers = var->getUsers();
78 assert(varUsers.size() == 1);
79 auto *saveNode = llvm::dyn_cast<SaveNode>(varUsers.begin()->getUser());
80 assert(saveNode != nullptr);
81 return saveNode;
82}
83
84} // namespace glow
85
86#endif // GLOW_TESTS_IMPORTERTESTUTILS_H
87