1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_
17#define TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_
18
19#include <memory>
20#include <string>
21#include <unordered_map>
22#include <unordered_set>
23#include <utility>
24#include <vector>
25
26#include "tensorflow/core/framework/graph.pb.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/variable.pb.h"
29#include "tensorflow/core/protobuf/queue_runner.pb.h"
30
31namespace tensorflow {
32namespace grappler {
33
34// A TensorFlow model to optimize.
35// Models are represented by the combination of a graph, one of more fetch
36// nodes, and potentially a set of nodes to feed.
37struct GrapplerItem {
38 GrapplerItem() = default;
39 GrapplerItem(const GrapplerItem& other) = default;
40 GrapplerItem(GrapplerItem&& other) = default;
41 GrapplerItem& operator=(const GrapplerItem& other) = default;
42 GrapplerItem& operator=(GrapplerItem&& other) = default;
43 virtual ~GrapplerItem() = default;
44
45 // Create a copy of this GrapplerItem with graph swapped with the argument.
46 GrapplerItem WithGraph(GraphDef&& graph) const;
47
48 string id; // A unique id for this item
49
50 // Inputs
51 GraphDef graph;
52 std::vector<std::pair<string, Tensor>> feed;
53 std::vector<string> fetch;
54
55 // Initialization op(s).
56 std::vector<string> init_ops;
57 // Expected initialization time in seconds, or 0 if unknown
58 int64_t expected_init_time = 0;
59
60 // Save/restore ops (if any)
61 string save_op;
62 string restore_op;
63 string save_restore_loc_tensor;
64
65 // Queue runner(s) required to run the queue(s) of this model.
66 std::vector<QueueRunnerDef> queue_runners;
67
68 // List of op names to keep in the graph. This includes nodes that are
69 // referenced in various collections, and therefore must be preserved to
70 // ensure that the optimized metagraph can still be loaded.
71 std::vector<string> keep_ops;
72
73 // Return the set of node evaluated during a regular train/inference step.
74 std::vector<const NodeDef*> MainOpsFanin() const;
75 // Return the set of node run to populate the queues (if any).
76 std::vector<const NodeDef*> EnqueueOpsFanin() const;
77 // Return the set nodes used by TensorFlow to initialize the graph.
78 std::vector<const NodeDef*> InitOpsFanin() const;
79 // Return the set of variables accessed during a regular train/inference step.
80 std::vector<const NodeDef*> MainVariables() const;
81 // Return a set of node names that must be preserved. This includes feed and
82 // fetch nodes, keep_ops, init_ops.
83 std::unordered_set<string> NodesToPreserve() const;
84
85 struct OptimizationOptions {
86 // Is it allowed to add nodes to the graph that do not have registered
87 // gradient function.
88 bool allow_non_differentiable_rewrites = true;
89
90 // Tensorflow function execution semantics is slightly different from the
91 // main Tensorflow graph, and we need to make sure that we do not change it
92 // by running Grappler optimizer passes. One main difference is that
93 // functions do not prune ops with side-effects and dataset-output ops (see
94 // PruneFunctionBody in common_runtime/function.cc).
95 bool allow_pruning_stateful_and_dataset_ops = true;
96
97 // If true Grappler will optimize the main graph, and also all functions in
98 // the graph function library (function can't be polymorphic, it can't have
99 // undefined type parameters in the function signature, or placeholder
100 // attributes in the function body).
101 bool optimize_function_library = true;
102
103 // Mark the grapper optimization run in eager mode or not.
104 bool is_eager_mode = false;
105 };
106
107 const std::unordered_set<string>& devices() const;
108 // Adds a device to a set of available devices, only if it's a valid fully
109 // defined device name. Returns `OkStatus()` if successfully added a device,
110 // and an error otherwise.
111 Status AddDevice(const string& device);
112 // Adds all valid devices from the other Grappler item to the device set.
113 Status AddDevices(const GrapplerItem& other);
114 // Adds all valid devices from the nodes of the graph to the device set.
115 // Returns `OkStatus()` if all device annotations found in a graph are valid
116 // fully defined device names, and an error otherwise.
117 Status InferDevicesFromGraph();
118 // Clears a set of available devices.
119 void ClearDevices();
120
121 const OptimizationOptions& optimization_options() const;
122 OptimizationOptions& optimization_options();
123
124 private:
125 // TODO(ezhulenev) Make GrapplerItem a class and hide all public data members.
126 // TODO(ezhulenev): Migrate all unordered collections to absl.
127
128 // A set of fully defined device names that can be used to place the nodes of
129 // the `graph`.
130 // Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0"
131 std::unordered_set<string> devices_;
132
133 OptimizationOptions optimization_options_;
134};
135
136GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
137
138} // end namespace grappler
139} // end namespace tensorflow
140
141#endif // TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_
142