1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "GraphScheduler.h"
17
18#include "llvm/Support/CommandLine.h"
19
20using namespace glow;
21
22namespace {
23llvm::cl::OptionCategory graphSchedulerCat("Graph Scheduler Options");
24
25llvm::cl::opt<SchedulerKind> graphScheduler(
26 llvm::cl::desc("Scheduler to use:"),
27 llvm::cl::values(clEnumValN(SchedulerKind::ChildMemSizeBased,
28 "child-mem-size-based",
29 "Use ChildMemSizeBased"),
30 clEnumValN(SchedulerKind::TopologicalSortBased,
31 "topological-sort-based",
32 "Use TopologicalSortBased")),
33 llvm::cl::init(SchedulerKind::ChildMemSizeBased),
34 llvm::cl::cat(graphSchedulerCat));
35} // namespace
36
37namespace glow {
38Scheduler *createScheduler(SchedulerKind schedulerKind, Function &G,
39 NodesPtrList &scheduled) {
40 switch (schedulerKind) {
41 case SchedulerKind::ChildMemSizeBased:
42 return new ChildMemSizeBasedScheduler(G, scheduled);
43 case SchedulerKind::TopologicalSortBased:
44 return new TopologicalSortBasedScheduler(G, scheduled);
45 }
46 llvm_unreachable("unreachable");
47}
48
49void IRFunction::scheduleGraph(NodesPtrList &Schedule) {
50 Schedule.clear();
51 auto constants = getGraph()->findConstants();
52 auto placeholders = getGraph()->findPlaceholders();
53 for (auto &N : constants) {
54 Schedule.push_back(N);
55 }
56 for (auto &N : placeholders) {
57 Schedule.push_back(N);
58 }
59 for (auto &N : getGraph()->getMetadataPlaceholders()) {
60 Schedule.push_back(N);
61 }
62 auto numVars = constants.size();
63 auto numPlaceholders =
64 placeholders.size() + getGraph()->getMetadataPlaceholders().size();
65 (void)numVars;
66 (void)numPlaceholders;
67 std::unique_ptr<Scheduler> scheduler{
68 createScheduler(graphScheduler, *getGraph(), Schedule)};
69 scheduler->schedule();
70 assert(scheduler->getSchedule().size() ==
71 getGraph()->getNodes().size() + numPlaceholders + numVars &&
72 "All graph nodes have to be scheduled");
73}
74} // namespace glow
75