1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_EXECUTIONENGINE_EXECUTIONENGINE_H
17#define GLOW_EXECUTIONENGINE_EXECUTIONENGINE_H
18
19#include "glow/Backend/Backend.h"
20#include "glow/Backend/CompiledFunction.h"
21#include "glow/Backends/DeviceManager.h"
22#include "glow/Base/Train.h"
23#include "glow/Base/Traits.h"
24#include "glow/Graph/Graph.h"
25#include "glow/Graph/PlaceholderBindings.h"
26#include "glow/Runtime/HostManager/HostManager.h"
27
28#include "llvm/ADT/ArrayRef.h"
29
30#include <memory>
31#include <unordered_map>
32
33namespace glow {
34
35/// This is the ExecutionEngine. It encapsulates the Glow Runtime. It handles
36/// compilation and execution of a network.
37class ExecutionEngine final {
38 /// Module containing the function and supporting information. This is reset
39 /// if the backend type is changed.
40 std::unique_ptr<Module> module_;
41
42 /// Raw pointer to module_ this is to support module access after the module
43 /// has been added to hostManager_.
44 Module *rawModule_;
45
46 /// Name of the backend being used for compilation and execution. Changing
47 /// this resets the ExecutionEngine.
48 std::string backendName_ = "";
49
50 /// Size of device memory in bytes, if 0 device default is used.
51 uint64_t deviceMemory_{0};
52
53 /// Whether to ignore the user-specified DeviceConfig.
54 bool ignoreUserDeviceConfig_{false};
55
56 /// The HostManager for executing the compiled functions.
57 std::unique_ptr<runtime::HostManager> hostManager_;
58
59 /// Glow functions compiled for this ExecutionEngine's backend.
60 std::set<std::string> compiledFunctions_;
61
62 /// Whether to move all Device Resident Tensors on to the host at the end of
63 /// the run.
64 bool ensureOutputsOnHost_{true};
65
66 /// Whether to override the cctx's skipModuleStrip setting and skip stripping
67 /// the module. Used for testing purposes.
68 bool skipModuleStrip_{false};
69
70 /// Whether to allow multiple functions when running. This is usually due to
71 /// running a pre-partitioned model.
72 bool allowMultiFunction_{false};
73
74 /// Single execution of the given function, \p name with the given context
75 /// \bindings.
76 void runInternal(ExecutionContext &context, llvm::StringRef name);
77
78public:
79 /// Constructor for an ExecutionEngine with \p backend and memory \p
80 /// deviceMemory in bytes. If \p ignoreUserDeviceConfig then user device
81 /// configs will be ignored. \p numDevices controls how many devices to create
82 /// for the EE.
83 ExecutionEngine(llvm::StringRef backend = "Interpreter",
84 uint64_t deviceMemory = 0,
85 bool ignoreUserDeviceConfig = false, unsigned numDevices = 1);
86
87 ~ExecutionEngine();
88
89 /// Set the code generator to \p backend. New code will be generated
90 /// using this backend. This clears all previously loaded functions and resets
91 /// the Module. \p numDevices controls how many devices to create for the EE.
92 void setBackendName(llvm::StringRef backend, size_t numDevices = 1);
93
94 /// Set the device memory to \p mem. This will reset the existing device,
95 /// clearing all existing functions and resetting the module.
96 void setDeviceMemory(uint64_t mem) {
97 deviceMemory_ = mem;
98 setBackendName(backendName_);
99 }
100
101 // Set whether or not to ensure outputs are in host memory.
102 void ensureOutputsOnHost(bool should) { ensureOutputsOnHost_ = should; }
103
104 /// Get the name of the current backend in use.
105 llvm::StringRef getBackendName() const;
106
107 /// \returns the internal graph. Note: After compilation the contents of the
108 /// module will have been altered and raw pointers to elements of the graph
109 /// may no longer be valid.
110 Module &getModule() const { return *rawModule_; }
111
112 /// Clears the ExecutionEngine and all CompiledFunctions.
113 void clear();
114
115 /// \returns the DAG for the specified \p network.
116 Expected<runtime::DAG *> getDAG(llvm::StringRef network) {
117 return hostManager_->getNetworkDAG(network);
118 }
119
120 /// Compiles all functions in the Module with the given \p cctx. This method
121 /// should be invoked before the run method and can only be called once
122 /// without resetting the backend.
123 void compile(CompilationContext &cctx);
124
125 /// A convenience function for the most common type of compile. Can only be
126 /// called once without resetting the backend.
127 void compile(CompilationMode mode);
128
129 /// Context aware single execution of a function. If more than one
130 /// function has been compiled by this ExecutionEngine then a name must be
131 /// supplied to specify which function to run.
132 void run(ExecutionContext &context);
133
134 /// Context aware single execution of a function with the given \p
135 /// name.
136 void run(ExecutionContext &context, llvm::StringRef name);
137
138 /// Context aware single execution of a function. If more than one
139 /// function has been compiled by this ExecutionEngine then a name must be
140 /// supplied to specify which function to run.
141 void run(PlaceholderBindings &bindings);
142
143 /// Context aware single execution of a function with the given \p
144 /// name.
145 void run(PlaceholderBindings &bindings, llvm::StringRef name);
146
147 /// \returns a reference to the backend with name \p backendName owned by the
148 /// Provisioner inside of \ref hostManager_.
149 Backend &getBackend(llvm::StringRef backendName) const;
150
151 /// \returns a reference to the backend with name of the current backend in
152 /// use by the EE.
153 Backend &getBackend() const;
154
155 /// \returns the single Function contained in this Module.
156 /// \pre Must be a single Function in the Module.
157 Function *getSingleFunctionFromModule() const;
158
159 /// Setter for \ref skipModuleStrip_ to \p b.
160 void setSkipModuleStrip(bool b) { skipModuleStrip_ = b; }
161};
162
163//===----------------------------------------------------------------------===//
164// Helper methods for running the execution engine.
165//===----------------------------------------------------------------------===//
166
167/// This method updates the placeholders in \p ph with the tensor content
168/// values \p inputs, in \p bindings.
169void updateInputPlaceholders(PlaceholderBindings &bindings,
170 llvm::ArrayRef<Placeholder *> ph,
171 llvm::ArrayRef<Tensor *> inputs);
172
173/// This method updates the placeholders in the module. The placeholders are
174/// found by name
175/// in \p ph with the tensor content values \p inputs.
176void updateInputPlaceholdersByName(PlaceholderBindings &bindings, Module *mod,
177 llvm::ArrayRef<llvm::StringRef> ph,
178 llvm::ArrayRef<Tensor *> inputs);
179
180/// Runs \p iterations iterations of the compiled function. The method updates a
181/// global counter and future invocations of this method continue running
182/// iterations of the batch at the next available slice.
183///
184/// The method updates the placeholder in \p ph with the tensors \p inputs. The
185/// shape of the slice has to be identical to the shape of slices in the batch.
186/// All dimensions, except for the first (batch) dimension must be identical.
187///
188/// The variable \p sampleCounter is consumed and updated by the function. This
189/// variable records the number of samples that were consumed by the network in
190/// previous iterations. The next input to be loaded is
191/// (sampleCounter % batchsize). If there is more than one compiledFunction \p
192/// name must be provided to specify the desired function.
193void runBatch(ExecutionEngine &EE, PlaceholderBindings &bindings,
194 size_t iterations, size_t &sampleCounter,
195 llvm::ArrayRef<Placeholder *> ph, llvm::ArrayRef<Tensor *> inputs,
196 llvm::StringRef name = "");
197
198/// Runs \p numMinibatchRuns iterations of the compiled function called \p name.
199/// The method updates a global counter and future invocations of this method
200/// continue running iterations of the batch at the next available slice.
201/// The provided callback function \p cb is invoked on each sample.
202void evalBatch(
203 ExecutionEngine &EE, PlaceholderBindings &bindings, size_t numMinibatchRuns,
204 size_t &sampleCounter, Placeholder *inputPH, Placeholder *outputPH,
205 Tensor &samplesInput, Tensor &labelsInput, llvm::StringRef name,
206 std::function<void(const Tensor &sampleIn, const Tensor &sampleOut,
207 const Tensor &label, size_t sampleIndex)> &&cb);
208} // namespace glow
209
210#endif // GLOW_EXECUTIONENGINE_EXECUTIONENGINE_H
211