1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ |
17 | |
18 | #include <vector> |
19 | |
20 | #include "absl/container/flat_hash_map.h" |
21 | #include "tensorflow/core/common_runtime/graph_runner.h" |
22 | #include "tensorflow/core/framework/function.pb.h" |
23 | #include "tensorflow/core/framework/shape_inference.h" |
24 | #include "tensorflow/core/graph/graph.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | #include "tensorflow/core/platform/macros.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace grappler { |
30 | class GraphProperties; |
31 | } |
32 | |
33 | // This class stores extra inference information in addition to |
34 | // InferenceContext, such as node input and output types. |
35 | class ExtendedInferenceContext { |
36 | public: |
37 | ExtendedInferenceContext( |
38 | std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node) |
39 | : inference_context_(std::move(ic)), op_(node->name()) { |
40 | input_types_.reserve(node->num_inputs()); |
41 | for (int i = 0; i < node->num_inputs(); i++) { |
42 | input_types_.push_back(node->input_type(i)); |
43 | } |
44 | output_types_.reserve(node->num_outputs()); |
45 | for (int i = 0; i < node->num_outputs(); i++) { |
46 | output_types_.push_back(node->output_type(i)); |
47 | } |
48 | } |
49 | |
50 | DataType input_type(int64_t idx) const { return input_types_[idx]; } |
51 | DataType output_type(int64_t idx) const { return output_types_[idx]; } |
52 | |
53 | shape_inference::InferenceContext* get_context() { |
54 | return inference_context_.get(); |
55 | } |
56 | |
57 | std::string op() const { return op_; } |
58 | |
59 | private: |
60 | std::unique_ptr<shape_inference::InferenceContext> inference_context_; |
61 | std::string op_; |
62 | std::vector<DataType> input_types_; |
63 | std::vector<DataType> output_types_; |
64 | |
65 | TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext); |
66 | }; |
67 | |
68 | // ShapeRefiner performs shape inference for TensorFlow Graphs. It is |
69 | // responsible for instantiating InferenceContext objects for each |
70 | // Node in the Graph, and providing/storing the 'input_tensor' Tensors |
71 | // used by Shape Inference functions, when available at graph |
72 | // construction time. |
73 | class ShapeRefiner { |
74 | public: |
75 | ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); |
76 | |
77 | // Same as ShapeRefiner(versions.producer(), ops) |
78 | ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops); |
79 | |
80 | ~ShapeRefiner(); |
81 | |
82 | // Performs validation of 'node' and runs 'node's shape function, |
83 | // storing its shape outputs. |
84 | // |
85 | // All inputs of 'node' must be added to ShapeRefiner prior to |
86 | // adding 'node'. |
87 | // |
88 | // Returns an error if: |
89 | // - the shape function for 'node' was not registered. |
90 | // - 'node' was added before its inputs. |
91 | // - The shape inference function returns an error. |
92 | Status AddNode(const Node* node); |
93 | |
94 | // Sets 'node's 'output_port' output to have shape 'shape'. |
95 | // |
96 | // Returns an error if 'node' was not previously added to this |
97 | // object, if 'output_port' is invalid, or if 'shape' is |
98 | // not compatible with the existing shape of the output. |
99 | Status SetShape(const Node* node, int output_port, |
100 | shape_inference::ShapeHandle shape); |
101 | |
102 | // Update the input shapes of node in case the shapes of the fan-ins of 'node' |
103 | // have themselves been modified (For example, in case of incremental shape |
104 | // refinement). If 'relax' is true, a new shape with the broadest set of |
105 | // information will be set as the new input (see InferenceContext::RelaxInput |
106 | // for full details and examples). Sets refined to true if any shapes have |
107 | // changed (in their string representations). Note that shapes may have been |
108 | // updated to newer versions (but with identical string representations) even |
109 | // if <*refined> is set to false. |
110 | Status UpdateNode(const Node* node, bool relax, bool* refined); |
111 | |
112 | // Returns the InferenceContext for 'node', if present. |
113 | shape_inference::InferenceContext* GetContext(const Node* node) const { |
114 | auto it = node_to_context_.find(node); |
115 | if (it == node_to_context_.end()) { |
116 | return nullptr; |
117 | } |
118 | return it->second->get_context(); |
119 | } |
120 | |
121 | // Returns the ExtendedInferenceContext for 'node', if present. |
122 | ExtendedInferenceContext* GetExtendedContext(const Node* node) const { |
123 | auto it = node_to_context_.find(node); |
124 | if (it == node_to_context_.end()) { |
125 | return nullptr; |
126 | } |
127 | return it->second.get(); |
128 | } |
129 | |
130 | // Getters and setters for graph_def_version_. |
131 | int32 graph_def_version() const { return graph_def_version_; } |
132 | void set_graph_def_version(int32_t version) { graph_def_version_ = version; } |
133 | |
134 | void set_require_shape_inference_fns(bool require_shape_inference_fns) { |
135 | require_shape_inference_fns_ = require_shape_inference_fns; |
136 | } |
137 | void set_disable_constant_propagation(bool disable) { |
138 | disable_constant_propagation_ = disable; |
139 | } |
140 | |
141 | // Set function library to enable function shape inference. |
142 | // Without function library, function inference always yields unknown shapes. |
143 | // With this enabled, shape inference can take more time since it descends |
144 | // into all function calls. It doesn't do inference once for each function |
145 | // definition, but once for each function call. |
146 | // The function library must outlive the shape refiner. |
147 | void set_function_library_for_shape_inference( |
148 | const tensorflow::FunctionLibraryDefinition* lib) { |
149 | function_library_ = lib; |
150 | } |
151 | |
152 | bool function_shape_inference_supported() const { |
153 | return function_library_ != nullptr; |
154 | } |
155 | |
156 | private: |
157 | friend class ShapeRefinerTest; |
158 | friend class ::tensorflow::grappler::GraphProperties; |
159 | |
160 | // Returns true if the ranks and all dimensions of <s0> and <s1> are either |
161 | // equal in value or both unknown. |
162 | static bool SameDefinedShape(shape_inference::InferenceContext* c, |
163 | shape_inference::ShapeHandle s0, |
164 | shape_inference::ShapeHandle s1); |
165 | |
166 | // Returns true if the shapes and types stored in <*existing> are identical in |
167 | // value to the shapes and types in <*updated>. |
168 | static bool IsUpdatedShapesOrTypes( |
169 | shape_inference::InferenceContext* c, |
170 | const std::vector<shape_inference::ShapeAndType>& existing, |
171 | const std::vector<shape_inference::ShapeAndType>& updated); |
172 | |
173 | // Performs shape inference for the given function_def within the |
174 | // given outer_context. Internally it instantiates the function as a graph |
175 | // and runs shape inference recursively on it with the input shapes provided |
176 | // by the outer_context. |
177 | // |
178 | // Returns an error if: |
179 | // - number of inputs/outputs on outer_context doesn't match the function_def |
180 | // |
181 | // On success: |
182 | // - outer_context will contain output shapes inferred from input shapes |
183 | Status InferShapesForFunction( |
184 | const FunctionDef* function_def, AttrSlice attributes, |
185 | shape_inference::InferenceContext* outer_context); |
186 | |
187 | // Performs shape inference for a node inside a function. |
188 | // |
189 | // 'outer_context' is the 'InferenceContext' for the function's call op. |
190 | Status InferShapesForFunctionSubNode( |
191 | const Node* node, shape_inference::InferenceContext* outer_context); |
192 | |
193 | // Performs validation of 'node' and runs 'node's shape function, |
194 | // storing its shape outputs. |
195 | // |
196 | // All inputs of 'node' must be added to ShapeRefiner prior to |
197 | // adding 'node'. |
198 | // |
199 | // Optionally, if 'node' is in a nested function, the 'InferenceContext' for |
200 | // the call op of the function can be passed as 'outer_context' (pass nullptr |
201 | // otherwise). This gets used to perform constant propagation across Arg nodes |
202 | // by requesting the constant of value of the incoming tensor from the |
203 | // 'outer_context'. |
204 | // |
205 | // Returns an error if: |
206 | // - the shape function for 'node' was not registered. |
207 | // - 'node' was added before its inputs. |
208 | // - The shape inference function returns an error. |
209 | Status AddNodeInternal(const Node* node, |
210 | shape_inference::InferenceContext* outer_context); |
211 | |
212 | // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge |
213 | // value can be evaluated, 'evaluated' is set to true and the value returned |
214 | // in 'result'. Otherwise 'evaluated' is set to false. |
215 | // |
216 | // Optionally, if 'node' is in a nested function, the 'InferenceContext' for |
217 | // the call op of the function can be passed as 'outer_context' (pass nullptr |
218 | // otherwise). This gets used to perform constant propagation across Arg nodes |
219 | // by requesting the constant of value of the incoming tensor from the |
220 | // 'outer_context'. |
221 | Status EvaluateConstantTensorForEdge( |
222 | const Node* node, int dst_idx, bool* evaluated, Tensor* result, |
223 | shape_inference::InferenceContext* outer_context); |
224 | |
225 | // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input |
226 | // tensors. The caller is responsible for checking that the specified edge is |
227 | // scalar and int32 or int64. |
228 | // |
229 | // Optionally, if 'node' is in a nested function, the 'InferenceContext' for |
230 | // the call op of the function can be passed as 'outer_context' (pass nullptr |
231 | // otherwise). This gets used to perform constant propagation across Arg nodes |
232 | // by requesting the constant of value of the incoming tensor from the |
233 | // 'outer_context'. |
234 | Status EvaluateConstantIntScalarEdge( |
235 | const Node* node, int dst_idx, bool* evaluated, int64_t* result, |
236 | shape_inference::InferenceContext* outer_context); |
237 | |
238 | // This function tries to materialize as much information about the 'node''s |
239 | // dst_idx input as a statically computable shape, and the result may be |
240 | // partially known, depending on what is statically inferable. |
241 | // |
242 | // This is called when node.input[dst_idx] is a tensor that is used to define |
243 | // the shape of some other tensor (e.g., the second argument to Reshape is a |
244 | // <shape> tensor, where each element of the shape tensor is a dimension of |
245 | // the target tensor). It returns in <result> a shape for that input. |
246 | // |
247 | // Unlike simply resolving node.input[dst_idx] to a constant and then |
248 | // converting that to a shape, this function can return a partial shape. This |
249 | // is useful for cases where the shape tensor is only partially defined, such |
250 | // as with calls for: reshape(x, shape(y)) where shape(y) is partially |
251 | // defined. |
252 | // |
253 | // The implementation has op implementations for ops commonly called on shape |
254 | // tensors, and the implementations are specialized to shape tensors (namely, |
255 | // the output is a vector). |
256 | // |
257 | // <target_context> is used when creating new DimensionHandle and ShapeHandle |
258 | // objects. |
259 | // |
260 | // Optionally, if 'node' is in a nested function, the 'InferenceContext' for |
261 | // the call op of the function can be passed as 'outer_context' (pass nullptr |
262 | // otherwise). This gets used to perform constant propagation across Arg nodes |
263 | // by requesting the constant of value of the incoming tensor from the |
264 | // 'outer_context'. |
265 | Status ConstantPartialShape(shape_inference::InferenceContext* target_context, |
266 | const Node* node, int dst_idx, |
267 | shape_inference::ShapeHandle* result, |
268 | shape_inference::InferenceContext* outer_context); |
269 | |
270 | // Implementation of ConstantPartialShape for StridedSlice nodes. |
271 | // |
272 | // Optionally, if 'node' is in a nested function, the 'InferenceContext' for |
273 | // the call op of the function can be passed as 'outer_context' (pass nullptr |
274 | // otherwise). This gets used to perform constant propagation across Arg nodes |
275 | // by requesting the constant of value of the incoming tensor from the |
276 | // 'outer_context'. |
277 | Status PartialStridedSliceShape( |
278 | Node* slice_node, shape_inference::InferenceContext* ctx, |
279 | shape_inference::ShapeHandle* result, |
280 | shape_inference::InferenceContext* outer_context); |
281 | |
282 | // Runs the shape function registered for the node's op type. |
283 | // |
284 | // Optionally, if 'node' is in a nested function, the 'InferenceContext' for |
285 | // the call op of the function can be passed as 'outer_context' (pass nullptr |
286 | // otherwise). This gets used to perform constant propagation across Arg nodes |
287 | // by requesting the constant of value of the incoming tensor from the |
288 | // 'outer_context'. |
289 | Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, |
290 | ExtendedInferenceContext* ec, |
291 | shape_inference::InferenceContext* outer_context = nullptr); |
292 | |
293 | int32 graph_def_version_; |
294 | const OpRegistryInterface* const ops_registry_; |
295 | |
296 | // The lifetime of the tensors are bound to the runner, so it should be the |
297 | // deleted after the tensors. |
298 | GraphRunner graph_runner_; |
299 | |
300 | // Stores a map from a node to its ExtendedInferenceContext. |
301 | absl::flat_hash_map<const Node*, std::unique_ptr<ExtendedInferenceContext>, |
302 | hash<const Node*>> |
303 | node_to_context_; |
304 | |
305 | // Holds a cache from 'tensor name' to the tensor that is |
306 | // evaluatable as a constant expression. This reduces repeated |
307 | // execution of the entire constant subgraph as a graph is being |
308 | // built up. This could be changed to some kind of size-based LRU |
309 | // cache to avoid consuming too much memory, if that eventually |
310 | // becomes a concern. |
311 | // |
312 | // Only tensors less than 1KiB are currently stored in the cache. |
313 | static constexpr int64_t kMaxTensorSize = 1024; |
314 | std::unordered_map<string, Tensor> const_tensor_map_; |
315 | |
316 | bool require_shape_inference_fns_ = true; |
317 | bool disable_constant_propagation_ = false; |
318 | |
319 | // Function library is optional, but has to be set to enable function |
320 | // shape inference. |
321 | const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr; |
322 | |
323 | // Cache the graph corresponding to each function definition for which shapes |
324 | // are refined. |
325 | absl::flat_hash_map<const FunctionDef*, std::unique_ptr<const Graph>, |
326 | hash<const FunctionDef*>> |
327 | functions_; |
328 | |
329 | TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); |
330 | }; |
331 | |
332 | } // namespace tensorflow |
333 | |
334 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ |
335 | |