1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
18
19#include "absl/time/time.h"
20#include "absl/types/optional.h"
21#include "tensorflow/core/common_runtime/device.h"
22#include "tensorflow/core/common_runtime/local_executor_params.h"
23#include "tensorflow/core/framework/rendezvous.h"
24#include "tensorflow/core/framework/session_state.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/graph/graph.h"
27#include "tensorflow/core/lib/core/errors.h"
28#include "tensorflow/core/lib/core/notification.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/lib/core/threadpool_interface.h"
31#include "tensorflow/core/platform/logging.h"
32#include "tensorflow/core/platform/macros.h"
33#include "tensorflow/core/util/managed_stack_trace.h"
34
35namespace tensorflow {
36
37class StepStatsCollector;
38
39// Executor runs a graph computation.
40// Example:
41// Graph* graph = ...;
42// ... construct graph ...
43// Executor* executor;
44// TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
45// Rendezvous* rendezvous = NewNaiveRendezvous();
46// TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
47// TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
48// TF_CHECK_OK(rendezvous->Recv("output", &output_tensor));
49// ... ...
50//
51// Multiple threads can call Executor::Run concurrently.
52class Executor {
53 public:
54 virtual ~Executor() {}
55
56 // RunAsync() executes the graph computation. "done" is run when the
57 // graph computation completes. If any error happens during the
58 // computation, "done" is run and the error is passed to "done".
59 //
60 // RunAsync() is given a few arguments in Args. The caller must
61 // ensure objects passed in Args (rendezvous, stats_collector, etc.)
62 // are alive at least until done is invoked. All pointers to the
63 // argument objects can be nullptr.
64 //
65 // "step_id" is a process-wide unique identifier for the step being
66 // run. Executors on different devices may receive the same step_id
67 // in the case that a step runs Ops on more than one device. The
68 // step_id is used for tracking resource usage of a given step.
69 //
70 // RunAsync() uses the given "rendezvous", if not null, as the
71 // mechanism to communicate inputs and outputs of the underlying
72 // graph computation.
73 //
74 // RunAsync() calls "stats_collector", if not null, to keep track of
75 // stats. This allows us to collect statistics and traces on demand.
76 //
77 // RunAsync() is provided a "call_frame", if the executor is used
78 // for executing a function, is used to pass arguments and return
79 // values between the caller and the callee.
80 //
81 // RunAsync() uses "cancellation_manager", if not nullptr, to
82 // register callbacks that should be called if the graph computation
83 // is canceled. Note that the callbacks merely unblock any
84 // long-running computation, and a canceled step will terminate by
85 // returning/calling the DoneCallback as usual.
86 //
87 // RunAsync() dispatches closures to "runner". Typically, "runner"
88 // is backed up by a bounded threadpool.
89 //
90 // "start_time_usecs" is a timestamp for the start of RunAsync()
91 // execution. Used for system-wide latency metrics.
92 struct Args {
93 int64_t step_id = 0;
94 RendezvousInterface* rendezvous = nullptr;
95 StepStatsCollectorInterface* stats_collector = nullptr;
96 CallFrameInterface* call_frame = nullptr;
97 CancellationManager* cancellation_manager = nullptr;
98 SessionState* session_state = nullptr;
99 // Unique session identifier. Can be empty.
100 string session_handle;
101 TensorStore* tensor_store = nullptr;
102 ScopedStepContainer* step_container = nullptr;
103 CollectiveExecutor* collective_executor = nullptr;
104 thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
105 CoordinationServiceAgent* coordination_service_agent = nullptr;
106 int64_t start_time_usecs = 0;
107 // The deadline for the kernel to complete by. Empty if unspecified.
108 absl::optional<absl::Time> deadline;
109 absl::optional<ManagedStackTrace> stack_trace = absl::nullopt;
110
111 // If true, calls Sync() on the device.
112 bool sync_on_finish = false;
113
114 typedef std::function<void()> Closure;
115 typedef std::function<void(Closure)> Runner;
116 Runner runner = nullptr;
117
118 // If true, all kernels will be treated as "inexpensive", and hence executed
119 // on the scheduling thread.
120 bool run_all_kernels_inline = false;
121 };
122 typedef std::function<void(const Status&)> DoneCallback;
123 virtual void RunAsync(const Args& args, DoneCallback done) = 0;
124
125 // Synchronous wrapper for RunAsync().
126 virtual Status Run(const Args& args) {
127 Status ret;
128 Notification n;
129 RunAsync(args, [&ret, &n](const Status& s) {
130 ret = s;
131 n.Notify();
132 });
133 n.WaitForNotification();
134 return ret;
135 }
136};
137
138// Creates an Executor that computes the given "graph".
139//
140// If successful, returns the constructed executor in "*executor". Otherwise,
141// returns an error status.
142//
143// "params" provides a set of context for the executor. We expect that
144// different context would provide different implementations.
145::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
146 const Graph& graph, Executor** executor);
147
148// A class to help run multiple executors in parallel and wait until
149// all of them are complete.
150//
151// ExecutorBarrier deletes itself after the function returned by Get()
152// is called.
153class ExecutorBarrier {
154 public:
155 typedef std::function<void(const Status&)> StatusCallback;
156
157 // Create an ExecutorBarrier for 'num' different executors.
158 //
159 // 'r' is the shared Rendezvous object that is used to communicate
160 // state. If any of the executors experiences an error, the
161 // rendezvous object will be aborted exactly once.
162 //
163 // 'done' is called after the last executor completes, and
164 // ExecutorBarrier is deleted.
165 ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
166 : rendez_(r), done_cb_(done), pending_(num) {}
167
168 ~ExecutorBarrier() {}
169
170 // Returns a closure that Executors must call when they are done
171 // computing, passing the status of their execution as an argument.
172 StatusCallback Get() {
173 return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
174 }
175
176 private:
177 Rendezvous* rendez_ = nullptr;
178 StatusCallback done_cb_ = nullptr;
179
180 mutable mutex mu_;
181 int pending_ TF_GUARDED_BY(mu_) = 0;
182 StatusGroup status_group_ TF_GUARDED_BY(mu_);
183
184 void WhenDone(const Status& s) {
185 Rendezvous* error_rendez = nullptr;
186 StatusCallback done = nullptr;
187 Status status;
188
189 {
190 mutex_lock l(mu_);
191
192 // If we are the first error encountered, trigger an abort of the
193 // Rendezvous object by this thread only.
194 if (status_group_.ok() && !s.ok()) {
195 error_rendez = rendez_;
196 error_rendez->Ref();
197 }
198
199 if (!s.ok() && !StatusGroup::IsDerived(s) &&
200 !status_group_.HasLogMessages()) {
201 status_group_.AttachLogMessages();
202 }
203
204 status_group_.Update(s);
205
206 // If this is the last call to WhenDone, call the final callback
207 // below.
208 if (--pending_ == 0) {
209 CHECK(done_cb_ != nullptr);
210 std::swap(done, done_cb_);
211 status = status_group_.as_summary_status();
212 }
213 }
214
215 if (error_rendez != nullptr) {
216 error_rendez->StartAbort(
217 errors::Aborted("Stopping remaining executors."));
218 error_rendez->Unref();
219 }
220
221 if (done != nullptr) {
222 delete this;
223 if (!status.ok()) {
224 VLOG(1) << "ExecutorBarrier finished with bad status: " << status;
225 }
226 done(status);
227 }
228 }
229
230 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
231};
232
233// A few helpers to facilitate create/delete kernels.
234
235// Creates a kernel based on "props" on device "device". The kernel can
236// access the functions in the "flib". The caller takes ownership of
237// returned "*kernel".
238Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
239 const std::shared_ptr<const NodeProperties>& props,
240 int graph_def_version, OpKernel** kernel);
241
242// Deletes "kernel" returned by CreateKernel.
243void DeleteNonCachedKernel(OpKernel* kernel);
244
245} // end namespace tensorflow
246
247#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
248