1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
18
19#include "tensorflow/core/common_runtime/device.h"
20#include "tensorflow/core/framework/function.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/graph/graph.h"
23#include "tensorflow/core/platform/env.h"
24
25// TODO(skyewm): can this be combined with EvaluateConstantTensor?
26
27namespace tensorflow {
28
29// This generator type is used to generate a name for the newly folded node
30// based on the node's old name.
31using ConstantFoldNameGenerator =
32 std::function<string(Graph* graph, string old_name)>;
33
34// Options specific to constant folding optimizations.
35struct ConstantFoldingOptions {
36 // If "consider" is not a nullptr, then only constant fold a node "n" if
37 // consider(n) returns true.
38 std::function<bool(const Node*)> consider = nullptr;
39 // If shape_map is not a nullptr, it is a map from node n to a
40 // vector of the (potentially partially-known) shapes of its
41 // outputs.
42 const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map =
43 nullptr; // not owned
44 // The maximum size of each constant created during constant folding
45 // optimization.
46 int64_t max_constant_size_in_bytes = 10 * 1024 * 1024;
47
48 // A generator for the name suffix of constant folded nodes. A
49 // default id generator that monotonically increases is used if nullptr is
50 // passed.
51 ConstantFoldNameGenerator generate_new_name = nullptr;
52};
53
54// Perform constant folding optimization on "graph".
55// Looks for nodes in "graph" that can be completely evaluated statically, i.e.,
56// that are only dependent on constants. Evaluates those nodes on a CPU device
57// and replaces those nodes with the result of the evaluation.
58// "partition_device", if non-null, is the device where all the graph nodes are
59// assumed to execute.
60// Sets `was_mutated` to true if and only if "graph" has been mutated.
61// The status is only set to a non-OK state if an unexpected error is hit
62// running the graph.
63Status ConstantFold(const ConstantFoldingOptions& opts,
64 FunctionLibraryRuntime* function_library, Env* env,
65 const Device* partition_device, Graph* graph,
66 bool* was_mutated);
67
68} // namespace tensorflow
69
70#endif // TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
71