1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/graph.h"
17
18#include <memory>
19#include <vector>
20
21#include "absl/container/flat_hash_map.h"
22#include "tensorflow/core/framework/full_type.pb.h"
23#include "tensorflow/core/framework/graph.pb.h"
24#include "tensorflow/core/framework/node_def.pb.h"
25#include "tensorflow/core/framework/node_properties.h"
26#include "tensorflow/core/framework/op_def_builder.h"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/versions.pb.h"
29#include "tensorflow/core/graph/graph_node_util.h"
30#include "tensorflow/core/graph/while_context.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/gtl/map_util.h"
33#include "tensorflow/core/lib/hash/hash.h"
34#include "tensorflow/core/lib/strings/strcat.h"
35#include "tensorflow/core/lib/strings/stringprintf.h"
36#include "tensorflow/core/platform/errors.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/public/version.h"
39
40namespace tensorflow {
41
42const int Graph::kControlSlot = -1;
43
44// Node
45Node::NodeClass Node::GetNodeClassForOp(const std::string& ts) {
46 static const absl::flat_hash_map<std::string, Node::NodeClass>*
47 kNodeClassTable =
48#define REF_CLASS(key, value) \
49 {key, value}, { "Ref" key, value }
50 new absl::flat_hash_map<std::string, Node::NodeClass>({
51 // Keep in same order as NodeClass values
52 REF_CLASS("Switch", NC_SWITCH),
53 REF_CLASS("_SwitchN", NC_SWITCH),
54 REF_CLASS("Merge", NC_MERGE),
55 REF_CLASS("Enter", NC_ENTER),
56 REF_CLASS("Exit", NC_EXIT),
57 REF_CLASS("NextIteration", NC_NEXT_ITERATION),
58 {"LoopCond", NC_LOOP_COND},
59 {"ControlTrigger", NC_CONTROL_TRIGGER},
60 {"_Send", NC_SEND},
61 {"_HostSend", NC_HOST_SEND},
62 {"_Recv", NC_RECV},
63 {"_HostRecv", NC_HOST_RECV},
64 {"Const", NC_CONSTANT},
65 {"HostConst", NC_CONSTANT},
66 {"Variable", NC_VARIABLE},
67 {"VariableV2", NC_VARIABLE},
68 REF_CLASS("Identity", NC_IDENTITY),
69 {"GetSessionHandle", NC_GET_SESSION_HANDLE},
70 {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
71 {"GetSessionTensor", NC_GET_SESSION_TENSOR},
72 {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
73 {"Size", NC_METADATA},
74 {"Shape", NC_METADATA},
75 {"Rank", NC_METADATA},
76 {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
77 {"CollectiveReduce", NC_COLLECTIVE},
78 {"CollectiveBcastSend", NC_COLLECTIVE},
79 {"CollectiveBcastRecv", NC_COLLECTIVE},
80 {"CollectiveGather", NC_COLLECTIVE},
81 {"FakeParam", NC_FAKE_PARAM},
82 {"PartitionedCall", NC_PARTITIONED_CALL},
83 {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
84 {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
85 {"If", NC_IF},
86 {"StatelessIf", NC_IF},
87 {"While", NC_WHILE},
88 {"StatelessWhile", NC_WHILE},
89 {"Case", NC_CASE},
90 {"StatelessCase", NC_CASE},
91 // Not using the constants defined in FunctionLibraryDefinition
92 // for the
93 // 4 ops below because android inference library does not link
94 // tf.function related files.
95 {"_Arg", NC_ARG},
96 {"_DeviceArg", NC_ARG},
97 {"_Retval", NC_RETVAL},
98 {"_DeviceRetval", NC_RETVAL},
99 {"_XlaMerge", NC_MERGE},
100 });
101#undef REF_CLASS
102
103 auto it = kNodeClassTable->find(ts);
104 if (it != kNodeClassTable->end()) {
105 return it->second;
106 } else {
107 return NC_OTHER;
108 }
109}
110
111std::string Node::DebugString() const {
112 std::string ret = strings::StrCat("{name:'", name(), "' id:", id_);
113 if (IsSource()) {
114 strings::StrAppend(&ret, " source}");
115 } else if (IsSink()) {
116 strings::StrAppend(&ret, " sink}");
117 } else {
118 strings::StrAppend(&ret, " op device:", "{requested: '", requested_device(),
119 "', assigned: '", assigned_device_name(), "'}", " def:{",
120 SummarizeNode(*this), "}}");
121 }
122 return ret;
123}
124
125Node::Node()
126 : id_(-1),
127 cost_id_(-1),
128 class_(NC_UNINITIALIZED),
129 props_(nullptr),
130 assigned_device_name_index_(0),
131 while_ctx_(nullptr) {}
132
133void Node::Initialize(int id, int cost_id,
134 std::shared_ptr<NodeProperties> props,
135 Node::NodeClass node_class) {
136 DCHECK_EQ(id_, -1);
137 DCHECK(in_edges_.empty());
138 DCHECK(out_edges_.empty());
139 id_ = id;
140 cost_id_ = cost_id;
141
142 props_ = std::move(props);
143 class_ = node_class;
144}
145
146void Node::Clear() {
147 in_edges_.clear();
148 out_edges_.clear();
149 id_ = -1;
150 cost_id_ = -1;
151 class_ = NC_UNINITIALIZED;
152 props_.reset();
153 assigned_device_name_index_ = 0;
154}
155
156void Node::UpdateProperties() {
157 DataTypeVector inputs;
158 DataTypeVector outputs;
159 Status status =
160 InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
161 if (!status.ok()) {
162 LOG(ERROR) << "Failed at updating node: " << status;
163 return;
164 }
165 if (props_->input_types != inputs || props_->output_types != outputs) {
166 if (TF_PREDICT_TRUE(props_.use_count() == 1)) {
167 props_->input_types = inputs;
168 props_->input_types_slice = props_->input_types;
169 props_->output_types = outputs;
170 props_->output_types_slice = props_->output_types;
171 } else {
172 props_ = std::make_shared<NodeProperties>(
173 props_->op_def, std::move(props_->node_def), inputs, outputs);
174 }
175 }
176}
177
178void Node::ClearTypeInfo() {
179 if (props_->node_def.has_experimental_type()) {
180 MaybeCopyOnWrite();
181 props_->node_def.clear_experimental_type();
182 }
183}
184
185const std::string& Node::name() const { return props_->node_def.name(); }
186const std::string& Node::type_string() const { return props_->node_def.op(); }
187const NodeDef& Node::def() const { return props_->node_def; }
188const OpDef& Node::op_def() const { return *props_->op_def; }
189
190NodeDef* Node::mutable_def() { return &props_->node_def; }
191
192int32 Node::num_inputs() const { return props_->input_types.size(); }
193DataType Node::input_type(int32_t i) const { return props_->input_types[i]; }
194const DataTypeVector& Node::input_types() const { return props_->input_types; }
195
196int32 Node::num_outputs() const { return props_->output_types.size(); }
197DataType Node::output_type(int32_t o) const { return props_->output_types[o]; }
198const DataTypeVector& Node::output_types() const {
199 return props_->output_types;
200}
201
202AttrSlice Node::attrs() const { return AttrSlice(def()); }
203
204const protobuf::RepeatedPtrField<std::string>& Node::requested_inputs() const {
205 return def().input();
206}
207
208const std::string& Node::requested_device() const { return def().device(); }
209
210gtl::iterator_range<NeighborIter> Node::out_nodes() const {
211 return gtl::make_range(NeighborIter(out_edges_.begin(), false),
212 NeighborIter(out_edges_.end(), false));
213}
214
215gtl::iterator_range<NeighborIter> Node::in_nodes() const {
216 return gtl::make_range(NeighborIter(in_edges_.begin(), true),
217 NeighborIter(in_edges_.end(), true));
218}
219
220void Node::MaybeCopyOnWrite() {
221 // TODO(mdan): As nodes become more dynamic, this may not be worth the cost.
222 // NodeProperties may be shared between Nodes. Make a copy if so.
223 if (!props_.unique()) {
224 props_ = std::make_shared<NodeProperties>(*props_);
225 }
226}
227
228AttrValue* Node::AddAttrHelper(const std::string& name) {
229 MaybeCopyOnWrite();
230 return &((*props_->node_def.mutable_attr())[name]);
231}
232
233void Node::ClearAttr(const std::string& name) {
234 MaybeCopyOnWrite();
235 (*props_->node_def.mutable_attr()).erase(name);
236}
237
238void Node::set_name(std::string name) {
239 MaybeCopyOnWrite();
240 props_->node_def.set_name(std::move(name));
241}
242
243void Node::set_requested_device(const std::string& device) {
244 MaybeCopyOnWrite();
245 props_->node_def.set_device(device);
246}
247
248void Node::set_original_node_names(const std::vector<std::string>& names) {
249 MaybeCopyOnWrite();
250 props_->node_def.mutable_experimental_debug_info()
251 ->clear_original_node_names();
252 if (!names.empty()) {
253 *props_->node_def.mutable_experimental_debug_info()
254 ->mutable_original_node_names() = {names.begin(), names.end()};
255 }
256}
257
258void Node::set_original_func_names(const std::vector<std::string>& names) {
259 MaybeCopyOnWrite();
260 props_->node_def.mutable_experimental_debug_info()
261 ->clear_original_func_names();
262 if (!names.empty()) {
263 *props_->node_def.mutable_experimental_debug_info()
264 ->mutable_original_func_names() = {names.begin(), names.end()};
265 }
266}
267
268Status Node::input_edge(int idx, const Edge** e) const {
269 if (idx < 0 || idx >= num_inputs()) {
270 return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
271 name(), " only has ", num_inputs(),
272 " inputs.");
273 }
274
275 // This does a linear search over the edges. In the common case,
276 // the number of elements is small enough that this search isn't
277 // expensive. Should it become a bottleneck, one can make an
278 // optimization where, if the number of edges is small, we use
279 // linear iteration, and if the number of edges is large, we perform
280 // an indexing step during construction that keeps an array of Edges
281 // indexed by pointer. This would keep the size of each Node small
282 // in the common case but make this function faster when the number
283 // of edges is large.
284 for (const Edge* edge : in_edges()) {
285 if (edge->dst_input() == idx) {
286 *e = edge;
287 return OkStatus();
288 }
289 }
290
291 return errors::NotFound("Could not find input edge ", idx, " for ", name());
292}
293
294// Returns a vector of the non-control input edges to a node, indexed by ID.
295Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
296 input_edges->clear();
297 input_edges->resize(num_inputs(), nullptr);
298
299 for (const Edge* edge : in_edges()) {
300 if (edge->IsControlEdge()) continue;
301 if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
302 return errors::Internal("Invalid edge input number ", edge->dst_input());
303 }
304 if ((*input_edges)[edge->dst_input()] != nullptr) {
305 return errors::Internal("Duplicate edge input number: ",
306 edge->dst_input());
307 }
308 (*input_edges)[edge->dst_input()] = edge;
309 }
310
311 for (int i = 0; i < num_inputs(); ++i) {
312 if ((*input_edges)[i] == nullptr) {
313 return errors::InvalidArgument("Missing edge input number: ", i);
314 }
315 }
316 return OkStatus();
317}
318
319Status Node::input_node(int idx, Node** n) const {
320 const Edge* e;
321 TF_RETURN_IF_ERROR(input_edge(idx, &e));
322 if (e == nullptr) {
323 *n = nullptr;
324 } else {
325 *n = e->src();
326 }
327 return OkStatus();
328}
329
330Status Node::input_node(int idx, const Node** const_n) const {
331 Node* n;
332 TF_RETURN_IF_ERROR(input_node(idx, &n));
333 *const_n = n;
334 return OkStatus();
335}
336
337Status Node::input_tensor(int idx, OutputTensor* t) const {
338 const Edge* e;
339 TF_RETURN_IF_ERROR(input_edge(idx, &e));
340 DCHECK(e != nullptr);
341 *t = OutputTensor(e->src(), e->src_output());
342 return OkStatus();
343}
344
345// NodeDebugInfo
346
347NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
348NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
349 : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
350 ndef.experimental_debug_info()) {}
351NodeDebugInfo::NodeDebugInfo(
352 StringPiece node_name, bool has_experimental_debug_info,
353 const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
354 : name(node_name) {
355 if (has_experimental_debug_info) {
356 const auto& node_names = experimental_debug_info.original_node_names();
357 original_node_names.assign(node_names.begin(), node_names.end());
358 const auto& func_names = experimental_debug_info.original_func_names();
359 original_func_names.assign(func_names.begin(), func_names.end());
360 }
361}
362// InputTensor
363
364bool InputTensor::operator==(const InputTensor& other) const {
365 return node == other.node && index == other.index;
366}
367
368uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
369 return Hash64Combine(std::hash<const Node*>()(s.node),
370 std::hash<int>()(s.index));
371}
372
373// OutputTensor
374
375bool OutputTensor::operator==(const OutputTensor& other) const {
376 return node == other.node && index == other.index;
377}
378
379uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
380 return Hash64Combine(std::hash<const Node*>()(s.node),
381 std::hash<int>()(s.index));
382}
383
384// Graph
385
386Graph::Graph(const OpRegistryInterface* ops)
387 : ops_(ops, FunctionDefLibrary()),
388 versions_(new VersionDef),
389 arena_(8 << 10 /* 8kB */) {
390 versions_->set_producer(TF_GRAPH_DEF_VERSION);
391 versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
392
393 // Initialize the name interning table for assigned_device_name.
394 device_names_.push_back("");
395 DCHECK_EQ(0, InternDeviceName(""));
396
397 // Source and sink have no endpoints, just control edges.
398 NodeDef def;
399 def.set_name("_SOURCE");
400 def.set_op("NoOp");
401 Status status;
402 Node* source = AddNode(def, &status);
403 TF_CHECK_OK(status);
404 CHECK_EQ(source->id(), kSourceId);
405
406 def.set_name("_SINK");
407 Node* sink = AddNode(def, &status);
408 TF_CHECK_OK(status);
409 CHECK_EQ(sink->id(), kSinkId);
410
411 AddControlEdge(source, sink);
412}
413
414Graph::Graph(const FunctionLibraryDefinition& flib_def)
415 : Graph(flib_def.default_registry()) {
416 // Need a new-enough consumer to support the functions we add to the graph.
417 if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
418 versions_->set_min_consumer(12);
419 }
420 Status s = ops_.AddLibrary(flib_def);
421 CHECK(s.ok()) << s.error_message();
422}
423
424Graph::~Graph() {
425 // Manually call the destructors for all the Nodes we constructed using
426 // placement new.
427 for (Node* node : nodes_) {
428 if (node != nullptr) {
429 node->~Node();
430 }
431 }
432 for (Node* node : free_nodes_) {
433 node->~Node();
434 }
435 // Edges have no destructor, and we arena-allocated them, so no need to
436 // destroy them.
437}
438
439std::unique_ptr<Graph> Graph::Clone() {
440 std::unique_ptr<Graph> new_graph(new Graph(flib_def()));
441 new_graph->Copy(*this);
442 return new_graph;
443}
444
445void Graph::Clear() {
446 // Do a direct iteration clearing nodes removing the RemoveNode helper method.
447 // This could avoid this helper and clear directly if it becomes performance
448 // sensitive.
449 for (Node* n : nodes()) {
450 if (!n->IsSource() && !n->IsSink()) RemoveNode(n);
451 }
452}
453
454const VersionDef& Graph::versions() const { return *versions_; }
455void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
456
457void Graph::Copy(const Graph& src) {
458 SetConstructionContext(src.GetConstructionContextInternal());
459 for (Node* n : nodes()) {
460 CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
461 }
462
463 // Copy GraphDef versions
464 set_versions(src.versions());
465
466 // Copy the nodes.
467 // "Node in src" -> "Node in *dest"
468 gtl::FlatMap<const Node*, Node*> node_map;
469 node_map.reserve(src.num_nodes());
470 node_map[src.source_node()] = source_node();
471 node_map[src.sink_node()] = sink_node();
472 for (Node* n : src.op_nodes()) {
473 auto copy = CopyNode(n);
474 copy->in_edges_.reserve(n->in_edges().size());
475 copy->out_edges_.reserve(n->out_edges().size());
476 node_map[n] = copy;
477 }
478
479 // Copy the edges
480 edges_.reserve(src.num_edges());
481 for (const Edge* e : src.edges()) {
482 Node* src_copy = node_map[e->src()];
483 Node* dst_copy = node_map[e->dst()];
484 AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
485 }
486}
487
488StatusOr<Node*> Graph::AddNode(NodeDef node_def) {
489 Status s;
490 Node* out = AddNode(std::move(node_def), &s);
491 TF_RETURN_IF_ERROR(s);
492 return out;
493}
494
495Node* Graph::AddNode(NodeDef node_def, Status* status) {
496 const OpRegistrationData* op_reg_data;
497 status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
498 if (!status->ok()) return nullptr;
499
500 DataTypeVector inputs;
501 DataTypeVector outputs;
502 status->Update(
503 InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
504 if (!status->ok()) {
505 *status = AttachDef(*status, node_def);
506 return nullptr;
507 }
508
509 Node::NodeClass node_class = op_reg_data->is_function_op
510 ? Node::NC_FUNCTION_OP
511 : Node::GetNodeClassForOp(node_def.op());
512
513 if (node_def.has_experimental_type()) {
514 VLOG(3) << "AddNode: node has type set, skipping type constructor "
515 << node_def.name();
516 } else {
517 if (op_reg_data->type_ctor != nullptr) {
518 VLOG(3) << "AddNode: found type constructor for " << node_def.name();
519 Status s =
520 full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def,
521 *(node_def.mutable_experimental_type()));
522 if (!s.ok()) {
523 *status = errors::InvalidArgument("type error: ", s.ToString());
524 VLOG(3) << "AddNode: type inference failed for " << node_def.name()
525 << ": " << s;
526 return nullptr;
527 }
528 } else {
529 VLOG(3) << "AddNode: no type constructor for " << node_def.name();
530 }
531 }
532
533 Node* node = AllocateNode(
534 std::make_shared<NodeProperties>(&op_reg_data->op_def,
535 std::move(node_def), inputs, outputs),
536 nullptr, node_class);
537 return node;
538}
539
540Node* Graph::CopyNode(const Node* node) {
541 DCHECK(!node->IsSource());
542 DCHECK(!node->IsSink());
543 Node* copy = AllocateNode(node->props_, node, node->class_);
544 copy->set_assigned_device_name(node->assigned_device_name());
545
546 // Since the OpDef of a function may be owned by the Graph that owns 'node',
547 // relookup the OpDef in the target graph. If it differs, then clone the
548 // node properties with the updated OpDef.
549 const OpDef* op_def;
550 TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
551 if (op_def != node->props_->op_def) {
552 copy->MaybeCopyOnWrite();
553 copy->props_->op_def = op_def;
554 }
555 copy->SetStackTrace(node->GetStackTrace());
556
557 return copy;
558}
559
560void Graph::RemoveNode(Node* node) {
561 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
562 DCHECK(!node->IsSource());
563 DCHECK(!node->IsSink());
564
565 // Remove any edges involving this node.
566 for (const Edge* e : node->in_edges_) {
567 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
568 edges_[e->id_] = nullptr;
569 RecycleEdge(e);
570 --num_edges_;
571 }
572 node->in_edges_.clear();
573 for (const Edge* e : node->out_edges_) {
574 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
575 edges_[e->id_] = nullptr;
576 RecycleEdge(e);
577 --num_edges_;
578 }
579 node->out_edges_.clear();
580 ReleaseNode(node);
581}
582
583const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
584 TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
585 TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
586
587 // source/sink must only be linked via control slots, and
588 // control slots must only be linked to control slots.
589 if (source == source_node() || dest == sink_node() || x == kControlSlot ||
590 y == kControlSlot) {
591 DCHECK_EQ(x, kControlSlot) << source->DebugString();
592 DCHECK_EQ(y, kControlSlot) << dest->DebugString();
593 }
594
595 Edge* e = nullptr;
596 if (free_edges_.empty()) {
597 e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
598 } else {
599 e = free_edges_.back();
600 free_edges_.pop_back();
601 }
602 e->id_ = edges_.size();
603 e->src_ = source;
604 e->dst_ = dest;
605 e->src_output_ = x;
606 e->dst_input_ = y;
607 CHECK(source->out_edges_.insert(e).second);
608 CHECK(dest->in_edges_.insert(e).second);
609 edges_.push_back(e);
610 ++num_edges_;
611
612 return e;
613}
614
615void Graph::RemoveEdge(const Edge* e) {
616 TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
617 TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
618 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
619 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
620 CHECK_EQ(e, edges_[e->id_]);
621 CHECK_GT(num_edges_, 0);
622
623 edges_[e->id_] = nullptr;
624 RecycleEdge(e);
625 --num_edges_;
626}
627
628void Graph::RecycleEdge(const Edge* e) {
629 free_edges_.push_back(const_cast<Edge*>(e));
630}
631
632const Edge* Graph::AddControlEdge(Node* source, Node* dest,
633 bool allow_duplicates) {
634 if (!allow_duplicates) {
635 for (const Edge* edge : dest->in_edges()) {
636 if (edge->IsControlEdge() && edge->src() == source) {
637 // The requested edge already exists.
638 return nullptr;
639 }
640 }
641 }
642 // Modify dest's NodeDef if necessary.
643 if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
644 // Check if this input is already in dest's NodeDef.
645 const std::string new_input = strings::StrCat("^", source->name());
646 bool input_exists = false;
647 for (const std::string& input : dest->props_->node_def.input()) {
648 if (input == new_input) {
649 input_exists = true;
650 break;
651 }
652 }
653 if (!input_exists) {
654 dest->MaybeCopyOnWrite();
655 dest->props_->node_def.add_input(new_input);
656 }
657 }
658 return AddEdge(source, kControlSlot, dest, kControlSlot);
659}
660
661void Graph::RemoveControlEdge(const Edge* e) {
662 if (!e->src_->IsSource() && !e->dst_->IsSink()) {
663 e->dst_->MaybeCopyOnWrite();
664 std::string e_src_name = strings::StrCat("^", e->src_->name());
665 auto* inputs = e->dst_->props_->node_def.mutable_input();
666 for (auto it = inputs->begin(); it != inputs->end(); ++it) {
667 if (*it == e_src_name) {
668 inputs->erase(it);
669 break;
670 }
671 }
672 }
673 RemoveEdge(e);
674}
675
676namespace {
677const Edge* FindEdge(const Node* dst, int index) {
678 for (const Edge* e : dst->in_edges()) {
679 if (e->dst_input() == index) return e;
680 }
681 return nullptr;
682}
683} // namespace
684
685Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
686 int dst_index) {
687 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
688 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
689 const Edge* e = FindEdge(dst, dst_index);
690 if (e == nullptr) {
691 return errors::InvalidArgument("Couldn't find edge to ",
692 FormatNodeForError(*dst));
693 }
694 RemoveEdge(e);
695 AddEdge(new_src, new_src_index, dst, dst_index);
696 dst->MaybeCopyOnWrite();
697 (*dst->props_->node_def.mutable_input())[dst_index] =
698 strings::StrCat(new_src->name(), ":", new_src_index);
699 return OkStatus();
700}
701
702Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
703 if (!dst->IsWhileNode()) {
704 return errors::Internal(
705 "dst argument to AddWhileEdgeHack should be a While op, got: ",
706 dst->DebugString());
707 }
708 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
709 // Find the current number of data inputs. We'll add the new edge to the next
710 // missing data input.
711 int dst_index = 0;
712 for (const Edge* edge : dst->in_edges()) {
713 if (edge->IsControlEdge()) continue;
714 ++dst_index;
715 }
716 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
717 AddEdge(new_src, new_src_index, dst, dst_index);
718 dst->MaybeCopyOnWrite();
719 dst->props_->node_def.add_input(
720 strings::StrCat(new_src->name(), ":", new_src_index));
721 return OkStatus();
722}
723
724Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
725 // Need a new-enough consumer to support the functions we add to the graph.
726 if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
727 versions_->set_min_consumer(12);
728 }
729 return ops_.AddLibrary(fdef_lib);
730}
731
732namespace {
733
734void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
735 if (src_slot == Graph::kControlSlot) {
736 dst->add_input(strings::StrCat("^", src_name));
737 } else if (src_slot == 0) {
738 dst->add_input(src_name.data(), src_name.size());
739 } else {
740 dst->add_input(strings::StrCat(src_name, ":", src_slot));
741 }
742}
743
744} // namespace
745
746void Graph::ToGraphDef(GraphDef* graph_def) const {
747 ToGraphDefSubRange(graph_def, 0);
748}
749
750GraphDef Graph::ToGraphDefDebug() const {
751 GraphDef ret;
752 ToGraphDef(&ret);
753 return ret;
754}
755
756void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
757 graph_def->Clear();
758 *graph_def->mutable_versions() = versions();
759 *graph_def->mutable_library() = ops_.ToProto();
760
761 graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
762
763 std::vector<const Edge*>
764 inputs; // Construct this outside the loop for speed.
765 for (auto id = from_node_id; id < num_node_ids(); ++id) {
766 const Node* node = FindNodeId(id);
767 if (node == nullptr || !node->IsOp()) continue;
768 NodeDef* node_def = graph_def->add_node();
769 *node_def = node->def();
770
771 // Use the node's assigned device, if any, instead of the device requested
772 // in the NodeDef.
773 if (!node->assigned_device_name().empty()) {
774 node_def->set_device(node->assigned_device_name());
775 }
776
777 // Get the inputs for this Node. We make sure control inputs are
778 // after data inputs, as required by GraphDef.
779 inputs.clear();
780 inputs.resize(node->num_inputs(), nullptr);
781 for (const Edge* edge : node->in_edges()) {
782 if (edge->IsControlEdge()) {
783 inputs.push_back(edge);
784 } else {
785 DCHECK(edge->dst_input() < inputs.size())
786 << "Edge " << edge->DebugString()
787 << " is overflowing the expected number of inputs ("
788 << node->num_inputs() << ") for node " << node->DebugString();
789 CHECK(inputs[edge->dst_input()] == nullptr)
790 << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
791 << " conflicts with pre-existing input edge "
792 << inputs[edge->dst_input()]->src()->name() << "->"
793 << inputs[edge->dst_input()]->dst()->name();
794
795 inputs[edge->dst_input()] = edge;
796 }
797 }
798 // Sort the control inputs for more predictable serialization.
799 std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
800 [](const Edge* a, const Edge* b) -> bool {
801 return a->src()->name() < b->src()->name();
802 });
803 node_def->clear_input();
804 node_def->mutable_input()->Reserve(inputs.size());
805
806 for (size_t i = 0; i < inputs.size(); ++i) {
807 const Edge* edge = inputs[i];
808 if (edge == nullptr) {
809 if (i < node->requested_inputs().size()) {
810 node_def->add_input(node->requested_inputs()[i]);
811 } else {
812 node_def->add_input("");
813 }
814 } else {
815 const Node* src = edge->src();
816 if (!src->IsOp()) continue;
817 AddInput(node_def, src->name(), edge->src_output());
818 }
819 }
820 }
821}
822
823std::string Graph::NewName(StringPiece prefix) {
824 return strings::StrCat(prefix, "/_", name_counter_++);
825}
826
827Status Graph::IsValidNode(const Node* node) const {
828 if (node == nullptr) {
829 return errors::InvalidArgument("Node is null");
830 }
831 const int id = node->id();
832 if (id < 0) {
833 return errors::InvalidArgument("node id ", id, " is less than zero");
834 }
835 if (static_cast<size_t>(id) >= nodes_.size()) {
836 return errors::InvalidArgument(
837 "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
838 }
839 if (nodes_[id] != node) {
840 return errors::InvalidArgument("Node with id ", id,
841 " is different from the passed in node. "
842 "Does it belong to a different graph?");
843 }
844 return OkStatus();
845}
846
847Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
848 TF_RETURN_IF_ERROR(IsValidNode(node));
849 if (idx >= node->num_outputs() || idx < 0) {
850 return errors::OutOfRange("Node '", node->name(), "' (type: '",
851 node->op_def().name(),
852 "', num of outputs: ", node->num_outputs(),
853 ") does not have ", "output ", idx);
854 }
855 return OkStatus();
856}
857
858Status Graph::IsValidInputTensor(const Node* node, int idx) const {
859 TF_RETURN_IF_ERROR(IsValidNode(node));
860 if (idx >= node->num_inputs() || idx < 0) {
861 return errors::OutOfRange("Node '", node->name(), "' (type: '",
862 node->op_def().name(),
863 "', num of inputs: ", node->num_inputs(),
864 ") does not have ", "input ", idx);
865 }
866 return OkStatus();
867}
868
869Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
870 const Node* cost_node, Node::NodeClass node_class) {
871 Node* node = nullptr;
872 if (free_nodes_.empty()) {
873 node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
874 } else {
875 node = free_nodes_.back();
876 free_nodes_.pop_back();
877 }
878 node->graph_ = this;
879 const int id = nodes_.size();
880 int cost_id = cost_node ? cost_node->cost_id() : id;
881 node->Initialize(id, cost_id, std::move(props), node_class);
882 nodes_.push_back(node);
883 ++num_nodes_;
884 return node;
885}
886
887void Graph::ReleaseNode(Node* node) {
888 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
889 nodes_[node->id()] = nullptr;
890 free_nodes_.push_back(node);
891 --num_nodes_;
892 node->Clear();
893}
894
895// Ensures that 'device_name' is present in the device name table, and returns
896// the index of that device name. The index is stable, and can be used in
897// calls to Node::set_assigned_device_name_index().
898int Graph::InternDeviceName(const std::string& device_name) {
899 // Special case, very common. Also, this allows us to use a single map
900 // lookup below, instead of two. The 'if (index_cell > 0)' test below
901 // relies on this check.
902 if (device_name.empty()) {
903 return 0;
904 }
905
906 int& index_cell = device_names_map_[device_name];
907 if (index_cell > 0) {
908 return index_cell;
909 }
910
911 const int index = device_names_map_.size();
912 index_cell = index;
913 device_names_.push_back(device_name);
914 return index;
915}
916
917Status Graph::AddWhileContext(StringPiece frame_name,
918 std::vector<Node*> enter_nodes,
919 std::vector<Node*> exit_nodes,
920 OutputTensor cond_output,
921 std::vector<OutputTensor> body_inputs,
922 std::vector<OutputTensor> body_outputs,
923 WhileContext** result) {
924 auto pair = while_ctxs_.insert(std::pair<std::string, WhileContext>(
925 std::string(frame_name),
926 WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
927 cond_output, std::move(body_inputs),
928 std::move(body_outputs))));
929 if (!pair.second) {
930 *result = nullptr;
931 return errors::InvalidArgument("WhileContext with frame name '", frame_name,
932 "' already exists");
933 }
934 *result = &pair.first->second;
935 return OkStatus();
936}
937
938std::unordered_map<std::string, Node*> Graph::BuildNodeNameIndex() const {
939 std::unordered_map<std::string, Node*> result;
940 for (Node* n : nodes()) {
941 result[n->name()] = n;
942 }
943 return result;
944}
945
946std::string Edge::DebugString() const {
947 return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
948 src_output_, dst_->name().c_str(), dst_input_);
949}
950
951} // namespace tensorflow
952