1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/graph/graph.h" |
17 | |
18 | #include <memory> |
19 | #include <vector> |
20 | |
21 | #include "absl/container/flat_hash_map.h" |
22 | #include "tensorflow/core/framework/full_type.pb.h" |
23 | #include "tensorflow/core/framework/graph.pb.h" |
24 | #include "tensorflow/core/framework/node_def.pb.h" |
25 | #include "tensorflow/core/framework/node_properties.h" |
26 | #include "tensorflow/core/framework/op_def_builder.h" |
27 | #include "tensorflow/core/framework/op_kernel.h" |
28 | #include "tensorflow/core/framework/versions.pb.h" |
29 | #include "tensorflow/core/graph/graph_node_util.h" |
30 | #include "tensorflow/core/graph/while_context.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/lib/gtl/map_util.h" |
33 | #include "tensorflow/core/lib/hash/hash.h" |
34 | #include "tensorflow/core/lib/strings/strcat.h" |
35 | #include "tensorflow/core/lib/strings/stringprintf.h" |
36 | #include "tensorflow/core/platform/errors.h" |
37 | #include "tensorflow/core/platform/logging.h" |
38 | #include "tensorflow/core/public/version.h" |
39 | |
40 | namespace tensorflow { |
41 | |
42 | const int Graph::kControlSlot = -1; |
43 | |
44 | // Node |
45 | Node::NodeClass Node::GetNodeClassForOp(const std::string& ts) { |
46 | static const absl::flat_hash_map<std::string, Node::NodeClass>* |
47 | kNodeClassTable = |
48 | #define REF_CLASS(key, value) \ |
49 | {key, value}, { "Ref" key, value } |
50 | new absl::flat_hash_map<std::string, Node::NodeClass>({ |
51 | // Keep in same order as NodeClass values |
52 | REF_CLASS("Switch" , NC_SWITCH), |
53 | REF_CLASS("_SwitchN" , NC_SWITCH), |
54 | REF_CLASS("Merge" , NC_MERGE), |
55 | REF_CLASS("Enter" , NC_ENTER), |
56 | REF_CLASS("Exit" , NC_EXIT), |
57 | REF_CLASS("NextIteration" , NC_NEXT_ITERATION), |
58 | {"LoopCond" , NC_LOOP_COND}, |
59 | {"ControlTrigger" , NC_CONTROL_TRIGGER}, |
60 | {"_Send" , NC_SEND}, |
61 | {"_HostSend" , NC_HOST_SEND}, |
62 | {"_Recv" , NC_RECV}, |
63 | {"_HostRecv" , NC_HOST_RECV}, |
64 | {"Const" , NC_CONSTANT}, |
65 | {"HostConst" , NC_CONSTANT}, |
66 | {"Variable" , NC_VARIABLE}, |
67 | {"VariableV2" , NC_VARIABLE}, |
68 | REF_CLASS("Identity" , NC_IDENTITY), |
69 | {"GetSessionHandle" , NC_GET_SESSION_HANDLE}, |
70 | {"GetSessionHandleV2" , NC_GET_SESSION_HANDLE}, |
71 | {"GetSessionTensor" , NC_GET_SESSION_TENSOR}, |
72 | {"DeleteSessionTensor" , NC_DELETE_SESSION_TENSOR}, |
73 | {"Size" , NC_METADATA}, |
74 | {"Shape" , NC_METADATA}, |
75 | {"Rank" , NC_METADATA}, |
76 | {"_ScopedAllocator" , NC_SCOPED_ALLOCATOR}, |
77 | {"CollectiveReduce" , NC_COLLECTIVE}, |
78 | {"CollectiveBcastSend" , NC_COLLECTIVE}, |
79 | {"CollectiveBcastRecv" , NC_COLLECTIVE}, |
80 | {"CollectiveGather" , NC_COLLECTIVE}, |
81 | {"FakeParam" , NC_FAKE_PARAM}, |
82 | {"PartitionedCall" , NC_PARTITIONED_CALL}, |
83 | {"StatefulPartitionedCall" , NC_PARTITIONED_CALL}, |
84 | {"SymbolicGradient" , NC_SYMBOLIC_GRADIENT}, |
85 | {"If" , NC_IF}, |
86 | {"StatelessIf" , NC_IF}, |
87 | {"While" , NC_WHILE}, |
88 | {"StatelessWhile" , NC_WHILE}, |
89 | {"Case" , NC_CASE}, |
90 | {"StatelessCase" , NC_CASE}, |
91 | // Not using the constants defined in FunctionLibraryDefinition |
92 | // for the |
93 | // 4 ops below because android inference library does not link |
94 | // tf.function related files. |
95 | {"_Arg" , NC_ARG}, |
96 | {"_DeviceArg" , NC_ARG}, |
97 | {"_Retval" , NC_RETVAL}, |
98 | {"_DeviceRetval" , NC_RETVAL}, |
99 | {"_XlaMerge" , NC_MERGE}, |
100 | }); |
101 | #undef REF_CLASS |
102 | |
103 | auto it = kNodeClassTable->find(ts); |
104 | if (it != kNodeClassTable->end()) { |
105 | return it->second; |
106 | } else { |
107 | return NC_OTHER; |
108 | } |
109 | } |
110 | |
111 | std::string Node::DebugString() const { |
112 | std::string ret = strings::StrCat("{name:'" , name(), "' id:" , id_); |
113 | if (IsSource()) { |
114 | strings::StrAppend(&ret, " source}" ); |
115 | } else if (IsSink()) { |
116 | strings::StrAppend(&ret, " sink}" ); |
117 | } else { |
118 | strings::StrAppend(&ret, " op device:" , "{requested: '" , requested_device(), |
119 | "', assigned: '" , assigned_device_name(), "'}" , " def:{" , |
120 | SummarizeNode(*this), "}}" ); |
121 | } |
122 | return ret; |
123 | } |
124 | |
125 | Node::Node() |
126 | : id_(-1), |
127 | cost_id_(-1), |
128 | class_(NC_UNINITIALIZED), |
129 | props_(nullptr), |
130 | assigned_device_name_index_(0), |
131 | while_ctx_(nullptr) {} |
132 | |
133 | void Node::Initialize(int id, int cost_id, |
134 | std::shared_ptr<NodeProperties> props, |
135 | Node::NodeClass node_class) { |
136 | DCHECK_EQ(id_, -1); |
137 | DCHECK(in_edges_.empty()); |
138 | DCHECK(out_edges_.empty()); |
139 | id_ = id; |
140 | cost_id_ = cost_id; |
141 | |
142 | props_ = std::move(props); |
143 | class_ = node_class; |
144 | } |
145 | |
146 | void Node::Clear() { |
147 | in_edges_.clear(); |
148 | out_edges_.clear(); |
149 | id_ = -1; |
150 | cost_id_ = -1; |
151 | class_ = NC_UNINITIALIZED; |
152 | props_.reset(); |
153 | assigned_device_name_index_ = 0; |
154 | } |
155 | |
156 | void Node::UpdateProperties() { |
157 | DataTypeVector inputs; |
158 | DataTypeVector outputs; |
159 | Status status = |
160 | InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs); |
161 | if (!status.ok()) { |
162 | LOG(ERROR) << "Failed at updating node: " << status; |
163 | return; |
164 | } |
165 | if (props_->input_types != inputs || props_->output_types != outputs) { |
166 | if (TF_PREDICT_TRUE(props_.use_count() == 1)) { |
167 | props_->input_types = inputs; |
168 | props_->input_types_slice = props_->input_types; |
169 | props_->output_types = outputs; |
170 | props_->output_types_slice = props_->output_types; |
171 | } else { |
172 | props_ = std::make_shared<NodeProperties>( |
173 | props_->op_def, std::move(props_->node_def), inputs, outputs); |
174 | } |
175 | } |
176 | } |
177 | |
178 | void Node::ClearTypeInfo() { |
179 | if (props_->node_def.has_experimental_type()) { |
180 | MaybeCopyOnWrite(); |
181 | props_->node_def.clear_experimental_type(); |
182 | } |
183 | } |
184 | |
185 | const std::string& Node::name() const { return props_->node_def.name(); } |
186 | const std::string& Node::type_string() const { return props_->node_def.op(); } |
187 | const NodeDef& Node::def() const { return props_->node_def; } |
188 | const OpDef& Node::op_def() const { return *props_->op_def; } |
189 | |
190 | NodeDef* Node::mutable_def() { return &props_->node_def; } |
191 | |
192 | int32 Node::num_inputs() const { return props_->input_types.size(); } |
193 | DataType Node::input_type(int32_t i) const { return props_->input_types[i]; } |
194 | const DataTypeVector& Node::input_types() const { return props_->input_types; } |
195 | |
196 | int32 Node::num_outputs() const { return props_->output_types.size(); } |
197 | DataType Node::output_type(int32_t o) const { return props_->output_types[o]; } |
198 | const DataTypeVector& Node::output_types() const { |
199 | return props_->output_types; |
200 | } |
201 | |
202 | AttrSlice Node::attrs() const { return AttrSlice(def()); } |
203 | |
204 | const protobuf::RepeatedPtrField<std::string>& Node::requested_inputs() const { |
205 | return def().input(); |
206 | } |
207 | |
208 | const std::string& Node::requested_device() const { return def().device(); } |
209 | |
210 | gtl::iterator_range<NeighborIter> Node::out_nodes() const { |
211 | return gtl::make_range(NeighborIter(out_edges_.begin(), false), |
212 | NeighborIter(out_edges_.end(), false)); |
213 | } |
214 | |
215 | gtl::iterator_range<NeighborIter> Node::in_nodes() const { |
216 | return gtl::make_range(NeighborIter(in_edges_.begin(), true), |
217 | NeighborIter(in_edges_.end(), true)); |
218 | } |
219 | |
220 | void Node::MaybeCopyOnWrite() { |
221 | // TODO(mdan): As nodes become more dynamic, this may not be worth the cost. |
222 | // NodeProperties may be shared between Nodes. Make a copy if so. |
223 | if (!props_.unique()) { |
224 | props_ = std::make_shared<NodeProperties>(*props_); |
225 | } |
226 | } |
227 | |
228 | AttrValue* Node::AddAttrHelper(const std::string& name) { |
229 | MaybeCopyOnWrite(); |
230 | return &((*props_->node_def.mutable_attr())[name]); |
231 | } |
232 | |
233 | void Node::ClearAttr(const std::string& name) { |
234 | MaybeCopyOnWrite(); |
235 | (*props_->node_def.mutable_attr()).erase(name); |
236 | } |
237 | |
238 | void Node::set_name(std::string name) { |
239 | MaybeCopyOnWrite(); |
240 | props_->node_def.set_name(std::move(name)); |
241 | } |
242 | |
243 | void Node::set_requested_device(const std::string& device) { |
244 | MaybeCopyOnWrite(); |
245 | props_->node_def.set_device(device); |
246 | } |
247 | |
248 | void Node::set_original_node_names(const std::vector<std::string>& names) { |
249 | MaybeCopyOnWrite(); |
250 | props_->node_def.mutable_experimental_debug_info() |
251 | ->clear_original_node_names(); |
252 | if (!names.empty()) { |
253 | *props_->node_def.mutable_experimental_debug_info() |
254 | ->mutable_original_node_names() = {names.begin(), names.end()}; |
255 | } |
256 | } |
257 | |
258 | void Node::set_original_func_names(const std::vector<std::string>& names) { |
259 | MaybeCopyOnWrite(); |
260 | props_->node_def.mutable_experimental_debug_info() |
261 | ->clear_original_func_names(); |
262 | if (!names.empty()) { |
263 | *props_->node_def.mutable_experimental_debug_info() |
264 | ->mutable_original_func_names() = {names.begin(), names.end()}; |
265 | } |
266 | } |
267 | |
268 | Status Node::input_edge(int idx, const Edge** e) const { |
269 | if (idx < 0 || idx >= num_inputs()) { |
270 | return errors::InvalidArgument("Invalid input_edge index: " , idx, ", Node " , |
271 | name(), " only has " , num_inputs(), |
272 | " inputs." ); |
273 | } |
274 | |
275 | // This does a linear search over the edges. In the common case, |
276 | // the number of elements is small enough that this search isn't |
277 | // expensive. Should it become a bottleneck, one can make an |
278 | // optimization where, if the number of edges is small, we use |
279 | // linear iteration, and if the number of edges is large, we perform |
280 | // an indexing step during construction that keeps an array of Edges |
281 | // indexed by pointer. This would keep the size of each Node small |
282 | // in the common case but make this function faster when the number |
283 | // of edges is large. |
284 | for (const Edge* edge : in_edges()) { |
285 | if (edge->dst_input() == idx) { |
286 | *e = edge; |
287 | return OkStatus(); |
288 | } |
289 | } |
290 | |
291 | return errors::NotFound("Could not find input edge " , idx, " for " , name()); |
292 | } |
293 | |
294 | // Returns a vector of the non-control input edges to a node, indexed by ID. |
295 | Status Node::input_edges(std::vector<const Edge*>* input_edges) const { |
296 | input_edges->clear(); |
297 | input_edges->resize(num_inputs(), nullptr); |
298 | |
299 | for (const Edge* edge : in_edges()) { |
300 | if (edge->IsControlEdge()) continue; |
301 | if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) { |
302 | return errors::Internal("Invalid edge input number " , edge->dst_input()); |
303 | } |
304 | if ((*input_edges)[edge->dst_input()] != nullptr) { |
305 | return errors::Internal("Duplicate edge input number: " , |
306 | edge->dst_input()); |
307 | } |
308 | (*input_edges)[edge->dst_input()] = edge; |
309 | } |
310 | |
311 | for (int i = 0; i < num_inputs(); ++i) { |
312 | if ((*input_edges)[i] == nullptr) { |
313 | return errors::InvalidArgument("Missing edge input number: " , i); |
314 | } |
315 | } |
316 | return OkStatus(); |
317 | } |
318 | |
319 | Status Node::input_node(int idx, Node** n) const { |
320 | const Edge* e; |
321 | TF_RETURN_IF_ERROR(input_edge(idx, &e)); |
322 | if (e == nullptr) { |
323 | *n = nullptr; |
324 | } else { |
325 | *n = e->src(); |
326 | } |
327 | return OkStatus(); |
328 | } |
329 | |
330 | Status Node::input_node(int idx, const Node** const_n) const { |
331 | Node* n; |
332 | TF_RETURN_IF_ERROR(input_node(idx, &n)); |
333 | *const_n = n; |
334 | return OkStatus(); |
335 | } |
336 | |
337 | Status Node::input_tensor(int idx, OutputTensor* t) const { |
338 | const Edge* e; |
339 | TF_RETURN_IF_ERROR(input_edge(idx, &e)); |
340 | DCHECK(e != nullptr); |
341 | *t = OutputTensor(e->src(), e->src_output()); |
342 | return OkStatus(); |
343 | } |
344 | |
345 | // NodeDebugInfo |
346 | |
347 | NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {} |
348 | NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef) |
349 | : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(), |
350 | ndef.experimental_debug_info()) {} |
351 | NodeDebugInfo::NodeDebugInfo( |
352 | StringPiece node_name, bool has_experimental_debug_info, |
353 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info) |
354 | : name(node_name) { |
355 | if (has_experimental_debug_info) { |
356 | const auto& node_names = experimental_debug_info.original_node_names(); |
357 | original_node_names.assign(node_names.begin(), node_names.end()); |
358 | const auto& func_names = experimental_debug_info.original_func_names(); |
359 | original_func_names.assign(func_names.begin(), func_names.end()); |
360 | } |
361 | } |
362 | // InputTensor |
363 | |
364 | bool InputTensor::operator==(const InputTensor& other) const { |
365 | return node == other.node && index == other.index; |
366 | } |
367 | |
368 | uint64 InputTensor::Hash::operator()(InputTensor const& s) const { |
369 | return Hash64Combine(std::hash<const Node*>()(s.node), |
370 | std::hash<int>()(s.index)); |
371 | } |
372 | |
373 | // OutputTensor |
374 | |
375 | bool OutputTensor::operator==(const OutputTensor& other) const { |
376 | return node == other.node && index == other.index; |
377 | } |
378 | |
379 | uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const { |
380 | return Hash64Combine(std::hash<const Node*>()(s.node), |
381 | std::hash<int>()(s.index)); |
382 | } |
383 | |
384 | // Graph |
385 | |
386 | Graph::Graph(const OpRegistryInterface* ops) |
387 | : ops_(ops, FunctionDefLibrary()), |
388 | versions_(new VersionDef), |
389 | arena_(8 << 10 /* 8kB */) { |
390 | versions_->set_producer(TF_GRAPH_DEF_VERSION); |
391 | versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); |
392 | |
393 | // Initialize the name interning table for assigned_device_name. |
394 | device_names_.push_back("" ); |
395 | DCHECK_EQ(0, InternDeviceName("" )); |
396 | |
397 | // Source and sink have no endpoints, just control edges. |
398 | NodeDef def; |
399 | def.set_name("_SOURCE" ); |
400 | def.set_op("NoOp" ); |
401 | Status status; |
402 | Node* source = AddNode(def, &status); |
403 | TF_CHECK_OK(status); |
404 | CHECK_EQ(source->id(), kSourceId); |
405 | |
406 | def.set_name("_SINK" ); |
407 | Node* sink = AddNode(def, &status); |
408 | TF_CHECK_OK(status); |
409 | CHECK_EQ(sink->id(), kSinkId); |
410 | |
411 | AddControlEdge(source, sink); |
412 | } |
413 | |
414 | Graph::Graph(const FunctionLibraryDefinition& flib_def) |
415 | : Graph(flib_def.default_registry()) { |
416 | // Need a new-enough consumer to support the functions we add to the graph. |
417 | if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) { |
418 | versions_->set_min_consumer(12); |
419 | } |
420 | Status s = ops_.AddLibrary(flib_def); |
421 | CHECK(s.ok()) << s.error_message(); |
422 | } |
423 | |
424 | Graph::~Graph() { |
425 | // Manually call the destructors for all the Nodes we constructed using |
426 | // placement new. |
427 | for (Node* node : nodes_) { |
428 | if (node != nullptr) { |
429 | node->~Node(); |
430 | } |
431 | } |
432 | for (Node* node : free_nodes_) { |
433 | node->~Node(); |
434 | } |
435 | // Edges have no destructor, and we arena-allocated them, so no need to |
436 | // destroy them. |
437 | } |
438 | |
439 | std::unique_ptr<Graph> Graph::Clone() { |
440 | std::unique_ptr<Graph> new_graph(new Graph(flib_def())); |
441 | new_graph->Copy(*this); |
442 | return new_graph; |
443 | } |
444 | |
445 | void Graph::Clear() { |
446 | // Do a direct iteration clearing nodes removing the RemoveNode helper method. |
447 | // This could avoid this helper and clear directly if it becomes performance |
448 | // sensitive. |
449 | for (Node* n : nodes()) { |
450 | if (!n->IsSource() && !n->IsSink()) RemoveNode(n); |
451 | } |
452 | } |
453 | |
454 | const VersionDef& Graph::versions() const { return *versions_; } |
455 | void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; } |
456 | |
457 | void Graph::Copy(const Graph& src) { |
458 | SetConstructionContext(src.GetConstructionContextInternal()); |
459 | for (Node* n : nodes()) { |
460 | CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty" ; |
461 | } |
462 | |
463 | // Copy GraphDef versions |
464 | set_versions(src.versions()); |
465 | |
466 | // Copy the nodes. |
467 | // "Node in src" -> "Node in *dest" |
468 | gtl::FlatMap<const Node*, Node*> node_map; |
469 | node_map.reserve(src.num_nodes()); |
470 | node_map[src.source_node()] = source_node(); |
471 | node_map[src.sink_node()] = sink_node(); |
472 | for (Node* n : src.op_nodes()) { |
473 | auto copy = CopyNode(n); |
474 | copy->in_edges_.reserve(n->in_edges().size()); |
475 | copy->out_edges_.reserve(n->out_edges().size()); |
476 | node_map[n] = copy; |
477 | } |
478 | |
479 | // Copy the edges |
480 | edges_.reserve(src.num_edges()); |
481 | for (const Edge* e : src.edges()) { |
482 | Node* src_copy = node_map[e->src()]; |
483 | Node* dst_copy = node_map[e->dst()]; |
484 | AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); |
485 | } |
486 | } |
487 | |
488 | StatusOr<Node*> Graph::AddNode(NodeDef node_def) { |
489 | Status s; |
490 | Node* out = AddNode(std::move(node_def), &s); |
491 | TF_RETURN_IF_ERROR(s); |
492 | return out; |
493 | } |
494 | |
495 | Node* Graph::AddNode(NodeDef node_def, Status* status) { |
496 | const OpRegistrationData* op_reg_data; |
497 | status->Update(ops_.LookUp(node_def.op(), &op_reg_data)); |
498 | if (!status->ok()) return nullptr; |
499 | |
500 | DataTypeVector inputs; |
501 | DataTypeVector outputs; |
502 | status->Update( |
503 | InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs)); |
504 | if (!status->ok()) { |
505 | *status = AttachDef(*status, node_def); |
506 | return nullptr; |
507 | } |
508 | |
509 | Node::NodeClass node_class = op_reg_data->is_function_op |
510 | ? Node::NC_FUNCTION_OP |
511 | : Node::GetNodeClassForOp(node_def.op()); |
512 | |
513 | if (node_def.has_experimental_type()) { |
514 | VLOG(3) << "AddNode: node has type set, skipping type constructor " |
515 | << node_def.name(); |
516 | } else { |
517 | if (op_reg_data->type_ctor != nullptr) { |
518 | VLOG(3) << "AddNode: found type constructor for " << node_def.name(); |
519 | Status s = |
520 | full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def, |
521 | *(node_def.mutable_experimental_type())); |
522 | if (!s.ok()) { |
523 | *status = errors::InvalidArgument("type error: " , s.ToString()); |
524 | VLOG(3) << "AddNode: type inference failed for " << node_def.name() |
525 | << ": " << s; |
526 | return nullptr; |
527 | } |
528 | } else { |
529 | VLOG(3) << "AddNode: no type constructor for " << node_def.name(); |
530 | } |
531 | } |
532 | |
533 | Node* node = AllocateNode( |
534 | std::make_shared<NodeProperties>(&op_reg_data->op_def, |
535 | std::move(node_def), inputs, outputs), |
536 | nullptr, node_class); |
537 | return node; |
538 | } |
539 | |
540 | Node* Graph::CopyNode(const Node* node) { |
541 | DCHECK(!node->IsSource()); |
542 | DCHECK(!node->IsSink()); |
543 | Node* copy = AllocateNode(node->props_, node, node->class_); |
544 | copy->set_assigned_device_name(node->assigned_device_name()); |
545 | |
546 | // Since the OpDef of a function may be owned by the Graph that owns 'node', |
547 | // relookup the OpDef in the target graph. If it differs, then clone the |
548 | // node properties with the updated OpDef. |
549 | const OpDef* op_def; |
550 | TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def)); |
551 | if (op_def != node->props_->op_def) { |
552 | copy->MaybeCopyOnWrite(); |
553 | copy->props_->op_def = op_def; |
554 | } |
555 | copy->SetStackTrace(node->GetStackTrace()); |
556 | |
557 | return copy; |
558 | } |
559 | |
560 | void Graph::RemoveNode(Node* node) { |
561 | TF_DCHECK_OK(IsValidNode(node)) << node->DebugString(); |
562 | DCHECK(!node->IsSource()); |
563 | DCHECK(!node->IsSink()); |
564 | |
565 | // Remove any edges involving this node. |
566 | for (const Edge* e : node->in_edges_) { |
567 | CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1}); |
568 | edges_[e->id_] = nullptr; |
569 | RecycleEdge(e); |
570 | --num_edges_; |
571 | } |
572 | node->in_edges_.clear(); |
573 | for (const Edge* e : node->out_edges_) { |
574 | CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1}); |
575 | edges_[e->id_] = nullptr; |
576 | RecycleEdge(e); |
577 | --num_edges_; |
578 | } |
579 | node->out_edges_.clear(); |
580 | ReleaseNode(node); |
581 | } |
582 | |
583 | const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { |
584 | TF_DCHECK_OK(IsValidNode(source)) << source->DebugString(); |
585 | TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString(); |
586 | |
587 | // source/sink must only be linked via control slots, and |
588 | // control slots must only be linked to control slots. |
589 | if (source == source_node() || dest == sink_node() || x == kControlSlot || |
590 | y == kControlSlot) { |
591 | DCHECK_EQ(x, kControlSlot) << source->DebugString(); |
592 | DCHECK_EQ(y, kControlSlot) << dest->DebugString(); |
593 | } |
594 | |
595 | Edge* e = nullptr; |
596 | if (free_edges_.empty()) { |
597 | e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new |
598 | } else { |
599 | e = free_edges_.back(); |
600 | free_edges_.pop_back(); |
601 | } |
602 | e->id_ = edges_.size(); |
603 | e->src_ = source; |
604 | e->dst_ = dest; |
605 | e->src_output_ = x; |
606 | e->dst_input_ = y; |
607 | CHECK(source->out_edges_.insert(e).second); |
608 | CHECK(dest->in_edges_.insert(e).second); |
609 | edges_.push_back(e); |
610 | ++num_edges_; |
611 | |
612 | return e; |
613 | } |
614 | |
615 | void Graph::RemoveEdge(const Edge* e) { |
616 | TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString(); |
617 | TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString(); |
618 | CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1}); |
619 | CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1}); |
620 | CHECK_EQ(e, edges_[e->id_]); |
621 | CHECK_GT(num_edges_, 0); |
622 | |
623 | edges_[e->id_] = nullptr; |
624 | RecycleEdge(e); |
625 | --num_edges_; |
626 | } |
627 | |
628 | void Graph::RecycleEdge(const Edge* e) { |
629 | free_edges_.push_back(const_cast<Edge*>(e)); |
630 | } |
631 | |
632 | const Edge* Graph::AddControlEdge(Node* source, Node* dest, |
633 | bool allow_duplicates) { |
634 | if (!allow_duplicates) { |
635 | for (const Edge* edge : dest->in_edges()) { |
636 | if (edge->IsControlEdge() && edge->src() == source) { |
637 | // The requested edge already exists. |
638 | return nullptr; |
639 | } |
640 | } |
641 | } |
642 | // Modify dest's NodeDef if necessary. |
643 | if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) { |
644 | // Check if this input is already in dest's NodeDef. |
645 | const std::string new_input = strings::StrCat("^" , source->name()); |
646 | bool input_exists = false; |
647 | for (const std::string& input : dest->props_->node_def.input()) { |
648 | if (input == new_input) { |
649 | input_exists = true; |
650 | break; |
651 | } |
652 | } |
653 | if (!input_exists) { |
654 | dest->MaybeCopyOnWrite(); |
655 | dest->props_->node_def.add_input(new_input); |
656 | } |
657 | } |
658 | return AddEdge(source, kControlSlot, dest, kControlSlot); |
659 | } |
660 | |
661 | void Graph::RemoveControlEdge(const Edge* e) { |
662 | if (!e->src_->IsSource() && !e->dst_->IsSink()) { |
663 | e->dst_->MaybeCopyOnWrite(); |
664 | std::string e_src_name = strings::StrCat("^" , e->src_->name()); |
665 | auto* inputs = e->dst_->props_->node_def.mutable_input(); |
666 | for (auto it = inputs->begin(); it != inputs->end(); ++it) { |
667 | if (*it == e_src_name) { |
668 | inputs->erase(it); |
669 | break; |
670 | } |
671 | } |
672 | } |
673 | RemoveEdge(e); |
674 | } |
675 | |
676 | namespace { |
677 | const Edge* FindEdge(const Node* dst, int index) { |
678 | for (const Edge* e : dst->in_edges()) { |
679 | if (e->dst_input() == index) return e; |
680 | } |
681 | return nullptr; |
682 | } |
683 | } // namespace |
684 | |
685 | Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, |
686 | int dst_index) { |
687 | TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); |
688 | TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index)); |
689 | const Edge* e = FindEdge(dst, dst_index); |
690 | if (e == nullptr) { |
691 | return errors::InvalidArgument("Couldn't find edge to " , |
692 | FormatNodeForError(*dst)); |
693 | } |
694 | RemoveEdge(e); |
695 | AddEdge(new_src, new_src_index, dst, dst_index); |
696 | dst->MaybeCopyOnWrite(); |
697 | (*dst->props_->node_def.mutable_input())[dst_index] = |
698 | strings::StrCat(new_src->name(), ":" , new_src_index); |
699 | return OkStatus(); |
700 | } |
701 | |
702 | Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) { |
703 | if (!dst->IsWhileNode()) { |
704 | return errors::Internal( |
705 | "dst argument to AddWhileEdgeHack should be a While op, got: " , |
706 | dst->DebugString()); |
707 | } |
708 | TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); |
709 | // Find the current number of data inputs. We'll add the new edge to the next |
710 | // missing data input. |
711 | int dst_index = 0; |
712 | for (const Edge* edge : dst->in_edges()) { |
713 | if (edge->IsControlEdge()) continue; |
714 | ++dst_index; |
715 | } |
716 | TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index)); |
717 | AddEdge(new_src, new_src_index, dst, dst_index); |
718 | dst->MaybeCopyOnWrite(); |
719 | dst->props_->node_def.add_input( |
720 | strings::StrCat(new_src->name(), ":" , new_src_index)); |
721 | return OkStatus(); |
722 | } |
723 | |
724 | Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { |
725 | // Need a new-enough consumer to support the functions we add to the graph. |
726 | if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) { |
727 | versions_->set_min_consumer(12); |
728 | } |
729 | return ops_.AddLibrary(fdef_lib); |
730 | } |
731 | |
732 | namespace { |
733 | |
734 | void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { |
735 | if (src_slot == Graph::kControlSlot) { |
736 | dst->add_input(strings::StrCat("^" , src_name)); |
737 | } else if (src_slot == 0) { |
738 | dst->add_input(src_name.data(), src_name.size()); |
739 | } else { |
740 | dst->add_input(strings::StrCat(src_name, ":" , src_slot)); |
741 | } |
742 | } |
743 | |
744 | } // namespace |
745 | |
746 | void Graph::ToGraphDef(GraphDef* graph_def) const { |
747 | ToGraphDefSubRange(graph_def, 0); |
748 | } |
749 | |
750 | GraphDef Graph::ToGraphDefDebug() const { |
751 | GraphDef ret; |
752 | ToGraphDef(&ret); |
753 | return ret; |
754 | } |
755 | |
756 | void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const { |
757 | graph_def->Clear(); |
758 | *graph_def->mutable_versions() = versions(); |
759 | *graph_def->mutable_library() = ops_.ToProto(); |
760 | |
761 | graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id)); |
762 | |
763 | std::vector<const Edge*> |
764 | inputs; // Construct this outside the loop for speed. |
765 | for (auto id = from_node_id; id < num_node_ids(); ++id) { |
766 | const Node* node = FindNodeId(id); |
767 | if (node == nullptr || !node->IsOp()) continue; |
768 | NodeDef* node_def = graph_def->add_node(); |
769 | *node_def = node->def(); |
770 | |
771 | // Use the node's assigned device, if any, instead of the device requested |
772 | // in the NodeDef. |
773 | if (!node->assigned_device_name().empty()) { |
774 | node_def->set_device(node->assigned_device_name()); |
775 | } |
776 | |
777 | // Get the inputs for this Node. We make sure control inputs are |
778 | // after data inputs, as required by GraphDef. |
779 | inputs.clear(); |
780 | inputs.resize(node->num_inputs(), nullptr); |
781 | for (const Edge* edge : node->in_edges()) { |
782 | if (edge->IsControlEdge()) { |
783 | inputs.push_back(edge); |
784 | } else { |
785 | DCHECK(edge->dst_input() < inputs.size()) |
786 | << "Edge " << edge->DebugString() |
787 | << " is overflowing the expected number of inputs (" |
788 | << node->num_inputs() << ") for node " << node->DebugString(); |
789 | CHECK(inputs[edge->dst_input()] == nullptr) |
790 | << "Edge " << edge->src()->name() << "->" << edge->dst()->name() |
791 | << " conflicts with pre-existing input edge " |
792 | << inputs[edge->dst_input()]->src()->name() << "->" |
793 | << inputs[edge->dst_input()]->dst()->name(); |
794 | |
795 | inputs[edge->dst_input()] = edge; |
796 | } |
797 | } |
798 | // Sort the control inputs for more predictable serialization. |
799 | std::sort(inputs.begin() + node->num_inputs(), inputs.end(), |
800 | [](const Edge* a, const Edge* b) -> bool { |
801 | return a->src()->name() < b->src()->name(); |
802 | }); |
803 | node_def->clear_input(); |
804 | node_def->mutable_input()->Reserve(inputs.size()); |
805 | |
806 | for (size_t i = 0; i < inputs.size(); ++i) { |
807 | const Edge* edge = inputs[i]; |
808 | if (edge == nullptr) { |
809 | if (i < node->requested_inputs().size()) { |
810 | node_def->add_input(node->requested_inputs()[i]); |
811 | } else { |
812 | node_def->add_input("" ); |
813 | } |
814 | } else { |
815 | const Node* src = edge->src(); |
816 | if (!src->IsOp()) continue; |
817 | AddInput(node_def, src->name(), edge->src_output()); |
818 | } |
819 | } |
820 | } |
821 | } |
822 | |
823 | std::string Graph::NewName(StringPiece prefix) { |
824 | return strings::StrCat(prefix, "/_" , name_counter_++); |
825 | } |
826 | |
827 | Status Graph::IsValidNode(const Node* node) const { |
828 | if (node == nullptr) { |
829 | return errors::InvalidArgument("Node is null" ); |
830 | } |
831 | const int id = node->id(); |
832 | if (id < 0) { |
833 | return errors::InvalidArgument("node id " , id, " is less than zero" ); |
834 | } |
835 | if (static_cast<size_t>(id) >= nodes_.size()) { |
836 | return errors::InvalidArgument( |
837 | "node id " , id, " is >= than number of nodes in graph " , nodes_.size()); |
838 | } |
839 | if (nodes_[id] != node) { |
840 | return errors::InvalidArgument("Node with id " , id, |
841 | " is different from the passed in node. " |
842 | "Does it belong to a different graph?" ); |
843 | } |
844 | return OkStatus(); |
845 | } |
846 | |
847 | Status Graph::IsValidOutputTensor(const Node* node, int idx) const { |
848 | TF_RETURN_IF_ERROR(IsValidNode(node)); |
849 | if (idx >= node->num_outputs() || idx < 0) { |
850 | return errors::OutOfRange("Node '" , node->name(), "' (type: '" , |
851 | node->op_def().name(), |
852 | "', num of outputs: " , node->num_outputs(), |
853 | ") does not have " , "output " , idx); |
854 | } |
855 | return OkStatus(); |
856 | } |
857 | |
858 | Status Graph::IsValidInputTensor(const Node* node, int idx) const { |
859 | TF_RETURN_IF_ERROR(IsValidNode(node)); |
860 | if (idx >= node->num_inputs() || idx < 0) { |
861 | return errors::OutOfRange("Node '" , node->name(), "' (type: '" , |
862 | node->op_def().name(), |
863 | "', num of inputs: " , node->num_inputs(), |
864 | ") does not have " , "input " , idx); |
865 | } |
866 | return OkStatus(); |
867 | } |
868 | |
869 | Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props, |
870 | const Node* cost_node, Node::NodeClass node_class) { |
871 | Node* node = nullptr; |
872 | if (free_nodes_.empty()) { |
873 | node = new (arena_.Alloc(sizeof(Node))) Node; // placement new |
874 | } else { |
875 | node = free_nodes_.back(); |
876 | free_nodes_.pop_back(); |
877 | } |
878 | node->graph_ = this; |
879 | const int id = nodes_.size(); |
880 | int cost_id = cost_node ? cost_node->cost_id() : id; |
881 | node->Initialize(id, cost_id, std::move(props), node_class); |
882 | nodes_.push_back(node); |
883 | ++num_nodes_; |
884 | return node; |
885 | } |
886 | |
887 | void Graph::ReleaseNode(Node* node) { |
888 | TF_DCHECK_OK(IsValidNode(node)) << node->DebugString(); |
889 | nodes_[node->id()] = nullptr; |
890 | free_nodes_.push_back(node); |
891 | --num_nodes_; |
892 | node->Clear(); |
893 | } |
894 | |
895 | // Ensures that 'device_name' is present in the device name table, and returns |
896 | // the index of that device name. The index is stable, and can be used in |
897 | // calls to Node::set_assigned_device_name_index(). |
898 | int Graph::InternDeviceName(const std::string& device_name) { |
899 | // Special case, very common. Also, this allows us to use a single map |
900 | // lookup below, instead of two. The 'if (index_cell > 0)' test below |
901 | // relies on this check. |
902 | if (device_name.empty()) { |
903 | return 0; |
904 | } |
905 | |
906 | int& index_cell = device_names_map_[device_name]; |
907 | if (index_cell > 0) { |
908 | return index_cell; |
909 | } |
910 | |
911 | const int index = device_names_map_.size(); |
912 | index_cell = index; |
913 | device_names_.push_back(device_name); |
914 | return index; |
915 | } |
916 | |
917 | Status Graph::AddWhileContext(StringPiece frame_name, |
918 | std::vector<Node*> enter_nodes, |
919 | std::vector<Node*> exit_nodes, |
920 | OutputTensor cond_output, |
921 | std::vector<OutputTensor> body_inputs, |
922 | std::vector<OutputTensor> body_outputs, |
923 | WhileContext** result) { |
924 | auto pair = while_ctxs_.insert(std::pair<std::string, WhileContext>( |
925 | std::string(frame_name), |
926 | WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes), |
927 | cond_output, std::move(body_inputs), |
928 | std::move(body_outputs)))); |
929 | if (!pair.second) { |
930 | *result = nullptr; |
931 | return errors::InvalidArgument("WhileContext with frame name '" , frame_name, |
932 | "' already exists" ); |
933 | } |
934 | *result = &pair.first->second; |
935 | return OkStatus(); |
936 | } |
937 | |
938 | std::unordered_map<std::string, Node*> Graph::BuildNodeNameIndex() const { |
939 | std::unordered_map<std::string, Node*> result; |
940 | for (Node* n : nodes()) { |
941 | result[n->name()] = n; |
942 | } |
943 | return result; |
944 | } |
945 | |
946 | std::string Edge::DebugString() const { |
947 | return strings::Printf("[id=%d %s:%d -> %s:%d]" , id_, src_->name().c_str(), |
948 | src_output_, dst_->name().c_str(), dst_input_); |
949 | } |
950 | |
951 | } // namespace tensorflow |
952 | |