1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ |
17 | #define TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ |
18 | |
19 | #include <memory> |
20 | #include <string> |
21 | #include <unordered_map> |
22 | #include <unordered_set> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | #include "tensorflow/core/framework/graph.pb.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/variable.pb.h" |
29 | #include "tensorflow/core/protobuf/queue_runner.pb.h" |
30 | |
31 | namespace tensorflow { |
32 | namespace grappler { |
33 | |
34 | // A TensorFlow model to optimize. |
35 | // Models are represented by the combination of a graph, one of more fetch |
36 | // nodes, and potentially a set of nodes to feed. |
37 | struct GrapplerItem { |
38 | GrapplerItem() = default; |
39 | GrapplerItem(const GrapplerItem& other) = default; |
40 | GrapplerItem(GrapplerItem&& other) = default; |
41 | GrapplerItem& operator=(const GrapplerItem& other) = default; |
42 | GrapplerItem& operator=(GrapplerItem&& other) = default; |
43 | virtual ~GrapplerItem() = default; |
44 | |
45 | // Create a copy of this GrapplerItem with graph swapped with the argument. |
46 | GrapplerItem WithGraph(GraphDef&& graph) const; |
47 | |
48 | string id; // A unique id for this item |
49 | |
50 | // Inputs |
51 | GraphDef graph; |
52 | std::vector<std::pair<string, Tensor>> feed; |
53 | std::vector<string> fetch; |
54 | |
55 | // Initialization op(s). |
56 | std::vector<string> init_ops; |
57 | // Expected initialization time in seconds, or 0 if unknown |
58 | int64_t expected_init_time = 0; |
59 | |
60 | // Save/restore ops (if any) |
61 | string save_op; |
62 | string restore_op; |
63 | string save_restore_loc_tensor; |
64 | |
65 | // Queue runner(s) required to run the queue(s) of this model. |
66 | std::vector<QueueRunnerDef> queue_runners; |
67 | |
68 | // List of op names to keep in the graph. This includes nodes that are |
69 | // referenced in various collections, and therefore must be preserved to |
70 | // ensure that the optimized metagraph can still be loaded. |
71 | std::vector<string> keep_ops; |
72 | |
73 | // Return the set of node evaluated during a regular train/inference step. |
74 | std::vector<const NodeDef*> MainOpsFanin() const; |
75 | // Return the set of node run to populate the queues (if any). |
76 | std::vector<const NodeDef*> EnqueueOpsFanin() const; |
77 | // Return the set nodes used by TensorFlow to initialize the graph. |
78 | std::vector<const NodeDef*> InitOpsFanin() const; |
79 | // Return the set of variables accessed during a regular train/inference step. |
80 | std::vector<const NodeDef*> MainVariables() const; |
81 | // Return a set of node names that must be preserved. This includes feed and |
82 | // fetch nodes, keep_ops, init_ops. |
83 | std::unordered_set<string> NodesToPreserve() const; |
84 | |
85 | struct OptimizationOptions { |
86 | // Is it allowed to add nodes to the graph that do not have registered |
87 | // gradient function. |
88 | bool allow_non_differentiable_rewrites = true; |
89 | |
90 | // Tensorflow function execution semantics is slightly different from the |
91 | // main Tensorflow graph, and we need to make sure that we do not change it |
92 | // by running Grappler optimizer passes. One main difference is that |
93 | // functions do not prune ops with side-effects and dataset-output ops (see |
94 | // PruneFunctionBody in common_runtime/function.cc). |
95 | bool allow_pruning_stateful_and_dataset_ops = true; |
96 | |
97 | // If true Grappler will optimize the main graph, and also all functions in |
98 | // the graph function library (function can't be polymorphic, it can't have |
99 | // undefined type parameters in the function signature, or placeholder |
100 | // attributes in the function body). |
101 | bool optimize_function_library = true; |
102 | |
103 | // Mark the grapper optimization run in eager mode or not. |
104 | bool is_eager_mode = false; |
105 | }; |
106 | |
107 | const std::unordered_set<string>& devices() const; |
108 | // Adds a device to a set of available devices, only if it's a valid fully |
109 | // defined device name. Returns `OkStatus()` if successfully added a device, |
110 | // and an error otherwise. |
111 | Status AddDevice(const string& device); |
112 | // Adds all valid devices from the other Grappler item to the device set. |
113 | Status AddDevices(const GrapplerItem& other); |
114 | // Adds all valid devices from the nodes of the graph to the device set. |
115 | // Returns `OkStatus()` if all device annotations found in a graph are valid |
116 | // fully defined device names, and an error otherwise. |
117 | Status InferDevicesFromGraph(); |
118 | // Clears a set of available devices. |
119 | void ClearDevices(); |
120 | |
121 | const OptimizationOptions& optimization_options() const; |
122 | OptimizationOptions& optimization_options(); |
123 | |
124 | private: |
125 | // TODO(ezhulenev) Make GrapplerItem a class and hide all public data members. |
126 | // TODO(ezhulenev): Migrate all unordered collections to absl. |
127 | |
128 | // A set of fully defined device names that can be used to place the nodes of |
129 | // the `graph`. |
130 | // Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0" |
131 | std::unordered_set<string> devices_; |
132 | |
133 | OptimizationOptions optimization_options_; |
134 | }; |
135 | |
136 | GrapplerItem::OptimizationOptions CreateOptOptionsForEager(); |
137 | |
138 | } // end namespace grappler |
139 | } // end namespace tensorflow |
140 | |
141 | #endif // TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ |
142 | |