1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// TODO(skyewm): this is necessary to make the single_threaded_cpu_device.h
17// include work. Some other include must be including eigen without defining
18// this. Consider defining in this in a BUILD rule.
19#define EIGEN_USE_THREADS
20
21#include "tensorflow/core/common_runtime/graph_runner.h"
22
23#include "tensorflow/core/common_runtime/device.h"
24#include "tensorflow/core/common_runtime/device_factory.h"
25#include "tensorflow/core/common_runtime/executor.h"
26#include "tensorflow/core/common_runtime/graph_constructor.h"
27#include "tensorflow/core/common_runtime/memory_types.h"
28#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
29#include "tensorflow/core/common_runtime/single_threaded_cpu_device.h"
30#include "tensorflow/core/framework/log_memory.h"
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/framework/tensor_util.h"
33#include "tensorflow/core/framework/versions.pb.h"
34#include "tensorflow/core/graph/algorithm.h"
35#include "tensorflow/core/graph/graph.h"
36#include "tensorflow/core/graph/node_builder.h"
37#include "tensorflow/core/graph/subgraph.h"
38#include "tensorflow/core/lib/core/threadpool.h"
39#include "tensorflow/core/lib/strings/strcat.h"
40#include "tensorflow/core/platform/env.h"
41#include "tensorflow/core/public/session_options.h"
42
43namespace tensorflow {
44
45namespace {
46
47// A simple rendezvous class.
48// Assumes a single sender and a single receiver, no duplicate sends, and no
49// sends of dead tensors.
50class SimpleRendezvous : public RendezvousInterface {
51 public:
52 explicit SimpleRendezvous() {}
53
54 Status Send(const ParsedKey& parsed, const Args& send_args, const Tensor& val,
55 const bool is_dead) override {
56 if (is_dead) {
57 return errors::Internal("Send of a dead tensor");
58 }
59
60 mutex_lock l(mu_);
61 string edge_name(parsed.edge_name);
62 if (table_.count(edge_name) > 0) {
63 return errors::Internal("Send of an already sent tensor");
64 }
65 table_[edge_name] = val;
66 return OkStatus();
67 }
68
69 void RecvAsync(const ParsedKey& parsed, const Args& recv_args,
70 DoneCallback done) override {
71 Tensor tensor;
72 Status status = OkStatus();
73 {
74 string key(parsed.edge_name);
75 mutex_lock l(mu_);
76 if (table_.count(key) <= 0) {
77 status = errors::Internal("Did not find key ", key);
78 } else {
79 tensor = table_[key];
80 }
81 }
82 done(status, Args{}, recv_args, tensor, false);
83 }
84
85 void StartAbort(const Status& status) override {}
86
87 private:
88 typedef std::unordered_map<string, Tensor> Table;
89
90 mutex mu_;
91 Table table_ TF_GUARDED_BY(mu_);
92};
93
94} // namespace
95
96GraphRunner::GraphRunner(Env* env)
97 : device_deleter_(NewSingleThreadedCpuDevice(env)),
98 device_(device_deleter_.get()) {}
99GraphRunner::GraphRunner(Device* device) : device_(device) {}
100
101GraphRunner::~GraphRunner() {}
102
103Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
104 const NamedTensorList& inputs,
105 const std::vector<string>& output_names,
106 std::vector<Tensor>* outputs) {
107 if (device_ == nullptr) {
108 return errors::NotFound("Cannot find a device for GraphRunner.");
109 }
110
111 if (function_library && function_library->device() &&
112 function_library->device()->device_type() != device_->device_type()) {
113 // Mismatch between function_library's device_type and device_'s
114 // device_type.
115 // TODO(matthewmurray) Can we create a new FunctionLibraryRuntime that is
116 // identical to function_library except that it uses the given 'device_'?
117 VLOG(1) << "Cannot run on: " << device_->device_type()
118 << " with a function library for a "
119 << function_library->device()->device_type() << " device.";
120 function_library = nullptr;
121 }
122
123 // TODO(vrv): Instead of copying the entire graph, consider modifying
124 // the existing graph, and then removing those removed edges.
125 // prior to returning.
126 std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry()));
127 CopyGraph(*graph, graph_to_run.get());
128
129 SimpleRendezvous rendez;
130
131 // Extract the input names and keys, and feed in the inputs.
132 std::vector<string> input_names;
133 for (const auto& in : inputs) {
134 const string& tensor_name = in.first;
135 input_names.emplace_back(tensor_name);
136 string full_key = Rendezvous::CreateKey("/device:CPU:0", 1, "/device:CPU:1",
137 tensor_name, FrameAndIter(0, 0));
138 Rendezvous::ParsedKey parsed;
139 TF_RETURN_IF_ERROR(Rendezvous::ParseKey(full_key, &parsed));
140 TF_RETURN_IF_ERROR(rendez.Send(parsed, Rendezvous::Args(), in.second,
141 false /* is_dead */));
142 }
143
144 // Call RewriteGraphForExecution
145 subgraph::RewriteGraphMetadata metadata;
146 TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
147 graph_to_run.get(), input_names, output_names, {} /* target nodes */,
148 device_->attributes(), false /* use_function_convention */, &metadata));
149
150 // Create the local executor and the Rendezvous for fetching back the
151 // constants.
152
153 // Run operators on the local thread. We should not need concurrency here; we
154 // should not be running expensive operators.
155 auto runner = [](Executor::Args::Closure c) { c(); };
156
157 LocalExecutorParams params;
158 // The ownership of the output tensors are bound to this device's lifetime.
159 params.device = device_;
160 params.function_library = function_library;
161 const int producer = graph_to_run->versions().producer();
162 params.create_kernel = [this, function_library, producer](
163 const std::shared_ptr<const NodeProperties>& props,
164 OpKernel** kernel) {
165 return CreateNonCachedKernel(device_, function_library, props, producer,
166 kernel);
167 };
168 params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
169
170 Executor* executor;
171 TF_RETURN_IF_ERROR(NewLocalExecutor(params, *graph_to_run, &executor));
172 std::unique_ptr<Executor> executor_unref(executor);
173
174 Executor::Args args;
175 // NOTE: we could take a step id as an argument, but currently
176 // there is no need since we never trace the running of a graph
177 // called via this method.
178 args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID;
179 args.runner = runner;
180 args.rendezvous = &rendez;
181 // NOTE: Use of graph runner is limited to single-device executions
182 // so a CollectiveExecutor should never be required.
183 args.collective_executor = nullptr;
184
185 CancellationManager cancellation_manager;
186 args.cancellation_manager = &cancellation_manager;
187
188 // Run the graph.
189 TF_RETURN_IF_ERROR(executor->Run(args));
190
191 outputs->resize(output_names.size());
192 for (size_t i = 0; i < output_names.size(); ++i) {
193 const string& output_key =
194 Rendezvous::CreateKey("/device:CPU:0", 1, "/device:CPU:1",
195 output_names[i], FrameAndIter(0, 0));
196 Rendezvous::ParsedKey parsed;
197 TF_RETURN_IF_ERROR(Rendezvous::ParseKey(output_key, &parsed));
198 bool is_dead;
199 Tensor output_tensor;
200 TF_RETURN_IF_ERROR(
201 rendez.Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead));
202 // Does a deep copy so that ownership of the tensor isn't tied to the
203 // allocator of the cpu device we created above. The allocator could be
204 // deleted along with the device.
205 (*outputs)[i] = tensor::DeepCopy(output_tensor);
206 }
207
208 return OkStatus();
209}
210
211} // namespace tensorflow
212