1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h"
17
18#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19#include "mlir/IR/Builders.h" // from @llvm-project
20#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
21#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22#include "mlir/IR/Dialect.h" // from @llvm-project
23#include "mlir/IR/SymbolTable.h" // from @llvm-project
24#include "mlir/IR/Types.h" // from @llvm-project
25#include "mlir/Support/LogicalResult.h" // from @llvm-project
26#include "tensorflow/compiler/jit/flags.h"
27#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
30#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
31#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
32#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
33#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
34#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
35#include "tensorflow/compiler/xla/status_macros.h"
36#include "tensorflow/core/common_runtime/optimization_registry.h"
37#include "tensorflow/core/framework/function.h"
38#include "tensorflow/core/framework/tensor.pb.h"
39#include "tensorflow/core/graph/algorithm.h"
40#include "tensorflow/core/platform/errors.h"
41#include "tensorflow/core/util/dump_graph.h"
42#include "tensorflow/dtensor/cc/constants.h"
43#include "tensorflow/dtensor/cc/dtensor_utils.h"
44#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
45#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
46#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
47
48namespace tensorflow {
49
50DTensorMlirPassRunner::DTensorMlirPassRunner()
51 : pass_manager_(&context_), logging_enabled_(false) {
52 logging_enabled_ = dtensor::MaybeEnableLogging(&pass_manager_);
53 if (logging_enabled_) pass_manager_.getContext()->enableMultithreading();
54
55 // TODO(hinsu, hongjunchoi): Figure out a better place to explicitly enable
56 // the MLIR bridge.
57 // Explicitly enable MLIR bridge as DTensor introduces some ops like
58 // XlaAllReduce are only supported in MLIR.
59 GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
60 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
61
62 // Creates a pipeline that include each DTensor related passes.
63 mlir::TF::StandardPipelineOptions pipeline_options;
64 dtensor::CreateDTensorMLIRPass(pipeline_options, &pass_manager_);
65}
66
67Status DTensorMlirPassRunner::RunOnGraph(
68 const DeviceSet& device_set, bool is_func,
69 FunctionLibraryDefinition* flib_def, std::unique_ptr<Graph>* graph,
70 absl::flat_hash_set<Node*>& control_ret_nodes, Fprint128 cache_key) {
71 Graph* input_graph = graph->get();
72 GraphDebugInfo debug_info;
73 GraphImportConfig import_config;
74 import_config.graph_as_function = true;
75 // DTensor relies on importing with shape_inference to work properly ATM.
76 // Make it explicit so that we're not affected by potential flipping of the
77 // flag.
78 import_config.enable_shape_inference = true;
79 // Graph pruning will prune away an op (may be side effecting) if the op is
80 // not reachable from a fetch/result or target/control ret. With how the entry
81 // function/Graph is created, it is possible if the op has no data results. To
82 // make sure this op does not get pruned away, the op is defined as a
83 // target/control ret.
84 import_config.control_outputs = {"eager_operation"};
85
86 // Import GraphDef to TF MLIR.
87 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
88 module_ref = ConvertGraphToMlir(*input_graph, debug_info, *flib_def,
89 import_config, &context_);
90 if (!module_ref.ok())
91 return errors::InvalidArgument(
92 absl::StrCat(
93 "Can not convert the graph to MLIR, errors from MLIR converter : ",
94 module_ref.status().error_message())
95 .c_str());
96
97 mlir::ModuleOp module = module_ref.value().get();
98
99 AddDevicesToOp(module, &device_set);
100
101 // Tag the module for logging or not depending on flag.
102 if (!is_func && !dtensor::LogOpByOp())
103 module->setAttr(dtensor::kDoNotLog, mlir::UnitAttr::get(&context_));
104
105 // Set the cache key for the module as an attribute. This attribute will be
106 // used to rename all private functions in the module (by appending the
107 // cache key) so they have unique names.
108 module->setAttr(
109 dtensor::kCacheKey,
110 mlir::StringAttr::get(&context_, absl::StrCat("_", cache_key.low64, "_",
111 cache_key.high64)));
112
113 // Executes and collects results from the passes.
114 mlir::StatusScopedDiagnosticHandler diag_handler(&context_);
115
116 if (logging_enabled_ && !module->hasAttr(dtensor::kDoNotLog))
117 pass_manager_.getContext()->disableMultithreading();
118 mlir::LogicalResult result = pass_manager_.run(module);
119 (void)result;
120 TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus());
121
122 if (logging_enabled_) pass_manager_.getContext()->enableMultithreading();
123
124 // Convert MLIR to graphdef for execution.
125 GraphExportConfig export_config;
126 TF_RETURN_WITH_CONTEXT_IF_ERROR(
127 ConvertMlirToGraph(module, export_config, graph, flib_def,
128 &control_ret_nodes),
129 "Error converting MLIR module back to graph");
130 Graph* output_graph = graph->get();
131 VLOG(4) << DumpGraphToFile("dtensor_mlir_pass_after", *output_graph,
132 flib_def);
133 return OkStatus();
134}
135
136} // namespace tensorflow
137