1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
17 | |
18 | #include <memory> |
19 | |
20 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
21 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
22 | #include "mlir/Pass/Pass.h" // from @llvm-project |
23 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
24 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
25 | #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
26 | #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" |
27 | #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" |
28 | #include "tensorflow/dtensor/cc/constants.h" |
29 | #include "tensorflow/dtensor/cc/dtensor_utils.h" |
30 | #include "tensorflow/dtensor/mlir/create_dtensor_mlir_passes.h" |
31 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
32 | #include "tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.h" |
33 | |
34 | namespace tensorflow { |
35 | namespace dtensor { |
36 | namespace { |
37 | class ConditionalPrinter : public BridgeLoggerConfig { |
38 | private: |
39 | bool do_not_print_; |
40 | |
41 | public: |
42 | explicit ConditionalPrinter(bool print_module_scope = false, |
43 | bool print_after_only_on_change = true) |
44 | : BridgeLoggerConfig(print_module_scope, print_after_only_on_change) { |
45 | do_not_print_ = !(LogOnAllTasks() || (ClientId() == 0)); |
46 | } |
47 | |
48 | void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *operation, |
49 | PrintCallbackFn print_callback) override {} |
50 | |
51 | void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation, |
52 | PrintCallbackFn print_callback) override { |
53 | mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(operation); |
54 | if (!module) module = operation->getParentOfType<mlir::ModuleOp>(); |
55 | if (module && !module->hasAttr(dtensor::kDoNotLog) && !do_not_print_) |
56 | BridgeLoggerConfig::printAfterIfEnabled(pass, operation, print_callback); |
57 | } |
58 | }; |
59 | } // namespace |
60 | |
61 | // Adds logger to DTensor transformation passmanager. |
62 | bool MaybeEnableLogging(mlir::PassManager *pm) { |
63 | if (VLOG_IS_ON(1)) { |
64 | // Print the whole module after each pass, which requires disabling |
65 | // multi-threading as well. |
66 | pm->getContext()->disableMultithreading(); |
67 | pm->enableIRPrinting(std::make_unique<ConditionalPrinter>( |
68 | /*print_module_scope=*/true)); |
69 | return true; |
70 | } |
71 | return false; |
72 | } |
73 | |
74 | void CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions &options, |
75 | mlir::OpPassManager *pm) { |
76 | // Remove ops that cannot be reached from the sink node. |
77 | pm->addNestedPass<mlir::func::FuncOp>( |
78 | mlir::tf_executor::CreateTFExecutorGraphPruningPass()); |
79 | // Remove graph-def executor dialect and represent IR as a flattened list of |
80 | // TF ops in functions. |
81 | pm->addNestedPass<mlir::func::FuncOp>( |
82 | mlir::CreateExecutorDialectToFunctionalConversionPass()); |
83 | |
84 | // This does not guarantee that shape are inferred for all ops. For ops with |
85 | // dynamic shapes, shape information may still be missing. |
86 | pm->addPass(mlir::TF::CreateTFShapeInferencePass()); |
87 | |
88 | // If V2 layout propagation algorithm, layouts are expressed as DTensorLayout |
89 | // op and Canonicalize and Inliner passes will not lose layout information. |
90 | pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorPropagateDefaultLayout()); |
91 | pm->addPass(mlir::createSCCPPass()); |
92 | pm->addPass(mlir::createCanonicalizerPass()); |
93 | pm->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); |
94 | pm->addPass(mlir::createInlinerPass()); |
95 | |
96 | // Ensure that all functions have `device_id` as 0th argument. |
97 | pm->addPass(CreateDTensorPropagateDeviceIdToFunctionArgs()); |
98 | |
99 | // Ensure that all functions with SparseTensor input is converted to its |
100 | // three component tensors and SparseToDenseOps are emitted for every usage |
101 | // of a SparseTensor. |
102 | pm->addPass(CreateDTensorSparseTensorToDenseTensor()); |
103 | |
104 | AddDTensorEmbeddingPass(pm); |
105 | |
106 | // After shape inference, there may be unused constants ops added when |
107 | // propagating caller-callee constants. As DTensor mesh/layout propgation |
108 | // passes assumes that there are no unreachable ops, removes trivial unused |
109 | // ops. Note that `Canonicalizer` pass in TF includes similar optimization. |
110 | // However, canonicalizer pass also rewrites some ops and may remove `_layout` |
111 | // or `_mesh` attributes in the re-written TF ops. |
112 | // TODO(hongjunchoi): Remove this pass once shape inference pass no longer |
113 | // creates unnecessary constants ops. |
114 | pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE()); |
115 | |
116 | // Canonicalization will merge tf.ConstOp from different DTensorLayout |
117 | // annotations, causing problem during mesh propagation. Undo the merge |
118 | // before creating clusters. |
119 | pm->addNestedPass<mlir::func::FuncOp>( |
120 | CreateDTensorUndoMergeConstAcrossMesh()); |
121 | |
122 | // Propagate mesh cluster config and cluster ops by mesh cluster so that |
123 | // SPMD expansion can be isolated to a single device mesh. |
124 | pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorOpToDeviceClusterPass()); |
125 | pm->addPass(CreateDTensorMeshPropagationPass()); |
126 | |
127 | { |
128 | mlir::OpPassManager &func_pm = pm->nest<mlir::func::FuncOp>(); |
129 | func_pm.addPass(CreateDTensorDeviceMeshClusterCoarsening()); |
130 | // Set empty layout to cluster wrapping `tf.VarHandleOp`. VarHandle op |
131 | // always runs in the default device where client program executes. |
132 | func_pm.addPass(CreateDTensorDesignateResourceHandleMesh()); |
133 | } |
134 | |
135 | // Validates that all cross mesh data transfers are expressed via |
136 | // DTensorLayout operation and lowers it to send/recvs. |
137 | pm->addPass(CreateDTensorHandleCrossClusterDependencies()); |
138 | |
139 | // Mark all ops and functions with global shape attribute to preserve global |
140 | // shape information as it is needed during Layout Propagation and SPMD |
141 | // expansion. |
142 | pm->addPass(CreateDTensorAnnotateGlobalShape()); |
143 | |
144 | // Propagate layout to all ops in graph. |
145 | pm->addPass(CreateDTensorMergeClustersPass()); |
146 | |
147 | AddDTensorEmbeddingPassV2(pm); |
148 | |
149 | // For DTensor Checkpoint V2, the outputs of tf.RestoreV2 ops |
150 | // do not have shape information. We can infer the shapes of these |
151 | // outputs from the tf.AssignVariableOps that consume these outputs. |
152 | // This pass fills in all missing shapes caused by tf.RestoreV2 ops. |
153 | pm->addPass(CreateDTensorInferShapesForRestoreV2Op()); |
154 | |
155 | pm->addPass(CreateDTensorLayoutPropagationPassV2()); |
156 | |
157 | // Expand graph to SPMD form given layouts are annotated to all ops. |
158 | // Remove all DTensorLayout ops after the expansion is done. |
159 | pm->addPass(CreateDTensorSPMDExpansion()); |
160 | |
161 | // Insert functions to save or load embeddings when using tpu device. |
162 | AddDTensorEmbeddingCheckpointPass(pm); |
163 | |
164 | // Expand all ops that consume SparseTensors to possibly new ops. |
165 | // Remove any unused SparseToDense, Layout, and Const Ops after |
166 | // the expansion is done. |
167 | // |
168 | // Note that this pass assumes that SparseTensor operands is represented |
169 | // as an operand from the output of a SparseToDenseOp. Thus, this pass |
170 | // must happen after SparseTensorToDenseTensor pass and after |
171 | // the SPMD Expansion pass. |
172 | pm->addPass(CreateDTensorSparseExpansion()); |
173 | |
174 | // Do a round of CSE: this helps reduce the number of consts in the graph now |
175 | // that SPMD expansion is done. We had replicated all Consts (so that each |
176 | // const only had one usage) as part of layout propagation. |
177 | pm->addPass(mlir::createCSEPass()); |
178 | |
179 | // Lower the AllGather collectives. This has to happen before the all reduce |
180 | // optimizations and AllGather may emit an AllReduce. |
181 | pm->addPass(CreateDTensorAllGatherLoweringPass()); |
182 | |
183 | // Fuses AllReduce and AllScatter into ReduceScatter. |
184 | if (!DoNotFuseReduceScatter()) { |
185 | pm->addNestedPass<mlir::func::FuncOp>( |
186 | CreateDTensorAllReduceScatterOptimization()); |
187 | } |
188 | |
189 | // Changes order of DTensorAllReduce + Add to Add + DTensorAllReduce to |
190 | // minimize number of all reduce operations. |
191 | pm->addNestedPass<mlir::func::FuncOp>( |
192 | CreateDTensorAllReduceSumOptimization()); |
193 | |
194 | AddDTensorAllReduceCombineOptimization(pm); |
195 | |
196 | // DTensorReduceScatter lowering should come before DTensorAllReduce |
197 | // and DTensorAllScatter lowerings since for some devices DTensorReduceScatter |
198 | // will be decomposed into an DTensorAllReduce+DTensorScatter. |
199 | pm->addPass(CreateDTensorReduceScatterLoweringPass()); |
200 | |
201 | // For large enough reduction groups in reduction ops, upcast the input |
202 | // tensors to higher precision type (e.g. bfloat16 -> float32). |
203 | if (EnableMixedPrecisionReduce()) { |
204 | pm->addNestedPass<mlir::func::FuncOp>( |
205 | CreateDTensorMixedPrecisionReducePass()); |
206 | } |
207 | |
208 | // Lower device-agnostic logical AllReduce ops into device-specific physical |
209 | // AllReduce ops. |
210 | // |
211 | // First, find DTensor collective ops such as DTensorAllReduce, which are |
212 | // generated by SPMD expansion. Lower them into device-specific forms. For |
213 | // most devices, there is a one-to-one mapping: DTensorAllReduce becomes |
214 | // CollectiveReduce on CPUs/GPUs and XlaAllReduce on TPU pods. |
215 | // Optionally, for special topologies, DTensorAllReduce |
216 | // could become a chain of collectives running on different devices: |
217 | // XlaAllReduce on each donut followed by CollectiveReduce on the hosts. Those |
218 | // collective ops running on hosts will have their _mesh attribute set to |
219 | // empty by this pass. The other ops continue to have no _mesh attributes, |
220 | // which means they run on the cluster mesh. |
221 | pm->addPass(CreateDTensorAllReduceLoweringPass()); |
222 | |
223 | pm->addPass(CreateDTensorAllScatterLoweringPass()); |
224 | |
225 | // Group together multiple device clusters assigned to the same mesh. Repeat |
226 | // this for every mesh to support multi-mesh. Collective lowering may have |
227 | // created multiple CPU mesh clusters for executing collective operations on |
228 | // CPUs. |
229 | // As so, we merge newly created CPU clusters after collective lowering |
230 | // especially for special topologies. |
231 | pm->addPass(CreateDTensorMergeClustersPass()); |
232 | pm->addPass(CreateDTensorLowerSendRecv()); |
233 | |
234 | // Convert tf_device.cluster into a function call op. |
235 | pm->addPass(mlir::TFDevice::CreateClusterOutliningPass()); |
236 | pm->addPass(CreateDTensorClusterFunctionConversion()); |
237 | |
238 | // During layout propagation, we clone all constants with multiple consumers |
239 | // for easier analaysis. |
240 | // This may create multiple same constants ops. Apply constant folding on |
241 | // duplicated constant operations to reduce graph size. |
242 | pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorConstantFolding()); |
243 | // DTensor SPMD lowering passes may have created auxiliary operations that are |
244 | // no longer used. Add additional DCE pass to remove unused non-side effecting |
245 | // ops. |
246 | pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE()); |
247 | |
248 | // DTensor SPMD Expansion may have caused multiple control flows and |
249 | // duplicate ops to calculate device ordinal. Re-run SCCP and merge |
250 | // controlflows if possible. |
251 | pm->addNestedPass<mlir::func::FuncOp>(mlir::createSCCPPass()); |
252 | pm->addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass()); |
253 | pm->addPass(mlir::TFDevice::CreateMergeControlFlowPass()); |
254 | |
255 | // TF2XLA Integration |
256 | { |
257 | // Make sure clusters that run on TPU's are correct metadata ops and |
258 | // attributes attached to be compatible with later TPU specific optimization |
259 | // passes. |
260 | pm->addPass(CreateDTensorTPUIntegration()); |
261 | |
262 | pm->addNestedPass<mlir::func::FuncOp>( |
263 | mlir::TFDevice::CreateDecomposeResourceOpsPass()); |
264 | // Sink constant ops into cluster region as DecomposeResourceOpsPass() could |
265 | // lift constant out due to folding. |
266 | pm->addNestedPass<mlir::func::FuncOp>( |
267 | mlir::TFDevice::CreateClusterConstantSinkingPass()); |
268 | |
269 | // Run another shape inference pass (and following DCE pass) because |
270 | // resource decomposition might have created new partial types. |
271 | pm->addPass(mlir::TF::CreateTFShapeInferencePass()); |
272 | pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE()); |
273 | pm->addPass(mlir::TFDevice::CreateResourceOpLiftingPass()); |
274 | pm->addPass(mlir::TFDevice::CreateClusterOutliningPass()); |
275 | |
276 | // Rename functions with unique names, to avoid collisions in the function |
277 | // library. |
278 | pm->addPass(CreateFunctionRenamingPass()); |
279 | |
280 | // As DTensor SPMD expansion handles sharded inputs for model |
281 | // parallelism, we set input/output sharding to maximal sharding |
282 | // for inputs/outputs of the TPU computation. |
283 | pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorSetDefaultSharding()); |
284 | |
285 | // Creates a pass that marks TPU cluster input-output pairs reading and |
286 | // writing to same resource variable as aliases. |
287 | pm->addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass()); |
288 | |
289 | // Convert compilation and replication attributes to unified attributes |
290 | // expected by TPURewritePass. |
291 | pm->addNestedPass<mlir::func::FuncOp>( |
292 | mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); |
293 | // Create TPU Compile and TPU Execute ops for each TPU devices. |
294 | pm->addPass(mlir::TFTPU::CreateTPURewritePass()); |
295 | // Convert unified compilation and replication attributes back to legacy |
296 | // attributes for subsequent passes. |
297 | pm->addNestedPass<mlir::func::FuncOp>( |
298 | mlir::TFTPU::CreateConvertToLegacyCompileAndReplicateAttributesPass()); |
299 | |
300 | // Add placeholder device attributes to resource arguments of TPU |
301 | // computation. This ensures the following |
302 | // CreateTPUMergeVariablesWithExecutePass correctly merges resource |
303 | // operations with TPUExecute op. |
304 | pm->addPass(CreateDTensorTpuAddResourceDeviceAttribute()); |
305 | // Translate TPUExecute op to TPUExecuteAndUpdateVariable op to enable |
306 | // buffer aliasing. |
307 | pm->addPass(mlir::TFTPU::CreateTPUMergeVariablesWithExecutePass()); |
308 | |
309 | pm->addPass(CreateDTensorUpdateTPUMetadata()); |
310 | // If send/recv exists between TPU and CPU, then TPU Compilation program key |
311 | // is used as input for recv op in host computation as well as TPUExecute op |
312 | // in device computation. As so, move TPUCompile logic to host computation |
313 | // and transfer program key using send/recv operations. |
314 | pm->addPass(CreateDTensorMoveCompilationToHost()); |
315 | pm->addPass(mlir::createSymbolDCEPass()); |
316 | } |
317 | |
318 | pm->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); |
319 | |
320 | // Convert graph into graph executor dialect so that transformed graph can be |
321 | // exported back to Graphdef. |
322 | pm->addNestedPass<mlir::func::FuncOp>( |
323 | mlir::CreateFunctionalToExecutorDialectConversionPass()); |
324 | pm->addPass(mlir::CreateBreakUpIslandsPass()); |
325 | pm->addNestedPass<mlir::func::FuncOp>( |
326 | mlir::TFDevice::CreateLaunchToDeviceAttributePass()); |
327 | // Add additional BreakUpIslandPass as LaunchToDeviceAttribute pass may have |
328 | // created additional islands. |
329 | pm->addPass(mlir::CreateBreakUpIslandsPass()); |
330 | } |
331 | |
332 | } // namespace dtensor |
333 | } // namespace tensorflow |
334 | |