1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // A Graph describes a set of computations that are to be |
17 | // performed, as well as the dependencies between those |
18 | // computations. The basic model is a DAG (directed acyclic graph) with |
19 | // * internal nodes representing computational operations to be performed; |
20 | // * edges represent dependencies, indicating the target may only be |
21 | // executed once the source has completed; and |
22 | // * predefined "source" (start) and "sink" (finish) nodes -- the source |
23 | // should be the only node that doesn't depend on anything, and the sink |
24 | // should be the only node that nothing depends on. |
25 | // |
26 | // Note: Node ids are intended to be relatively dense in the |
27 | // 0..max_id range, but there may be gaps since ids won't be reused. |
28 | // |
29 | // Note: Some dependencies between operations are due to one operation |
30 | // consuming the output of another. In fact operations can produce |
31 | // multiple outputs and consume multiple inputs, and some |
32 | // optimizations will care about which specific outputs are connected |
33 | // to which specific inputs. We therefore represent data dependency |
34 | // between output O of layer A and input I of layer B using |
35 | // "input index" and "output index" labels per edge. |
36 | |
37 | #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_ |
38 | #define TENSORFLOW_CORE_GRAPH_GRAPH_H_ |
39 | |
40 | #include <functional> |
41 | #include <memory> |
42 | #include <string> |
43 | #include <vector> |
44 | |
45 | #include "absl/types/optional.h" |
46 | #include "tensorflow/core/framework/full_type.pb.h" |
47 | #include "tensorflow/core/framework/function.h" |
48 | #include "tensorflow/core/framework/node_def.pb.h" |
49 | #include "tensorflow/core/framework/node_def_util.h" |
50 | #include "tensorflow/core/framework/op.h" |
51 | #include "tensorflow/core/framework/types.h" |
52 | #include "tensorflow/core/graph/edgeset.h" |
53 | #include "tensorflow/core/lib/core/arena.h" |
54 | #include "tensorflow/core/lib/core/refcount.h" |
55 | #include "tensorflow/core/lib/core/status.h" |
56 | #include "tensorflow/core/lib/gtl/iterator_range.h" |
57 | #include "tensorflow/core/platform/logging.h" |
58 | #include "tensorflow/core/platform/macros.h" |
59 | #include "tensorflow/core/platform/stringpiece.h" |
60 | #include "tensorflow/core/platform/types.h" |
61 | |
62 | namespace tensorflow { |
63 | |
64 | class Edge; |
65 | class EdgeSetTest; |
66 | class Graph; |
67 | class GraphDef; |
68 | class Node; |
69 | struct OutputTensor; |
70 | class VersionDef; |
71 | class WhileContext; |
72 | |
73 | class NeighborIter; // Declared below |
74 | class NodeIter; // Declared below |
75 | |
76 | // Indicates where the graph instance is originated from. |
77 | enum class ConstructionContext { |
78 | kNotTracked, // Not tracked. |
79 | kDirectSession, // From `tensorflow::DirectSession`, TF1 session API. |
80 | kEagerRuntime, // Registered from TF2 eager runtime. |
81 | }; |
82 | |
83 | class Node { |
84 | public: |
85 | std::string DebugString() const; |
86 | int id() const { return id_; } |
87 | int cost_id() const { return cost_id_; } |
88 | const std::string& name() const; |
89 | void set_name(std::string name); |
90 | const std::string& type_string() const; |
91 | |
92 | // def() provides the NodeDef the user supplied, but the specifics |
93 | // of this Node may have changed due to placement, optimization, etc. |
94 | // In particular: |
95 | // * def().name() will match name(); |
96 | // * def().op() will match type_string() and op_def().name(); |
97 | // * def().input() is not reliable, use "in_edges()" below instead; |
98 | // * def().device() is the "user's requested device" and may not match |
99 | // the actual assigned device, see assigned_device_name() below; |
100 | // * def().attr() is authoritative. |
101 | // TODO(irving): Replace with NodeInfo. |
102 | const NodeDef& def() const; |
103 | const OpDef& op_def() const; |
104 | |
105 | // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove? |
106 | NodeDef* mutable_def(); |
107 | |
108 | // input and output types |
109 | int32 num_inputs() const; |
110 | DataType input_type(int32_t i) const; |
111 | const DataTypeVector& input_types() const; |
112 | |
113 | int32 num_outputs() const; |
114 | DataType output_type(int32_t o) const; |
115 | const DataTypeVector& output_types() const; |
116 | |
117 | // The device requested by the user. For the actual assigned device, |
118 | // use assigned_device_name() below. |
119 | const std::string& requested_device() const; |
120 | |
121 | // This changes the user requested device but not necessarily the device that |
122 | // on which the operation will run. |
123 | void set_requested_device(const std::string& device); |
124 | |
125 | // This gives the device the runtime has assigned this node to. If |
126 | // you want the device the user requested, use def().device() instead. |
127 | // TODO(josh11b): Validate that the assigned_device, if not empty: |
128 | // fully specifies a device, and satisfies def().device(). |
129 | // TODO(josh11b): Move assigned_device_name outside of Node into a |
130 | // NodeId->DeviceName map. |
131 | const std::string& assigned_device_name() const; |
132 | void set_assigned_device_name(const std::string& device_name); |
133 | bool has_assigned_device_name() const { |
134 | return assigned_device_name_index_ > 0; |
135 | } |
136 | int assigned_device_name_index() const { return assigned_device_name_index_; } |
137 | void set_assigned_device_name_index(int index); |
138 | |
139 | // Sets 'original_node_names' field of this node's DebugInfo proto to |
140 | // 'names'. |
141 | void set_original_node_names(const std::vector<string>& names); |
142 | void set_original_func_names(const std::vector<string>& names); |
143 | |
144 | // Read only access to attributes |
145 | AttrSlice attrs() const; |
146 | |
147 | // Inputs requested by the NodeDef. For the actual inputs, use in_edges. |
148 | const protobuf::RepeatedPtrField<string>& requested_inputs() const; |
149 | |
150 | // Get the neighboring nodes via edges either in or out of this node. This |
151 | // includes control edges. |
152 | gtl::iterator_range<NeighborIter> in_nodes() const; |
153 | gtl::iterator_range<NeighborIter> out_nodes() const; |
154 | const EdgeSet& in_edges() const { return in_edges_; } |
155 | const EdgeSet& out_edges() const { return out_edges_; } |
156 | |
157 | // Node type helpers. |
158 | bool IsSource() const { return id() == 0; } |
159 | bool IsSink() const { return id() == 1; } |
160 | // Anything other than the special Source & Sink nodes. |
161 | bool IsOp() const { return id() > 1; } |
162 | |
163 | // Node class helpers |
164 | bool IsSwitch() const { return class_ == NC_SWITCH; } |
165 | bool IsMerge() const { return class_ == NC_MERGE; } |
166 | bool IsEnter() const { return class_ == NC_ENTER; } |
167 | bool IsExit() const { return class_ == NC_EXIT; } |
168 | bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; } |
169 | bool IsLoopCond() const { return class_ == NC_LOOP_COND; } |
170 | bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; } |
171 | bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; } |
172 | bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; } |
173 | bool IsConstant() const { return class_ == NC_CONSTANT; } |
174 | bool IsVariable() const { return class_ == NC_VARIABLE; } |
175 | bool IsIdentity() const { return class_ == NC_IDENTITY; } |
176 | bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; } |
177 | bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; } |
178 | bool IsDeleteSessionTensor() const { |
179 | return class_ == NC_DELETE_SESSION_TENSOR; |
180 | } |
181 | bool IsControlFlow() const { |
182 | return (class_ != NC_OTHER) && // Fast path |
183 | (IsSwitch() || IsMerge() || IsEnter() || IsExit() || |
184 | IsNextIteration()); |
185 | } |
186 | bool IsHostSend() const { return class_ == NC_HOST_SEND; } |
187 | bool IsHostRecv() const { return class_ == NC_HOST_RECV; } |
188 | bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; } |
189 | bool IsCollective() const { return class_ == NC_COLLECTIVE; } |
190 | |
191 | bool IsMetadata() const { return class_ == NC_METADATA; } |
192 | bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; } |
193 | bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; } |
194 | |
195 | // Returns true if this node is any kind of function call node. |
196 | // |
197 | // NOTE: "function call nodes" include partitioned call ops, symbolic gradient |
198 | // ops, and ops whose type_string is the name of a function ("function ops"). |
199 | bool IsFunctionCall() const { |
200 | return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP || |
201 | class_ == NC_SYMBOLIC_GRADIENT; |
202 | } |
203 | |
204 | bool IsIfNode() const { return class_ == NC_IF; } |
205 | bool IsWhileNode() const { return class_ == NC_WHILE; } |
206 | bool IsCaseNode() const { return class_ == NC_CASE; } |
207 | // Is this node a function input |
208 | bool IsArg() const { return class_ == NC_ARG; } |
209 | // Is this node a function output |
210 | bool IsRetval() const { return class_ == NC_RETVAL; } |
211 | |
212 | bool IsDistributedCommunication() const { |
213 | return op_def().is_distributed_communication(); |
214 | } |
215 | |
216 | template <typename T> |
217 | void AddAttr(const std::string& name, const T& val) { |
218 | SetAttrValue(val, AddAttrHelper(name)); |
219 | UpdateProperties(); |
220 | } |
221 | |
222 | void AddAttr(const std::string& name, std::vector<string>&& val) { |
223 | MoveAttrValue(std::move(val), AddAttrHelper(name)); |
224 | UpdateProperties(); |
225 | } |
226 | |
227 | void ClearAttr(const std::string& name); |
228 | |
229 | // Returns into '*e' the edge connecting to the 'idx' input of this Node. |
230 | Status input_edge(int idx, const Edge** e) const; |
231 | |
232 | // Returns into '*edges' the input data edges of this Node, indexed by input |
233 | // number. Does not return control edges. |
234 | Status input_edges(std::vector<const Edge*>* edges) const; |
235 | |
236 | // Returns into '*n' the node that has an output connected to the |
237 | // 'idx' input of this Node. |
238 | Status input_node(int idx, const Node** n) const; |
239 | Status input_node(int idx, Node** n) const; |
240 | |
241 | // Returns into '*t' the idx-th input tensor of this node, represented as the |
242 | // output tensor of input_node(idx). |
243 | Status input_tensor(int idx, OutputTensor* t) const; |
244 | |
245 | WhileContext* while_ctx() const { return while_ctx_; } |
246 | void set_while_ctx(WhileContext* while_ctx) { |
247 | DCHECK(IsExit()); |
248 | DCHECK(while_ctx_ == nullptr); |
249 | while_ctx_ = while_ctx; |
250 | } |
251 | |
252 | std::shared_ptr<NodeProperties> properties() const { return props_; } |
253 | |
254 | // Sets the stack trace for the node. Assumes that getting and setting the |
255 | // stack trace for a given node will not race. |
256 | void SetStackTrace(const std::shared_ptr<AbstractStackTrace>& stack_trace) { |
257 | stack_trace_ = stack_trace; |
258 | } |
259 | |
260 | // Get the stack trace for when the node was instantiated. |
261 | const std::shared_ptr<AbstractStackTrace>& GetStackTrace() const { |
262 | return stack_trace_; |
263 | } |
264 | |
265 | // Called after an attr has changed. Decides whether we need to update some |
266 | // property of the node (stored in props_). |
267 | void UpdateProperties(); |
268 | |
269 | // Erases type information from the node. |
270 | void ClearTypeInfo(); |
271 | |
272 | // Called after an incident non-control edge has changed. Does nothing if not |
273 | // all input edges are defined. |
274 | void RunForwardTypeInference(); |
275 | |
276 | private: |
277 | // TODO(mdan): Drop this. |
278 | friend class Graph; |
279 | Node(); |
280 | |
281 | // Stack trace for the user code for node instantiation. Can be shared across |
282 | // multiple nodes (e.g. when inlining). |
283 | std::shared_ptr<AbstractStackTrace> stack_trace_; |
284 | |
285 | // Releases memory from props_, in addition to restoring *this to its |
286 | // uninitialized state. |
287 | void Clear(); |
288 | |
289 | // Make a copy of the Node's props_ if props_ is shared with |
290 | // other nodes. This must be called before mutating properties, |
291 | // e.g. in AddAttr. |
292 | void MaybeCopyOnWrite(); |
293 | |
294 | AttrValue* AddAttrHelper(const std::string& name); |
295 | |
296 | // A set of mutually exclusive classes for different kinds of nodes, |
297 | // class_ is initialized in the Node::Initialize routine based on the |
298 | // node's type_string(). |
299 | enum NodeClass { |
300 | NC_UNINITIALIZED, |
301 | NC_SWITCH, |
302 | NC_MERGE, |
303 | NC_ENTER, |
304 | NC_EXIT, |
305 | NC_NEXT_ITERATION, |
306 | NC_LOOP_COND, |
307 | NC_CONTROL_TRIGGER, |
308 | NC_SEND, |
309 | NC_HOST_SEND, |
310 | NC_RECV, |
311 | NC_HOST_RECV, |
312 | NC_CONSTANT, |
313 | NC_VARIABLE, |
314 | NC_IDENTITY, |
315 | NC_GET_SESSION_HANDLE, |
316 | NC_GET_SESSION_TENSOR, |
317 | NC_DELETE_SESSION_TENSOR, |
318 | NC_METADATA, |
319 | NC_SCOPED_ALLOCATOR, |
320 | NC_COLLECTIVE, |
321 | NC_FAKE_PARAM, |
322 | NC_PARTITIONED_CALL, |
323 | NC_FUNCTION_OP, |
324 | NC_SYMBOLIC_GRADIENT, |
325 | NC_IF, |
326 | NC_WHILE, |
327 | NC_CASE, |
328 | NC_ARG, |
329 | NC_RETVAL, |
330 | NC_OTHER // Not a special kind of node |
331 | }; |
332 | |
333 | void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props, |
334 | NodeClass node_class); |
335 | |
336 | static NodeClass GetNodeClassForOp(const std::string& ts); |
337 | |
338 | int id_; // -1 until Initialize() is called |
339 | int cost_id_; // -1 if there is no corresponding cost accounting node |
340 | NodeClass class_; |
341 | |
342 | EdgeSet in_edges_; |
343 | EdgeSet out_edges_; |
344 | |
345 | // NOTE(skyewm): inheriting from core::RefCounted may have a slight |
346 | // performance benefit over using shared_ptr, at the cost of manual ref |
347 | // counting |
348 | std::shared_ptr<NodeProperties> props_; |
349 | |
350 | // Index within Graph::device_names_ of the name of device assigned |
351 | // to perform this computation. |
352 | int assigned_device_name_index_; |
353 | |
354 | // A back-pointer to the Graph that owns this node. Currently, this exists |
355 | // solely to allow Node::[set_]assigned_device_name() to work. However, if all |
356 | // callers of Node::[set_]assigned_device_name() are modified to use the |
357 | // equivalent methods defined directly on Graph, then we can remove this |
358 | // field and reclaim that memory. |
359 | Graph* graph_; |
360 | |
361 | // Set if this is an exit node of a while loop with an associated |
362 | // WhileContext. Otherwise null. (This is only set for exit nodes because |
363 | // they're the first nodes of a loop encountered while creating the gradient |
364 | // graph. Exit nodes that are part of while loop gradient graphs will not have |
365 | // this set.) |
366 | WhileContext* while_ctx_; |
367 | |
368 | TF_DISALLOW_COPY_AND_ASSIGN(Node); |
369 | }; |
370 | |
371 | // Stores debug information associated with the Node. |
372 | struct NodeDebugInfo { |
373 | const std::string name; |
374 | std::vector<string> original_node_names; |
375 | std::vector<string> original_func_names; |
376 | |
377 | NodeDebugInfo(const Node& n); |
378 | NodeDebugInfo(const NodeDef& ndef); |
379 | NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info, |
380 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info); |
381 | }; |
382 | |
383 | // Represents an input of a node, i.e., the `index`-th input to `node`. |
384 | struct InputTensor { |
385 | Node* node; |
386 | int index; |
387 | |
388 | InputTensor(Node* n, int i) : node(n), index(i) {} |
389 | InputTensor() : node(nullptr), index(0) {} |
390 | |
391 | // Returns true if this InputTensor is identical to 'other'. Nodes are |
392 | // compared using pointer equality. |
393 | bool operator==(const InputTensor& other) const; |
394 | |
395 | // A hash function for InputTensors. Nodes are hashed based on their pointer |
396 | // value. |
397 | struct Hash { |
398 | uint64 operator()(InputTensor const& s) const; |
399 | }; |
400 | }; |
401 | |
402 | // Represents an output of a node, i.e., the `index`-th output of `node`. Note |
403 | // that a single `OutputTensor` can correspond to multiple `Edge`s if the output |
404 | // is consumed by multiple destination nodes. |
405 | struct OutputTensor { |
406 | Node* node; |
407 | int index; |
408 | |
409 | OutputTensor(Node* n, int i) : node(n), index(i) {} |
410 | OutputTensor() : node(nullptr), index(0) {} |
411 | |
412 | // Returns true if this OutputTensor is identical to 'other'. Nodes are |
413 | // compared using pointer equality. |
414 | bool operator==(const OutputTensor& other) const; |
415 | |
416 | // A hash function for OutputTensors. Nodes are hashed based on their pointer |
417 | // value. |
418 | struct Hash { |
419 | uint64 operator()(OutputTensor const& s) const; |
420 | }; |
421 | }; |
422 | |
423 | class Edge { |
424 | public: |
425 | Node* src() const { return src_; } |
426 | Node* dst() const { return dst_; } |
427 | int id() const { return id_; } |
428 | |
429 | // Return the index of the source output that produces the data |
430 | // carried by this edge. The special value kControlSlot is used |
431 | // for control dependencies. |
432 | int src_output() const { return src_output_; } |
433 | |
434 | // Return the index of the destination input that consumes the data |
435 | // carried by this edge. The special value kControlSlot is used |
436 | // for control dependencies. |
437 | int dst_input() const { return dst_input_; } |
438 | |
439 | // Return true iff this is an edge that indicates a control-flow |
440 | // (as opposed to a data-flow) dependency. |
441 | bool IsControlEdge() const; |
442 | |
443 | std::string DebugString() const; |
444 | |
445 | private: |
446 | Edge() {} |
447 | |
448 | friend class EdgeSetTest; |
449 | friend class Graph; |
450 | Node* src_; |
451 | Node* dst_; |
452 | int id_; |
453 | int src_output_; |
454 | int dst_input_; |
455 | }; |
456 | |
457 | // Allows for iteration of the edges of a Graph, by iterating the underlying |
458 | // Graph.edges_ vector while skipping over null entries. |
459 | class GraphEdgesIterable { |
460 | private: |
461 | const std::vector<Edge*>& edges_; |
462 | |
463 | public: |
464 | explicit GraphEdgesIterable(const std::vector<Edge*>& edges) |
465 | : edges_(edges) {} |
466 | |
467 | typedef Edge* value_type; |
468 | |
469 | class const_iterator { |
470 | private: |
471 | // The underlying iterator. |
472 | std::vector<value_type>::const_iterator iter_; |
473 | |
474 | // The end of the underlying iterator. |
475 | std::vector<value_type>::const_iterator end_; |
476 | |
477 | // Advances iter_ until it reaches a non-null item, or reaches the end. |
478 | void apply_filter() { |
479 | while (iter_ != end_ && *iter_ == nullptr) { |
480 | ++iter_; |
481 | } |
482 | } |
483 | |
484 | public: |
485 | const_iterator(std::vector<value_type>::const_iterator iter, |
486 | std::vector<value_type>::const_iterator end) |
487 | : iter_(iter), end_(end) { |
488 | apply_filter(); |
489 | } |
490 | |
491 | bool operator==(const const_iterator& other) const { |
492 | return iter_ == other.iter_; |
493 | } |
494 | |
495 | bool operator!=(const const_iterator& other) const { |
496 | return iter_ != other.iter_; |
497 | } |
498 | |
499 | // This is the prefix increment operator (++x), which is the operator |
500 | // used by C++ range iteration (for (x : y) ...). We intentionally do not |
501 | // provide a postfix increment operator. |
502 | const_iterator& operator++() { |
503 | ++iter_; |
504 | apply_filter(); |
505 | return *this; |
506 | } |
507 | |
508 | value_type operator*() { return *iter_; } |
509 | }; |
510 | |
511 | const_iterator begin() { |
512 | return const_iterator(edges_.begin(), edges_.end()); |
513 | } |
514 | const_iterator end() { return const_iterator(edges_.end(), edges_.end()); } |
515 | }; |
516 | |
517 | // Thread compatible but not thread safe. |
518 | class Graph { |
519 | public: |
520 | // Constructs a graph with a single SOURCE (always id kSourceId) and a |
521 | // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. |
522 | // |
523 | // The graph can hold ops found in the registry. `ops`s lifetime must be at |
524 | // least that of the constructed graph's. |
525 | explicit Graph(const OpRegistryInterface* ops); |
526 | |
527 | // Constructs a graph with a single SOURCE (always id kSourceId) and a |
528 | // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. |
529 | // |
530 | // The graph can hold ops found in `flib_def`. Unlike the constructor taking |
531 | // an OpRegistryInterface, this constructor copies the function definitions in |
532 | // `flib_def` so its lifetime may be shorter than that of the graph's. The |
533 | // OpRegistryInterface backing `flib_def` must still have the lifetime of the |
534 | // graph though. |
535 | explicit Graph(const FunctionLibraryDefinition& flib_def); |
536 | |
537 | ~Graph(); |
538 | |
539 | // Clone the current graph into a new one. |
540 | std::unique_ptr<Graph> Clone(); |
541 | |
542 | static const int kControlSlot; |
543 | |
544 | // The GraphDef version range of this graph (see graph.proto). |
545 | const VersionDef& versions() const; |
546 | void set_versions(const VersionDef& versions); |
547 | |
548 | // Adds a new node to this graph, and returns it. Infers the Op and |
549 | // input/output types for the node. *this owns the returned instance. |
550 | // Returns nullptr and sets *status on error. |
551 | Node* AddNode(NodeDef node_def, Status* status); |
552 | |
553 | // Same as above, but using StatusOr. This method is always preferred. |
554 | StatusOr<Node*> AddNode(NodeDef node_def); |
555 | |
556 | // Copies *node, which may belong to another graph, to a new node, |
557 | // which is returned. Does not copy any edges. *this owns the |
558 | // returned instance. |
559 | Node* CopyNode(const Node* node); |
560 | |
561 | // Removes a node from this graph, including all edges from or to it. |
562 | // *node should not be accessed after calling this function. |
563 | // REQUIRES: node->IsOp() |
564 | void RemoveNode(Node* node); |
565 | |
566 | void Copy(const Graph& src); |
567 | |
568 | // Removes all nodes from this graph, including all edges from or to them. |
569 | // No Node* references to the Graph are valid post. |
570 | void Clear(); |
571 | |
572 | // Adds an edge that connects the xth output of `source` to the yth input of |
573 | // `dest` and returns it. Does not update dest's NodeDef. |
574 | const Edge* AddEdge(Node* source, int x, Node* dest, int y); |
575 | |
576 | // Adds a control edge (no data flows along this edge) that connects `source` |
577 | // to `dest`. If `dest`s NodeDef is missing the corresponding control input, |
578 | // adds the control input. |
579 | // |
580 | // If such a control edge already exists and `allow_duplicates` is false, no |
581 | // edge is added and the function returns nullptr. Otherwise the edge is |
582 | // unconditionally created and returned. The NodeDef is not updated if |
583 | // `allow_duplicates` is true. |
584 | // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by |
585 | // graph_partition.cc. Figure out if we can do away with it. |
586 | const Edge* AddControlEdge(Node* source, Node* dest, |
587 | bool allow_duplicates = false); |
588 | |
589 | // Removes edge from the graph. Does not update the destination node's |
590 | // NodeDef. |
591 | // REQUIRES: The edge must exist. |
592 | void RemoveEdge(const Edge* edge); |
593 | |
594 | // Removes control edge `edge` from the graph. Note that this also updates |
595 | // the corresponding NodeDef to reflect the change. |
596 | // REQUIRES: The control edge must exist. |
597 | void RemoveControlEdge(const Edge* e); |
598 | |
599 | // Updates the input to a node. The existing edge to `dst` is removed and an |
600 | // edge from `new_src` to `dst` is created. The NodeDef associated with `dst` |
601 | // is also updated. |
602 | Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index); |
603 | |
604 | // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a |
605 | // "While" op during gradient construction, see AddInputWhileHack in |
606 | // python_api.h for more details. |
607 | Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst); |
608 | |
609 | // Adds the function and gradient definitions in `fdef_lib` to this graph's op |
610 | // registry. Ignores duplicate functions, and returns a bad status if an |
611 | // imported function differs from an existing function or op with the same |
612 | // name. |
613 | Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib); |
614 | |
615 | // The number of live nodes in the graph. |
616 | // |
617 | // Because nodes can be removed from the graph, num_nodes() is often |
618 | // smaller than num_node_ids(). If one needs to create an array of |
619 | // nodes indexed by node ids, num_node_ids() should be used as the |
620 | // array's size. |
621 | int num_nodes() const { return num_nodes_; } |
622 | |
623 | // The number of live nodes in the graph, excluding the Source and Sink nodes. |
624 | int num_op_nodes() const { |
625 | DCHECK_GE(num_nodes_, 2); |
626 | return num_nodes_ - 2; |
627 | } |
628 | |
629 | // The number of live edges in the graph. |
630 | // |
631 | // Because edges can be removed from the graph, num_edges() is often |
632 | // smaller than num_edge_ids(). If one needs to create an array of |
633 | // edges indexed by edge ids, num_edge_ids() should be used as the |
634 | // array's size. |
635 | int num_edges() const { return num_edges_; } |
636 | |
637 | // Serialize the nodes starting at `from_node_id` to a GraphDef. |
638 | void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const; |
639 | |
640 | // Serialize to a GraphDef. |
641 | void ToGraphDef(GraphDef* graph_def) const; |
642 | |
643 | // This version can be called from debugger to inspect the graph content. |
644 | // Use the previous version outside debug context for efficiency reasons. |
645 | // |
646 | // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is |
647 | // not defined in some TensorFlow builds. |
648 | GraphDef ToGraphDefDebug() const; |
649 | |
650 | // Generate new node name with the specified prefix that is unique |
651 | // across this graph. |
652 | std::string NewName(StringPiece prefix); |
653 | |
654 | // Access to the list of all nodes. Example usage: |
655 | // for (Node* node : graph.nodes()) { ... } |
656 | gtl::iterator_range<NodeIter> nodes() const; |
657 | |
658 | // Access to the list of all nodes, excluding the Source and Sink nodes. |
659 | gtl::iterator_range<NodeIter> op_nodes() const; |
660 | |
661 | // Returns one more than the maximum id assigned to any node. |
662 | int num_node_ids() const { return nodes_.size(); } |
663 | |
664 | // Returns the node associated with an id, or nullptr if no node |
665 | // with that id (the node with that id was removed and the id has |
666 | // not yet been re-used). *this owns the returned instance. |
667 | // REQUIRES: 0 <= id < num_node_ids(). |
668 | Node* FindNodeId(int id) const { return nodes_[id]; } |
669 | |
670 | // Returns one more than the maximum id assigned to any edge. |
671 | int num_edge_ids() const { return edges_.size(); } |
672 | |
673 | // Returns the Edge associated with an id, or nullptr if no edge |
674 | // with that id (the edge with that id was removed and the id has |
675 | // not yet been re-used). *this owns the returned instance. |
676 | // REQUIRES: 0 <= id < num_edge_ids(). |
677 | const Edge* FindEdgeId(int id) const { return edges_[id]; } |
678 | |
679 | // Access to the set of all edges. Example usage: |
680 | // for (const Edge* e : graph.edges()) { ... } |
681 | GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); } |
682 | |
683 | // The pre-defined nodes. |
684 | enum { kSourceId = 0, kSinkId = 1 }; |
685 | Node* source_node() const { return FindNodeId(kSourceId); } |
686 | Node* sink_node() const { return FindNodeId(kSinkId); } |
687 | |
688 | const OpRegistryInterface* op_registry() const { return &ops_; } |
689 | const FunctionLibraryDefinition& flib_def() const { return ops_; } |
690 | |
691 | // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove? |
692 | FunctionLibraryDefinition* mutable_flib_def() { return &ops_; } |
693 | |
694 | void CheckDeviceNameIndex(int index) { |
695 | DCHECK_GE(index, 0); |
696 | DCHECK_LT(index, static_cast<int>(device_names_.size())); |
697 | } |
698 | |
699 | int InternDeviceName(const std::string& device_name); |
700 | |
701 | const std::string& get_assigned_device_name(const Node& node) const { |
702 | return device_names_[node.assigned_device_name_index()]; |
703 | } |
704 | |
705 | void set_assigned_device_name_index(Node* node, int device_name_index) { |
706 | CheckDeviceNameIndex(device_name_index); |
707 | node->assigned_device_name_index_ = device_name_index; |
708 | } |
709 | |
710 | void set_assigned_device_name(Node* node, const std::string& device_name) { |
711 | node->assigned_device_name_index_ = InternDeviceName(device_name); |
712 | } |
713 | |
714 | // Returns OK if `node` is non-null and belongs to this graph |
715 | Status IsValidNode(const Node* node) const; |
716 | |
717 | // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not |
718 | // accept control outputs. |
719 | Status IsValidOutputTensor(const Node* node, int idx) const; |
720 | |
721 | // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept |
722 | // control inputs. |
723 | Status IsValidInputTensor(const Node* node, int idx) const; |
724 | |
725 | // Create and return a new WhileContext owned by this graph. This is called |
726 | // when a new while loop is created. `frame_name` must be unique among |
727 | // WhileContexts in this graph. |
728 | Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes, |
729 | std::vector<Node*> exit_nodes, |
730 | OutputTensor cond_output, |
731 | std::vector<OutputTensor> body_inputs, |
732 | std::vector<OutputTensor> body_outputs, |
733 | WhileContext** result); |
734 | |
735 | // Builds a node name to node pointer index for all nodes in the graph. |
736 | std::unordered_map<string, Node*> BuildNodeNameIndex() const; |
737 | |
738 | absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const { |
739 | return const_arg_indices_cache_; |
740 | } |
741 | |
742 | // TODO(kkb): Add to the constructor when it becomes managable. |
743 | // Sets the graph construction context. |
744 | void SetConstructionContext(ConstructionContext construction_context) { |
745 | construction_context_ = construction_context; |
746 | } |
747 | |
748 | // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable |
749 | // making this stable and make it available widely. |
750 | // Returns the graph construction context. It's `kUnknown` if not set. |
751 | ConstructionContext GetConstructionContextInternal() const { |
752 | return construction_context_; |
753 | } |
754 | |
755 | // TODO(josh11b): uint64 hash() const; |
756 | |
757 | private: |
758 | // If cost_node is non-null, then cost accounting (in CostModel) |
759 | // will be associated with that node rather than the new one being |
760 | // created. |
761 | // |
762 | // Ownership of the returned Node is not transferred to caller. |
763 | Node* AllocateNode(std::shared_ptr<NodeProperties> props, |
764 | const Node* cost_node, Node::NodeClass node_class); |
765 | void ReleaseNode(Node* node); |
766 | // Insert edge in free_edges_ for possible reuse. |
767 | void RecycleEdge(const Edge* edge); |
768 | // Registry of all known ops, including functions. |
769 | FunctionLibraryDefinition ops_; |
770 | |
771 | // GraphDef versions |
772 | const std::unique_ptr<VersionDef> versions_; |
773 | |
774 | // Allocator which will give us good locality. |
775 | core::Arena arena_; |
776 | |
777 | // Map from node ids to allocated nodes. nodes_[id] may be nullptr if |
778 | // the node with that id was removed from the graph. |
779 | std::vector<Node*> nodes_; |
780 | |
781 | // Number of nodes alive. |
782 | int64_t num_nodes_ = 0; |
783 | |
784 | // Map from edge ids to allocated edges. edges_[id] may be nullptr if |
785 | // the edge with that id was removed from the graph. |
786 | std::vector<Edge*> edges_; |
787 | |
788 | // The number of entries in edges_ that are not nullptr. |
789 | int num_edges_ = 0; |
790 | |
791 | // Allocated but free nodes and edges. |
792 | std::vector<Node*> free_nodes_; |
793 | std::vector<Edge*> free_edges_; |
794 | |
795 | // For generating unique names. |
796 | int name_counter_ = 0; |
797 | |
798 | // In most graphs, the number of unique values used for the |
799 | // Node::assigned_device_name() property is quite small. If the graph is |
800 | // large, then this duplication of values can consume a significant amount of |
801 | // memory. Instead, we represent the same information using an interning |
802 | // table, which consists of a vector of unique strings (device_names_), as |
803 | // well a map (device_names_map_) from unique strings to indices within the |
804 | // unique string table. |
805 | // |
806 | // The InternDeviceName() method handles adding a new entry into the table, |
807 | // or locating the index of an existing entry. |
808 | // |
809 | // The fact that Node::assigned_device_name() is implemented using an |
810 | // interning table is intentionally public. This allows algorithms that |
811 | // frequently access this field to do so efficiently, especially for the case |
812 | // where the assigned_device_name of one Node is copied directly from that |
813 | // of another Node. |
814 | |
815 | // A table of the unique assigned device names. Indices do NOT correspond |
816 | // to node IDs. Index 0 is always the empty string. |
817 | std::vector<string> device_names_; |
818 | |
819 | // Maps unique device names to indices within device_names_[i]. |
820 | std::unordered_map<string, int> device_names_map_; |
821 | |
822 | // All the while contexts owned by this graph, keyed by frame name, |
823 | // corresponding to all the while loops contained in this graph (including |
824 | // nested loops). The stored contexts are usually accessed via |
825 | // AddWhileContext() or Node::while_ctx(), but this manages the lifetime. |
826 | std::map<string, WhileContext> while_ctxs_; |
827 | |
828 | // Cache of the indices of the arguments which need to be constant for the XLA |
829 | // compilation. |
830 | mutable absl::optional<std::vector<bool>> const_arg_indices_cache_; |
831 | |
832 | // Indicates the context that this Graph instance is constructed. |
833 | ConstructionContext construction_context_ = ConstructionContext::kNotTracked; |
834 | |
835 | TF_DISALLOW_COPY_AND_ASSIGN(Graph); |
836 | }; |
837 | |
838 | // TODO(josh11b): We may want to support keeping an index on various |
839 | // node/edge attributes in a graph, particularly node names. |
840 | |
841 | // Helper routines |
842 | |
843 | inline bool IsSource(const Node* node) { return node->IsSource(); } |
844 | inline bool IsSink(const Node* node) { return node->IsSink(); } |
845 | inline bool IsSwitch(const Node* node) { return node->IsSwitch(); } |
846 | inline bool IsMerge(const Node* node) { return node->IsMerge(); } |
847 | inline bool IsEnter(const Node* node) { return node->IsEnter(); } |
848 | inline bool IsExit(const Node* node) { return node->IsExit(); } |
849 | inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); } |
850 | inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); } |
851 | inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); } |
852 | inline bool IsSend(const Node* node) { return node->IsSend(); } |
853 | inline bool IsRecv(const Node* node) { return node->IsRecv(); } |
854 | inline bool IsHostSend(const Node* node) { return node->IsHostSend(); } |
855 | inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); } |
856 | |
857 | // True for Nodes that mediate the transfer of values between processes. |
858 | inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); } |
859 | |
860 | inline bool IsConstant(const Node* node) { return node->IsConstant(); } |
861 | inline bool IsVariable(const Node* node) { return node->IsVariable(); } |
862 | inline bool IsIdentity(const Node* node) { return node->IsIdentity(); } |
863 | |
864 | // Returns true iff 'n' is a control flow node. |
865 | inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); } |
866 | |
867 | // Returns true if the node only depends on its input's metadata |
868 | // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops. |
869 | inline bool IsMetadata(const Node* n) { return n->IsMetadata(); } |
870 | |
871 | inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); } |
872 | |
873 | inline bool IsHostMemoryPreserving(const Node* node) { |
874 | return IsIdentity(node) || IsControlFlow(node); |
875 | } |
876 | |
877 | inline bool IsDistributedCommunication(const Node* n) { |
878 | return n->IsDistributedCommunication(); |
879 | } |
880 | |
881 | // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see |
882 | // https://en.cppreference.com/w/cpp/iterator/iterator). |
883 | |
884 | // Iterator for stepping through the nodes of a graph. |
885 | class NodeIter |
886 | : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t, |
887 | /*Pointer*/ Node*, /*Reference*/ Node*> { |
888 | public: |
889 | NodeIter(const Graph* graph, int id); |
890 | bool operator==(const NodeIter& rhs) const; |
891 | bool operator!=(const NodeIter& rhs) const; |
892 | void operator++(); |
893 | reference operator*() const; |
894 | pointer operator->() const; |
895 | |
896 | private: |
897 | // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr |
898 | const Graph* graph_; |
899 | int id_; |
900 | }; |
901 | |
902 | // Iterator for stepping through the neighbors of a node. |
903 | class NeighborIter |
904 | : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t, |
905 | /*Pointer*/ Node*, /*Reference*/ Node*> { |
906 | public: |
907 | NeighborIter(EdgeSet::const_iterator iter, bool incoming); |
908 | bool operator==(const NeighborIter& rhs) const; |
909 | bool operator!=(const NeighborIter& rhs) const; |
910 | void operator++(); |
911 | reference operator*() const; |
912 | pointer operator->() const; |
913 | |
914 | private: |
915 | EdgeSet::const_iterator iter_; |
916 | bool incoming_; |
917 | }; |
918 | |
919 | // IMPLEMENTATION DETAILS, PLEASE IGNORE |
920 | |
921 | inline NodeIter::NodeIter(const Graph* graph, int id) |
922 | : graph_(graph), id_(id) {} |
923 | |
924 | inline bool NodeIter::operator==(const NodeIter& rhs) const { |
925 | DCHECK(graph_ == rhs.graph_); |
926 | return id_ == rhs.id_; |
927 | } |
928 | |
929 | inline bool NodeIter::operator!=(const NodeIter& rhs) const { |
930 | return !(*this == rhs); |
931 | } |
932 | |
933 | inline void NodeIter::operator++() { |
934 | while (1) { |
935 | DCHECK_LE(id_, graph_->num_node_ids()); |
936 | ++id_; |
937 | if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) { |
938 | return; |
939 | } |
940 | } |
941 | } |
942 | |
943 | inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); } |
944 | |
945 | inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); } |
946 | |
947 | inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming) |
948 | : iter_(iter), incoming_(incoming) {} |
949 | |
950 | inline bool NeighborIter::operator==(const NeighborIter& rhs) const { |
951 | return iter_ == rhs.iter_ && incoming_ == rhs.incoming_; |
952 | } |
953 | |
954 | inline bool NeighborIter::operator!=(const NeighborIter& rhs) const { |
955 | return !(*this == rhs); |
956 | } |
957 | |
958 | inline void NeighborIter::operator++() { ++iter_; } |
959 | |
960 | inline Node* NeighborIter::operator*() const { |
961 | const Edge* e = *iter_; |
962 | return incoming_ ? e->src() : e->dst(); |
963 | } |
964 | |
965 | inline Node* NeighborIter::operator->() const { |
966 | const Edge* e = *iter_; |
967 | return incoming_ ? e->src() : e->dst(); |
968 | } |
969 | |
970 | inline bool Edge::IsControlEdge() const { |
971 | // Note that if either src_output_ or dst_input_ is kControlSlot, |
972 | // so is the other one (AddEdge checks this). |
973 | return src_output_ == Graph::kControlSlot; |
974 | } |
975 | |
976 | inline gtl::iterator_range<NodeIter> Graph::nodes() const { |
977 | // Note that NodeId 0 is always valid since we don't let the source |
978 | // node be removed from the graph. |
979 | return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids())); |
980 | } |
981 | |
982 | inline gtl::iterator_range<NodeIter> Graph::op_nodes() const { |
983 | // Note that NodeId 0 is always valid since we don't let the source |
984 | // node be removed from the graph. |
985 | // |
986 | // The current implementation of Graph maintains the invariant that the |
987 | // first two nodes are the source and sink nodes, and all other nodes are op |
988 | // nodes. This method (op_nodes()) relies on this invariant. |
989 | NodeIter begin(this, 0); |
990 | NodeIter end(this, num_node_ids()); |
991 | if (begin != end) { |
992 | ++begin; |
993 | } |
994 | if (begin != end) { |
995 | ++begin; |
996 | } |
997 | return gtl::make_range(begin, end); |
998 | } |
999 | |
1000 | inline void Node::set_assigned_device_name_index(int index) { |
1001 | graph_->CheckDeviceNameIndex(index); |
1002 | assigned_device_name_index_ = index; |
1003 | } |
1004 | |
1005 | inline void Node::set_assigned_device_name(const std::string& device_name) { |
1006 | graph_->set_assigned_device_name(this, device_name); |
1007 | } |
1008 | |
1009 | inline const std::string& Node::assigned_device_name() const { |
1010 | return graph_->get_assigned_device_name(*this); |
1011 | } |
1012 | |
1013 | } // namespace tensorflow |
1014 | |
1015 | #endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_ |
1016 | |