1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/eval_const_tensor.h" |
17 | |
18 | #include <deque> |
19 | |
20 | #include "tensorflow/core/common_runtime/graph_runner.h" |
21 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
22 | #include "tensorflow/core/framework/bounds_check.h" |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/shape_inference.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/versions.pb.h" |
27 | #include "tensorflow/core/graph/graph.h" |
28 | |
29 | namespace tensorflow { |
30 | namespace { |
31 | |
32 | using ::tensorflow::shape_inference::InferenceContext; |
33 | using ::tensorflow::shape_inference::ShapeHandle; |
34 | |
35 | // Returns a Tensor containing the underlyiing constant value of a Node if the |
36 | // node contains a constant value. |
37 | Status EvaluateConstantNode(const Node& node, Tensor* output, bool* success) { |
38 | *success = false; |
39 | if (node.IsConstant()) { |
40 | if (output->FromProto(node.def().attr().at("value" ).tensor())) { |
41 | *success = true; |
42 | } |
43 | } |
44 | return OkStatus(); |
45 | } |
46 | |
47 | // Returns the int value corresponding to the input src at the i'th edge if the |
48 | // input src contains a scalar tensor. |
49 | Status EvaluateConstantIntFromScalarEdge(const Node& node, int input_idx, |
50 | int64* output, bool* success) { |
51 | *success = false; |
52 | Tensor scalar; |
53 | const Edge* edge; |
54 | TF_RETURN_IF_ERROR(node.input_edge(input_idx, &edge)); |
55 | TF_RETURN_IF_ERROR(EvaluateConstantNode(*edge->src(), &scalar, success)); |
56 | if (success && scalar.NumElements() == 1) { |
57 | if (scalar.dtype() == DT_INT32) { |
58 | *output = scalar.scalar<int32>()(); |
59 | } else if (scalar.dtype() == DT_INT64) { |
60 | *output = scalar.scalar<int64_t>()(); |
61 | } else { |
62 | *success = false; |
63 | } |
64 | } |
65 | return OkStatus(); |
66 | } |
67 | |
68 | // Tries to infer the tensor output based on the input dims of a |
69 | // Shape node. |
70 | // [allow_partial = false] |
71 | // Can infer the Shape op's output tensor only when the |
72 | // input shapes to the Shape op are fully defined. |
73 | // [allow_partial = true] |
74 | // Can infer the Shape op's output tensor as long as the rank of the input |
75 | // shapes to the Shape op are known. Uses kUnknownDim for unknown dims. |
76 | Status TryToInferTensorOutputFromShapeNode(const Node& shape_node, |
77 | InferenceContext* shape_c, |
78 | Tensor* output, bool* success, |
79 | bool allow_partial = false) { |
80 | *success = false; |
81 | if (shape_node.type_string() != "Shape" ) return OkStatus(); |
82 | if (shape_c == nullptr) return OkStatus(); |
83 | if (!shape_c->FullyDefined(shape_c->input(0)) && !allow_partial) |
84 | return OkStatus(); |
85 | if (!shape_c->RankKnown(shape_c->input(0))) return OkStatus(); |
86 | |
87 | int src_rank = shape_c->Rank(shape_c->input(0)); |
88 | Tensor t(shape_node.output_type(0), TensorShape({src_rank})); |
89 | if (shape_node.output_type(0) == DT_INT32) { |
90 | auto flat = t.flat<int>(); |
91 | for (int i = 0; i < src_rank; i++) { |
92 | int64_t dimension; |
93 | if (shape_c->ValueKnown(shape_c->Dim(shape_c->input(0), i))) { |
94 | dimension = shape_c->Value(shape_c->Dim(shape_c->input(0), i)); |
95 | if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) { |
96 | return errors::InvalidArgument( |
97 | "Shape has output type int32, but dimension exceeds maximum " |
98 | "int32 value" ); |
99 | } |
100 | } else { |
101 | dimension = shape_c->kUnknownDim; |
102 | } |
103 | flat(i) = static_cast<int32>(dimension); |
104 | } |
105 | } else if (shape_node.output_type(0) == DT_INT64) { |
106 | auto flat = t.flat<int64_t>(); |
107 | for (int i = 0; i < src_rank; i++) { |
108 | if (shape_c->ValueKnown(shape_c->Dim(shape_c->input(0), i))) { |
109 | flat(i) = shape_c->Value(shape_c->Dim(shape_c->input(0), i)); |
110 | } else { |
111 | flat(i) = shape_c->kUnknownDim; |
112 | } |
113 | } |
114 | } else { |
115 | return errors::FailedPrecondition( |
116 | "Shape has output type that is not int32 or int64" ); |
117 | } |
118 | *output = t; |
119 | *success = true; |
120 | return OkStatus(); |
121 | } |
122 | |
123 | // Tries to infer the tensor output of a StridedSlice node. This can be done |
124 | // when taking a slice of a fully defined Shape node or when taking a slice |
125 | // of partial Shape node along a known dimension. |
126 | // Examples: |
127 | // tf.shape(x)[0]; x.shape = (5, 10) - slicing fully defined shape |
128 | // tf.shape(x)[0]; x.shape = (5, ?) - slicing partial shape along known dim |
129 | Status TryToInferTensorOutputFromStridedSliceNode(const Node& node, |
130 | const ShapeRefiner& refiner, |
131 | Tensor* output, |
132 | bool* success) { |
133 | *success = false; |
134 | const Edge* edge; |
135 | TF_RETURN_IF_ERROR(node.input_edge(0, &edge)); |
136 | const Node* shape_node = edge->src(); |
137 | const Node* stride_node = edge->dst(); |
138 | InferenceContext* shape_c = refiner.GetContext(shape_node); |
139 | InferenceContext* stride_c = refiner.GetContext(stride_node); |
140 | |
141 | if (stride_c == nullptr || shape_c == nullptr) return OkStatus(); |
142 | if (stride_node == nullptr || shape_node == nullptr) return OkStatus(); |
143 | if (stride_node->type_string() != "StridedSlice" ) return OkStatus(); |
144 | if (shape_node->type_string() != "Shape" ) return OkStatus(); |
145 | |
146 | // Only attempt to evaluate if the rank of the inputs to the Shape node are |
147 | // known. |
148 | if (!shape_c->RankKnown(shape_c->input(0))) return OkStatus(); |
149 | |
150 | // Only attempt to evaluate if begin/end/strides values of the StridedSlice |
151 | // node are all scalars. |
152 | for (int i = 1; i <= 3; ++i) { |
153 | ShapeHandle input_shape = stride_c->input(i); |
154 | if (stride_c->Value(stride_c->Dim(input_shape, 0)) != 1) { |
155 | return OkStatus(); |
156 | } |
157 | } |
158 | |
159 | // Only attempt to evaluate cases with non-complex masks. |
160 | int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; |
161 | TF_RETURN_IF_ERROR(stride_c->GetAttr("begin_mask" , &begin_mask)); |
162 | TF_RETURN_IF_ERROR(stride_c->GetAttr("end_mask" , &end_mask)); |
163 | TF_RETURN_IF_ERROR(stride_c->GetAttr("ellipsis_mask" , &ellipsis_mask)); |
164 | TF_RETURN_IF_ERROR(stride_c->GetAttr("new_axis_mask" , &new_axis_mask)); |
165 | TF_RETURN_IF_ERROR(stride_c->GetAttr("shrink_axis_mask" , &shrink_axis_mask)); |
166 | |
167 | // Case where user has sliced a single element of a collection. E.g. |
168 | // collection[i]. |
169 | bool accesses_single_element = begin_mask == 0 && end_mask == 0 && |
170 | ellipsis_mask == 0 && new_axis_mask == 0 && |
171 | shrink_axis_mask == 1; |
172 | |
173 | if (!accesses_single_element) return OkStatus(); |
174 | |
175 | // Calculate the output tensor from the Shape node. |
176 | Tensor shape_output; |
177 | TF_RETURN_IF_ERROR(TryToInferTensorOutputFromShapeNode( |
178 | *shape_node, shape_c, &shape_output, success, /*allow_partial=*/true)); |
179 | if (!success) return OkStatus(); |
180 | |
181 | // Discard the output tensor computed above if the StridedSlice points to an |
182 | // unknown dimension. |
183 | int64 begin_value = 0; |
184 | bool evaluated = false; |
185 | *success = false; |
186 | TF_RETURN_IF_ERROR(EvaluateConstantIntFromScalarEdge( |
187 | *stride_node, 1, &begin_value, &evaluated)); |
188 | |
189 | if (evaluated && node.output_type(0) == shape_output.dtype()) { |
190 | begin_value = begin_value < 0 |
191 | ? begin_value + shape_c->Rank(shape_c->input(0)) |
192 | : begin_value; |
193 | Tensor t(node.output_type(0), TensorShape({})); |
194 | if (shape_output.dtype() == DT_INT32 && |
195 | shape_output.flat<int>()(begin_value) != -1) { |
196 | t.flat<int32>()(0) = shape_output.flat<int>()(begin_value); |
197 | *output = t; |
198 | *success = true; |
199 | } else if (shape_output.dtype() == DT_INT64 && |
200 | shape_output.flat<int64_t>()(begin_value) != -1) { |
201 | t.flat<int64_t>()(0) = shape_output.flat<int64_t>()(begin_value); |
202 | *output = t; |
203 | *success = true; |
204 | } |
205 | } |
206 | |
207 | return OkStatus(); |
208 | } |
209 | |
210 | // Tries to infer tensor output based on the input shapes of the node. In some |
211 | // cases, the shapes of the inputs are sufficient for inferring the contents of |
212 | // the output tensor. For example, a Shape op with fully defined input shapes |
213 | // can have its output tensor inferred. |
214 | Status TryToInferTensorOutputFromInputShapes(const Edge& edge, |
215 | const ShapeRefiner& refiner, |
216 | Tensor* output, bool* success) { |
217 | *success = false; |
218 | const Node* node = edge.src(); |
219 | InferenceContext* c = refiner.GetContext(node); |
220 | if (c == nullptr) { |
221 | // An input without context is a soft failure; we sometimes need to break |
222 | // control flow loops by running shape inference on a node without first |
223 | // adding its input. |
224 | return OkStatus(); |
225 | } |
226 | |
227 | if (node->type_string() == "StridedSlice" ) { |
228 | TF_RETURN_IF_ERROR(TryToInferTensorOutputFromStridedSliceNode( |
229 | *node, refiner, output, success)); |
230 | } else if (node->type_string() == "Shape" ) { |
231 | // If input shapes to the shape op are fully defined, |
232 | // we can infer the shape op's output tensor. |
233 | TF_RETURN_IF_ERROR( |
234 | TryToInferTensorOutputFromShapeNode(*node, c, output, success)); |
235 | } else if (node->type_string() == "Rank" ) { |
236 | bool rank_known = c->RankKnown(c->input(0)); |
237 | if (rank_known) { |
238 | int32 input_rank = c->Rank(c->input(0)); |
239 | Tensor t(node->output_type(0), TensorShape({})); |
240 | t.flat<int32>()(0) = input_rank; |
241 | *output = t; |
242 | *success = true; |
243 | } |
244 | } else if (node->type_string() == "Size" ) { |
245 | bool fully_defined_inputs = c->FullyDefined(c->input(0)); |
246 | if (fully_defined_inputs) { |
247 | int32 rank = c->Rank(c->input(0)); |
248 | Tensor t(node->output_type(0), TensorShape({})); |
249 | int64 size = 1; |
250 | for (int i = 0; i < rank; i++) { |
251 | size *= c->Value(c->Dim(c->input(0), i)); |
252 | } |
253 | if (node->output_type(0) == DT_INT32) { |
254 | if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) { |
255 | return errors::InvalidArgument( |
256 | "Size has output type int32, but size exceeds maximum int32 " |
257 | "value" ); |
258 | } |
259 | t.flat<int32>()(0) = static_cast<int32>(size); |
260 | } else if (node->output_type(0) == DT_INT64) { |
261 | t.flat<int64_t>()(0) = size; |
262 | } else { |
263 | return errors::FailedPrecondition( |
264 | "Size has output type that is not int32 or int64" ); |
265 | } |
266 | *output = t; |
267 | *success = true; |
268 | } |
269 | } |
270 | return OkStatus(); |
271 | } |
272 | |
273 | // Returns true if 'node' has a registered CPU kernel. |
274 | bool HasCpuKernel(const Node& node) { |
275 | return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr, |
276 | /*kernel_class_name=*/nullptr) |
277 | .ok(); |
278 | } |
279 | |
280 | Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) { |
281 | DCHECK(node->IsArg()); |
282 | TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index" , index)); |
283 | if (*index < 0 || num_function_inputs <= *index) { |
284 | return errors::Internal( |
285 | "Function instantiation included invalid input index: " , index, |
286 | " not in [0, " , num_function_inputs, ")." ); |
287 | } |
288 | return OkStatus(); |
289 | } |
290 | |
291 | // Extracts the subgraph ending at 'target_node' that is statically computable |
292 | // and inserts into 'out_graph'. If statically computable, 'is_constant_graph' |
293 | // will be set to true. |
294 | Status ( |
295 | const Node& target_node, const ShapeRefiner& refiner, |
296 | const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph, |
297 | bool* is_constant_graph, |
298 | std::vector<std::pair<string, Tensor>>* const_inputs, |
299 | InferenceContext* outer_context) { |
300 | *is_constant_graph = false; |
301 | std::unordered_set<string> const_inputs_added; |
302 | if (target_node.op_def().is_stateful()) { |
303 | return OkStatus(); |
304 | } |
305 | |
306 | if (IsMerge(&target_node)) { |
307 | return OkStatus(); |
308 | } |
309 | |
310 | if (target_node.type_string() == "PlaceholderWithDefault" ) { |
311 | return OkStatus(); |
312 | } |
313 | |
314 | // Since constant-folding runs on the CPU, do not attempt to constant-fold |
315 | // operators that have no CPU kernel. |
316 | if (!HasCpuKernel(target_node)) { |
317 | return OkStatus(); |
318 | } |
319 | |
320 | // TODO(skyewm): should more of the filtering applied in input nodes below be |
321 | // applied to target_node here? |
322 | |
323 | // Identify the possibly constant subgraph by recursively iterating backwards |
324 | // through the inputs to 'target_node' until we either 1) find an already |
325 | // existing input to our subgraph 'const_inputs', 2) Discover our graph is not |
326 | // constant, or 3) Hit a root node. |
327 | |
328 | struct NodeAndRecursed { |
329 | Node* new_node = nullptr; |
330 | bool recursed = false; |
331 | }; |
332 | |
333 | std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed; |
334 | Node* target_node_copy = out_graph->CopyNode(&target_node); |
335 | old_to_new_and_recursed[&target_node].new_node = target_node_copy; |
336 | old_to_new_and_recursed[&target_node].recursed = true; |
337 | |
338 | // Add the target node's inputs to seed the recursion. |
339 | std::deque<const Edge*> edges_to_visit; |
340 | for (const Edge* e : target_node.in_edges()) { |
341 | // TODO(skyewm): control edges will be meaningful if/when we handle control |
342 | // flow (e.g. constants in cond branches are triggered via control edges). |
343 | if (e->IsControlEdge()) continue; |
344 | edges_to_visit.push_back(e); |
345 | } |
346 | |
347 | *is_constant_graph = true; |
348 | |
349 | // Iterate over the set of edges to visit (backwards). |
350 | while (!edges_to_visit.empty()) { |
351 | const Edge* current_edge = edges_to_visit.front(); |
352 | edges_to_visit.pop_front(); |
353 | Node* current_node = current_edge->src(); |
354 | |
355 | // If the node is stateful, assume the graph is not constant unless it is |
356 | // an Arg node which is handled later on. |
357 | if (!current_node->IsArg() && current_node->op_def().is_stateful()) { |
358 | *is_constant_graph = false; |
359 | return OkStatus(); |
360 | } |
361 | |
362 | // During construction or import from GraphConstructor, back edges may not |
363 | // be filled in. In addition, control flow constructs may depend on control |
364 | // edges which aren't handled by this method. Don't constant fold through |
365 | // merges at all for now. |
366 | if (IsMerge(current_node)) { |
367 | *is_constant_graph = false; |
368 | return OkStatus(); |
369 | } |
370 | |
371 | // Don't constant fold enter/exit currently either, as it's easy to end |
372 | // up with a partial frame. |
373 | if (IsEnter(current_node) || IsExit(current_node)) { |
374 | *is_constant_graph = false; |
375 | return OkStatus(); |
376 | } |
377 | |
378 | // Placeholders should never be constant folded because their outputs are |
379 | // fed by the user. Note that "Placeholder" nodes have no inputs so are |
380 | // handled below. |
381 | if (current_node->type_string() == "PlaceholderWithDefault" ) { |
382 | *is_constant_graph = false; |
383 | return OkStatus(); |
384 | } |
385 | |
386 | if (!HasCpuKernel(*current_node)) { |
387 | *is_constant_graph = false; |
388 | return OkStatus(); |
389 | } |
390 | |
391 | // If there is nothing more to recurse down, see if |
392 | // the generator node is a constant or an Arg node whose value is available |
393 | // in the `outer_context`. |
394 | if (current_node->num_inputs() == 0) { |
395 | if (outer_context && current_node->IsArg()) { |
396 | const string& tensor_name = |
397 | strings::StrCat(current_node->name(), ":" , 0); |
398 | // If we do not already have a constant Tensor for this Arg try to |
399 | // fetch it from the outer context. |
400 | if (const_inputs_added.count(tensor_name) == 0) { |
401 | int index; |
402 | TF_RETURN_IF_ERROR(GetArgNodeIndex( |
403 | current_node, outer_context->num_inputs(), &index)); |
404 | const Tensor* const_tensor = outer_context->input_tensor(index); |
405 | if (const_tensor) { |
406 | const_inputs->emplace_back(tensor_name, *const_tensor); |
407 | const_inputs_added.insert(tensor_name); |
408 | } else { |
409 | // Request a constant value for this Arg. If that is statically |
410 | // computable, shape refiner will re-run the shape inference for |
411 | // this function with this tensor's value. |
412 | outer_context->request_input_tensor(index); |
413 | *is_constant_graph = false; |
414 | return OkStatus(); |
415 | } |
416 | } |
417 | } else if (!current_node->IsConstant()) { |
418 | // Generator node is not a constant, so subgraph is not |
419 | // constant. |
420 | *is_constant_graph = false; |
421 | return OkStatus(); |
422 | } |
423 | } |
424 | |
425 | // Either the node is a constant, or the node is a potential |
426 | // intermediate node on the path from a constant. |
427 | // |
428 | // Add a copy of its node and a new edge to the new subgraph. |
429 | |
430 | // Get or create the version of 'current_node' in the new graph. |
431 | Node* current_node_copy; |
432 | // This gets or creates the NodeAndRecursed entry for current_node. |
433 | NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node]; |
434 | if (node_and_recursed->new_node == nullptr) { |
435 | // First time processing this node. |
436 | current_node_copy = out_graph->CopyNode(current_node); |
437 | // Track the mapping from the original node to the new one. |
438 | node_and_recursed->new_node = current_node_copy; |
439 | } else { |
440 | current_node_copy = node_and_recursed->new_node; |
441 | } |
442 | |
443 | // Add the edge to the destination node. |
444 | { |
445 | auto it = old_to_new_and_recursed.find(current_edge->dst()); |
446 | if (it == old_to_new_and_recursed.end()) { |
447 | return errors::Internal( |
448 | "Could not find mapping from old to new copy of destination node: " , |
449 | current_edge->dst()->name()); |
450 | } |
451 | Node* dst_copy = it->second.new_node; |
452 | |
453 | out_graph->AddEdge(current_node_copy, current_edge->src_output(), |
454 | dst_copy, current_edge->dst_input()); |
455 | } |
456 | |
457 | const string& output_tensor_name = |
458 | strings::StrCat(current_node->name(), ":" , current_edge->src_output()); |
459 | |
460 | // Some tensor values can be inferred. For example, a shape op |
461 | // with input shapes fully defined can have its output tensor inferred. |
462 | Tensor tensor_inferred; |
463 | bool successfully_inferred_tensor = false; |
464 | TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes( |
465 | *current_edge, refiner, &tensor_inferred, |
466 | &successfully_inferred_tensor)); |
467 | if (successfully_inferred_tensor) { |
468 | const_inputs->emplace_back(output_tensor_name, tensor_inferred); |
469 | const_inputs_added.insert(output_tensor_name); |
470 | continue; |
471 | } |
472 | |
473 | // If we have a copy of the input tensor materialized already, |
474 | // then add to the list of inputs to feed and do not recurse further. |
475 | if (cached_values != nullptr) { |
476 | auto it = cached_values->find(output_tensor_name); |
477 | if (it != cached_values->end() && |
478 | const_inputs_added.count(output_tensor_name) == 0) { |
479 | const_inputs->emplace_back(output_tensor_name, it->second); |
480 | const_inputs_added.insert(output_tensor_name); |
481 | continue; |
482 | } |
483 | } |
484 | |
485 | // If this node's inputs have not been processed already, do so now. |
486 | if (!node_and_recursed->recursed) { |
487 | node_and_recursed->recursed = true; |
488 | for (const Edge* e : current_node->in_edges()) { |
489 | if (e->IsControlEdge()) continue; |
490 | edges_to_visit.push_back(e); |
491 | } |
492 | } |
493 | } |
494 | return OkStatus(); |
495 | } |
496 | |
497 | } // namespace |
498 | |
499 | Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner, |
500 | const OpRegistryInterface& ops, |
501 | int32 graph_def_version, bool* evaluated, |
502 | Tensor* result, GraphRunner* graph_runner, |
503 | std::unordered_map<string, Tensor>* cached_values, |
504 | int64 max_cached_value_size, |
505 | bool disable_constant_propagation, |
506 | InferenceContext* outer_context) { |
507 | *evaluated = false; |
508 | const Node* src = tensor.node; |
509 | |
510 | // Simple case: the source node is a constant |
511 | TF_RETURN_IF_ERROR(EvaluateConstantNode(*src, result, evaluated)); |
512 | if (*evaluated) return OkStatus(); |
513 | |
514 | // Shape Slice: the source node is slicing a single value of a shape |
515 | // This is needed to handle the case where the StridedSlice is the only |
516 | // SubGraph and there are no other subgraphs as in a simple expression such as |
517 | // tf.shape([-1, 10])[-1] (the ExtractConstantSubgraph call below |
518 | // only looks at all the input srcs of the various edges; there is never a |
519 | // chance to evaluate the StridedSlice node as it is never an input src). |
520 | if (src->type_string() == "StridedSlice" ) { |
521 | Tensor slice_output; |
522 | TF_RETURN_IF_ERROR(TryToInferTensorOutputFromStridedSliceNode( |
523 | *src, refiner, &slice_output, evaluated)); |
524 | if (*evaluated) { |
525 | *result = slice_output; |
526 | return OkStatus(); |
527 | } |
528 | } |
529 | |
530 | // If the source node is an Arg return its value, if available in the outer |
531 | // context. |
532 | if (src->IsArg() && outer_context) { |
533 | int index; |
534 | TF_RETURN_IF_ERROR( |
535 | GetArgNodeIndex(src, outer_context->num_inputs(), &index)); |
536 | const Tensor* const_tensor = outer_context->input_tensor(index); |
537 | if (const_tensor) { |
538 | *evaluated = true; |
539 | *result = *(outer_context->input_tensor(index)); |
540 | } else { |
541 | outer_context->request_input_tensor(index); |
542 | } |
543 | return OkStatus(); |
544 | } |
545 | |
546 | if (disable_constant_propagation) { |
547 | return OkStatus(); |
548 | } |
549 | |
550 | bool is_constant_graph = false; |
551 | Graph subgraph(&ops); |
552 | auto versions = subgraph.versions(); |
553 | versions.set_producer(graph_def_version); |
554 | subgraph.set_versions(versions); |
555 | |
556 | std::vector<std::pair<string, Tensor>> const_inputs; |
557 | TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values, |
558 | &subgraph, &is_constant_graph, |
559 | &const_inputs, outer_context)); |
560 | if (!is_constant_graph) { |
561 | return OkStatus(); |
562 | } |
563 | const string output_tensor_name = |
564 | strings::StrCat(src->name(), ":" , tensor.index); |
565 | std::vector<Tensor> outputs; |
566 | |
567 | std::unique_ptr<GraphRunner> graph_runner_storage; |
568 | if (graph_runner == nullptr) { |
569 | // TODO(skyewm): Convert to std::make_unique when available. |
570 | graph_runner_storage.reset(new GraphRunner(Env::Default())); |
571 | graph_runner = graph_runner_storage.get(); |
572 | } |
573 | |
574 | // NOTE; we should pass in a function library runtime if we want |
575 | // to support constant-expression evaluation on functions. |
576 | Status s = graph_runner->Run(&subgraph, nullptr /* function_library */, |
577 | const_inputs, {output_tensor_name}, &outputs); |
578 | |
579 | // If all kernels in the constant graph are not registered |
580 | // in the process, GraphRunner::Run may fail, in which case |
581 | // we cannot propagate constants, so this is best-effort. |
582 | if (s.ok()) { |
583 | *result = outputs[0]; |
584 | *evaluated = true; |
585 | |
586 | // We memoize (small) constants evaluated so far, so |
587 | // ExtractConstantSubgraph can avoid extracting the full |
588 | // subgraph. As we build up large graphs, this avoids |
589 | // repeated computation of the early parts of a constant |
590 | // graph. |
591 | if (cached_values != nullptr && |
592 | outputs[0].TotalBytes() <= max_cached_value_size) { |
593 | (*cached_values)[output_tensor_name] = outputs[0]; |
594 | } |
595 | } |
596 | return OkStatus(); |
597 | } |
598 | |
599 | } // namespace tensorflow |
600 | |