1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ |
17 | |
18 | #include "tensorflow/core/graph/graph.h" |
19 | #include "tensorflow/core/lib/core/status.h" |
20 | |
21 | // TODO(skyewm): can this be combined with ConstantFold? |
22 | |
23 | namespace tensorflow { |
24 | |
25 | class GraphRunner; |
26 | class OpRegistryInterface; |
27 | class ShapeRefiner; |
28 | class Tensor; |
29 | |
30 | // Attempts to evaluate `tensor`. This will only be possible if `tensor` doesn't |
31 | // depend on any graph inputs (this function is safe to call if this isn't the |
32 | // case though). |
33 | // |
34 | // If the evaluation is successful, `evaluated` will be set to true and |
35 | // `tensor`s value returned in `result`. Otherwise `evaluated` will be set to |
36 | // false. An error status is returned if something is wrong with the graph or |
37 | // input. Note that `evaluated` may set to false if OkStatus() is returned. |
38 | // |
39 | // Params: |
40 | // tensor - the tensor to be evaluated. |
41 | // refiner - used to fetch the InferenceContexts for nodes in the graph. |
42 | // ops - the OpRegistryInterface for the graph. |
43 | // graph_def_version - the producer version of the graph. |
44 | // evaluated - output param indicating whether evaluation was successful. |
45 | // result - output param containing the result if evaluated is true. |
46 | // graph_runner - optional. If not set, a GraphRunner will be created for |
47 | // evaluating tensor. This can be set to avoid creating a new GraphRunner |
48 | // for every call. |
49 | // cached_values - optional. This can be used to cache evaluated results |
50 | // across calls, to avoid evaluating the same parts of the graph multiple |
51 | // times. |
52 | // max_cached_value_size - optional. If `cached_values` is set, the maximum |
53 | // result size to cache. |
54 | // disable_constant_propagation - if true, only Const node values will be |
55 | // returned. |
56 | // outer_context - optional. The InferenceContext for the call node if inside |
57 | // a nested function. This is useful for doing constant propagation across |
58 | // Arg nodes. |
59 | Status EvaluateConstantTensor( |
60 | OutputTensor tensor, const ShapeRefiner& refiner, |
61 | const OpRegistryInterface& ops, int32_t graph_def_version, bool* evaluated, |
62 | Tensor* result, GraphRunner* graph_runner = nullptr, |
63 | std::unordered_map<string, Tensor>* cached_values = nullptr, |
64 | int64_t max_cached_value_size = 1024, |
65 | bool disable_constant_propagation = false, |
66 | shape_inference::InferenceContext* outer_context = nullptr); |
67 | |
68 | } // namespace tensorflow |
69 | |
70 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ |
71 | |