1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ |
18 | |
19 | #include "tensorflow/core/common_runtime/device.h" |
20 | #include "tensorflow/core/framework/function.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/graph/graph.h" |
23 | #include "tensorflow/core/platform/env.h" |
24 | |
25 | // TODO(skyewm): can this be combined with EvaluateConstantTensor? |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // This generator type is used to generate a name for the newly folded node |
30 | // based on the node's old name. |
31 | using ConstantFoldNameGenerator = |
32 | std::function<string(Graph* graph, string old_name)>; |
33 | |
34 | // Options specific to constant folding optimizations. |
35 | struct ConstantFoldingOptions { |
36 | // If "consider" is not a nullptr, then only constant fold a node "n" if |
37 | // consider(n) returns true. |
38 | std::function<bool(const Node*)> consider = nullptr; |
39 | // If shape_map is not a nullptr, it is a map from node n to a |
40 | // vector of the (potentially partially-known) shapes of its |
41 | // outputs. |
42 | const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map = |
43 | nullptr; // not owned |
44 | // The maximum size of each constant created during constant folding |
45 | // optimization. |
46 | int64_t max_constant_size_in_bytes = 10 * 1024 * 1024; |
47 | |
48 | // A generator for the name suffix of constant folded nodes. A |
49 | // default id generator that monotonically increases is used if nullptr is |
50 | // passed. |
51 | ConstantFoldNameGenerator generate_new_name = nullptr; |
52 | }; |
53 | |
54 | // Perform constant folding optimization on "graph". |
55 | // Looks for nodes in "graph" that can be completely evaluated statically, i.e., |
56 | // that are only dependent on constants. Evaluates those nodes on a CPU device |
57 | // and replaces those nodes with the result of the evaluation. |
58 | // "partition_device", if non-null, is the device where all the graph nodes are |
59 | // assumed to execute. |
60 | // Sets `was_mutated` to true if and only if "graph" has been mutated. |
61 | // The status is only set to a non-OK state if an unexpected error is hit |
62 | // running the graph. |
63 | Status ConstantFold(const ConstantFoldingOptions& opts, |
64 | FunctionLibraryRuntime* function_library, Env* env, |
65 | const Device* partition_device, Graph* graph, |
66 | bool* was_mutated); |
67 | |
68 | } // namespace tensorflow |
69 | |
70 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ |
71 | |