1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/data/rewrite_utils.h"
16
17#include "tensorflow/core/platform/refcount.h"
18
19// On mobile we do not provide this functionality because not all of its
20// dependencies are available there.
21#if !defined(IS_MOBILE_PLATFORM)
22
23#include <algorithm>
24#include <functional>
25#include <map>
26#include <memory>
27#include <string>
28#include <unordered_map>
29#include <utility>
30#include <vector>
31
32#include "absl/container/flat_hash_set.h"
33#include "absl/strings/str_cat.h"
34#include "absl/strings/substitute.h"
35#include "tensorflow/core/common_runtime/graph_constructor.h"
36#include "tensorflow/core/common_runtime/graph_runner.h"
37#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
38#include "tensorflow/core/data/dataset_utils.h"
39#include "tensorflow/core/data/hash_utils.h"
40#include "tensorflow/core/data/serialization_utils.h"
41#include "tensorflow/core/framework/dataset.h"
42#include "tensorflow/core/framework/function.h"
43#include "tensorflow/core/framework/function.pb.h"
44#include "tensorflow/core/framework/graph.pb.h"
45#include "tensorflow/core/framework/metrics.h"
46#include "tensorflow/core/framework/node_def.pb.h"
47#include "tensorflow/core/framework/op.h"
48#include "tensorflow/core/framework/op_def_util.h"
49#include "tensorflow/core/framework/op_kernel.h"
50#include "tensorflow/core/framework/tensor.h"
51#include "tensorflow/core/graph/graph.h"
52#include "tensorflow/core/graph/graph_def_builder.h"
53#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
54#include "tensorflow/core/grappler/graph_view.h"
55#include "tensorflow/core/grappler/grappler_item.h"
56#include "tensorflow/core/grappler/grappler_item_builder.h"
57#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
58#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
59#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
60#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
61#include "tensorflow/core/lib/hash/hash.h"
62#include "tensorflow/core/lib/strings/proto_serialization.h"
63#include "tensorflow/core/platform/errors.h"
64#include "tensorflow/core/platform/status.h"
65#include "tensorflow/core/platform/statusor.h"
66#include "tensorflow/core/platform/tstring.h"
67#include "tensorflow/core/protobuf/config.pb.h"
68#include "tensorflow/core/protobuf/device_properties.pb.h"
69#include "tensorflow/core/protobuf/meta_graph.pb.h"
70#include "tensorflow/core/protobuf/rewriter_config.pb.h"
71
72namespace tensorflow {
73namespace data {
74namespace {
75
76constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
77constexpr char kOptimizers[] = "optimizers";
78constexpr char kOptimizerConfigs[] = "optimizer_configs";
79
80void AddFakeSinks(FunctionDef* function_def) {
81 int counter = 0;
82 for (const auto& output : function_def->signature().output_arg()) {
83 NodeDef* node = function_def->add_node_def();
84 tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
85 strings::StrCat("FakeSink", counter++), function_def, node);
86 node->set_op("Identity");
87 node->add_input(function_def->ret().at(output.name()));
88 (*node->mutable_attr())["T"].set_type(output.type());
89
90 (*function_def->mutable_ret())[output.name()] =
91 strings::StrCat(node->name(), ":output:0");
92 }
93}
94
95void RemoveFakeSinks(FunctionDef* function_def) {
96 // Map from identity node names to their input tensor strings
97 std::map<std::string, std::string> identity_map;
98 for (const auto& node : function_def->node_def()) {
99 if (node.op() == "Identity" && node.input_size() == 1) {
100 identity_map[node.name()] = node.input(0);
101 }
102 }
103 for (const auto& output_arg : function_def->signature().output_arg()) {
104 const std::string& tensor = function_def->ret().at(output_arg.name());
105 const std::string& output_node = tensor.substr(0, tensor.find(':'));
106 if (identity_map.find(output_node) != identity_map.end()) {
107 (*function_def->mutable_ret())[output_arg.name()] =
108 identity_map.at(output_node);
109 }
110 }
111}
112
113Status ApplyRewrites(OpKernelContext* ctx,
114 const std::function<RewriterConfig(void)> config_factory,
115 GraphDef* graph_def, string* dataset_node) {
116 std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
117 GetGrapplerItem(graph_def, dataset_node, /*add_fake_sinks=*/true);
118 std::unordered_map<std::string, tensorflow::DeviceProperties> device_map;
119 tensorflow::grappler::VirtualCluster cluster(device_map);
120
121 // Run data optimizer using grappler's meta optimizer.
122 tensorflow::ConfigProto config;
123 *config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
124 TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
125 std::move(*grappler_item), config, ctx->device(), &cluster, graph_def));
126
127 // Remove fake sinks after optimizations are done.
128 //
129 // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
130 // to be optimizable, we will no longer need this.
131 for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
132 RemoveFakeSinks(&function_def);
133 }
134
135 return OkStatus();
136}
137} // anonymous namespace
138
139RewriterConfig CreateRewriterConfig(
140 const absl::flat_hash_set<tstring>& optimizations,
141 const absl::flat_hash_set<tstring>& optimizations_configs) {
142 RewriterConfig rewriter_config;
143 rewriter_config.add_optimizers(kOptimizerName);
144 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
145 rewriter_config.set_fail_on_optimizer_errors(true);
146 auto custom_optimizer = rewriter_config.add_custom_optimizers();
147 custom_optimizer->set_name(kOptimizerName);
148 auto* custom_optimizations_list =
149 (*custom_optimizer->mutable_parameter_map())[kOptimizers].mutable_list();
150 const auto& registered_optimizers =
151 grappler::CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
152 for (const auto& optimization : optimizations) {
153 if (std::find(registered_optimizers.begin(), registered_optimizers.end(),
154 optimization) != registered_optimizers.end()) {
155 custom_optimizations_list->add_s(optimization.data(),
156 optimization.size());
157 } else {
158 VLOG(1) << "Optimization " << optimization << " is not registered.";
159 }
160 }
161 auto* config_list =
162 (*custom_optimizer->mutable_parameter_map())[kOptimizerConfigs]
163 .mutable_list();
164 for (const auto& config : optimizations_configs) {
165 config_list->add_s(config.data(), config.size());
166 }
167 return rewriter_config;
168}
169
170Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
171 std::function<RewriterConfig(void)> config_factory,
172 bool record_fingerprint,
173 core::RefCountPtr<DatasetBase>* rewritten_input) {
174 std::vector<std::pair<string, Tensor>> input_list;
175 GraphDef graph_def;
176 string output_node;
177 TF_RETURN_IF_ERROR(
178 AsGraphDefForRewrite(ctx, input, &input_list, &graph_def, &output_node));
179
180 VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
181 TF_RETURN_IF_ERROR(
182 ApplyRewrites(ctx, config_factory, &graph_def, &output_node));
183 VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
184
185 // Instantiate the optimized input pipeline by running the optimized graph
186 // using the optimized function library.
187 FunctionLibraryRuntime* flr = nullptr;
188 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
189 std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
190 TF_RETURN_IF_ERROR(
191 ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
192
193 // Some functions may have been modified without having their names changed
194 // (for example, nested dataset graphs from FlatMap or Interleave).
195 TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
196
197 Graph graph(OpRegistry::Global());
198 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
199 std::vector<Tensor> outputs;
200 GraphRunner graph_runner(flr->device());
201
202 TF_RETURN_IF_ERROR(
203 graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
204 DatasetBase* rewritten_dataset;
205 TF_RETURN_IF_ERROR(
206 GetDatasetFromVariantTensor(outputs[0], &rewritten_dataset));
207 rewritten_dataset->Ref();
208 rewritten_input->reset(rewritten_dataset);
209
210 if (record_fingerprint) {
211 (*ctx->runner())([graph_def = std::move(graph_def),
212 lib_def = lib_def.release(),
213 input_list = std::move(input_list),
214 output_node = std::move(output_node)]() {
215 std::unique_ptr<FunctionLibraryDefinition> lib_def_owner(lib_def);
216 const NodeDef* node_def = nullptr;
217 for (const auto& node : graph_def.node()) {
218 if (node.name() == output_node) {
219 node_def = &node;
220 break;
221 }
222 }
223 if (node_def == nullptr) {
224 VLOG(3) << "Failed to find node: " << output_node;
225 return;
226 }
227 uint64 hash = 0;
228 Status s = HashNode(graph_def, *node_def, *lib_def, &hash);
229 if (!s.ok()) {
230 VLOG(3) << "Failed to hash graph: " << s.ToString();
231 return;
232 }
233 for (const auto& pair : input_list) {
234 hash = Hash64CombineUnordered(hash, Hash64(pair.first));
235 uint64 tensor_hash = 0;
236 Status s = HashTensor(pair.second, &tensor_hash);
237 if (s.ok()) {
238 hash = Hash64CombineUnordered(hash, tensor_hash);
239 } else {
240 VLOG(3) << "Failed to hash tensor: " << s.ToString();
241 }
242 }
243 string graph_hash =
244 strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
245 metrics::RecordTFDataFingerprint(graph_hash);
246 });
247 }
248
249 return OkStatus();
250}
251
252std::unique_ptr<tensorflow::grappler::GrapplerItem> GetGrapplerItem(
253 GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks) {
254 // Add an identity node as the fetch node, otherwise we might get 'placeholder
255 // is both fed and fetched' errors in some cases when using input list with
256 // placeholder dataset nodes.
257 NodeDef* node = graph_def->mutable_node()->Add();
258 tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
259 node);
260 node->set_op("Identity");
261 node->add_input(*dataset_node);
262 (*node->mutable_attr())["T"].set_type(DT_VARIANT);
263 *dataset_node = node->name();
264
265 if (add_fake_sinks) {
266 // Add fake sink node to graph and functions to allow rewriting the actual
267 // sink nodes.
268 //
269 // TODO(b/118820916): When MetaOptimizer adds provisions for function
270 // retvals to be optimizable, we will no longer need this.
271 for (auto& function_def :
272 *graph_def->mutable_library()->mutable_function()) {
273 AddFakeSinks(&function_def);
274 }
275 }
276
277 // Create metagraph.
278 MetaGraphDef meta_graph_def;
279 (*meta_graph_def.mutable_graph_def()) = *graph_def;
280
281 // Grappler determines fetch ops from collection 'train_op'.
282 CollectionDef collection_def;
283 auto node_list = collection_def.mutable_node_list();
284 node_list->add_value(*dataset_node);
285 (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
286
287 // Create Grappler item.
288 tensorflow::grappler::ItemConfig item_config;
289 item_config.apply_optimizations = true;
290 std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
291 tensorflow::grappler::GrapplerItemFromMetaGraphDef(
292 "graph", meta_graph_def, item_config);
293 // Grappler should not optimize function library of tf.data graphs. The
294 // tf.data meta optimizer takes care of optimizing tf.data functions.
295 grappler_item->optimization_options().optimize_function_library = false;
296 return grappler_item;
297}
298
299absl::flat_hash_set<tstring> SelectOptimizations(
300 const absl::flat_hash_set<string>& experiments,
301 const absl::flat_hash_set<tstring>& optimizations_enabled,
302 const absl::flat_hash_set<tstring>& optimizations_disabled,
303 const absl::flat_hash_set<tstring>& optimizations_default) {
304 absl::flat_hash_set<tstring> optimizations;
305
306 // Add the enabled optimizations.
307 optimizations.insert(optimizations_enabled.begin(),
308 optimizations_enabled.end());
309
310 // Add all default optimization that are not disabled.
311 for (const auto& optimization : optimizations_default) {
312 if (!optimizations_disabled.contains(optimization)) {
313 optimizations.insert(optimization);
314 }
315 }
316
317 // Add experiments that correspond to an optimization unless the optimization
318 // is disabled.
319 const auto& registered_optimizers =
320 grappler::CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
321 for (const auto& experiment : experiments) {
322 if (std::find(registered_optimizers.begin(), registered_optimizers.end(),
323 experiment) != registered_optimizers.end() &&
324 !optimizations_disabled.contains(experiment)) {
325 optimizations.insert(experiment);
326 }
327 }
328
329 return optimizations;
330}
331
332StatusOr<std::string> GetDatasetNode(const GraphDef& graph_def) {
333 // Symbolic `_Retval` node indicates which node corresponds to the dataset.
334 for (const auto& node : graph_def.node()) {
335 if (node.op() == "_Retval") {
336 return node.input(0);
337 }
338 }
339 return errors::NotFound(
340 absl::Substitute("Dataset node for graph is not found:\n$0",
341 graph_def.ShortDebugString()));
342}
343
344StatusOr<NodeDef> GetDatasetNodeDef(const GraphDef& graph_def) {
345 TF_ASSIGN_OR_RETURN(std::string dataset_node_name, GetDatasetNode(graph_def));
346 for (const auto& node : graph_def.node()) {
347 if (node.name() == dataset_node_name) {
348 return node;
349 }
350 }
351 return errors::NotFound(
352 absl::Substitute("Dataset node for graph is not found:\n$0",
353 graph_def.ShortDebugString()));
354}
355
356} // namespace data
357} // namespace tensorflow
358#endif // !IS_MOBILE_PLATFORM
359