1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_BACKENDS_EXECUTIONCONTEXT_H |
17 | #define GLOW_BACKENDS_EXECUTIONCONTEXT_H |
18 | |
19 | #include "glow/ExecutionContext/TraceEvents.h" |
20 | #include "glow/Graph/PlaceholderBindings.h" |
21 | |
22 | #include "llvm/ADT/STLExtras.h" |
23 | |
24 | namespace glow { |
25 | namespace runtime { |
26 | class DeviceManager; |
27 | } |
28 | |
29 | /// Sub-classed per backend, this holds Device specific per-function information |
30 | /// if that is necessary on that particular backend. |
31 | class DeviceBindings { |
32 | const std::string backend_; |
33 | |
34 | public: |
35 | DeviceBindings(llvm::StringRef backend) : backend_{backend} {} |
36 | virtual ~DeviceBindings() {} |
37 | |
38 | virtual std::unique_ptr<DeviceBindings> clone() { |
39 | return glow::make_unique<DeviceBindings>(backend_); |
40 | } |
41 | }; |
42 | |
43 | /// The runtime context for a single execution (Inferance or Training) in the |
44 | /// the Glow Execution Engine or HostManager. This class includes the mapping |
45 | /// between Input/Output Placeholders and the materialized Tensors used for this |
46 | /// run, the set of Device specific details required to execute the function, |
47 | /// and stores TraceEvents that were generated as a result of the run. |
48 | class ExecutionContext { |
49 | std::unique_ptr<PlaceholderBindings> placeholderBindings_; |
50 | std::unique_ptr<DeviceBindings> deviceBindings_; |
51 | |
52 | /// Pointer to DeviceManager this context is bound to, use for P2P/DRT |
53 | /// enablement. Unused otherwise. |
54 | runtime::DeviceManager *boundDeviceManager_{nullptr}; |
55 | |
56 | /// Trace Events recorded during this run. |
57 | std::unique_ptr<TraceContext> traceContext_; |
58 | |
59 | /// Positional bindings for external inputs/outputs |
60 | std::vector<std::pair<Placeholder *, Tensor>> externalIOBindings_; |
61 | |
62 | public: |
63 | ExecutionContext() |
64 | : placeholderBindings_(glow::make_unique<PlaceholderBindings>()) {} |
65 | |
66 | ExecutionContext(std::unique_ptr<PlaceholderBindings> bindings) |
67 | : placeholderBindings_(std::move(bindings)) {} |
68 | |
69 | ExecutionContext(std::unique_ptr<PlaceholderBindings> bindings, |
70 | std::unique_ptr<DeviceBindings> devices) |
71 | : placeholderBindings_(std::move(bindings)), |
72 | deviceBindings_(std::move(devices)) {} |
73 | |
74 | /// \returns positional bindings for external inputs |
75 | std::vector<std::pair<Placeholder *, Tensor>> &getExternalIOBindings() { |
76 | return externalIOBindings_; |
77 | } |
78 | |
79 | /// \returns positional bindings for external inputs |
80 | const std::vector<std::pair<Placeholder *, Tensor>> & |
81 | getExternalIOBindings() const { |
82 | return externalIOBindings_; |
83 | } |
84 | |
85 | /// \returns a non-owning pointer to the PlaceholderBindings. |
86 | PlaceholderBindings *getPlaceholderBindings() { |
87 | return placeholderBindings_.get(); |
88 | } |
89 | |
90 | /// \returns a const non-owning pointer to the PlaceholderBindings. |
91 | const PlaceholderBindings *getPlaceholderBindings() const { |
92 | return placeholderBindings_.get(); |
93 | } |
94 | |
95 | /// \returns an owning pointer to the PlaceholderBindings. |
96 | std::unique_ptr<PlaceholderBindings> movePlaceholderBindings() { |
97 | return std::move(placeholderBindings_); |
98 | } |
99 | |
100 | /// \returns a non-owning pointer to the DeviceBindings. |
101 | DeviceBindings *getDeviceBindings() { return deviceBindings_.get(); } |
102 | |
103 | /// \returns a const non-owning pointer to the DeviceBindings. |
104 | const DeviceBindings *getDeviceBindings() const { |
105 | return deviceBindings_.get(); |
106 | } |
107 | |
108 | /// \returns a non-owning pointer the the deviceManager this context is bound |
109 | /// to. |
110 | runtime::DeviceManager *getBoundDeviceManager() { |
111 | return boundDeviceManager_; |
112 | } |
113 | |
114 | /// Sets which device this context is bound to. NOTE this should not be |
115 | /// changed once set. |
116 | void setBoundDeviceManager(runtime::DeviceManager *device) { |
117 | DCHECK(boundDeviceManager_ == nullptr); |
118 | boundDeviceManager_ = device; |
119 | } |
120 | |
121 | /// Sets the DeviceBindings and \returns the existing value. |
122 | std::unique_ptr<DeviceBindings> |
123 | setDeviceBindings(std::unique_ptr<DeviceBindings> bindings) { |
124 | std::swap(deviceBindings_, bindings); |
125 | return bindings; |
126 | } |
127 | |
128 | /// \returns a non-owning pointer to the TraceContext. |
129 | TraceContext *getTraceContext() { return traceContext_.get(); } |
130 | |
131 | /// \returns a const non-owning pointer to the TraceContext. |
132 | const TraceContext *getTraceContext() const { return traceContext_.get(); } |
133 | |
134 | /// Sets the TraceContext and \returns the existing value. |
135 | std::unique_ptr<TraceContext> |
136 | setTraceContext(std::unique_ptr<TraceContext> traceContext) { |
137 | std::swap(traceContext_, traceContext); |
138 | return traceContext; |
139 | } |
140 | |
141 | /// Clones this ExecutionContext, but does not clone underlying Tensors. |
142 | ExecutionContext clone() { |
143 | if (deviceBindings_) { |
144 | return ExecutionContext( |
145 | glow::make_unique<PlaceholderBindings>(placeholderBindings_->clone()), |
146 | deviceBindings_->clone()); |
147 | } else { |
148 | return ExecutionContext(glow::make_unique<PlaceholderBindings>( |
149 | placeholderBindings_->clone())); |
150 | } |
151 | } |
152 | |
153 | /// A helper function to create a scoped TraceEvent builder. |
154 | /// If there is no TraceContext, this will still create an object, but it will |
155 | /// do nothing. |
156 | ScopedTraceBlock scopedEvent(llvm::StringRef name, TraceLevel level) { |
157 | return ScopedTraceBlock(getTraceContext(), level, name); |
158 | } |
159 | |
160 | /// A helper function to log a TraceEvent at the current time, if there is a |
161 | /// TraceContext available. |
162 | void logTraceEvent(llvm::StringRef name, TraceLevel level, |
163 | char type = TraceEvent::InstantType, |
164 | std::map<std::string, std::string> args = {}) { |
165 | TraceContext *traceContext = getTraceContext(); |
166 | if (traceContext) { |
167 | traceContext->logTraceEvent(name, level, type, std::move(args)); |
168 | } |
169 | } |
170 | }; |
171 | |
172 | } // namespace glow |
173 | |
174 | #endif // GLOW_BACKENDS_EXECUTIONCONTEXT_H |
175 | |