1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_BACKENDS_EXECUTIONCONTEXT_H
17#define GLOW_BACKENDS_EXECUTIONCONTEXT_H
18
19#include "glow/ExecutionContext/TraceEvents.h"
20#include "glow/Graph/PlaceholderBindings.h"
21
22#include "llvm/ADT/STLExtras.h"
23
24namespace glow {
25namespace runtime {
26class DeviceManager;
27}
28
29/// Sub-classed per backend, this holds Device specific per-function information
30/// if that is necessary on that particular backend.
31class DeviceBindings {
32 const std::string backend_;
33
34public:
35 DeviceBindings(llvm::StringRef backend) : backend_{backend} {}
36 virtual ~DeviceBindings() {}
37
38 virtual std::unique_ptr<DeviceBindings> clone() {
39 return glow::make_unique<DeviceBindings>(backend_);
40 }
41};
42
43/// The runtime context for a single execution (Inferance or Training) in the
44/// the Glow Execution Engine or HostManager. This class includes the mapping
45/// between Input/Output Placeholders and the materialized Tensors used for this
46/// run, the set of Device specific details required to execute the function,
47/// and stores TraceEvents that were generated as a result of the run.
48class ExecutionContext {
49 std::unique_ptr<PlaceholderBindings> placeholderBindings_;
50 std::unique_ptr<DeviceBindings> deviceBindings_;
51
52 /// Pointer to DeviceManager this context is bound to, use for P2P/DRT
53 /// enablement. Unused otherwise.
54 runtime::DeviceManager *boundDeviceManager_{nullptr};
55
56 /// Trace Events recorded during this run.
57 std::unique_ptr<TraceContext> traceContext_;
58
59 /// Positional bindings for external inputs/outputs
60 std::vector<std::pair<Placeholder *, Tensor>> externalIOBindings_;
61
62public:
63 ExecutionContext()
64 : placeholderBindings_(glow::make_unique<PlaceholderBindings>()) {}
65
66 ExecutionContext(std::unique_ptr<PlaceholderBindings> bindings)
67 : placeholderBindings_(std::move(bindings)) {}
68
69 ExecutionContext(std::unique_ptr<PlaceholderBindings> bindings,
70 std::unique_ptr<DeviceBindings> devices)
71 : placeholderBindings_(std::move(bindings)),
72 deviceBindings_(std::move(devices)) {}
73
74 /// \returns positional bindings for external inputs
75 std::vector<std::pair<Placeholder *, Tensor>> &getExternalIOBindings() {
76 return externalIOBindings_;
77 }
78
79 /// \returns positional bindings for external inputs
80 const std::vector<std::pair<Placeholder *, Tensor>> &
81 getExternalIOBindings() const {
82 return externalIOBindings_;
83 }
84
85 /// \returns a non-owning pointer to the PlaceholderBindings.
86 PlaceholderBindings *getPlaceholderBindings() {
87 return placeholderBindings_.get();
88 }
89
90 /// \returns a const non-owning pointer to the PlaceholderBindings.
91 const PlaceholderBindings *getPlaceholderBindings() const {
92 return placeholderBindings_.get();
93 }
94
95 /// \returns an owning pointer to the PlaceholderBindings.
96 std::unique_ptr<PlaceholderBindings> movePlaceholderBindings() {
97 return std::move(placeholderBindings_);
98 }
99
100 /// \returns a non-owning pointer to the DeviceBindings.
101 DeviceBindings *getDeviceBindings() { return deviceBindings_.get(); }
102
103 /// \returns a const non-owning pointer to the DeviceBindings.
104 const DeviceBindings *getDeviceBindings() const {
105 return deviceBindings_.get();
106 }
107
108 /// \returns a non-owning pointer the the deviceManager this context is bound
109 /// to.
110 runtime::DeviceManager *getBoundDeviceManager() {
111 return boundDeviceManager_;
112 }
113
114 /// Sets which device this context is bound to. NOTE this should not be
115 /// changed once set.
116 void setBoundDeviceManager(runtime::DeviceManager *device) {
117 DCHECK(boundDeviceManager_ == nullptr);
118 boundDeviceManager_ = device;
119 }
120
121 /// Sets the DeviceBindings and \returns the existing value.
122 std::unique_ptr<DeviceBindings>
123 setDeviceBindings(std::unique_ptr<DeviceBindings> bindings) {
124 std::swap(deviceBindings_, bindings);
125 return bindings;
126 }
127
128 /// \returns a non-owning pointer to the TraceContext.
129 TraceContext *getTraceContext() { return traceContext_.get(); }
130
131 /// \returns a const non-owning pointer to the TraceContext.
132 const TraceContext *getTraceContext() const { return traceContext_.get(); }
133
134 /// Sets the TraceContext and \returns the existing value.
135 std::unique_ptr<TraceContext>
136 setTraceContext(std::unique_ptr<TraceContext> traceContext) {
137 std::swap(traceContext_, traceContext);
138 return traceContext;
139 }
140
141 /// Clones this ExecutionContext, but does not clone underlying Tensors.
142 ExecutionContext clone() {
143 if (deviceBindings_) {
144 return ExecutionContext(
145 glow::make_unique<PlaceholderBindings>(placeholderBindings_->clone()),
146 deviceBindings_->clone());
147 } else {
148 return ExecutionContext(glow::make_unique<PlaceholderBindings>(
149 placeholderBindings_->clone()));
150 }
151 }
152
153 /// A helper function to create a scoped TraceEvent builder.
154 /// If there is no TraceContext, this will still create an object, but it will
155 /// do nothing.
156 ScopedTraceBlock scopedEvent(llvm::StringRef name, TraceLevel level) {
157 return ScopedTraceBlock(getTraceContext(), level, name);
158 }
159
160 /// A helper function to log a TraceEvent at the current time, if there is a
161 /// TraceContext available.
162 void logTraceEvent(llvm::StringRef name, TraceLevel level,
163 char type = TraceEvent::InstantType,
164 std::map<std::string, std::string> args = {}) {
165 TraceContext *traceContext = getTraceContext();
166 if (traceContext) {
167 traceContext->logTraceEvent(name, level, type, std::move(args));
168 }
169 }
170};
171
172} // namespace glow
173
174#endif // GLOW_BACKENDS_EXECUTIONCONTEXT_H
175