1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_TESTS_IMPORTERTESTUTILS_H |
17 | #define GLOW_TESTS_IMPORTERTESTUTILS_H |
18 | |
19 | #include "glow/Base/Tensor.h" |
20 | #include "glow/Graph/Graph.h" |
21 | |
22 | namespace glow { |
23 | |
24 | /// Populate \p result with increasing number following the NCHW format. |
25 | /// E.g.: |
26 | /// N=1: |
27 | /// W=3 |
28 | /// <----> |
29 | /// ^ +---+---+---+ |
30 | /// C=1/ | 0 | 1 | 2 | ^ |
31 | /// / +---+---+---+ | H=3 |
32 | /// v | 3 | 4 | 5 | v |
33 | /// +---+---+---+ |
34 | /// | 6 | 7 | 8 | |
35 | /// +---+---+---+ |
36 | void getNCHWData(Tensor *result, dim_t n, dim_t c, dim_t h, dim_t w) { |
37 | result->reset(ElemKind::FloatTy, {n, c, h, w}); |
38 | auto RH = result->getHandle<>(); |
39 | for (size_t i = 0, e = n * c * h * w; i < e; i++) |
40 | RH.raw(i) = i; |
41 | } |
42 | |
43 | /// Populate \p result with increasing number following the NCTHW format. |
44 | /// E.g.: |
45 | /// N=1, T=2: |
46 | /// T=0 T=1 |
47 | /// W=3 |
48 | /// <----> |
49 | /// ^ +---+---+---+ +----+----+----+ |
50 | /// C=1/ | 0 | 1 | 2 | | 9 | 10 | 11 | ^ |
51 | /// / +---+---+---+ +----+----+----+ | H=3 |
52 | /// v | 3 | 4 | 5 | | 12 | 13 | 14 | v |
53 | /// +---+---+---+ +----+----+----+ |
54 | /// | 6 | 7 | 8 | | 15 | 16 | 17 | |
55 | /// +---+---+---+ +----+----+----+ |
56 | void getNCTHWData(Tensor *result, dim_t n, dim_t c, dim_t t, dim_t h, dim_t w) { |
57 | result->reset(ElemKind::FloatTy, {n, c, t, h, w}); |
58 | auto RH = result->getHandle<>(); |
59 | for (size_t i = 0, e = n * c * t * h * w; i < e; i++) |
60 | RH.raw(i) = i; |
61 | } |
62 | |
63 | /// \returns the number of nodes in \p F of kind \p kind. |
64 | unsigned countNodeKind(Function *F, Kinded::Kind kind) { |
65 | unsigned count = 0; |
66 | for (auto &n : F->getNodes()) { |
67 | if (n.getKind() == kind) { |
68 | count++; |
69 | } |
70 | } |
71 | return count; |
72 | } |
73 | |
74 | /// Helper function to get the save node from a Variable \p var. |
75 | /// \pre (var->getUsers().size() == 1) |
76 | SaveNode *getSaveNodeFromDest(Storage *var) { |
77 | auto &varUsers = var->getUsers(); |
78 | assert(varUsers.size() == 1); |
79 | auto *saveNode = llvm::dyn_cast<SaveNode>(varUsers.begin()->getUser()); |
80 | assert(saveNode != nullptr); |
81 | return saveNode; |
82 | } |
83 | |
84 | } // namespace glow |
85 | |
86 | #endif // GLOW_TESTS_IMPORTERTESTUTILS_H |
87 | |