1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/grappler/utils.h"
17
18#include <iterator>
19#include <memory>
20#include <queue>
21#include <vector>
22
23#include "absl/container/flat_hash_set.h"
24#include "absl/strings/match.h"
25#include "absl/strings/str_cat.h"
26#include "tensorflow/core/framework/attr_value.pb.h"
27#include "tensorflow/core/framework/function.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/framework/op.h"
30#include "tensorflow/core/framework/op_def.pb.h"
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/framework/types.h"
33#include "tensorflow/core/lib/core/stringpiece.h"
34#include "tensorflow/core/lib/strings/numbers.h"
35#include "tensorflow/core/lib/strings/scanner.h"
36#include "tensorflow/core/lib/strings/strcat.h"
37#include "tensorflow/core/platform/notification.h"
38#include "tensorflow/core/util/device_name_utils.h"
39
40namespace tensorflow {
41namespace grappler {
42namespace {
43template <typename T>
44bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
45 using RealType = typename Eigen::NumTraits<T>::Real;
46 if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
47 value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
48 return false;
49 }
50 tensor->flat<T>()(0) = static_cast<T>(value);
51 return true;
52}
53
54template <typename T>
55bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
56 using RealType = typename Eigen::NumTraits<T>::Real;
57 if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
58 value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
59 return false;
60 }
61 tensor->flat<T>()(0) = static_cast<T>(value);
62 return true;
63}
64
65// Is 'node' an operator that consumes only the shape of its input, not the
66// data itself?
67// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
68// TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
69bool IsShapeConsumer(const NodeDef& node) {
70 const string& op = node.op();
71 return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
72}
73
74} // namespace
75
76string TensorIdToString(const TensorId& tensor_id) {
77 return tensor_id.index() == 0 ? string(tensor_id.node())
78 : tensor_id.ToString();
79}
80
81string SafeTensorIdToString(const SafeTensorId& tensor_id) {
82 return tensor_id.index() == 0 ? tensor_id.node() : tensor_id.ToString();
83}
84
85bool IsSameInput(const string& name1, const string& name2) {
86 if (name1 == name2) return true;
87 TensorId tensor1 = ParseTensorName(name1);
88 TensorId tensor2 = ParseTensorName(name2);
89 return tensor1 == tensor2;
90}
91
92bool IsControlInput(absl::string_view name) {
93 return !name.empty() && name[0] == '^';
94}
95
96bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
97
98string AddPrefixToNodeName(const string& name, const string& prefix,
99 const string& delimiter) {
100 if (!name.empty()) {
101 if (name[0] == '^') {
102 return absl::StrCat("^", prefix, delimiter, name.substr(1));
103 }
104 }
105 return absl::StrCat(prefix, delimiter, name);
106}
107
108string AddPrefixToNodeName(const string& name, const string& prefix) {
109 return AddPrefixToNodeName(name, prefix, "/");
110}
111
112bool ExecuteWithTimeout(std::function<void()> fn, const int64_t timeout_in_ms,
113 thread::ThreadPool* const thread_pool) {
114 if (timeout_in_ms <= 0) {
115 fn();
116 return true;
117 }
118 auto done = std::make_shared<Notification>();
119 thread_pool->Schedule([done, fn]() {
120 fn();
121 done->Notify();
122 });
123 const bool notified =
124 WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
125 return notified;
126}
127
128string AsControlDependency(const NodeDef& node) {
129 return absl::StrCat("^", node.name());
130}
131
132string AsControlDependency(const string& node_name) {
133 CHECK(!node_name.empty());
134 return (!node_name.empty() && node_name[0] == '^')
135 ? node_name
136 : absl::StrCat("^", node_name);
137}
138
139bool NodeIsOnCpu(const NodeDef* node) {
140 string task, device;
141 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
142 absl::StartsWith(device, DEVICE_CPU);
143}
144
145bool NodeIsOnGpu(const NodeDef* node) {
146 string task, device;
147 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
148 absl::StartsWith(device, DEVICE_GPU);
149}
150
151int NumOutputs(const NodeDef& node, GraphDef* graph) {
152 int num_outputs = 0;
153 const OpDef* op_def = nullptr;
154 auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
155 if (status.ok()) {
156 for (const auto& output : op_def->output_arg()) {
157 if (!output.type_list_attr().empty()) {
158 num_outputs +=
159 node.attr().at(output.type_list_attr()).list().type_size();
160 } else if (!output.number_attr().empty()) {
161 num_outputs += node.attr().at(output.number_attr()).i();
162 } else {
163 num_outputs++;
164 }
165 }
166 } else {
167 FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
168 auto status = fdef.LookUpOpDef(node.op(), &op_def);
169 if (status.ok()) {
170 num_outputs = op_def->output_arg_size();
171 }
172 }
173 return num_outputs;
174}
175
176bool HasControlInputs(const NodeDef& node) {
177 const int num_inputs = node.input_size();
178 if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
179 return true;
180 }
181 return false;
182}
183
184bool HasRegularInputs(const NodeDef& node) {
185 const int num_inputs = node.input_size();
186 if (num_inputs > 0 && !IsControlInput(node.input(0))) {
187 return true;
188 }
189 return false;
190}
191
192int NumNonControlInputs(const NodeDef& node) {
193 int num_inputs = 0;
194 for (; num_inputs < node.input_size(); ++num_inputs) {
195 const string& input = node.input(num_inputs);
196 if (IsControlInput(input)) {
197 return num_inputs;
198 }
199 }
200 return num_inputs;
201}
202
203int NumControlInputs(const NodeDef& node) {
204 int num_inputs = 0;
205 for (; num_inputs < node.input_size(); ++num_inputs) {
206 const string& input = node.input(node.input_size() - num_inputs - 1);
207 if (!IsControlInput(input)) {
208 return num_inputs;
209 }
210 }
211 return num_inputs;
212}
213
214bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
215 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
216 for (const string& node_as_input : output->input()) {
217 if (IsControlInput(node_as_input)) break;
218
219 TensorId tensor = ParseTensorName(node_as_input);
220 if (tensor.node() == node.name()) {
221 return true;
222 }
223 }
224 }
225 return false;
226}
227
228bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
229 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
230 for (int idx = output->input_size() - 1; idx >= 0; --idx) {
231 const string& node_as_input = output->input(idx);
232 if (!IsControlInput(node_as_input)) break;
233
234 TensorId tensor = ParseTensorName(node_as_input);
235 if (tensor.node() == node.name()) {
236 return true;
237 }
238 }
239 }
240 return false;
241}
242
243int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
244 int num_outputs = 0;
245 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
246 for (int idx = output->input_size() - 1; idx >= 0; --idx) {
247 const string& node_as_input = output->input(idx);
248 if (!IsControlInput(node_as_input)) break;
249
250 TensorId tensor = ParseTensorName(node_as_input);
251 if (tensor.node() == node.name()) {
252 ++num_outputs;
253 }
254 }
255 }
256 return num_outputs;
257}
258
259int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
260 int num_outputs = 0;
261 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
262 for (const string& node_as_input : output->input()) {
263 if (IsControlInput(node_as_input)) {
264 break;
265 }
266 if (node_as_input == node.name()) {
267 ++num_outputs;
268 } else {
269 const TensorId tensor = ParseTensorName(node_as_input);
270 if (tensor.node() == node.name()) {
271 ++num_outputs;
272 }
273 }
274 }
275 }
276 return num_outputs;
277}
278
279int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
280 int num_data_outputs = 0;
281 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
282 if (IsShapeConsumer(*output)) continue;
283
284 for (int i = 0; i < output->input_size(); ++i) {
285 const string& input = output->input(i);
286 if (!IsControlInput(input) && NodeName(input) == node.name()) {
287 ++num_data_outputs;
288 break;
289 }
290 }
291 }
292 return num_data_outputs;
293}
294
295// Returns the data type in attribute `attr_name` of `node`. If that attribute
296// doesn't exist, returns DT_INVALID.
297DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
298 if (!node.attr().count(type_attr)) {
299 return DT_INVALID;
300 }
301 const auto& attr = node.attr().at(type_attr);
302 if (attr.value_case() != AttrValue::kType) {
303 return DT_INVALID;
304 }
305 return attr.type();
306}
307
308NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
309 bool follow_control_input,
310 const std::function<bool(const NodeDef&)>& pred_fn) {
311 const NodeDef* current = &source;
312 const NodeDef* next = current;
313 while (next == &source || (next != nullptr && pred_fn(*next))) {
314 current = next;
315 if (current->input_size() == 0 ||
316 (!follow_control_input && IsControlInput(current->input(0)))) {
317 break;
318 }
319 next = node_map.GetNode(current->input(0));
320 if (next == nullptr) {
321 LOG(ERROR) << "Node not found: " << current->input(0);
322 }
323 }
324 return const_cast<NodeDef*>(current);
325}
326
327// Every permutation is a product of one or more cycles. Iterate over the cycles
328// in the permutation, and convert each of those into a product of
329// transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
330void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
331 bool invert_permutation) {
332 CHECK_EQ(graph->node_size(), permutation->size());
333 std::vector<int> inv_perm(permutation->size(), 0);
334 if (invert_permutation) {
335 for (size_t n = 0; n < permutation->size(); ++n) {
336 inv_perm[(*permutation)[n]] = n;
337 }
338 permutation->swap(inv_perm);
339 }
340 for (int n = 0, end = permutation->size(); n + 1 < end; ++n) {
341 while (n != (*permutation)[n]) {
342 std::size_t r = (*permutation)[n];
343 graph->mutable_node()->SwapElements(n, r);
344 std::swap((*permutation)[n], (*permutation)[r]);
345 }
346 }
347}
348
349void DedupControlInputs(NodeDef* node) {
350 absl::flat_hash_set<string> inputs;
351 int pos = 0;
352 while (pos < node->input_size()) {
353 const string& input = node->input(pos);
354 if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
355 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
356 node->mutable_input()->RemoveLast();
357 } else {
358 ++pos;
359 }
360 }
361}
362
363namespace {
364
365template <typename UniqueContainer>
366void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
367 GraphDef* graph) {
368 static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
369 "Need to pass container of ints");
370
371 int last = graph->node_size() - 1;
372 for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
373 const int index = *it;
374 graph->mutable_node()->SwapElements(index, last);
375 last--;
376 }
377 graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
378}
379
380template <typename T>
381inline void STLSortAndRemoveDuplicates(T* v) {
382 std::sort(v->begin(), v->end());
383 v->erase(std::unique(v->begin(), v->end()), v->end());
384}
385
386} // namespace
387
388void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
389 GraphDef* graph) {
390 EraseNodesFromGraphImpl(nodes_to_delete, graph);
391}
392
393void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
394 STLSortAndRemoveDuplicates(&nodes_to_delete);
395 EraseNodesFromGraphImpl(nodes_to_delete, graph);
396}
397
398void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
399 GraphDef* graph) {
400 std::vector<int> nodes_idx_to_delete;
401 nodes_idx_to_delete.reserve(nodes_to_delete.size());
402 for (int i = 0; i < graph->node_size(); ++i) {
403 if (nodes_to_delete.count(graph->node(i).name()))
404 nodes_idx_to_delete.push_back(i);
405 }
406 EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
407}
408
409#define HANDLE_DOUBLE_CASE(DTYPE) \
410 case DTYPE: \
411 if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
412 static_cast<double>(value), tensor)) { \
413 return errors::InvalidArgument("Cannot store value ", value, \
414 " in tensor of type " #DTYPE); \
415 } \
416 break
417
418#define HANDLE_INT_CASE(DTYPE) \
419 case DTYPE: \
420 if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value, \
421 tensor)) { \
422 return errors::InvalidArgument("Cannot store value ", value, \
423 " in tensor of type " #DTYPE); \
424 } \
425 break
426
427Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
428 // TODO(rmlarsen): Support more general shapes.
429 // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
430 if (tensor->NumElements() != 1) {
431 return errors::InvalidArgument(
432 "Expected scalar tensor, got num_elements = ", tensor->NumElements());
433 }
434 switch (dtype) {
435 HANDLE_DOUBLE_CASE(DT_HALF);
436 HANDLE_DOUBLE_CASE(DT_BFLOAT16);
437 HANDLE_DOUBLE_CASE(DT_BOOL);
438 HANDLE_DOUBLE_CASE(DT_FLOAT);
439 HANDLE_DOUBLE_CASE(DT_DOUBLE);
440 HANDLE_DOUBLE_CASE(DT_UINT8);
441 HANDLE_DOUBLE_CASE(DT_INT8);
442 HANDLE_DOUBLE_CASE(DT_UINT16);
443 HANDLE_DOUBLE_CASE(DT_INT16);
444 HANDLE_DOUBLE_CASE(DT_INT32);
445 HANDLE_DOUBLE_CASE(DT_INT64);
446 HANDLE_DOUBLE_CASE(DT_COMPLEX64);
447 HANDLE_DOUBLE_CASE(DT_COMPLEX128);
448 HANDLE_INT_CASE(DT_QINT8);
449 HANDLE_INT_CASE(DT_QUINT8);
450 HANDLE_INT_CASE(DT_QINT16);
451 HANDLE_INT_CASE(DT_QUINT16);
452 HANDLE_INT_CASE(DT_QINT32);
453 default:
454 return errors::InvalidArgument("Unsupported type ",
455 DataTypeString(dtype));
456 }
457 return OkStatus();
458}
459
460#undef HANDLE_CASE
461
462Status CheckAttrExists(const NodeDef& node, const string& key) {
463 if (!HasNodeAttr(node, key)) {
464 return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
465 "' attr: ", node.ShortDebugString());
466 }
467 return OkStatus();
468}
469
470Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
471 for (const string& key : keys) {
472 TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
473 }
474 return OkStatus();
475}
476
477Status IsKernelRegisteredForNode(
478 absl::string_view node_name, bool has_experimental_debug_info,
479 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
480 absl::string_view node_op, absl::string_view node_device,
481 AttrSlice node_attrs) {
482 DeviceNameUtils::ParsedName parsed_name;
483 if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) {
484 return errors::InvalidArgument("Could not parse device name: ",
485 node_device);
486 }
487 return FindKernelDef(DeviceType(parsed_name.type), node_name,
488 has_experimental_debug_info, experimental_debug_info,
489 node_op, node_device, node_attrs, nullptr, nullptr);
490}
491
492Status IsKernelRegisteredForNode(const NodeDef& node) {
493 return IsKernelRegisteredForNode(node.name(),
494 node.has_experimental_debug_info(),
495 node.experimental_debug_info(), node.op(),
496 node.device(), AttrSlice(&node.attr()));
497}
498
499namespace {
500void RemoveAttributes(const std::vector<absl::string_view>& to_remove,
501 NodeDef* node) {
502 if (to_remove.size() == node->attr_size()) {
503 node->clear_attr();
504 } else {
505 for (const auto& key : to_remove) {
506 node->mutable_attr()->erase(string(key));
507 }
508 }
509}
510} // namespace
511
512int EraseRegularNodeAttributes(NodeDef* node) {
513 std::vector<absl::string_view> to_remove;
514 for (const auto& attr : node->attr()) {
515 if (!attr.first.empty() && (attr.first)[0] != '_') {
516 to_remove.push_back(attr.first);
517 }
518 }
519 RemoveAttributes(to_remove, node);
520 return to_remove.size();
521}
522
523int EraseNodeOutputAttributes(NodeDef* node) {
524 std::vector<absl::string_view> to_remove;
525 for (const auto& attr : node->attr()) {
526 const string& attr_name = attr.first;
527 if (attr_name == "_xla_inferred_shapes" ||
528 absl::StartsWith(attr_name, "_output_")) {
529 to_remove.push_back(attr_name);
530 }
531 }
532 RemoveAttributes(to_remove, node);
533 return to_remove.size();
534}
535
536} // end namespace grappler
537} // end namespace tensorflow
538