1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ |
18 | |
19 | #include <functional> |
20 | #include <memory> |
21 | |
22 | #include "tensorflow/core/framework/function.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | class AttrSlice; |
28 | class Graph; |
29 | class GraphDef; |
30 | class NameAttrList; |
31 | class Node; |
32 | class NodeDef; |
33 | class OpDef; |
34 | |
35 | // Debugging facility. Returns a debug string for a graph |
36 | // representing an instantiated function. |
37 | string DebugString(const Graph* g); |
38 | |
39 | // Dump the contents of the "graph" to log files if the logging level is |
40 | // sufficiently high. |
41 | void DumpGraph(StringPiece label, const Graph* g); |
42 | |
43 | // Convert the Graph of a function to a GraphDef. |
44 | // |
45 | // Handles renaming of nodes to avoid duplicate names which may |
46 | // be present after various rewriting operations. |
47 | void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false); |
48 | |
49 | // Extracts function name and attributes from `call_def` |
50 | // `call_def` can be a native function call (where the op type is the function |
51 | // name) or a call through PartitionedCall/StatefulPartitionedCall. |
52 | Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, |
53 | NameAttrList* function); |
54 | |
55 | // A few hand-crafted optimization on the instantiated function body |
56 | // (a Graph*). |
57 | |
58 | // Removes nodes that are |
59 | // 1. not stateful; and |
60 | // 2. not _Arg; and |
61 | // 3. not reachable from _Retval. |
62 | // |
63 | // This function is triggered by function inlining, unlike 'PruneFunctionBody' |
64 | // it doesn't preserve nodes that are reachable from control returns. Function |
65 | // inlining is responsible for connecting control return nodes with the nodes |
66 | // that have input control edges from the inlined function call node. |
67 | // |
68 | // Assuming that automatic control dependency tracking is correct, absence of |
69 | // outgoing control edge from the function call node means that no one needs to |
70 | // observe side-effect that might have been generated by the function (see |
71 | // documentation in common_runtime/function.cc for details). |
72 | // |
73 | // Returns true iff any node is removed from "g". |
74 | bool RemoveDeadNodes(Graph* g); |
75 | |
76 | // Find a pattern: |
77 | // src -(in)-> node -(out)-> dst, where |
78 | // 1) node is an identity node; |
79 | // 2) in is the only incoming data edge; |
80 | // 3) out is the only outgoing data edge; |
81 | // |
82 | // Rewrites the above pattern with src->dst and relevant data |
83 | // dependencies updated. Repeat the process until no such pattern |
84 | // left. |
85 | bool RemoveIdentityNodes(Graph* g); |
86 | |
87 | // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes. |
88 | bool RemoveListArrayConverter(Graph* g); |
89 | |
90 | // Extracts function name and attributes from `call_def` and invokes |
91 | // flr->Instantiate(name, attrs, handle). |
92 | // `call_def` can be a native function call (where the op type is the function |
93 | // name) or a call through PartitionedCall/StatefulPartitionedCall. |
94 | Status InstantiateFunctionCall(const NodeDef& call_def, |
95 | FunctionLibraryRuntime* flr, |
96 | FunctionLibraryRuntime::Handle* handle); |
97 | |
98 | // Returns true iff `n` represents a function call. `n` can be a native |
99 | // function call (n.type_string() is the function name), |
100 | // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which |
101 | // has been deprecated for a while). |
102 | bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n); |
103 | } // end namespace tensorflow |
104 | |
105 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ |
106 | |