1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "../../lib/IR/GraphScheduler.h"
18
19#include "glow/Graph/Graph.h"
20#include "glow/Graph/Node.h"
21#include "glow/Graph/Nodes.h"
22#include "glow/Graph/PlaceholderBindings.h"
23
24#include "gtest/gtest.h"
25
26using namespace glow;
27
28/// Tests a case in which the memory required to store a node's
29/// output is greater than the memory required to store its input.
30/// This node uses more memory after it executes, and should be
31/// scheduled after its siblings that free up memory after
32/// they execute.
33TEST(GraphScheduler, testMaxSizeLessThanResultSize) {
34 Module MD;
35 PlaceholderBindings bindings;
36 auto *smallTensorA =
37 MD.createPlaceholder(ElemKind::FloatTy, {1, 4, 4}, "small_1", false);
38 bindings.allocate(smallTensorA);
39 auto *smallTensorB =
40 MD.createPlaceholder(ElemKind::FloatTy, {1, 4, 4}, "small_2", false);
41 bindings.allocate(smallTensorB);
42 auto *bigTensor =
43 MD.createPlaceholder(ElemKind::FloatTy, {100, 4, 4}, "big", false);
44 bindings.allocate(bigTensor);
45 Function *F = MD.createFunction("F");
46 Node *transposeBig = F->createTranspose("transposeBig", bigTensor, {0, 2, 1});
47 Node *sliceBig =
48 F->createSlice("sliceBig", transposeBig, {0, 0, 0}, {1, 4, 4});
49 Node *concatSmall =
50 F->createConcat("concatSmall", {smallTensorA, smallTensorB}, 0);
51 F->createConcat("concat", {concatSmall, sliceBig}, 0);
52
53 // The graph created above looks like this:
54 //
55 // bigTensor smallTensorA smallTensorB
56 // {100, 4, 4} {1, 4, 4} {1, 4, 4}
57 // | \ /
58 // v v v
59 // transposeBig {0, 2, 1} concatSmall {0}
60 // {100, 4, 4} {2, 4, 4}
61 // | |
62 // v |
63 // sliceBig {0, 0, 0}, {1, 4, 4} |
64 // {1, 4, 4} |
65 // | |
66 // | |
67 // | |
68 // --------> concat {0} <-------
69 // {3, 4, 4}
70 //
71
72 {
73 // Since all of the tensors are Variables, they don't need
74 // memory for storing their outputs. Consequently, sliceBig
75 // should be scheduled before concatSmall in this example
76 // because the former frees up some memory while the latter
77 // uses up more memory after execution.
78 NodesPtrList schedule;
79 ChildMemSizeBasedScheduler scheduler(*F, schedule);
80 scheduler.schedule();
81
82 // Find the positions of sliceBig and concatSmall in
83 // the schedule.
84 auto concatSmallIt =
85 std::find(schedule.begin(), schedule.end(), concatSmall);
86 auto sliceBigIt = std::find(schedule.begin(), schedule.end(), sliceBig);
87
88 // For the reason given above, sliceBig should be scheduled
89 // before concatSmall.
90 EXPECT_LT(std::distance(schedule.begin(), sliceBigIt),
91 std::distance(schedule.begin(), concatSmallIt));
92 }
93
94 {
95 // The graph will be traversed in post order. The root
96 // node is concat node in this case. Then, concatSmall node
97 // will be visited, since it's the left operand of concat node.
98 // Consequently, sliceBig should be scheduled after concatSmall.
99
100 NodesPtrList schedule;
101 TopologicalSortBasedScheduler scheduler(*F, schedule);
102 scheduler.schedule();
103
104 // Find the positions of sliceBig and concatSmall in
105 // the schedule.
106 auto concatSmallIt =
107 std::find(schedule.begin(), schedule.end(), concatSmall);
108 auto sliceBigIt = std::find(schedule.begin(), schedule.end(), sliceBig);
109
110 // For the reason given above, sliceBig should be scheduled
111 // after concatSmall.
112 EXPECT_GT(std::distance(schedule.begin(), sliceBigIt),
113 std::distance(schedule.begin(), concatSmallIt));
114 }
115}
116
117TEST(GraphScheduler, ScheduleQuantizationProfileRightAfterNodeBeingProfiled) {
118 Module MD;
119 PlaceholderBindings bindings;
120 auto *input1 =
121 MD.createPlaceholder(ElemKind::FloatTy, {1, 4, 4}, "input1", false);
122 bindings.allocate(input1);
123 auto *input2 =
124 MD.createPlaceholder(ElemKind::FloatTy, {1, 4, 4}, "input2", false);
125 bindings.allocate(input2);
126 Function *F = MD.createFunction("F");
127 Node *add = F->createAdd("add", input1, input2);
128 Node *sub = F->createSub("sub", input1, input2);
129 Node *mul = F->createMul("mul", add, sub);
130 Node *save = F->createSave("save", mul);
131 Node *quantizationProfileAdd =
132 F->createQuantizationProfile(bindings, "qpAdd", add);
133 Node *quantizationProfileSub =
134 F->createQuantizationProfile(bindings, "qpSub", sub);
135
136 // Since all of the tensors are Variables, they don't need
137 // memory for storing their outputs. Consequently, sliceBig
138 // should be scheduled before concatSmall in this example
139 // because the former frees up some memory while the latter
140 // uses up more memory after execution.
141 NodesPtrList schedule;
142 ChildMemSizeBasedScheduler scheduler(*F, schedule);
143 scheduler.schedule();
144
145 // Find the positions of add and quantizationProfileAdd in the schedule.
146 auto addIt = std::find(schedule.begin(), schedule.end(), add);
147 auto qpAddIt =
148 std::find(schedule.begin(), schedule.end(), quantizationProfileAdd);
149 // Expect the quantization profiling node to be scheduled right after the node
150 // being profiled.
151 EXPECT_EQ(++addIt, qpAddIt);
152
153 // Find the positions of sub and quantizationProfileSub in the schedule.
154 auto subIt = std::find(schedule.begin(), schedule.end(), sub);
155 auto qpSubIt =
156 std::find(schedule.begin(), schedule.end(), quantizationProfileSub);
157 // Expect the quantization profiling node to be scheduled right after the node
158 // being profiled.
159 EXPECT_EQ(++subIt, qpSubIt);
160
161 // Expect the save node to be the last in the schedule.
162 EXPECT_EQ(save, schedule.back());
163}
164