1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
17
18#include "tensorflow/core/graph/graph.h"
19#include "tensorflow/core/lib/core/status.h"
20
21// TODO(skyewm): can this be combined with ConstantFold?
22
23namespace tensorflow {
24
25class GraphRunner;
26class OpRegistryInterface;
27class ShapeRefiner;
28class Tensor;
29
30// Attempts to evaluate `tensor`. This will only be possible if `tensor` doesn't
31// depend on any graph inputs (this function is safe to call if this isn't the
32// case though).
33//
34// If the evaluation is successful, `evaluated` will be set to true and
35// `tensor`s value returned in `result`. Otherwise `evaluated` will be set to
36// false. An error status is returned if something is wrong with the graph or
37// input. Note that `evaluated` may set to false if OkStatus() is returned.
38//
39// Params:
40// tensor - the tensor to be evaluated.
41// refiner - used to fetch the InferenceContexts for nodes in the graph.
42// ops - the OpRegistryInterface for the graph.
43// graph_def_version - the producer version of the graph.
44// evaluated - output param indicating whether evaluation was successful.
45// result - output param containing the result if evaluated is true.
46// graph_runner - optional. If not set, a GraphRunner will be created for
47// evaluating tensor. This can be set to avoid creating a new GraphRunner
48// for every call.
49// cached_values - optional. This can be used to cache evaluated results
50// across calls, to avoid evaluating the same parts of the graph multiple
51// times.
52// max_cached_value_size - optional. If `cached_values` is set, the maximum
53// result size to cache.
54// disable_constant_propagation - if true, only Const node values will be
55// returned.
56// outer_context - optional. The InferenceContext for the call node if inside
57// a nested function. This is useful for doing constant propagation across
58// Arg nodes.
59Status EvaluateConstantTensor(
60 OutputTensor tensor, const ShapeRefiner& refiner,
61 const OpRegistryInterface& ops, int32_t graph_def_version, bool* evaluated,
62 Tensor* result, GraphRunner* graph_runner = nullptr,
63 std::unordered_map<string, Tensor>* cached_values = nullptr,
64 int64_t max_cached_value_size = 1024,
65 bool disable_constant_propagation = false,
66 shape_inference::InferenceContext* outer_context = nullptr);
67
68} // namespace tensorflow
69
70#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
71