1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // TODO(skyewm): this is necessary to make the single_threaded_cpu_device.h |
17 | // include work. Some other include must be including eigen without defining |
18 | // this. Consider defining in this in a BUILD rule. |
19 | #define EIGEN_USE_THREADS |
20 | |
21 | #include "tensorflow/core/common_runtime/graph_runner.h" |
22 | |
23 | #include "tensorflow/core/common_runtime/device.h" |
24 | #include "tensorflow/core/common_runtime/device_factory.h" |
25 | #include "tensorflow/core/common_runtime/executor.h" |
26 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
27 | #include "tensorflow/core/common_runtime/memory_types.h" |
28 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
29 | #include "tensorflow/core/common_runtime/single_threaded_cpu_device.h" |
30 | #include "tensorflow/core/framework/log_memory.h" |
31 | #include "tensorflow/core/framework/op_kernel.h" |
32 | #include "tensorflow/core/framework/tensor_util.h" |
33 | #include "tensorflow/core/framework/versions.pb.h" |
34 | #include "tensorflow/core/graph/algorithm.h" |
35 | #include "tensorflow/core/graph/graph.h" |
36 | #include "tensorflow/core/graph/node_builder.h" |
37 | #include "tensorflow/core/graph/subgraph.h" |
38 | #include "tensorflow/core/lib/core/threadpool.h" |
39 | #include "tensorflow/core/lib/strings/strcat.h" |
40 | #include "tensorflow/core/platform/env.h" |
41 | #include "tensorflow/core/public/session_options.h" |
42 | |
43 | namespace tensorflow { |
44 | |
45 | namespace { |
46 | |
47 | // A simple rendezvous class. |
48 | // Assumes a single sender and a single receiver, no duplicate sends, and no |
49 | // sends of dead tensors. |
50 | class SimpleRendezvous : public RendezvousInterface { |
51 | public: |
52 | explicit SimpleRendezvous() {} |
53 | |
54 | Status Send(const ParsedKey& parsed, const Args& send_args, const Tensor& val, |
55 | const bool is_dead) override { |
56 | if (is_dead) { |
57 | return errors::Internal("Send of a dead tensor" ); |
58 | } |
59 | |
60 | mutex_lock l(mu_); |
61 | string edge_name(parsed.edge_name); |
62 | if (table_.count(edge_name) > 0) { |
63 | return errors::Internal("Send of an already sent tensor" ); |
64 | } |
65 | table_[edge_name] = val; |
66 | return OkStatus(); |
67 | } |
68 | |
69 | void RecvAsync(const ParsedKey& parsed, const Args& recv_args, |
70 | DoneCallback done) override { |
71 | Tensor tensor; |
72 | Status status = OkStatus(); |
73 | { |
74 | string key(parsed.edge_name); |
75 | mutex_lock l(mu_); |
76 | if (table_.count(key) <= 0) { |
77 | status = errors::Internal("Did not find key " , key); |
78 | } else { |
79 | tensor = table_[key]; |
80 | } |
81 | } |
82 | done(status, Args{}, recv_args, tensor, false); |
83 | } |
84 | |
85 | void StartAbort(const Status& status) override {} |
86 | |
87 | private: |
88 | typedef std::unordered_map<string, Tensor> Table; |
89 | |
90 | mutex mu_; |
91 | Table table_ TF_GUARDED_BY(mu_); |
92 | }; |
93 | |
94 | } // namespace |
95 | |
96 | GraphRunner::GraphRunner(Env* env) |
97 | : device_deleter_(NewSingleThreadedCpuDevice(env)), |
98 | device_(device_deleter_.get()) {} |
99 | GraphRunner::GraphRunner(Device* device) : device_(device) {} |
100 | |
101 | GraphRunner::~GraphRunner() {} |
102 | |
103 | Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, |
104 | const NamedTensorList& inputs, |
105 | const std::vector<string>& output_names, |
106 | std::vector<Tensor>* outputs) { |
107 | if (device_ == nullptr) { |
108 | return errors::NotFound("Cannot find a device for GraphRunner." ); |
109 | } |
110 | |
111 | if (function_library && function_library->device() && |
112 | function_library->device()->device_type() != device_->device_type()) { |
113 | // Mismatch between function_library's device_type and device_'s |
114 | // device_type. |
115 | // TODO(matthewmurray) Can we create a new FunctionLibraryRuntime that is |
116 | // identical to function_library except that it uses the given 'device_'? |
117 | VLOG(1) << "Cannot run on: " << device_->device_type() |
118 | << " with a function library for a " |
119 | << function_library->device()->device_type() << " device." ; |
120 | function_library = nullptr; |
121 | } |
122 | |
123 | // TODO(vrv): Instead of copying the entire graph, consider modifying |
124 | // the existing graph, and then removing those removed edges. |
125 | // prior to returning. |
126 | std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry())); |
127 | CopyGraph(*graph, graph_to_run.get()); |
128 | |
129 | SimpleRendezvous rendez; |
130 | |
131 | // Extract the input names and keys, and feed in the inputs. |
132 | std::vector<string> input_names; |
133 | for (const auto& in : inputs) { |
134 | const string& tensor_name = in.first; |
135 | input_names.emplace_back(tensor_name); |
136 | string full_key = Rendezvous::CreateKey("/device:CPU:0" , 1, "/device:CPU:1" , |
137 | tensor_name, FrameAndIter(0, 0)); |
138 | Rendezvous::ParsedKey parsed; |
139 | TF_RETURN_IF_ERROR(Rendezvous::ParseKey(full_key, &parsed)); |
140 | TF_RETURN_IF_ERROR(rendez.Send(parsed, Rendezvous::Args(), in.second, |
141 | false /* is_dead */)); |
142 | } |
143 | |
144 | // Call RewriteGraphForExecution |
145 | subgraph::RewriteGraphMetadata metadata; |
146 | TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( |
147 | graph_to_run.get(), input_names, output_names, {} /* target nodes */, |
148 | device_->attributes(), false /* use_function_convention */, &metadata)); |
149 | |
150 | // Create the local executor and the Rendezvous for fetching back the |
151 | // constants. |
152 | |
153 | // Run operators on the local thread. We should not need concurrency here; we |
154 | // should not be running expensive operators. |
155 | auto runner = [](Executor::Args::Closure c) { c(); }; |
156 | |
157 | LocalExecutorParams params; |
158 | // The ownership of the output tensors are bound to this device's lifetime. |
159 | params.device = device_; |
160 | params.function_library = function_library; |
161 | const int producer = graph_to_run->versions().producer(); |
162 | params.create_kernel = [this, function_library, producer]( |
163 | const std::shared_ptr<const NodeProperties>& props, |
164 | OpKernel** kernel) { |
165 | return CreateNonCachedKernel(device_, function_library, props, producer, |
166 | kernel); |
167 | }; |
168 | params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; |
169 | |
170 | Executor* executor; |
171 | TF_RETURN_IF_ERROR(NewLocalExecutor(params, *graph_to_run, &executor)); |
172 | std::unique_ptr<Executor> executor_unref(executor); |
173 | |
174 | Executor::Args args; |
175 | // NOTE: we could take a step id as an argument, but currently |
176 | // there is no need since we never trace the running of a graph |
177 | // called via this method. |
178 | args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID; |
179 | args.runner = runner; |
180 | args.rendezvous = &rendez; |
181 | // NOTE: Use of graph runner is limited to single-device executions |
182 | // so a CollectiveExecutor should never be required. |
183 | args.collective_executor = nullptr; |
184 | |
185 | CancellationManager cancellation_manager; |
186 | args.cancellation_manager = &cancellation_manager; |
187 | |
188 | // Run the graph. |
189 | TF_RETURN_IF_ERROR(executor->Run(args)); |
190 | |
191 | outputs->resize(output_names.size()); |
192 | for (size_t i = 0; i < output_names.size(); ++i) { |
193 | const string& output_key = |
194 | Rendezvous::CreateKey("/device:CPU:0" , 1, "/device:CPU:1" , |
195 | output_names[i], FrameAndIter(0, 0)); |
196 | Rendezvous::ParsedKey parsed; |
197 | TF_RETURN_IF_ERROR(Rendezvous::ParseKey(output_key, &parsed)); |
198 | bool is_dead; |
199 | Tensor output_tensor; |
200 | TF_RETURN_IF_ERROR( |
201 | rendez.Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead)); |
202 | // Does a deep copy so that ownership of the tensor isn't tied to the |
203 | // allocator of the cpu device we created above. The allocator could be |
204 | // deleted along with the device. |
205 | (*outputs)[i] = tensor::DeepCopy(output_tensor); |
206 | } |
207 | |
208 | return OkStatus(); |
209 | } |
210 | |
211 | } // namespace tensorflow |
212 | |