1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h" |
17 | |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
19 | #include "mlir/IR/Builders.h" // from @llvm-project |
20 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
21 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
22 | #include "mlir/IR/Dialect.h" // from @llvm-project |
23 | #include "mlir/IR/SymbolTable.h" // from @llvm-project |
24 | #include "mlir/IR/Types.h" // from @llvm-project |
25 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
26 | #include "tensorflow/compiler/jit/flags.h" |
27 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
28 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
29 | #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" |
30 | #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" |
31 | #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" |
32 | #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" |
33 | #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" |
34 | #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" |
35 | #include "tensorflow/compiler/xla/status_macros.h" |
36 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
37 | #include "tensorflow/core/framework/function.h" |
38 | #include "tensorflow/core/framework/tensor.pb.h" |
39 | #include "tensorflow/core/graph/algorithm.h" |
40 | #include "tensorflow/core/platform/errors.h" |
41 | #include "tensorflow/core/util/dump_graph.h" |
42 | #include "tensorflow/dtensor/cc/constants.h" |
43 | #include "tensorflow/dtensor/cc/dtensor_utils.h" |
44 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" |
45 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
46 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
47 | |
48 | namespace tensorflow { |
49 | |
50 | DTensorMlirPassRunner::DTensorMlirPassRunner() |
51 | : pass_manager_(&context_), logging_enabled_(false) { |
52 | logging_enabled_ = dtensor::MaybeEnableLogging(&pass_manager_); |
53 | if (logging_enabled_) pass_manager_.getContext()->enableMultithreading(); |
54 | |
55 | // TODO(hinsu, hongjunchoi): Figure out a better place to explicitly enable |
56 | // the MLIR bridge. |
57 | // Explicitly enable MLIR bridge as DTensor introduces some ops like |
58 | // XlaAllReduce are only supported in MLIR. |
59 | GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = |
60 | ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; |
61 | |
62 | // Creates a pipeline that include each DTensor related passes. |
63 | mlir::TF::StandardPipelineOptions pipeline_options; |
64 | dtensor::CreateDTensorMLIRPass(pipeline_options, &pass_manager_); |
65 | } |
66 | |
67 | Status DTensorMlirPassRunner::RunOnGraph( |
68 | const DeviceSet& device_set, bool is_func, |
69 | FunctionLibraryDefinition* flib_def, std::unique_ptr<Graph>* graph, |
70 | absl::flat_hash_set<Node*>& control_ret_nodes, Fprint128 cache_key) { |
71 | Graph* input_graph = graph->get(); |
72 | GraphDebugInfo debug_info; |
73 | GraphImportConfig import_config; |
74 | import_config.graph_as_function = true; |
75 | // DTensor relies on importing with shape_inference to work properly ATM. |
76 | // Make it explicit so that we're not affected by potential flipping of the |
77 | // flag. |
78 | import_config.enable_shape_inference = true; |
79 | // Graph pruning will prune away an op (may be side effecting) if the op is |
80 | // not reachable from a fetch/result or target/control ret. With how the entry |
81 | // function/Graph is created, it is possible if the op has no data results. To |
82 | // make sure this op does not get pruned away, the op is defined as a |
83 | // target/control ret. |
84 | import_config.control_outputs = {"eager_operation" }; |
85 | |
86 | // Import GraphDef to TF MLIR. |
87 | stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> |
88 | module_ref = ConvertGraphToMlir(*input_graph, debug_info, *flib_def, |
89 | import_config, &context_); |
90 | if (!module_ref.ok()) |
91 | return errors::InvalidArgument( |
92 | absl::StrCat( |
93 | "Can not convert the graph to MLIR, errors from MLIR converter : " , |
94 | module_ref.status().error_message()) |
95 | .c_str()); |
96 | |
97 | mlir::ModuleOp module = module_ref.value().get(); |
98 | |
99 | AddDevicesToOp(module, &device_set); |
100 | |
101 | // Tag the module for logging or not depending on flag. |
102 | if (!is_func && !dtensor::LogOpByOp()) |
103 | module->setAttr(dtensor::kDoNotLog, mlir::UnitAttr::get(&context_)); |
104 | |
105 | // Set the cache key for the module as an attribute. This attribute will be |
106 | // used to rename all private functions in the module (by appending the |
107 | // cache key) so they have unique names. |
108 | module->setAttr( |
109 | dtensor::kCacheKey, |
110 | mlir::StringAttr::get(&context_, absl::StrCat("_" , cache_key.low64, "_" , |
111 | cache_key.high64))); |
112 | |
113 | // Executes and collects results from the passes. |
114 | mlir::StatusScopedDiagnosticHandler diag_handler(&context_); |
115 | |
116 | if (logging_enabled_ && !module->hasAttr(dtensor::kDoNotLog)) |
117 | pass_manager_.getContext()->disableMultithreading(); |
118 | mlir::LogicalResult result = pass_manager_.run(module); |
119 | (void)result; |
120 | TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus()); |
121 | |
122 | if (logging_enabled_) pass_manager_.getContext()->enableMultithreading(); |
123 | |
124 | // Convert MLIR to graphdef for execution. |
125 | GraphExportConfig export_config; |
126 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
127 | ConvertMlirToGraph(module, export_config, graph, flib_def, |
128 | &control_ret_nodes), |
129 | "Error converting MLIR module back to graph" ); |
130 | Graph* output_graph = graph->get(); |
131 | VLOG(4) << DumpGraphToFile("dtensor_mlir_pass_after" , *output_graph, |
132 | flib_def); |
133 | return OkStatus(); |
134 | } |
135 | |
136 | } // namespace tensorflow |
137 | |