1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/single_threaded_executor.h" |
17 | |
18 | #include <utility> |
19 | |
20 | #include "tensorflow/core/common_runtime/entry.h" |
21 | #include "tensorflow/core/common_runtime/executor.h" |
22 | #include "tensorflow/core/common_runtime/executor_factory.h" |
23 | #include "tensorflow/core/common_runtime/renamed_device.h" |
24 | #include "tensorflow/core/graph/algorithm.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | #include "tensorflow/core/lib/gtl/cleanup.h" |
28 | #include "tensorflow/core/platform/errors.h" |
29 | #include "tensorflow/core/platform/macros.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | Status ValidateOpIsSafeForSyncExecution( |
34 | const Node& n, bool allow_control_flow_sync_execution) { |
35 | for (DataType dt : n.output_types()) { |
36 | if (IsRefType(dt)) { |
37 | return errors::Unimplemented( |
38 | "Single-threaded executor does not support reference-typed " |
39 | "edges. But saw type " , |
40 | DataTypeString(dt), " in outputs of node " , n.name()); |
41 | } |
42 | } |
43 | // Executing Switch nodes requires propagating deadness which is |
44 | // not currently supported in the SingleThreadedExecutor. |
45 | if (n.IsSwitch()) { |
46 | return errors::FailedPrecondition( |
47 | "Single-threaded executor does not support switch op, but saw node " , |
48 | n.name(), |
49 | ". Perhaps your graph contains old-style control flow primitives? " |
50 | "Try using tf.compat.v1.enable_control_flow_v2()." ); |
51 | } |
52 | if (n.IsControlFlow() && !allow_control_flow_sync_execution) { |
53 | return errors::FailedPrecondition( |
54 | "Single-threaded executor does not support low level control flow, " |
55 | " but saw control flow node " , |
56 | n.name(), |
57 | ". Perhaps your graph contains old-style control flow primitives? " |
58 | "Try using tf.compat.v1.enable_control_flow_v2()." ); |
59 | } |
60 | return OkStatus(); |
61 | } |
62 | |
63 | namespace { |
64 | |
65 | typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; |
66 | typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; |
67 | |
68 | static const string& kSingleThreadedExecutor = |
69 | *new string("SINGLE_THREADED_EXECUTOR" ); |
70 | |
71 | class SingleThreadedExecutorImpl : public Executor { |
72 | public: |
73 | explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params) |
74 | : params_(params) {} |
75 | |
76 | ~SingleThreadedExecutorImpl() override { |
77 | for (const KernelState& kernel_state : kernels_) { |
78 | params_.delete_kernel(kernel_state.kernel); |
79 | } |
80 | for (const ConstTensorKernelState& kernel_state : const_tensor_kernels_) { |
81 | params_.delete_kernel(kernel_state.kernel); |
82 | } |
83 | } |
84 | |
85 | Status Initialize(const Graph& graph) { |
86 | // Topologicially sort `graph` to get a sequence of OpKernels. |
87 | std::vector<Node*> ordered_nodes; |
88 | ordered_nodes.reserve(graph.num_nodes()); |
89 | GetReversePostOrder(graph, &ordered_nodes); |
90 | int ordered_nodes_size = ordered_nodes.size(); |
91 | if (ordered_nodes_size != graph.num_nodes()) { |
92 | return errors::InvalidArgument("Graph had " , graph.num_nodes(), |
93 | " but reverse post-order had " , |
94 | ordered_nodes.size()); |
95 | } |
96 | |
97 | // We reserve two less nodes because we do not need to create kernels for |
98 | // the _SOURCE and _SINK nodes. |
99 | kernels_.reserve(ordered_nodes.size() - 2); |
100 | std::vector<Node*> nodes_with_kernels; |
101 | std::vector<Node*> nodes_with_const_tensor_kernels; |
102 | nodes_with_kernels.reserve(ordered_nodes.size() - 2); |
103 | |
104 | std::map<size_t, Node*> arg_index_to_node_map; |
105 | absl::flat_hash_map<Node*, size_t> node_to_index_map; |
106 | |
107 | // Create the kernel and input-related structures for each node in `graph`. |
108 | for (Node* n : ordered_nodes) { |
109 | if (n->IsSource() || n->IsSink()) { |
110 | continue; |
111 | } |
112 | TF_RETURN_IF_ERROR(ValidateOpIsSafeForSyncExecution( |
113 | *n, params_.allow_control_flow_sync_execution)); |
114 | if (n->IsArg()) { |
115 | int32_t arg_index; |
116 | TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index" , &arg_index)); |
117 | if (arg_index < 0) { |
118 | return errors::InvalidArgument("Invalid argument index " , arg_index, |
119 | " in node " , n->name()); |
120 | } |
121 | arg_index_to_node_map[arg_index] = n; |
122 | // We do not create a kernel for Arg nodes, and instead inline the |
123 | // argument handling directly in the executor code. |
124 | continue; |
125 | } |
126 | |
127 | OpKernel* kernel; |
128 | TF_RETURN_IF_ERROR(params_.create_kernel(n->properties(), &kernel)); |
129 | |
130 | const Tensor* const_tensor; |
131 | if (n->num_outputs() == 1 && (const_tensor = kernel->const_tensor())) { |
132 | // Nodes that produce a single constant tensor are handled specially: |
133 | // we evaluate the tensor once, and propagate it to its consumers as |
134 | // a `const Tensor*`, to avoid refcount manipulation. |
135 | const size_t kernel_index = const_tensor_kernels_.size(); |
136 | const_tensor_kernels_.push_back({}); |
137 | nodes_with_const_tensor_kernels.push_back(n); |
138 | ConstTensorKernelState& kernel_state = |
139 | const_tensor_kernels_[kernel_index]; |
140 | kernel_state.kernel = kernel; |
141 | kernel_state.const_tensor = *const_tensor; |
142 | } else { |
143 | const size_t kernel_index = kernels_.size(); |
144 | kernels_.push_back({}); |
145 | nodes_with_kernels.push_back(n); |
146 | KernelState& kernel_state = kernels_[kernel_index]; |
147 | kernel_state.kernel = kernel; |
148 | kernel_state.num_inputs = n->num_inputs(); |
149 | kernel_state.num_outputs = n->num_outputs(); |
150 | node_to_index_map[n] = kernel_index; |
151 | if (kernel_index == 0) { |
152 | kernel_state.input_start_index = 0; |
153 | } else { |
154 | const KernelState& previous_kernel_state = kernels_[kernel_index - 1]; |
155 | kernel_state.input_start_index = |
156 | previous_kernel_state.input_start_index + |
157 | previous_kernel_state.num_inputs; |
158 | } |
159 | } |
160 | } |
161 | |
162 | // Build the mapping from each Arg node output to the input slot for the |
163 | // corresponding destination node. |
164 | if (!arg_index_to_node_map.empty()) { |
165 | const size_t num_args = arg_index_to_node_map.rbegin()->first + 1; |
166 | arg_output_locations_.resize(num_args); |
167 | for (const auto& arg_index_node_pair : arg_index_to_node_map) { |
168 | const size_t arg_index = arg_index_node_pair.first; |
169 | const Node* arg_node = arg_index_node_pair.second; |
170 | arg_output_locations_[arg_index].reserve(arg_node->out_edges().size()); |
171 | for (const Edge* e : arg_node->out_edges()) { |
172 | if (e->src_output() == Graph::kControlSlot) { |
173 | continue; |
174 | } else if (e->src_output() != 0) { |
175 | return errors::Internal("Invalid output index " , e->src_output(), |
176 | " from argument node " , arg_index); |
177 | } |
178 | arg_output_locations_[arg_index].push_back( |
179 | kernels_[node_to_index_map[e->dst()]].input_start_index + |
180 | e->dst_input()); |
181 | } |
182 | } |
183 | } |
184 | |
185 | // Build the mapping from each const tensor kernel to the input slot for the |
186 | // corresponding destination node. |
187 | for (size_t i = 0; i < const_tensor_kernels_.size(); ++i) { |
188 | Node* n = nodes_with_const_tensor_kernels[i]; |
189 | ConstTensorKernelState& kernel_state = const_tensor_kernels_[i]; |
190 | for (const Edge* e : n->out_edges()) { |
191 | if (e->src_output() == Graph::kControlSlot) { |
192 | continue; |
193 | } else if (e->src_output() != 0) { |
194 | return errors::Internal("Invalid output index " , e->src_output(), |
195 | " from node " , n->DebugString()); |
196 | } |
197 | kernel_state.output_locations.push_back( |
198 | kernels_[node_to_index_map[e->dst()]].input_start_index + |
199 | e->dst_input()); |
200 | } |
201 | |
202 | bool on_host = |
203 | kernel_state.kernel->output_memory_types()[0] == HOST_MEMORY; |
204 | kernel_state.output_alloc_attr.set_on_host(on_host); |
205 | } |
206 | |
207 | // Build the mapping from each node output to the input slot for the |
208 | // corresponding destination node. |
209 | for (size_t i = 0; i < kernels_.size(); ++i) { |
210 | Node* n = nodes_with_kernels[i]; |
211 | KernelState& kernel_state = kernels_[i]; |
212 | kernel_state.output_locations.resize(kernel_state.num_outputs); |
213 | for (const Edge* e : n->out_edges()) { |
214 | if (!e->IsControlEdge()) { |
215 | kernel_state.output_locations[e->src_output()].push_back( |
216 | kernels_[node_to_index_map[e->dst()]].input_start_index + |
217 | e->dst_input()); |
218 | } |
219 | } |
220 | |
221 | // Compute allocator attributes for each node output, and corresponding |
222 | // node input. |
223 | kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs); |
224 | AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data(); |
225 | |
226 | OpKernel* op_kernel = kernel_state.kernel; |
227 | for (int out = 0; out < n->num_outputs(); out++) { |
228 | DCHECK_LT(out, op_kernel->output_memory_types().size()); |
229 | bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; |
230 | if (on_host) { |
231 | AllocatorAttributes h; |
232 | h.set_on_host(on_host); |
233 | attrs[out].Merge(h); |
234 | } |
235 | } |
236 | } |
237 | |
238 | if (!kernels_.empty()) { |
239 | const KernelState& last_kernel_state = kernels_.back(); |
240 | total_num_inputs_ = |
241 | last_kernel_state.input_start_index + last_kernel_state.num_inputs; |
242 | input_alloc_attrs_.resize(total_num_inputs_); |
243 | for (size_t i = 0; i < kernels_.size(); ++i) { |
244 | for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) { |
245 | for (size_t output_location : kernels_[i].output_locations[j]) { |
246 | input_alloc_attrs_[output_location] = |
247 | kernels_[i].output_alloc_attrs[j]; |
248 | } |
249 | } |
250 | } |
251 | } else { |
252 | total_num_inputs_ = 0; |
253 | } |
254 | return OkStatus(); |
255 | } |
256 | |
257 | Status Run(const Args& args) override { |
258 | // The inputs to each kernel are stored contiguously in `inputs`. |
259 | // |
260 | // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to |
261 | // determine the range of elements in this vector that correspond to |
262 | // the inputs of `kernels_[i]`. |
263 | // |
264 | // This vector has the following layout: |
265 | // |
266 | // * Kernel 0, input 0. |
267 | // * Kernel 0, input 1. |
268 | // * ... |
269 | // * Kernel 0, input `kernels_[0].num_inputs - 1`. |
270 | // * Kernel 1, input 0. |
271 | // * ... |
272 | // * Kernel 1, input `kernels_[1].num_inputs - 1`. |
273 | // * ... |
274 | // * Kernel `kernels_.size() - 1`, input 0. |
275 | // * ... |
276 | // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`. |
277 | // |
278 | // Note that kernels with zero inputs do not correspond to any elements in |
279 | // this vector. |
280 | // |
281 | // We use `ManualConstructor<Tensor>` to avoid the overhead of |
282 | // default-constructing an invalid `Tensor` for each slot at the beginning |
283 | // of execution: |
284 | // * Elements are initialized when the outputs of a kernel execution are |
285 | // propagated to the inputs of kernels that depend on them. |
286 | // * The elements corresponding to the inputs for kernel `i` are destroyed |
287 | // after kernel `i` executes. |
288 | // * In an error case (see below), we use the connectivity information in |
289 | // `KernelState::output_locations` to determine which locations have been |
290 | // initialized, and manually destroy them. |
291 | std::vector<Entry> inputs(total_num_inputs_); |
292 | |
293 | // TODO(mrry): Can we avoid copying into these vectors? Consider modifying |
294 | // OpKernelContext to take the TensorValueVec as a pointer into `inputs`. |
295 | TensorValueVec node_inputs; |
296 | AllocatorAttributeVec input_alloc_attrs; |
297 | |
298 | // Override intra op thread pool if requested. |
299 | Device* device = params_.device; |
300 | std::unique_ptr<Device> user_device; |
301 | if (args.user_intra_op_threadpool != nullptr) { |
302 | user_device = RenamedDevice::NewRenamedDevice( |
303 | device->name(), device, /*owns_underlying=*/false, |
304 | /*isolate_session_state=*/false, args.user_intra_op_threadpool); |
305 | device = user_device.get(); |
306 | } |
307 | |
308 | // Prepare the parameters that will be the same for all kernels. |
309 | OpKernelContext::Params params; |
310 | params.step_id = args.step_id; |
311 | params.device = device; |
312 | params.log_memory = false; // TODO(mrry): Too severe? |
313 | params.rendezvous = args.rendezvous; |
314 | params.session_state = args.session_state; |
315 | params.session_metadata = params_.session_metadata; |
316 | params.tensor_store = args.tensor_store; |
317 | params.cancellation_manager = args.cancellation_manager; |
318 | params.call_frame = args.call_frame; |
319 | params.function_library = params_.function_library; |
320 | params.resource_manager = device->resource_manager(); |
321 | params.step_container = args.step_container; |
322 | params.collective_executor = args.collective_executor; |
323 | params.stack_trace = args.stack_trace; |
324 | params.slice_reader_cache = nullptr; // TODO(mrry): Too severe? |
325 | |
326 | Args::Runner runner_copy = args.runner; |
327 | params.runner = &runner_copy; |
328 | params.run_all_kernels_inline = args.run_all_kernels_inline; |
329 | params.stats_collector = args.stats_collector; |
330 | params.executor_type = &kSingleThreadedExecutor; |
331 | |
332 | // NOTE(mrry): We are assuming that the graph is loopless and condless. |
333 | params.frame_iter = FrameAndIter(0, 0); |
334 | params.is_input_dead = false; |
335 | |
336 | device->TryGetDeviceContext(¶ms.op_device_context).IgnoreError(); |
337 | auto context_cleanup = gtl::MakeCleanup([¶ms] { |
338 | if (params.op_device_context != nullptr) { |
339 | params.op_device_context->Unref(); |
340 | } |
341 | }); |
342 | |
343 | // TODO(mrry): Consider implementing forwarding. |
344 | params.forward_from_array = nullptr; |
345 | |
346 | const size_t received_args = |
347 | args.call_frame ? args.call_frame->num_args() : 0; |
348 | if (TF_PREDICT_FALSE(arg_output_locations_.size() > received_args)) { |
349 | return errors::InvalidArgument("Expected " , arg_output_locations_.size(), |
350 | " arguments, but only received " , |
351 | received_args, "." ); |
352 | } |
353 | |
354 | // ArgOp is a relatively expensive OpKernel due to the Tensor |
355 | // allocations that it performs. Therefore we specialize its implementation |
356 | // and forward arguments directly to the inputs of kernels that consume |
357 | // them. |
358 | for (size_t i = 0; i < arg_output_locations_.size(); ++i) { |
359 | const size_t num_destinations = arg_output_locations_[i].size(); |
360 | if (num_destinations > 0) { |
361 | if (args.call_frame->CanConsumeArg(i)) { |
362 | // The first destination input can consume the argument. |
363 | Entry& first_input = inputs[arg_output_locations_[i][0]]; |
364 | first_input.state = Entry::State::HAS_VALUE; |
365 | first_input.val.Init(); |
366 | args.call_frame->ConsumeArg(i, first_input.val.get()); |
367 | // All subsequent destination inputs get a shallow copy of the first |
368 | // destination input. |
369 | // |
370 | // NOTE: If we had metadata about which kernels might attempt to |
371 | // forward their input, we could arrange the kernel order so that |
372 | // one of those kernels was executed last. |
373 | for (size_t j = 1; j < num_destinations; ++j) { |
374 | Entry& input = inputs[arg_output_locations_[i][j]]; |
375 | input.state = Entry::State::HAS_VALUE; |
376 | input.val.Init(*first_input.val); |
377 | } |
378 | } else { |
379 | const Tensor* arg; |
380 | TF_RETURN_IF_ERROR(args.call_frame->GetArg(i, &arg)); |
381 | for (size_t j = 0; j < num_destinations; ++j) { |
382 | Entry& input = inputs[arg_output_locations_[i][j]]; |
383 | // NOTE: We must make at least one shallow copy of the argument |
384 | // tensor that remains live until all consuming kernels have |
385 | // executed, to keep the reference count > 1, and inhibit buffer |
386 | // forwarding. For simplicity, we shallow copy into the input entry |
387 | // for each consuming kernel. |
388 | input.state = Entry::State::HAS_VALUE; |
389 | input.val.Init(*arg); |
390 | } |
391 | } |
392 | } |
393 | } |
394 | |
395 | // Kernels that return a constant value (e.g. ConstOp) are relatively |
396 | // expensive due to the Tensor allocations that they perform. Therefore we |
397 | // specialize their implementation and forward their constant value directly |
398 | // to the inputs of kernels that consume them. |
399 | for (const ConstTensorKernelState& kernel_state : const_tensor_kernels_) { |
400 | for (size_t i = 0; i < kernel_state.output_locations.size(); ++i) { |
401 | Entry& input = inputs[kernel_state.output_locations[i]]; |
402 | input.state = Entry::State::HAS_CONST_TENSOR; |
403 | input.const_tensor = &kernel_state.const_tensor; |
404 | } |
405 | } |
406 | |
407 | // Execute the kernels one-at-a-time in topological order. |
408 | for (size_t i = 0; i < kernels_.size(); ++i) { |
409 | const KernelState& kernel_state = kernels_[i]; |
410 | |
411 | // Prepare the per-kernel parameters. |
412 | const size_t input_start_index = kernel_state.input_start_index; |
413 | const size_t num_inputs = kernel_state.num_inputs; |
414 | const size_t num_outputs = kernel_state.num_outputs; |
415 | |
416 | node_inputs.clear(); |
417 | node_inputs.resize(num_inputs); |
418 | input_alloc_attrs.clear(); |
419 | input_alloc_attrs.resize(num_inputs); |
420 | for (size_t j = 0; j < num_inputs; ++j) { |
421 | Entry& input = inputs[input_start_index + j]; |
422 | switch (input.state) { |
423 | case Entry::State::HAS_CONST_TENSOR: |
424 | // NOTE(mrry): This `const_cast` is necessary because `TensorValue` |
425 | // stores a non-const `Tensor*`, and relies on the `OpKernelContext` |
426 | // accessors making dynamic checks that prevent using an immutable |
427 | // tensor as a mutable tensor. |
428 | node_inputs[j].tensor = const_cast<Tensor*>(input.const_tensor); |
429 | break; |
430 | case Entry::State::HAS_VALUE: |
431 | node_inputs[j].tensor = input.val.get(); |
432 | break; |
433 | default: |
434 | DCHECK(false) << "Input did not have a valid value." ; |
435 | } |
436 | input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j]; |
437 | } |
438 | params.inputs = node_inputs; |
439 | params.input_alloc_attrs = input_alloc_attrs; |
440 | params.op_kernel = kernel_state.kernel; |
441 | params.output_attr_array = kernel_state.output_alloc_attrs.data(); |
442 | OpKernelContext ctx(¶ms, num_outputs); |
443 | |
444 | // Actually execute the kernel. |
445 | device->Compute(kernel_state.kernel, &ctx); |
446 | TF_RETURN_IF_ERROR(ctx.status()); |
447 | |
448 | // Free the inputs to the current kernel. |
449 | for (size_t j = 0; j < num_inputs; ++j) { |
450 | inputs[input_start_index + j].ClearVal(); |
451 | } |
452 | |
453 | // Forward the outputs of the kernel to the inputs of subsequent kernels. |
454 | for (size_t j = 0; j < num_outputs; ++j) { |
455 | TensorValue val = ctx.release_output(j); |
456 | const size_t num_destinations = kernel_state.output_locations[j].size(); |
457 | if (num_destinations > 0) { |
458 | // TODO(mrry): Consider flattening the `output_locations` vector |
459 | // to improve the cache-friendliness of this loop. |
460 | for (size_t k = 0; k < num_destinations - 1; ++k) { |
461 | // TODO(mrry): Validate that the types match the expected values or |
462 | // ensure that the necessary validation has already happened. |
463 | Entry& input = inputs[kernel_state.output_locations[j][k]]; |
464 | input.state = Entry::State::HAS_VALUE; |
465 | if (val.tensor != nullptr) { |
466 | input.val.Init(*val.tensor); |
467 | } else { |
468 | input.val.Init(Tensor(kernel_state.kernel->output_type(j))); |
469 | } |
470 | } |
471 | // Move `arg` to the last consumer to avoid the cost of copying it. |
472 | Entry& input = |
473 | inputs[kernel_state.output_locations[j][num_destinations - 1]]; |
474 | input.state = Entry::State::HAS_VALUE; |
475 | if (val.tensor != nullptr) { |
476 | input.val.Init(std::move(*val.tensor)); |
477 | } else { |
478 | input.val.Init(Tensor(kernel_state.kernel->output_type(j))); |
479 | } |
480 | } |
481 | delete val.tensor; |
482 | } |
483 | } |
484 | return OkStatus(); |
485 | } |
486 | |
487 | // Execute all operations in the calling thread when asynchronous execution |
488 | // is requested. Callers may expect to perform expensive work in the calling |
489 | // thread even when the execution itself is single-threaded. |
490 | // |
491 | // This also avoid stack-overflow issues with functional control flow. |
492 | void RunAsync(const Args& args, DoneCallback done) override { |
493 | args.runner([this, args, done]() { done(Run(args)); }); |
494 | } |
495 | |
496 | private: |
497 | const LocalExecutorParams params_; |
498 | |
499 | // All following members are read-only after Initialize(). |
500 | |
501 | // The sum of the number of inputs for each node in the graph. This determines |
502 | // the length of the flat `inputs` vector. See comment at the beginning of |
503 | // `RunAsync()` for details. |
504 | size_t total_num_inputs_; |
505 | |
506 | // Represents cached graph structure state for each kernel. |
507 | struct KernelState { |
508 | // The kernel object. Not owned. |
509 | // |
510 | // This pointer is managed by `params_.create_kernel()` and |
511 | // `params_.delete_kernel()`. |
512 | OpKernel* kernel; |
513 | |
514 | // These fields determine the range of elements in `inputs` that corresponds |
515 | // to the inputs of `kernel`. |
516 | size_t input_start_index; |
517 | size_t num_inputs; |
518 | |
519 | size_t num_outputs; |
520 | |
521 | // For the `j`th output of `kernel`, `output_locations[j]` contains the |
522 | // locations in the flat `inputs` vector to which that output must be |
523 | // copied. See comment at the beginning of `Run()` for details. |
524 | std::vector<std::vector<size_t>> |
525 | output_locations; // Length = `num_outputs`. |
526 | |
527 | // Memory space information for each output of `kernel`. |
528 | std::vector<AllocatorAttributes> |
529 | output_alloc_attrs; // Length = `num_outputs`. |
530 | }; |
531 | std::vector<KernelState> kernels_; |
532 | |
533 | // For the `i`th argument, `arg_output_locations_[i]` contains the locations |
534 | // in the flat `inputs` vector to which that argument must be copied. |
535 | std::vector<std::vector<size_t>> |
536 | arg_output_locations_; // Length = `num_args`. |
537 | |
538 | // Represents cached graph structure state for each kernel that produces |
539 | // a single constant-valued tensor. |
540 | struct ConstTensorKernelState { |
541 | // The kernel object. Not owned. |
542 | // |
543 | // This pointer is managed by `params_.create_kernel()` and |
544 | // `params_.delete_kernel()`. |
545 | OpKernel* kernel; |
546 | |
547 | // The cached value of `kernel->const_tensor()`. |
548 | // |
549 | // NOTE: We keep a `Tensor` rather than a `const Tensor*` here in order to |
550 | // keep the reference count on the underlying buffer above 1. Otherwise, a |
551 | // kernel could interpret the input as a forwardable tensor, and mutate the |
552 | // underlying constant tensor. |
553 | Tensor const_tensor; |
554 | |
555 | // For the single output of `kernel`, `output_locations` contains the |
556 | // locations in the flat `inputs` vector to which that output must be |
557 | // copied. See comment at the beginning of `Run()` for details. |
558 | std::vector<size_t> output_locations; // Length = `num_outputs`. |
559 | |
560 | // Memory space information for the single output of `kernel`. |
561 | AllocatorAttributes output_alloc_attr; |
562 | }; |
563 | std::vector<ConstTensorKernelState> const_tensor_kernels_; |
564 | |
565 | // Memory space information for each input. This information is stored in the |
566 | // same order as the flat `inputs` vector. See comment at the beginning of |
567 | // `RunAsync()` for details. |
568 | std::vector<AllocatorAttributes> |
569 | input_alloc_attrs_; // Length = `total_num_inputs_`. |
570 | }; |
571 | |
572 | class SingleThreadedExecutorRegistrar { |
573 | public: |
574 | SingleThreadedExecutorRegistrar() { |
575 | ExecutorFactory::Register(kSingleThreadedExecutor, new Factory()); |
576 | } |
577 | |
578 | private: |
579 | class Factory : public ExecutorFactory { |
580 | Status NewExecutor(const LocalExecutorParams& params, const Graph& graph, |
581 | std::unique_ptr<Executor>* out_executor) override { |
582 | Executor* ret; |
583 | TF_RETURN_IF_ERROR(NewSingleThreadedExecutor(params, graph, &ret)); |
584 | out_executor->reset(ret); |
585 | return OkStatus(); |
586 | } |
587 | }; |
588 | }; |
589 | static SingleThreadedExecutorRegistrar registrar; |
590 | |
591 | } // namespace |
592 | |
593 | Status NewSingleThreadedExecutor(const LocalExecutorParams& params, |
594 | const Graph& graph, Executor** executor) { |
595 | auto impl = std::make_unique<SingleThreadedExecutorImpl>(params); |
596 | TF_RETURN_IF_ERROR(impl->Initialize(graph)); |
597 | *executor = impl.release(); |
598 | return OkStatus(); |
599 | } |
600 | |
601 | } // namespace tensorflow |
602 | |