1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
17
18#include <vector>
19
20#include "absl/container/flat_hash_map.h"
21#include "tensorflow/core/common_runtime/graph_runner.h"
22#include "tensorflow/core/framework/function.pb.h"
23#include "tensorflow/core/framework/shape_inference.h"
24#include "tensorflow/core/graph/graph.h"
25#include "tensorflow/core/lib/core/status.h"
26#include "tensorflow/core/platform/macros.h"
27
28namespace tensorflow {
29namespace grappler {
30class GraphProperties;
31}
32
33// This class stores extra inference information in addition to
34// InferenceContext, such as node input and output types.
35class ExtendedInferenceContext {
36 public:
37 ExtendedInferenceContext(
38 std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node)
39 : inference_context_(std::move(ic)), op_(node->name()) {
40 input_types_.reserve(node->num_inputs());
41 for (int i = 0; i < node->num_inputs(); i++) {
42 input_types_.push_back(node->input_type(i));
43 }
44 output_types_.reserve(node->num_outputs());
45 for (int i = 0; i < node->num_outputs(); i++) {
46 output_types_.push_back(node->output_type(i));
47 }
48 }
49
50 DataType input_type(int64_t idx) const { return input_types_[idx]; }
51 DataType output_type(int64_t idx) const { return output_types_[idx]; }
52
53 shape_inference::InferenceContext* get_context() {
54 return inference_context_.get();
55 }
56
57 std::string op() const { return op_; }
58
59 private:
60 std::unique_ptr<shape_inference::InferenceContext> inference_context_;
61 std::string op_;
62 std::vector<DataType> input_types_;
63 std::vector<DataType> output_types_;
64
65 TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext);
66};
67
68// ShapeRefiner performs shape inference for TensorFlow Graphs. It is
69// responsible for instantiating InferenceContext objects for each
70// Node in the Graph, and providing/storing the 'input_tensor' Tensors
71// used by Shape Inference functions, when available at graph
72// construction time.
73class ShapeRefiner {
74 public:
75 ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
76
77 // Same as ShapeRefiner(versions.producer(), ops)
78 ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops);
79
80 ~ShapeRefiner();
81
82 // Performs validation of 'node' and runs 'node's shape function,
83 // storing its shape outputs.
84 //
85 // All inputs of 'node' must be added to ShapeRefiner prior to
86 // adding 'node'.
87 //
88 // Returns an error if:
89 // - the shape function for 'node' was not registered.
90 // - 'node' was added before its inputs.
91 // - The shape inference function returns an error.
92 Status AddNode(const Node* node);
93
94 // Sets 'node's 'output_port' output to have shape 'shape'.
95 //
96 // Returns an error if 'node' was not previously added to this
97 // object, if 'output_port' is invalid, or if 'shape' is
98 // not compatible with the existing shape of the output.
99 Status SetShape(const Node* node, int output_port,
100 shape_inference::ShapeHandle shape);
101
102 // Update the input shapes of node in case the shapes of the fan-ins of 'node'
103 // have themselves been modified (For example, in case of incremental shape
104 // refinement). If 'relax' is true, a new shape with the broadest set of
105 // information will be set as the new input (see InferenceContext::RelaxInput
106 // for full details and examples). Sets refined to true if any shapes have
107 // changed (in their string representations). Note that shapes may have been
108 // updated to newer versions (but with identical string representations) even
109 // if <*refined> is set to false.
110 Status UpdateNode(const Node* node, bool relax, bool* refined);
111
112 // Returns the InferenceContext for 'node', if present.
113 shape_inference::InferenceContext* GetContext(const Node* node) const {
114 auto it = node_to_context_.find(node);
115 if (it == node_to_context_.end()) {
116 return nullptr;
117 }
118 return it->second->get_context();
119 }
120
121 // Returns the ExtendedInferenceContext for 'node', if present.
122 ExtendedInferenceContext* GetExtendedContext(const Node* node) const {
123 auto it = node_to_context_.find(node);
124 if (it == node_to_context_.end()) {
125 return nullptr;
126 }
127 return it->second.get();
128 }
129
130 // Getters and setters for graph_def_version_.
131 int32 graph_def_version() const { return graph_def_version_; }
132 void set_graph_def_version(int32_t version) { graph_def_version_ = version; }
133
134 void set_require_shape_inference_fns(bool require_shape_inference_fns) {
135 require_shape_inference_fns_ = require_shape_inference_fns;
136 }
137 void set_disable_constant_propagation(bool disable) {
138 disable_constant_propagation_ = disable;
139 }
140
141 // Set function library to enable function shape inference.
142 // Without function library, function inference always yields unknown shapes.
143 // With this enabled, shape inference can take more time since it descends
144 // into all function calls. It doesn't do inference once for each function
145 // definition, but once for each function call.
146 // The function library must outlive the shape refiner.
147 void set_function_library_for_shape_inference(
148 const tensorflow::FunctionLibraryDefinition* lib) {
149 function_library_ = lib;
150 }
151
152 bool function_shape_inference_supported() const {
153 return function_library_ != nullptr;
154 }
155
156 private:
157 friend class ShapeRefinerTest;
158 friend class ::tensorflow::grappler::GraphProperties;
159
160 // Returns true if the ranks and all dimensions of <s0> and <s1> are either
161 // equal in value or both unknown.
162 static bool SameDefinedShape(shape_inference::InferenceContext* c,
163 shape_inference::ShapeHandle s0,
164 shape_inference::ShapeHandle s1);
165
166 // Returns true if the shapes and types stored in <*existing> are identical in
167 // value to the shapes and types in <*updated>.
168 static bool IsUpdatedShapesOrTypes(
169 shape_inference::InferenceContext* c,
170 const std::vector<shape_inference::ShapeAndType>& existing,
171 const std::vector<shape_inference::ShapeAndType>& updated);
172
173 // Performs shape inference for the given function_def within the
174 // given outer_context. Internally it instantiates the function as a graph
175 // and runs shape inference recursively on it with the input shapes provided
176 // by the outer_context.
177 //
178 // Returns an error if:
179 // - number of inputs/outputs on outer_context doesn't match the function_def
180 //
181 // On success:
182 // - outer_context will contain output shapes inferred from input shapes
183 Status InferShapesForFunction(
184 const FunctionDef* function_def, AttrSlice attributes,
185 shape_inference::InferenceContext* outer_context);
186
187 // Performs shape inference for a node inside a function.
188 //
189 // 'outer_context' is the 'InferenceContext' for the function's call op.
190 Status InferShapesForFunctionSubNode(
191 const Node* node, shape_inference::InferenceContext* outer_context);
192
193 // Performs validation of 'node' and runs 'node's shape function,
194 // storing its shape outputs.
195 //
196 // All inputs of 'node' must be added to ShapeRefiner prior to
197 // adding 'node'.
198 //
199 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
200 // the call op of the function can be passed as 'outer_context' (pass nullptr
201 // otherwise). This gets used to perform constant propagation across Arg nodes
202 // by requesting the constant of value of the incoming tensor from the
203 // 'outer_context'.
204 //
205 // Returns an error if:
206 // - the shape function for 'node' was not registered.
207 // - 'node' was added before its inputs.
208 // - The shape inference function returns an error.
209 Status AddNodeInternal(const Node* node,
210 shape_inference::InferenceContext* outer_context);
211
212 // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
213 // value can be evaluated, 'evaluated' is set to true and the value returned
214 // in 'result'. Otherwise 'evaluated' is set to false.
215 //
216 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
217 // the call op of the function can be passed as 'outer_context' (pass nullptr
218 // otherwise). This gets used to perform constant propagation across Arg nodes
219 // by requesting the constant of value of the incoming tensor from the
220 // 'outer_context'.
221 Status EvaluateConstantTensorForEdge(
222 const Node* node, int dst_idx, bool* evaluated, Tensor* result,
223 shape_inference::InferenceContext* outer_context);
224
225 // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
226 // tensors. The caller is responsible for checking that the specified edge is
227 // scalar and int32 or int64.
228 //
229 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
230 // the call op of the function can be passed as 'outer_context' (pass nullptr
231 // otherwise). This gets used to perform constant propagation across Arg nodes
232 // by requesting the constant of value of the incoming tensor from the
233 // 'outer_context'.
234 Status EvaluateConstantIntScalarEdge(
235 const Node* node, int dst_idx, bool* evaluated, int64_t* result,
236 shape_inference::InferenceContext* outer_context);
237
238 // This function tries to materialize as much information about the 'node''s
239 // dst_idx input as a statically computable shape, and the result may be
240 // partially known, depending on what is statically inferable.
241 //
242 // This is called when node.input[dst_idx] is a tensor that is used to define
243 // the shape of some other tensor (e.g., the second argument to Reshape is a
244 // <shape> tensor, where each element of the shape tensor is a dimension of
245 // the target tensor). It returns in <result> a shape for that input.
246 //
247 // Unlike simply resolving node.input[dst_idx] to a constant and then
248 // converting that to a shape, this function can return a partial shape. This
249 // is useful for cases where the shape tensor is only partially defined, such
250 // as with calls for: reshape(x, shape(y)) where shape(y) is partially
251 // defined.
252 //
253 // The implementation has op implementations for ops commonly called on shape
254 // tensors, and the implementations are specialized to shape tensors (namely,
255 // the output is a vector).
256 //
257 // <target_context> is used when creating new DimensionHandle and ShapeHandle
258 // objects.
259 //
260 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
261 // the call op of the function can be passed as 'outer_context' (pass nullptr
262 // otherwise). This gets used to perform constant propagation across Arg nodes
263 // by requesting the constant of value of the incoming tensor from the
264 // 'outer_context'.
265 Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
266 const Node* node, int dst_idx,
267 shape_inference::ShapeHandle* result,
268 shape_inference::InferenceContext* outer_context);
269
270 // Implementation of ConstantPartialShape for StridedSlice nodes.
271 //
272 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
273 // the call op of the function can be passed as 'outer_context' (pass nullptr
274 // otherwise). This gets used to perform constant propagation across Arg nodes
275 // by requesting the constant of value of the incoming tensor from the
276 // 'outer_context'.
277 Status PartialStridedSliceShape(
278 Node* slice_node, shape_inference::InferenceContext* ctx,
279 shape_inference::ShapeHandle* result,
280 shape_inference::InferenceContext* outer_context);
281
282 // Runs the shape function registered for the node's op type.
283 //
284 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
285 // the call op of the function can be passed as 'outer_context' (pass nullptr
286 // otherwise). This gets used to perform constant propagation across Arg nodes
287 // by requesting the constant of value of the incoming tensor from the
288 // 'outer_context'.
289 Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
290 ExtendedInferenceContext* ec,
291 shape_inference::InferenceContext* outer_context = nullptr);
292
293 int32 graph_def_version_;
294 const OpRegistryInterface* const ops_registry_;
295
296 // The lifetime of the tensors are bound to the runner, so it should be the
297 // deleted after the tensors.
298 GraphRunner graph_runner_;
299
300 // Stores a map from a node to its ExtendedInferenceContext.
301 absl::flat_hash_map<const Node*, std::unique_ptr<ExtendedInferenceContext>,
302 hash<const Node*>>
303 node_to_context_;
304
305 // Holds a cache from 'tensor name' to the tensor that is
306 // evaluatable as a constant expression. This reduces repeated
307 // execution of the entire constant subgraph as a graph is being
308 // built up. This could be changed to some kind of size-based LRU
309 // cache to avoid consuming too much memory, if that eventually
310 // becomes a concern.
311 //
312 // Only tensors less than 1KiB are currently stored in the cache.
313 static constexpr int64_t kMaxTensorSize = 1024;
314 std::unordered_map<string, Tensor> const_tensor_map_;
315
316 bool require_shape_inference_fns_ = true;
317 bool disable_constant_propagation_ = false;
318
319 // Function library is optional, but has to be set to enable function
320 // shape inference.
321 const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr;
322
323 // Cache the graph corresponding to each function definition for which shapes
324 // are refined.
325 absl::flat_hash_map<const FunctionDef*, std::unique_ptr<const Graph>,
326 hash<const FunctionDef*>>
327 functions_;
328
329 TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner);
330};
331
332} // namespace tensorflow
333
334#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
335