1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/data/rewrite_utils.h" |
16 | |
17 | #include "tensorflow/core/platform/refcount.h" |
18 | |
19 | // On mobile we do not provide this functionality because not all of its |
20 | // dependencies are available there. |
21 | #if !defined(IS_MOBILE_PLATFORM) |
22 | |
23 | #include <algorithm> |
24 | #include <functional> |
25 | #include <map> |
26 | #include <memory> |
27 | #include <string> |
28 | #include <unordered_map> |
29 | #include <utility> |
30 | #include <vector> |
31 | |
32 | #include "absl/container/flat_hash_set.h" |
33 | #include "absl/strings/str_cat.h" |
34 | #include "absl/strings/substitute.h" |
35 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
36 | #include "tensorflow/core/common_runtime/graph_runner.h" |
37 | #include "tensorflow/core/common_runtime/process_function_library_runtime.h" |
38 | #include "tensorflow/core/data/dataset_utils.h" |
39 | #include "tensorflow/core/data/hash_utils.h" |
40 | #include "tensorflow/core/data/serialization_utils.h" |
41 | #include "tensorflow/core/framework/dataset.h" |
42 | #include "tensorflow/core/framework/function.h" |
43 | #include "tensorflow/core/framework/function.pb.h" |
44 | #include "tensorflow/core/framework/graph.pb.h" |
45 | #include "tensorflow/core/framework/metrics.h" |
46 | #include "tensorflow/core/framework/node_def.pb.h" |
47 | #include "tensorflow/core/framework/op.h" |
48 | #include "tensorflow/core/framework/op_def_util.h" |
49 | #include "tensorflow/core/framework/op_kernel.h" |
50 | #include "tensorflow/core/framework/tensor.h" |
51 | #include "tensorflow/core/graph/graph.h" |
52 | #include "tensorflow/core/graph/graph_def_builder.h" |
53 | #include "tensorflow/core/grappler/clusters/virtual_cluster.h" |
54 | #include "tensorflow/core/grappler/graph_view.h" |
55 | #include "tensorflow/core/grappler/grappler_item.h" |
56 | #include "tensorflow/core/grappler/grappler_item_builder.h" |
57 | #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" |
58 | #include "tensorflow/core/grappler/optimizers/data/function_utils.h" |
59 | #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" |
60 | #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" |
61 | #include "tensorflow/core/lib/hash/hash.h" |
62 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
63 | #include "tensorflow/core/platform/errors.h" |
64 | #include "tensorflow/core/platform/status.h" |
65 | #include "tensorflow/core/platform/statusor.h" |
66 | #include "tensorflow/core/platform/tstring.h" |
67 | #include "tensorflow/core/protobuf/config.pb.h" |
68 | #include "tensorflow/core/protobuf/device_properties.pb.h" |
69 | #include "tensorflow/core/protobuf/meta_graph.pb.h" |
70 | #include "tensorflow/core/protobuf/rewriter_config.pb.h" |
71 | |
72 | namespace tensorflow { |
73 | namespace data { |
74 | namespace { |
75 | |
76 | constexpr char kOptimizerName[] = "tf_data_meta_optimizer" ; |
77 | constexpr char kOptimizers[] = "optimizers" ; |
78 | constexpr char kOptimizerConfigs[] = "optimizer_configs" ; |
79 | |
80 | void AddFakeSinks(FunctionDef* function_def) { |
81 | int counter = 0; |
82 | for (const auto& output : function_def->signature().output_arg()) { |
83 | NodeDef* node = function_def->add_node_def(); |
84 | tensorflow::grappler::function_utils::SetUniqueFunctionNodeName( |
85 | strings::StrCat("FakeSink" , counter++), function_def, node); |
86 | node->set_op("Identity" ); |
87 | node->add_input(function_def->ret().at(output.name())); |
88 | (*node->mutable_attr())["T" ].set_type(output.type()); |
89 | |
90 | (*function_def->mutable_ret())[output.name()] = |
91 | strings::StrCat(node->name(), ":output:0" ); |
92 | } |
93 | } |
94 | |
95 | void RemoveFakeSinks(FunctionDef* function_def) { |
96 | // Map from identity node names to their input tensor strings |
97 | std::map<std::string, std::string> identity_map; |
98 | for (const auto& node : function_def->node_def()) { |
99 | if (node.op() == "Identity" && node.input_size() == 1) { |
100 | identity_map[node.name()] = node.input(0); |
101 | } |
102 | } |
103 | for (const auto& output_arg : function_def->signature().output_arg()) { |
104 | const std::string& tensor = function_def->ret().at(output_arg.name()); |
105 | const std::string& output_node = tensor.substr(0, tensor.find(':')); |
106 | if (identity_map.find(output_node) != identity_map.end()) { |
107 | (*function_def->mutable_ret())[output_arg.name()] = |
108 | identity_map.at(output_node); |
109 | } |
110 | } |
111 | } |
112 | |
113 | Status ApplyRewrites(OpKernelContext* ctx, |
114 | const std::function<RewriterConfig(void)> config_factory, |
115 | GraphDef* graph_def, string* dataset_node) { |
116 | std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item = |
117 | GetGrapplerItem(graph_def, dataset_node, /*add_fake_sinks=*/true); |
118 | std::unordered_map<std::string, tensorflow::DeviceProperties> device_map; |
119 | tensorflow::grappler::VirtualCluster cluster(device_map); |
120 | |
121 | // Run data optimizer using grappler's meta optimizer. |
122 | tensorflow::ConfigProto config; |
123 | *config.mutable_graph_options()->mutable_rewrite_options() = config_factory(); |
124 | TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( |
125 | std::move(*grappler_item), config, ctx->device(), &cluster, graph_def)); |
126 | |
127 | // Remove fake sinks after optimizations are done. |
128 | // |
129 | // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals |
130 | // to be optimizable, we will no longer need this. |
131 | for (auto& function_def : *graph_def->mutable_library()->mutable_function()) { |
132 | RemoveFakeSinks(&function_def); |
133 | } |
134 | |
135 | return OkStatus(); |
136 | } |
137 | } // anonymous namespace |
138 | |
139 | RewriterConfig CreateRewriterConfig( |
140 | const absl::flat_hash_set<tstring>& optimizations, |
141 | const absl::flat_hash_set<tstring>& optimizations_configs) { |
142 | RewriterConfig rewriter_config; |
143 | rewriter_config.add_optimizers(kOptimizerName); |
144 | rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE); |
145 | rewriter_config.set_fail_on_optimizer_errors(true); |
146 | auto custom_optimizer = rewriter_config.add_custom_optimizers(); |
147 | custom_optimizer->set_name(kOptimizerName); |
148 | auto* custom_optimizations_list = |
149 | (*custom_optimizer->mutable_parameter_map())[kOptimizers].mutable_list(); |
150 | const auto& registered_optimizers = |
151 | grappler::CustomGraphOptimizerRegistry::GetRegisteredOptimizers(); |
152 | for (const auto& optimization : optimizations) { |
153 | if (std::find(registered_optimizers.begin(), registered_optimizers.end(), |
154 | optimization) != registered_optimizers.end()) { |
155 | custom_optimizations_list->add_s(optimization.data(), |
156 | optimization.size()); |
157 | } else { |
158 | VLOG(1) << "Optimization " << optimization << " is not registered." ; |
159 | } |
160 | } |
161 | auto* config_list = |
162 | (*custom_optimizer->mutable_parameter_map())[kOptimizerConfigs] |
163 | .mutable_list(); |
164 | for (const auto& config : optimizations_configs) { |
165 | config_list->add_s(config.data(), config.size()); |
166 | } |
167 | return rewriter_config; |
168 | } |
169 | |
170 | Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, |
171 | std::function<RewriterConfig(void)> config_factory, |
172 | bool record_fingerprint, |
173 | core::RefCountPtr<DatasetBase>* rewritten_input) { |
174 | std::vector<std::pair<string, Tensor>> input_list; |
175 | GraphDef graph_def; |
176 | string output_node; |
177 | TF_RETURN_IF_ERROR( |
178 | AsGraphDefForRewrite(ctx, input, &input_list, &graph_def, &output_node)); |
179 | |
180 | VLOG(3) << "Before graph rewrites: " << graph_def.DebugString(); |
181 | TF_RETURN_IF_ERROR( |
182 | ApplyRewrites(ctx, config_factory, &graph_def, &output_node)); |
183 | VLOG(3) << "After graph rewrites: " << graph_def.DebugString(); |
184 | |
185 | // Instantiate the optimized input pipeline by running the optimized graph |
186 | // using the optimized function library. |
187 | FunctionLibraryRuntime* flr = nullptr; |
188 | std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr; |
189 | std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr; |
190 | TF_RETURN_IF_ERROR( |
191 | ctx->function_library()->Clone(&lib_def, &pflr, &flr, true)); |
192 | |
193 | // Some functions may have been modified without having their names changed |
194 | // (for example, nested dataset graphs from FlatMap or Interleave). |
195 | TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library())); |
196 | |
197 | Graph graph(OpRegistry::Global()); |
198 | TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); |
199 | std::vector<Tensor> outputs; |
200 | GraphRunner graph_runner(flr->device()); |
201 | |
202 | TF_RETURN_IF_ERROR( |
203 | graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs)); |
204 | DatasetBase* rewritten_dataset; |
205 | TF_RETURN_IF_ERROR( |
206 | GetDatasetFromVariantTensor(outputs[0], &rewritten_dataset)); |
207 | rewritten_dataset->Ref(); |
208 | rewritten_input->reset(rewritten_dataset); |
209 | |
210 | if (record_fingerprint) { |
211 | (*ctx->runner())([graph_def = std::move(graph_def), |
212 | lib_def = lib_def.release(), |
213 | input_list = std::move(input_list), |
214 | output_node = std::move(output_node)]() { |
215 | std::unique_ptr<FunctionLibraryDefinition> lib_def_owner(lib_def); |
216 | const NodeDef* node_def = nullptr; |
217 | for (const auto& node : graph_def.node()) { |
218 | if (node.name() == output_node) { |
219 | node_def = &node; |
220 | break; |
221 | } |
222 | } |
223 | if (node_def == nullptr) { |
224 | VLOG(3) << "Failed to find node: " << output_node; |
225 | return; |
226 | } |
227 | uint64 hash = 0; |
228 | Status s = HashNode(graph_def, *node_def, *lib_def, &hash); |
229 | if (!s.ok()) { |
230 | VLOG(3) << "Failed to hash graph: " << s.ToString(); |
231 | return; |
232 | } |
233 | for (const auto& pair : input_list) { |
234 | hash = Hash64CombineUnordered(hash, Hash64(pair.first)); |
235 | uint64 tensor_hash = 0; |
236 | Status s = HashTensor(pair.second, &tensor_hash); |
237 | if (s.ok()) { |
238 | hash = Hash64CombineUnordered(hash, tensor_hash); |
239 | } else { |
240 | VLOG(3) << "Failed to hash tensor: " << s.ToString(); |
241 | } |
242 | } |
243 | string graph_hash = |
244 | strings::StrCat(strings::Hex(hash, strings::kZeroPad16)); |
245 | metrics::RecordTFDataFingerprint(graph_hash); |
246 | }); |
247 | } |
248 | |
249 | return OkStatus(); |
250 | } |
251 | |
252 | std::unique_ptr<tensorflow::grappler::GrapplerItem> GetGrapplerItem( |
253 | GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks) { |
254 | // Add an identity node as the fetch node, otherwise we might get 'placeholder |
255 | // is both fed and fetched' errors in some cases when using input list with |
256 | // placeholder dataset nodes. |
257 | NodeDef* node = graph_def->mutable_node()->Add(); |
258 | tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink" , graph_def, |
259 | node); |
260 | node->set_op("Identity" ); |
261 | node->add_input(*dataset_node); |
262 | (*node->mutable_attr())["T" ].set_type(DT_VARIANT); |
263 | *dataset_node = node->name(); |
264 | |
265 | if (add_fake_sinks) { |
266 | // Add fake sink node to graph and functions to allow rewriting the actual |
267 | // sink nodes. |
268 | // |
269 | // TODO(b/118820916): When MetaOptimizer adds provisions for function |
270 | // retvals to be optimizable, we will no longer need this. |
271 | for (auto& function_def : |
272 | *graph_def->mutable_library()->mutable_function()) { |
273 | AddFakeSinks(&function_def); |
274 | } |
275 | } |
276 | |
277 | // Create metagraph. |
278 | MetaGraphDef meta_graph_def; |
279 | (*meta_graph_def.mutable_graph_def()) = *graph_def; |
280 | |
281 | // Grappler determines fetch ops from collection 'train_op'. |
282 | CollectionDef collection_def; |
283 | auto node_list = collection_def.mutable_node_list(); |
284 | node_list->add_value(*dataset_node); |
285 | (*meta_graph_def.mutable_collection_def())["train_op" ] = collection_def; |
286 | |
287 | // Create Grappler item. |
288 | tensorflow::grappler::ItemConfig item_config; |
289 | item_config.apply_optimizations = true; |
290 | std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item = |
291 | tensorflow::grappler::GrapplerItemFromMetaGraphDef( |
292 | "graph" , meta_graph_def, item_config); |
293 | // Grappler should not optimize function library of tf.data graphs. The |
294 | // tf.data meta optimizer takes care of optimizing tf.data functions. |
295 | grappler_item->optimization_options().optimize_function_library = false; |
296 | return grappler_item; |
297 | } |
298 | |
299 | absl::flat_hash_set<tstring> SelectOptimizations( |
300 | const absl::flat_hash_set<string>& experiments, |
301 | const absl::flat_hash_set<tstring>& optimizations_enabled, |
302 | const absl::flat_hash_set<tstring>& optimizations_disabled, |
303 | const absl::flat_hash_set<tstring>& optimizations_default) { |
304 | absl::flat_hash_set<tstring> optimizations; |
305 | |
306 | // Add the enabled optimizations. |
307 | optimizations.insert(optimizations_enabled.begin(), |
308 | optimizations_enabled.end()); |
309 | |
310 | // Add all default optimization that are not disabled. |
311 | for (const auto& optimization : optimizations_default) { |
312 | if (!optimizations_disabled.contains(optimization)) { |
313 | optimizations.insert(optimization); |
314 | } |
315 | } |
316 | |
317 | // Add experiments that correspond to an optimization unless the optimization |
318 | // is disabled. |
319 | const auto& registered_optimizers = |
320 | grappler::CustomGraphOptimizerRegistry::GetRegisteredOptimizers(); |
321 | for (const auto& experiment : experiments) { |
322 | if (std::find(registered_optimizers.begin(), registered_optimizers.end(), |
323 | experiment) != registered_optimizers.end() && |
324 | !optimizations_disabled.contains(experiment)) { |
325 | optimizations.insert(experiment); |
326 | } |
327 | } |
328 | |
329 | return optimizations; |
330 | } |
331 | |
332 | StatusOr<std::string> GetDatasetNode(const GraphDef& graph_def) { |
333 | // Symbolic `_Retval` node indicates which node corresponds to the dataset. |
334 | for (const auto& node : graph_def.node()) { |
335 | if (node.op() == "_Retval" ) { |
336 | return node.input(0); |
337 | } |
338 | } |
339 | return errors::NotFound( |
340 | absl::Substitute("Dataset node for graph is not found:\n$0" , |
341 | graph_def.ShortDebugString())); |
342 | } |
343 | |
344 | StatusOr<NodeDef> GetDatasetNodeDef(const GraphDef& graph_def) { |
345 | TF_ASSIGN_OR_RETURN(std::string dataset_node_name, GetDatasetNode(graph_def)); |
346 | for (const auto& node : graph_def.node()) { |
347 | if (node.name() == dataset_node_name) { |
348 | return node; |
349 | } |
350 | } |
351 | return errors::NotFound( |
352 | absl::Substitute("Dataset node for graph is not found:\n$0" , |
353 | graph_def.ShortDebugString())); |
354 | } |
355 | |
356 | } // namespace data |
357 | } // namespace tensorflow |
358 | #endif // !IS_MOBILE_PLATFORM |
359 | |