1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/grappler/utils.h" |
17 | |
18 | #include <iterator> |
19 | #include <memory> |
20 | #include <queue> |
21 | #include <vector> |
22 | |
23 | #include "absl/container/flat_hash_set.h" |
24 | #include "absl/strings/match.h" |
25 | #include "absl/strings/str_cat.h" |
26 | #include "tensorflow/core/framework/attr_value.pb.h" |
27 | #include "tensorflow/core/framework/function.h" |
28 | #include "tensorflow/core/framework/node_def_util.h" |
29 | #include "tensorflow/core/framework/op.h" |
30 | #include "tensorflow/core/framework/op_def.pb.h" |
31 | #include "tensorflow/core/framework/op_kernel.h" |
32 | #include "tensorflow/core/framework/types.h" |
33 | #include "tensorflow/core/lib/core/stringpiece.h" |
34 | #include "tensorflow/core/lib/strings/numbers.h" |
35 | #include "tensorflow/core/lib/strings/scanner.h" |
36 | #include "tensorflow/core/lib/strings/strcat.h" |
37 | #include "tensorflow/core/platform/notification.h" |
38 | #include "tensorflow/core/util/device_name_utils.h" |
39 | |
40 | namespace tensorflow { |
41 | namespace grappler { |
42 | namespace { |
43 | template <typename T> |
44 | bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) { |
45 | using RealType = typename Eigen::NumTraits<T>::Real; |
46 | if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) || |
47 | value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) { |
48 | return false; |
49 | } |
50 | tensor->flat<T>()(0) = static_cast<T>(value); |
51 | return true; |
52 | } |
53 | |
54 | template <typename T> |
55 | bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) { |
56 | using RealType = typename Eigen::NumTraits<T>::Real; |
57 | if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) || |
58 | value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) { |
59 | return false; |
60 | } |
61 | tensor->flat<T>()(0) = static_cast<T>(value); |
62 | return true; |
63 | } |
64 | |
65 | // Is 'node' an operator that consumes only the shape of its input, not the |
66 | // data itself? |
67 | // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency. |
68 | // TODO(ezhulenev): what about Identity passing tensor to Shape consumer? |
69 | bool IsShapeConsumer(const NodeDef& node) { |
70 | const string& op = node.op(); |
71 | return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size" ; |
72 | } |
73 | |
74 | } // namespace |
75 | |
76 | string TensorIdToString(const TensorId& tensor_id) { |
77 | return tensor_id.index() == 0 ? string(tensor_id.node()) |
78 | : tensor_id.ToString(); |
79 | } |
80 | |
81 | string SafeTensorIdToString(const SafeTensorId& tensor_id) { |
82 | return tensor_id.index() == 0 ? tensor_id.node() : tensor_id.ToString(); |
83 | } |
84 | |
85 | bool IsSameInput(const string& name1, const string& name2) { |
86 | if (name1 == name2) return true; |
87 | TensorId tensor1 = ParseTensorName(name1); |
88 | TensorId tensor2 = ParseTensorName(name2); |
89 | return tensor1 == tensor2; |
90 | } |
91 | |
92 | bool IsControlInput(absl::string_view name) { |
93 | return !name.empty() && name[0] == '^'; |
94 | } |
95 | |
96 | bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; } |
97 | |
98 | string AddPrefixToNodeName(const string& name, const string& prefix, |
99 | const string& delimiter) { |
100 | if (!name.empty()) { |
101 | if (name[0] == '^') { |
102 | return absl::StrCat("^" , prefix, delimiter, name.substr(1)); |
103 | } |
104 | } |
105 | return absl::StrCat(prefix, delimiter, name); |
106 | } |
107 | |
108 | string AddPrefixToNodeName(const string& name, const string& prefix) { |
109 | return AddPrefixToNodeName(name, prefix, "/" ); |
110 | } |
111 | |
112 | bool ExecuteWithTimeout(std::function<void()> fn, const int64_t timeout_in_ms, |
113 | thread::ThreadPool* const thread_pool) { |
114 | if (timeout_in_ms <= 0) { |
115 | fn(); |
116 | return true; |
117 | } |
118 | auto done = std::make_shared<Notification>(); |
119 | thread_pool->Schedule([done, fn]() { |
120 | fn(); |
121 | done->Notify(); |
122 | }); |
123 | const bool notified = |
124 | WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000); |
125 | return notified; |
126 | } |
127 | |
128 | string AsControlDependency(const NodeDef& node) { |
129 | return absl::StrCat("^" , node.name()); |
130 | } |
131 | |
132 | string AsControlDependency(const string& node_name) { |
133 | CHECK(!node_name.empty()); |
134 | return (!node_name.empty() && node_name[0] == '^') |
135 | ? node_name |
136 | : absl::StrCat("^" , node_name); |
137 | } |
138 | |
139 | bool NodeIsOnCpu(const NodeDef* node) { |
140 | string task, device; |
141 | return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && |
142 | absl::StartsWith(device, DEVICE_CPU); |
143 | } |
144 | |
145 | bool NodeIsOnGpu(const NodeDef* node) { |
146 | string task, device; |
147 | return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && |
148 | absl::StartsWith(device, DEVICE_GPU); |
149 | } |
150 | |
151 | int NumOutputs(const NodeDef& node, GraphDef* graph) { |
152 | int num_outputs = 0; |
153 | const OpDef* op_def = nullptr; |
154 | auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); |
155 | if (status.ok()) { |
156 | for (const auto& output : op_def->output_arg()) { |
157 | if (!output.type_list_attr().empty()) { |
158 | num_outputs += |
159 | node.attr().at(output.type_list_attr()).list().type_size(); |
160 | } else if (!output.number_attr().empty()) { |
161 | num_outputs += node.attr().at(output.number_attr()).i(); |
162 | } else { |
163 | num_outputs++; |
164 | } |
165 | } |
166 | } else { |
167 | FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library()); |
168 | auto status = fdef.LookUpOpDef(node.op(), &op_def); |
169 | if (status.ok()) { |
170 | num_outputs = op_def->output_arg_size(); |
171 | } |
172 | } |
173 | return num_outputs; |
174 | } |
175 | |
176 | bool HasControlInputs(const NodeDef& node) { |
177 | const int num_inputs = node.input_size(); |
178 | if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) { |
179 | return true; |
180 | } |
181 | return false; |
182 | } |
183 | |
184 | bool HasRegularInputs(const NodeDef& node) { |
185 | const int num_inputs = node.input_size(); |
186 | if (num_inputs > 0 && !IsControlInput(node.input(0))) { |
187 | return true; |
188 | } |
189 | return false; |
190 | } |
191 | |
192 | int NumNonControlInputs(const NodeDef& node) { |
193 | int num_inputs = 0; |
194 | for (; num_inputs < node.input_size(); ++num_inputs) { |
195 | const string& input = node.input(num_inputs); |
196 | if (IsControlInput(input)) { |
197 | return num_inputs; |
198 | } |
199 | } |
200 | return num_inputs; |
201 | } |
202 | |
203 | int NumControlInputs(const NodeDef& node) { |
204 | int num_inputs = 0; |
205 | for (; num_inputs < node.input_size(); ++num_inputs) { |
206 | const string& input = node.input(node.input_size() - num_inputs - 1); |
207 | if (!IsControlInput(input)) { |
208 | return num_inputs; |
209 | } |
210 | } |
211 | return num_inputs; |
212 | } |
213 | |
214 | bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) { |
215 | for (const NodeDef* output : node_map.GetOutputs(node.name())) { |
216 | for (const string& node_as_input : output->input()) { |
217 | if (IsControlInput(node_as_input)) break; |
218 | |
219 | TensorId tensor = ParseTensorName(node_as_input); |
220 | if (tensor.node() == node.name()) { |
221 | return true; |
222 | } |
223 | } |
224 | } |
225 | return false; |
226 | } |
227 | |
228 | bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) { |
229 | for (const NodeDef* output : node_map.GetOutputs(node.name())) { |
230 | for (int idx = output->input_size() - 1; idx >= 0; --idx) { |
231 | const string& node_as_input = output->input(idx); |
232 | if (!IsControlInput(node_as_input)) break; |
233 | |
234 | TensorId tensor = ParseTensorName(node_as_input); |
235 | if (tensor.node() == node.name()) { |
236 | return true; |
237 | } |
238 | } |
239 | } |
240 | return false; |
241 | } |
242 | |
243 | int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) { |
244 | int num_outputs = 0; |
245 | for (const NodeDef* output : node_map.GetOutputs(node.name())) { |
246 | for (int idx = output->input_size() - 1; idx >= 0; --idx) { |
247 | const string& node_as_input = output->input(idx); |
248 | if (!IsControlInput(node_as_input)) break; |
249 | |
250 | TensorId tensor = ParseTensorName(node_as_input); |
251 | if (tensor.node() == node.name()) { |
252 | ++num_outputs; |
253 | } |
254 | } |
255 | } |
256 | return num_outputs; |
257 | } |
258 | |
259 | int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) { |
260 | int num_outputs = 0; |
261 | for (const NodeDef* output : node_map.GetOutputs(node.name())) { |
262 | for (const string& node_as_input : output->input()) { |
263 | if (IsControlInput(node_as_input)) { |
264 | break; |
265 | } |
266 | if (node_as_input == node.name()) { |
267 | ++num_outputs; |
268 | } else { |
269 | const TensorId tensor = ParseTensorName(node_as_input); |
270 | if (tensor.node() == node.name()) { |
271 | ++num_outputs; |
272 | } |
273 | } |
274 | } |
275 | } |
276 | return num_outputs; |
277 | } |
278 | |
279 | int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) { |
280 | int num_data_outputs = 0; |
281 | for (const NodeDef* output : node_map.GetOutputs(node.name())) { |
282 | if (IsShapeConsumer(*output)) continue; |
283 | |
284 | for (int i = 0; i < output->input_size(); ++i) { |
285 | const string& input = output->input(i); |
286 | if (!IsControlInput(input) && NodeName(input) == node.name()) { |
287 | ++num_data_outputs; |
288 | break; |
289 | } |
290 | } |
291 | } |
292 | return num_data_outputs; |
293 | } |
294 | |
295 | // Returns the data type in attribute `attr_name` of `node`. If that attribute |
296 | // doesn't exist, returns DT_INVALID. |
297 | DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) { |
298 | if (!node.attr().count(type_attr)) { |
299 | return DT_INVALID; |
300 | } |
301 | const auto& attr = node.attr().at(type_attr); |
302 | if (attr.value_case() != AttrValue::kType) { |
303 | return DT_INVALID; |
304 | } |
305 | return attr.type(); |
306 | } |
307 | |
308 | NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, |
309 | bool follow_control_input, |
310 | const std::function<bool(const NodeDef&)>& pred_fn) { |
311 | const NodeDef* current = &source; |
312 | const NodeDef* next = current; |
313 | while (next == &source || (next != nullptr && pred_fn(*next))) { |
314 | current = next; |
315 | if (current->input_size() == 0 || |
316 | (!follow_control_input && IsControlInput(current->input(0)))) { |
317 | break; |
318 | } |
319 | next = node_map.GetNode(current->input(0)); |
320 | if (next == nullptr) { |
321 | LOG(ERROR) << "Node not found: " << current->input(0); |
322 | } |
323 | } |
324 | return const_cast<NodeDef*>(current); |
325 | } |
326 | |
327 | // Every permutation is a product of one or more cycles. Iterate over the cycles |
328 | // in the permutation, and convert each of those into a product of |
329 | // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation |
330 | void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation, |
331 | bool invert_permutation) { |
332 | CHECK_EQ(graph->node_size(), permutation->size()); |
333 | std::vector<int> inv_perm(permutation->size(), 0); |
334 | if (invert_permutation) { |
335 | for (size_t n = 0; n < permutation->size(); ++n) { |
336 | inv_perm[(*permutation)[n]] = n; |
337 | } |
338 | permutation->swap(inv_perm); |
339 | } |
340 | for (int n = 0, end = permutation->size(); n + 1 < end; ++n) { |
341 | while (n != (*permutation)[n]) { |
342 | std::size_t r = (*permutation)[n]; |
343 | graph->mutable_node()->SwapElements(n, r); |
344 | std::swap((*permutation)[n], (*permutation)[r]); |
345 | } |
346 | } |
347 | } |
348 | |
349 | void DedupControlInputs(NodeDef* node) { |
350 | absl::flat_hash_set<string> inputs; |
351 | int pos = 0; |
352 | while (pos < node->input_size()) { |
353 | const string& input = node->input(pos); |
354 | if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) { |
355 | node->mutable_input()->SwapElements(pos, node->input_size() - 1); |
356 | node->mutable_input()->RemoveLast(); |
357 | } else { |
358 | ++pos; |
359 | } |
360 | } |
361 | } |
362 | |
363 | namespace { |
364 | |
365 | template <typename UniqueContainer> |
366 | void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete, |
367 | GraphDef* graph) { |
368 | static_assert(std::is_same<typename UniqueContainer::value_type, int>::value, |
369 | "Need to pass container of ints" ); |
370 | |
371 | int last = graph->node_size() - 1; |
372 | for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) { |
373 | const int index = *it; |
374 | graph->mutable_node()->SwapElements(index, last); |
375 | last--; |
376 | } |
377 | graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size()); |
378 | } |
379 | |
380 | template <typename T> |
381 | inline void STLSortAndRemoveDuplicates(T* v) { |
382 | std::sort(v->begin(), v->end()); |
383 | v->erase(std::unique(v->begin(), v->end()), v->end()); |
384 | } |
385 | |
386 | } // namespace |
387 | |
388 | void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, |
389 | GraphDef* graph) { |
390 | EraseNodesFromGraphImpl(nodes_to_delete, graph); |
391 | } |
392 | |
393 | void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) { |
394 | STLSortAndRemoveDuplicates(&nodes_to_delete); |
395 | EraseNodesFromGraphImpl(nodes_to_delete, graph); |
396 | } |
397 | |
398 | void EraseNodesFromGraph(const std::set<string>& nodes_to_delete, |
399 | GraphDef* graph) { |
400 | std::vector<int> nodes_idx_to_delete; |
401 | nodes_idx_to_delete.reserve(nodes_to_delete.size()); |
402 | for (int i = 0; i < graph->node_size(); ++i) { |
403 | if (nodes_to_delete.count(graph->node(i).name())) |
404 | nodes_idx_to_delete.push_back(i); |
405 | } |
406 | EraseNodesFromGraphImpl(nodes_idx_to_delete, graph); |
407 | } |
408 | |
409 | #define HANDLE_DOUBLE_CASE(DTYPE) \ |
410 | case DTYPE: \ |
411 | if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \ |
412 | static_cast<double>(value), tensor)) { \ |
413 | return errors::InvalidArgument("Cannot store value ", value, \ |
414 | " in tensor of type " #DTYPE); \ |
415 | } \ |
416 | break |
417 | |
418 | #define HANDLE_INT_CASE(DTYPE) \ |
419 | case DTYPE: \ |
420 | if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value, \ |
421 | tensor)) { \ |
422 | return errors::InvalidArgument("Cannot store value ", value, \ |
423 | " in tensor of type " #DTYPE); \ |
424 | } \ |
425 | break |
426 | |
427 | Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { |
428 | // TODO(rmlarsen): Support more general shapes. |
429 | // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported. |
430 | if (tensor->NumElements() != 1) { |
431 | return errors::InvalidArgument( |
432 | "Expected scalar tensor, got num_elements = " , tensor->NumElements()); |
433 | } |
434 | switch (dtype) { |
435 | HANDLE_DOUBLE_CASE(DT_HALF); |
436 | HANDLE_DOUBLE_CASE(DT_BFLOAT16); |
437 | HANDLE_DOUBLE_CASE(DT_BOOL); |
438 | HANDLE_DOUBLE_CASE(DT_FLOAT); |
439 | HANDLE_DOUBLE_CASE(DT_DOUBLE); |
440 | HANDLE_DOUBLE_CASE(DT_UINT8); |
441 | HANDLE_DOUBLE_CASE(DT_INT8); |
442 | HANDLE_DOUBLE_CASE(DT_UINT16); |
443 | HANDLE_DOUBLE_CASE(DT_INT16); |
444 | HANDLE_DOUBLE_CASE(DT_INT32); |
445 | HANDLE_DOUBLE_CASE(DT_INT64); |
446 | HANDLE_DOUBLE_CASE(DT_COMPLEX64); |
447 | HANDLE_DOUBLE_CASE(DT_COMPLEX128); |
448 | HANDLE_INT_CASE(DT_QINT8); |
449 | HANDLE_INT_CASE(DT_QUINT8); |
450 | HANDLE_INT_CASE(DT_QINT16); |
451 | HANDLE_INT_CASE(DT_QUINT16); |
452 | HANDLE_INT_CASE(DT_QINT32); |
453 | default: |
454 | return errors::InvalidArgument("Unsupported type " , |
455 | DataTypeString(dtype)); |
456 | } |
457 | return OkStatus(); |
458 | } |
459 | |
460 | #undef HANDLE_CASE |
461 | |
462 | Status CheckAttrExists(const NodeDef& node, const string& key) { |
463 | if (!HasNodeAttr(node, key)) { |
464 | return errors::InvalidArgument("Node '" , node.name(), "' lacks '" , key, |
465 | "' attr: " , node.ShortDebugString()); |
466 | } |
467 | return OkStatus(); |
468 | } |
469 | |
470 | Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) { |
471 | for (const string& key : keys) { |
472 | TF_RETURN_IF_ERROR(CheckAttrExists(node, key)); |
473 | } |
474 | return OkStatus(); |
475 | } |
476 | |
477 | Status IsKernelRegisteredForNode( |
478 | absl::string_view node_name, bool has_experimental_debug_info, |
479 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info, |
480 | absl::string_view node_op, absl::string_view node_device, |
481 | AttrSlice node_attrs) { |
482 | DeviceNameUtils::ParsedName parsed_name; |
483 | if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) { |
484 | return errors::InvalidArgument("Could not parse device name: " , |
485 | node_device); |
486 | } |
487 | return FindKernelDef(DeviceType(parsed_name.type), node_name, |
488 | has_experimental_debug_info, experimental_debug_info, |
489 | node_op, node_device, node_attrs, nullptr, nullptr); |
490 | } |
491 | |
492 | Status IsKernelRegisteredForNode(const NodeDef& node) { |
493 | return IsKernelRegisteredForNode(node.name(), |
494 | node.has_experimental_debug_info(), |
495 | node.experimental_debug_info(), node.op(), |
496 | node.device(), AttrSlice(&node.attr())); |
497 | } |
498 | |
499 | namespace { |
500 | void RemoveAttributes(const std::vector<absl::string_view>& to_remove, |
501 | NodeDef* node) { |
502 | if (to_remove.size() == node->attr_size()) { |
503 | node->clear_attr(); |
504 | } else { |
505 | for (const auto& key : to_remove) { |
506 | node->mutable_attr()->erase(string(key)); |
507 | } |
508 | } |
509 | } |
510 | } // namespace |
511 | |
512 | int EraseRegularNodeAttributes(NodeDef* node) { |
513 | std::vector<absl::string_view> to_remove; |
514 | for (const auto& attr : node->attr()) { |
515 | if (!attr.first.empty() && (attr.first)[0] != '_') { |
516 | to_remove.push_back(attr.first); |
517 | } |
518 | } |
519 | RemoveAttributes(to_remove, node); |
520 | return to_remove.size(); |
521 | } |
522 | |
523 | int EraseNodeOutputAttributes(NodeDef* node) { |
524 | std::vector<absl::string_view> to_remove; |
525 | for (const auto& attr : node->attr()) { |
526 | const string& attr_name = attr.first; |
527 | if (attr_name == "_xla_inferred_shapes" || |
528 | absl::StartsWith(attr_name, "_output_" )) { |
529 | to_remove.push_back(attr_name); |
530 | } |
531 | } |
532 | RemoveAttributes(to_remove, node); |
533 | return to_remove.size(); |
534 | } |
535 | |
536 | } // end namespace grappler |
537 | } // end namespace tensorflow |
538 | |