1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ |
18 | |
19 | // Operations calling functions are becoming ubiquitous in TF 2.0. |
20 | // Examples include PartitionedCallOp, functional If/While, and Dataset ops. |
21 | // Such operations might require deep inspection - looking at the body of the |
22 | // called function - to place them and surrounding ops correctly. |
23 | |
24 | // This file contains some utilities for placer to correctly place such ops |
25 | // including: |
26 | // - PlacerInspectionRequiredOpChecker: A simple class with a single |
27 | // IsPlacerInspectionRequired method. |
28 | // - IsolatePlacerInspectionRequiredOps: This function adds Identity ops for |
29 | // each input/output of ops requiring placer inspection. It greatly simplifies |
30 | // the implementation of placing such ops. |
31 | |
32 | #include <vector> |
33 | |
34 | #include "absl/types/optional.h" |
35 | #include "tensorflow/core/framework/function.h" |
36 | #include "tensorflow/core/graph/graph.h" |
37 | #include "tensorflow/core/lib/core/status.h" |
38 | |
39 | namespace tensorflow { |
40 | |
41 | // PlacerInspectionRequiredOpChecker allows one to check if Placer needs to |
42 | // look deeply into the op to place ops consuming the outputs correctly. |
43 | // |
44 | // It is a class instead of a standalone method because checking whether |
45 | // a function returns a resource takes non-trivial time and we cache the |
46 | // results. |
47 | class PlacerInspectionRequiredOpChecker { |
48 | public: |
49 | // Constructs a PlacerInspectionRequiredOpChecker for nodes of `graph`. |
50 | // The functions referenced by nodes in `graph` will be looked up in |
51 | // `flib_def` |
52 | PlacerInspectionRequiredOpChecker(const Graph* graph, |
53 | const FunctionLibraryDefinition* flib_def); |
54 | |
55 | // If `node` is considered a deep op, sets `*is_deep` to true and returns |
56 | // OkStatus(). If an error occurs, returns that error, and the value of |
57 | // `*is_deep` is undefined. |
58 | // Currently, an op is considered deep, if it is a calling a function |
59 | // returning a resource. This definition is driven by Placer's need to |
60 | // look inside the op. |
61 | // REQUIRES: `node` is part of `graph` passed into constructor. |
62 | Status IsPlacerInspectionRequired(const Node& node, bool* is_deep); |
63 | |
64 | private: |
65 | const Graph& graph_; |
66 | const FunctionLibraryDefinition& flib_def_; |
67 | // Indexed by the node id. |
68 | // If cache_[node_id] is empty, the deepness of the node with id `node_id` has |
69 | // not been computed yet. Else, it contains the value already computed. |
70 | std::vector<absl::optional<bool>> cache_; |
71 | }; |
72 | |
73 | // Extracts `fdef` and `func` from `flib_def` for the function identified |
74 | // in "f" attribute of `node`. |
75 | Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def, |
76 | const Node& node, const FunctionDef** fdef, |
77 | NameAttrList* func); |
78 | |
79 | // The "call" stack of functions. |
80 | // Useful for better error messages as well as for detecting recursion. |
81 | // Stores references to graph nodes. These references must outlive this. |
82 | class FunctionStack { |
83 | public: |
84 | explicit FunctionStack(const string& function_name); |
85 | |
86 | // `node_in_current_function` must outlive this. |
87 | FunctionStack Push(const Node* node_in_current_function, |
88 | const string& new_current_function) const; |
89 | |
90 | // Returns true iff this stack already includes `function_name`. |
91 | bool HasFunction(const string& function_name) const; |
92 | |
93 | const string& current_function_name() const { return current_function_name_; } |
94 | |
95 | // Format's this suitable for error interpolation that retrieves |
96 | // Python files and line numbers. |
97 | string FormatForError() const; |
98 | |
99 | private: |
100 | struct Frame { |
101 | Frame(const string& function, const Node* node) |
102 | : function_name(function), node(node) {} |
103 | |
104 | string function_name; |
105 | const Node* node; |
106 | }; |
107 | |
108 | // The function at the top of the stack. In other words, the function |
109 | // that is currently being inspected for placement. |
110 | string current_function_name_; |
111 | |
112 | // The stack of frames that got the placement to the current_function_name_. |
113 | // frames_[0].function_name is the top function that Placer was constructed |
114 | // with. frames_[0].function_name can be empty if placer was constructed with |
115 | // a nameless graph, not a function. frames_[0].node_name is a name of a node |
116 | // in frames_[0].function_name that required deep inspection (e.g. a |
117 | // PartitionedCallOp). The function that this node invoked is |
118 | // frames_[1].function_name, if frames_.size() > 1. Else, the function that |
119 | // this node invoked is current_function_name_. |
120 | std::vector<Frame> frames_; |
121 | }; |
122 | |
123 | // Adds Identities for each input and output of function-calling ops in `graph` |
124 | // |
125 | // For example, the following graph calling a function on inputs `a` and `b` |
126 | // and producing output `y` will be rewritten to include identities on all |
127 | // edges: |
128 | // |
129 | // a b |
130 | // | | |
131 | // v v |
132 | // f (PartitionedCallOp) |
133 | // | |
134 | // v |
135 | // y |
136 | // |
137 | // is transformed to |
138 | // |
139 | // a b |
140 | // | | |
141 | // a_f (Identity) b_f (Identity) |
142 | // | | |
143 | // v v |
144 | // f (PartitionedCallOp) |
145 | // | |
146 | // f_y (Identity) |
147 | // | |
148 | // v |
149 | // y |
150 | // |
151 | Status IsolatePlacerInspectionRequiredOps( |
152 | const FunctionLibraryDefinition& flib_def, Graph* graph); |
153 | |
154 | } // namespace tensorflow |
155 | |
156 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ |
157 | |