1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
17
18#include <memory>
19
20#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21#include "mlir/IR/BuiltinOps.h" // from @llvm-project
22#include "mlir/Pass/Pass.h" // from @llvm-project
23#include "mlir/Pass/PassManager.h" // from @llvm-project
24#include "mlir/Transforms/Passes.h" // from @llvm-project
25#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
27#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
28#include "tensorflow/dtensor/cc/constants.h"
29#include "tensorflow/dtensor/cc/dtensor_utils.h"
30#include "tensorflow/dtensor/mlir/create_dtensor_mlir_passes.h"
31#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
32#include "tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.h"
33
34namespace tensorflow {
35namespace dtensor {
36namespace {
37class ConditionalPrinter : public BridgeLoggerConfig {
38 private:
39 bool do_not_print_;
40
41 public:
42 explicit ConditionalPrinter(bool print_module_scope = false,
43 bool print_after_only_on_change = true)
44 : BridgeLoggerConfig(print_module_scope, print_after_only_on_change) {
45 do_not_print_ = !(LogOnAllTasks() || (ClientId() == 0));
46 }
47
48 void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
49 PrintCallbackFn print_callback) override {}
50
51 void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
52 PrintCallbackFn print_callback) override {
53 mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(operation);
54 if (!module) module = operation->getParentOfType<mlir::ModuleOp>();
55 if (module && !module->hasAttr(dtensor::kDoNotLog) && !do_not_print_)
56 BridgeLoggerConfig::printAfterIfEnabled(pass, operation, print_callback);
57 }
58};
59} // namespace
60
61// Adds logger to DTensor transformation passmanager.
62bool MaybeEnableLogging(mlir::PassManager *pm) {
63 if (VLOG_IS_ON(1)) {
64 // Print the whole module after each pass, which requires disabling
65 // multi-threading as well.
66 pm->getContext()->disableMultithreading();
67 pm->enableIRPrinting(std::make_unique<ConditionalPrinter>(
68 /*print_module_scope=*/true));
69 return true;
70 }
71 return false;
72}
73
74void CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions &options,
75 mlir::OpPassManager *pm) {
76 // Remove ops that cannot be reached from the sink node.
77 pm->addNestedPass<mlir::func::FuncOp>(
78 mlir::tf_executor::CreateTFExecutorGraphPruningPass());
79 // Remove graph-def executor dialect and represent IR as a flattened list of
80 // TF ops in functions.
81 pm->addNestedPass<mlir::func::FuncOp>(
82 mlir::CreateExecutorDialectToFunctionalConversionPass());
83
84 // This does not guarantee that shape are inferred for all ops. For ops with
85 // dynamic shapes, shape information may still be missing.
86 pm->addPass(mlir::TF::CreateTFShapeInferencePass());
87
88 // If V2 layout propagation algorithm, layouts are expressed as DTensorLayout
89 // op and Canonicalize and Inliner passes will not lose layout information.
90 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorPropagateDefaultLayout());
91 pm->addPass(mlir::createSCCPPass());
92 pm->addPass(mlir::createCanonicalizerPass());
93 pm->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
94 pm->addPass(mlir::createInlinerPass());
95
96 // Ensure that all functions have `device_id` as 0th argument.
97 pm->addPass(CreateDTensorPropagateDeviceIdToFunctionArgs());
98
99 // Ensure that all functions with SparseTensor input is converted to its
100 // three component tensors and SparseToDenseOps are emitted for every usage
101 // of a SparseTensor.
102 pm->addPass(CreateDTensorSparseTensorToDenseTensor());
103
104 AddDTensorEmbeddingPass(pm);
105
106 // After shape inference, there may be unused constants ops added when
107 // propagating caller-callee constants. As DTensor mesh/layout propgation
108 // passes assumes that there are no unreachable ops, removes trivial unused
109 // ops. Note that `Canonicalizer` pass in TF includes similar optimization.
110 // However, canonicalizer pass also rewrites some ops and may remove `_layout`
111 // or `_mesh` attributes in the re-written TF ops.
112 // TODO(hongjunchoi): Remove this pass once shape inference pass no longer
113 // creates unnecessary constants ops.
114 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
115
116 // Canonicalization will merge tf.ConstOp from different DTensorLayout
117 // annotations, causing problem during mesh propagation. Undo the merge
118 // before creating clusters.
119 pm->addNestedPass<mlir::func::FuncOp>(
120 CreateDTensorUndoMergeConstAcrossMesh());
121
122 // Propagate mesh cluster config and cluster ops by mesh cluster so that
123 // SPMD expansion can be isolated to a single device mesh.
124 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorOpToDeviceClusterPass());
125 pm->addPass(CreateDTensorMeshPropagationPass());
126
127 {
128 mlir::OpPassManager &func_pm = pm->nest<mlir::func::FuncOp>();
129 func_pm.addPass(CreateDTensorDeviceMeshClusterCoarsening());
130 // Set empty layout to cluster wrapping `tf.VarHandleOp`. VarHandle op
131 // always runs in the default device where client program executes.
132 func_pm.addPass(CreateDTensorDesignateResourceHandleMesh());
133 }
134
135 // Validates that all cross mesh data transfers are expressed via
136 // DTensorLayout operation and lowers it to send/recvs.
137 pm->addPass(CreateDTensorHandleCrossClusterDependencies());
138
139 // Mark all ops and functions with global shape attribute to preserve global
140 // shape information as it is needed during Layout Propagation and SPMD
141 // expansion.
142 pm->addPass(CreateDTensorAnnotateGlobalShape());
143
144 // Propagate layout to all ops in graph.
145 pm->addPass(CreateDTensorMergeClustersPass());
146
147 AddDTensorEmbeddingPassV2(pm);
148
149 // For DTensor Checkpoint V2, the outputs of tf.RestoreV2 ops
150 // do not have shape information. We can infer the shapes of these
151 // outputs from the tf.AssignVariableOps that consume these outputs.
152 // This pass fills in all missing shapes caused by tf.RestoreV2 ops.
153 pm->addPass(CreateDTensorInferShapesForRestoreV2Op());
154
155 pm->addPass(CreateDTensorLayoutPropagationPassV2());
156
157 // Expand graph to SPMD form given layouts are annotated to all ops.
158 // Remove all DTensorLayout ops after the expansion is done.
159 pm->addPass(CreateDTensorSPMDExpansion());
160
161 // Insert functions to save or load embeddings when using tpu device.
162 AddDTensorEmbeddingCheckpointPass(pm);
163
164 // Expand all ops that consume SparseTensors to possibly new ops.
165 // Remove any unused SparseToDense, Layout, and Const Ops after
166 // the expansion is done.
167 //
168 // Note that this pass assumes that SparseTensor operands is represented
169 // as an operand from the output of a SparseToDenseOp. Thus, this pass
170 // must happen after SparseTensorToDenseTensor pass and after
171 // the SPMD Expansion pass.
172 pm->addPass(CreateDTensorSparseExpansion());
173
174 // Do a round of CSE: this helps reduce the number of consts in the graph now
175 // that SPMD expansion is done. We had replicated all Consts (so that each
176 // const only had one usage) as part of layout propagation.
177 pm->addPass(mlir::createCSEPass());
178
179 // Lower the AllGather collectives. This has to happen before the all reduce
180 // optimizations and AllGather may emit an AllReduce.
181 pm->addPass(CreateDTensorAllGatherLoweringPass());
182
183 // Fuses AllReduce and AllScatter into ReduceScatter.
184 if (!DoNotFuseReduceScatter()) {
185 pm->addNestedPass<mlir::func::FuncOp>(
186 CreateDTensorAllReduceScatterOptimization());
187 }
188
189 // Changes order of DTensorAllReduce + Add to Add + DTensorAllReduce to
190 // minimize number of all reduce operations.
191 pm->addNestedPass<mlir::func::FuncOp>(
192 CreateDTensorAllReduceSumOptimization());
193
194 AddDTensorAllReduceCombineOptimization(pm);
195
196 // DTensorReduceScatter lowering should come before DTensorAllReduce
197 // and DTensorAllScatter lowerings since for some devices DTensorReduceScatter
198 // will be decomposed into an DTensorAllReduce+DTensorScatter.
199 pm->addPass(CreateDTensorReduceScatterLoweringPass());
200
201 // For large enough reduction groups in reduction ops, upcast the input
202 // tensors to higher precision type (e.g. bfloat16 -> float32).
203 if (EnableMixedPrecisionReduce()) {
204 pm->addNestedPass<mlir::func::FuncOp>(
205 CreateDTensorMixedPrecisionReducePass());
206 }
207
208 // Lower device-agnostic logical AllReduce ops into device-specific physical
209 // AllReduce ops.
210 //
211 // First, find DTensor collective ops such as DTensorAllReduce, which are
212 // generated by SPMD expansion. Lower them into device-specific forms. For
213 // most devices, there is a one-to-one mapping: DTensorAllReduce becomes
214 // CollectiveReduce on CPUs/GPUs and XlaAllReduce on TPU pods.
215 // Optionally, for special topologies, DTensorAllReduce
216 // could become a chain of collectives running on different devices:
217 // XlaAllReduce on each donut followed by CollectiveReduce on the hosts. Those
218 // collective ops running on hosts will have their _mesh attribute set to
219 // empty by this pass. The other ops continue to have no _mesh attributes,
220 // which means they run on the cluster mesh.
221 pm->addPass(CreateDTensorAllReduceLoweringPass());
222
223 pm->addPass(CreateDTensorAllScatterLoweringPass());
224
225 // Group together multiple device clusters assigned to the same mesh. Repeat
226 // this for every mesh to support multi-mesh. Collective lowering may have
227 // created multiple CPU mesh clusters for executing collective operations on
228 // CPUs.
229 // As so, we merge newly created CPU clusters after collective lowering
230 // especially for special topologies.
231 pm->addPass(CreateDTensorMergeClustersPass());
232 pm->addPass(CreateDTensorLowerSendRecv());
233
234 // Convert tf_device.cluster into a function call op.
235 pm->addPass(mlir::TFDevice::CreateClusterOutliningPass());
236 pm->addPass(CreateDTensorClusterFunctionConversion());
237
238 // During layout propagation, we clone all constants with multiple consumers
239 // for easier analaysis.
240 // This may create multiple same constants ops. Apply constant folding on
241 // duplicated constant operations to reduce graph size.
242 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorConstantFolding());
243 // DTensor SPMD lowering passes may have created auxiliary operations that are
244 // no longer used. Add additional DCE pass to remove unused non-side effecting
245 // ops.
246 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
247
248 // DTensor SPMD Expansion may have caused multiple control flows and
249 // duplicate ops to calculate device ordinal. Re-run SCCP and merge
250 // controlflows if possible.
251 pm->addNestedPass<mlir::func::FuncOp>(mlir::createSCCPPass());
252 pm->addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
253 pm->addPass(mlir::TFDevice::CreateMergeControlFlowPass());
254
255 // TF2XLA Integration
256 {
257 // Make sure clusters that run on TPU's are correct metadata ops and
258 // attributes attached to be compatible with later TPU specific optimization
259 // passes.
260 pm->addPass(CreateDTensorTPUIntegration());
261
262 pm->addNestedPass<mlir::func::FuncOp>(
263 mlir::TFDevice::CreateDecomposeResourceOpsPass());
264 // Sink constant ops into cluster region as DecomposeResourceOpsPass() could
265 // lift constant out due to folding.
266 pm->addNestedPass<mlir::func::FuncOp>(
267 mlir::TFDevice::CreateClusterConstantSinkingPass());
268
269 // Run another shape inference pass (and following DCE pass) because
270 // resource decomposition might have created new partial types.
271 pm->addPass(mlir::TF::CreateTFShapeInferencePass());
272 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
273 pm->addPass(mlir::TFDevice::CreateResourceOpLiftingPass());
274 pm->addPass(mlir::TFDevice::CreateClusterOutliningPass());
275
276 // Rename functions with unique names, to avoid collisions in the function
277 // library.
278 pm->addPass(CreateFunctionRenamingPass());
279
280 // As DTensor SPMD expansion handles sharded inputs for model
281 // parallelism, we set input/output sharding to maximal sharding
282 // for inputs/outputs of the TPU computation.
283 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorSetDefaultSharding());
284
285 // Creates a pass that marks TPU cluster input-output pairs reading and
286 // writing to same resource variable as aliases.
287 pm->addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass());
288
289 // Convert compilation and replication attributes to unified attributes
290 // expected by TPURewritePass.
291 pm->addNestedPass<mlir::func::FuncOp>(
292 mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass());
293 // Create TPU Compile and TPU Execute ops for each TPU devices.
294 pm->addPass(mlir::TFTPU::CreateTPURewritePass());
295 // Convert unified compilation and replication attributes back to legacy
296 // attributes for subsequent passes.
297 pm->addNestedPass<mlir::func::FuncOp>(
298 mlir::TFTPU::CreateConvertToLegacyCompileAndReplicateAttributesPass());
299
300 // Add placeholder device attributes to resource arguments of TPU
301 // computation. This ensures the following
302 // CreateTPUMergeVariablesWithExecutePass correctly merges resource
303 // operations with TPUExecute op.
304 pm->addPass(CreateDTensorTpuAddResourceDeviceAttribute());
305 // Translate TPUExecute op to TPUExecuteAndUpdateVariable op to enable
306 // buffer aliasing.
307 pm->addPass(mlir::TFTPU::CreateTPUMergeVariablesWithExecutePass());
308
309 pm->addPass(CreateDTensorUpdateTPUMetadata());
310 // If send/recv exists between TPU and CPU, then TPU Compilation program key
311 // is used as input for recv op in host computation as well as TPUExecute op
312 // in device computation. As so, move TPUCompile logic to host computation
313 // and transfer program key using send/recv operations.
314 pm->addPass(CreateDTensorMoveCompilationToHost());
315 pm->addPass(mlir::createSymbolDCEPass());
316 }
317
318 pm->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
319
320 // Convert graph into graph executor dialect so that transformed graph can be
321 // exported back to Graphdef.
322 pm->addNestedPass<mlir::func::FuncOp>(
323 mlir::CreateFunctionalToExecutorDialectConversionPass());
324 pm->addPass(mlir::CreateBreakUpIslandsPass());
325 pm->addNestedPass<mlir::func::FuncOp>(
326 mlir::TFDevice::CreateLaunchToDeviceAttributePass());
327 // Add additional BreakUpIslandPass as LaunchToDeviceAttribute pass may have
328 // created additional islands.
329 pm->addPass(mlir::CreateBreakUpIslandsPass());
330}
331
332} // namespace dtensor
333} // namespace tensorflow
334