1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <string> |
18 | #include <utility> |
19 | |
20 | #include "absl/strings/str_cat.h" |
21 | #include "llvm/ADT/APInt.h" |
22 | #include "llvm/ADT/DenseMap.h" |
23 | #include "llvm/ADT/DenseSet.h" |
24 | #include "llvm/ADT/STLExtras.h" |
25 | #include "llvm/ADT/SmallVector.h" |
26 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
27 | #include "mlir/IR/Attributes.h" // from @llvm-project |
28 | #include "mlir/IR/Builders.h" // from @llvm-project |
29 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
30 | #include "mlir/IR/Diagnostics.h" // from @llvm-project |
31 | #include "mlir/IR/Operation.h" // from @llvm-project |
32 | #include "mlir/IR/Types.h" // from @llvm-project |
33 | #include "mlir/IR/Value.h" // from @llvm-project |
34 | #include "mlir/IR/Visitors.h" // from @llvm-project |
35 | #include "mlir/Pass/Pass.h" // from @llvm-project |
36 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
37 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
38 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
39 | #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
40 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
41 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
42 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" |
43 | #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" |
44 | #include "tensorflow/dtensor/cc/constants.h" |
45 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" |
46 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
47 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
48 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
49 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
50 | |
51 | namespace tensorflow { |
52 | namespace dtensor { |
53 | |
54 | namespace { |
55 | #define GEN_PASS_DEF_DTENSORMERGECLUSTERS |
56 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
57 | |
58 | constexpr char kMissingMeshErrorMsg[] = |
59 | "Failed to extract mesh for DTensorMergeCluster pass. " |
60 | "All clusters must have specified mesh." ; |
61 | |
62 | constexpr char kSendRecvKeyPrefix[] = "SendRecvKeyForControlflow_" ; |
63 | |
64 | // Extracts mesh from `cluster`. |
65 | mlir::LogicalResult (mlir::tf_device::ClusterOp cluster, |
66 | Mesh* mesh_output) { |
67 | auto mesh_or_status = ExtractDeviceMeshFromOp(cluster); |
68 | if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg); |
69 | |
70 | const absl::optional<Mesh>& mesh_or_null = *mesh_or_status; |
71 | if (!mesh_or_null.has_value()) |
72 | return cluster.emitOpError(kMissingMeshErrorMsg); |
73 | |
74 | *mesh_output = mesh_or_null.value(); |
75 | return mlir::success(); |
76 | } |
77 | |
78 | // Returns all tf_device.ClusterOps nested inside `op`. |
79 | llvm::SmallVector<mlir::tf_device::ClusterOp, 4> FindAllDeviceClusters( |
80 | mlir::Operation* op) { |
81 | llvm::SmallVector<mlir::tf_device::ClusterOp, 4> nested_clusters; |
82 | op->walk([&](mlir::tf_device::ClusterOp nested_cluster) { |
83 | nested_clusters.emplace_back(nested_cluster); |
84 | }); |
85 | return nested_clusters; |
86 | } |
87 | |
88 | mlir::LogicalResult MergeAttributes( |
89 | mlir::Operation* op, mlir::DenseIntElementsAttr indices_attr, |
90 | mlir::ArrayAttr layout_attr, mlir::DenseIntElementsAttr indices_attr2, |
91 | mlir::ArrayAttr layout_attr2, llvm::SmallVector<int, 4>* merged_indices, |
92 | llvm::SmallVector<mlir::Attribute, 4>* merged_layout) { |
93 | llvm::SmallDenseMap<llvm::APInt, mlir::Attribute> attr_map; |
94 | attr_map.reserve(indices_attr.size() + indices_attr2.size()); |
95 | for (const auto& data : llvm::zip(indices_attr, layout_attr)) |
96 | attr_map.try_emplace(std::get<0>(data), std::get<1>(data)); |
97 | |
98 | for (const auto& data : llvm::zip(indices_attr2, layout_attr2)) { |
99 | const auto& index = std::get<0>(data); |
100 | const auto& layout = std::get<1>(data); |
101 | auto result = attr_map.try_emplace(index, layout); |
102 | |
103 | if (!result.second && layout != result.first->getSecond()) { |
104 | return op->emitOpError( |
105 | "Found conflicting metadata attributes while merging clusters" ); |
106 | } |
107 | } |
108 | |
109 | merged_indices->reserve(attr_map.size()); |
110 | merged_layout->reserve(attr_map.size()); |
111 | for (const auto& it : attr_map) { |
112 | merged_indices->emplace_back(it.first.getSExtValue()); |
113 | merged_layout->emplace_back(it.second); |
114 | } |
115 | return mlir::success(); |
116 | } |
117 | |
118 | // Merges metadata attribute from `src_cluster` to `target_cluster`. If metadata |
119 | // attribute exists for both clusters, merge the attributes and verify that |
120 | // there are no conflicing attributes. |
121 | mlir::LogicalResult MergeClusterMetadata( |
122 | mlir::tf_device::ClusterOp src_cluster, |
123 | mlir::tf_device::ClusterOp target_cluster) { |
124 | if (mlir::failed(ValidateMetadataAttributes(src_cluster)) || |
125 | mlir::failed(ValidateMetadataAttributes(target_cluster))) |
126 | return mlir::failure(); |
127 | |
128 | mlir::OpBuilder builder(target_cluster); |
129 | |
130 | // Extract resource metadata from src/target clusters. |
131 | auto src_resource_handle_indices_metadata = |
132 | src_cluster->getAttrOfType<mlir::DenseIntElementsAttr>( |
133 | kNewResourceLayoutIndices); |
134 | auto src_inferred_resource_handle_layouts_metadata = |
135 | src_cluster->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts); |
136 | |
137 | auto target_resource_handle_indices_metadata = |
138 | target_cluster->getAttrOfType<mlir::DenseIntElementsAttr>( |
139 | kNewResourceLayoutIndices); |
140 | auto target_inferred_resource_handle_layouts_metadata = |
141 | target_cluster->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts); |
142 | const bool should_merge_resource_metadata = |
143 | (src_inferred_resource_handle_layouts_metadata && |
144 | src_resource_handle_indices_metadata && |
145 | target_inferred_resource_handle_layouts_metadata && |
146 | target_resource_handle_indices_metadata); |
147 | // If only source cluster has metadata, then simply copy the metadata to |
148 | // target cluster. |
149 | if (src_inferred_resource_handle_layouts_metadata && |
150 | !target_inferred_resource_handle_layouts_metadata) { |
151 | target_cluster->setAttr(kNewResourceLayoutIndices, |
152 | src_resource_handle_indices_metadata); |
153 | target_cluster->setAttr(kNewResourceArgLayouts, |
154 | src_inferred_resource_handle_layouts_metadata); |
155 | } else if (should_merge_resource_metadata) { |
156 | // If both src cluster and target cluster has metadata, merge the metadata |
157 | // and check if there are no conflicts. |
158 | llvm::SmallVector<int, 4> merged_resource_indices; |
159 | llvm::SmallVector<mlir::Attribute, 4> merged_resource_layouts; |
160 | if (mlir::failed(MergeAttributes( |
161 | src_cluster, src_resource_handle_indices_metadata, |
162 | src_inferred_resource_handle_layouts_metadata, |
163 | target_resource_handle_indices_metadata, |
164 | target_inferred_resource_handle_layouts_metadata, |
165 | &merged_resource_indices, &merged_resource_layouts))) |
166 | return mlir::failure(); |
167 | |
168 | target_cluster->setAttr( |
169 | kNewResourceArgLayouts, |
170 | builder.getArrayAttr( |
171 | llvm::ArrayRef<mlir::Attribute>(merged_resource_layouts))); |
172 | |
173 | target_cluster->setAttr( |
174 | kNewResourceLayoutIndices, |
175 | builder.getI32VectorAttr(llvm::ArrayRef<int>(merged_resource_indices))); |
176 | } |
177 | |
178 | // Extract shape metadata from src/target clusters. |
179 | auto src_shape_layouts = |
180 | src_cluster->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout); |
181 | auto src_shape_op_indices = |
182 | src_cluster->getAttrOfType<mlir::DenseIntElementsAttr>( |
183 | kShapeOpInputLayoutIndices); |
184 | auto target_shape_layouts = |
185 | target_cluster->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout); |
186 | auto target_shape_op_indices = |
187 | target_cluster->getAttrOfType<mlir::DenseIntElementsAttr>( |
188 | kShapeOpInputLayoutIndices); |
189 | |
190 | const bool should_merge_shape_metadata = |
191 | (src_shape_layouts && src_shape_op_indices && target_shape_layouts && |
192 | target_shape_op_indices); |
193 | |
194 | // If only src cluster has shape metadata, copy shape metadata to target |
195 | // cluster. |
196 | if (src_shape_layouts && !target_shape_layouts) { |
197 | target_cluster->setAttr(kShapeOpInputLayoutIndices, src_shape_op_indices); |
198 | target_cluster->setAttr(kShapeOpInputLayout, src_shape_layouts); |
199 | } else if (should_merge_shape_metadata) { |
200 | // If both src/target clusters have shape metadata, merge the shape metadata |
201 | // and set the merged metadata to target cluster. |
202 | llvm::SmallVector<int, 4> merged_shape_indices; |
203 | llvm::SmallVector<mlir::Attribute, 4> merged_shape_layouts; |
204 | if (mlir::failed(MergeAttributes( |
205 | src_cluster, src_shape_op_indices, src_shape_layouts, |
206 | target_shape_op_indices, target_shape_layouts, |
207 | &merged_shape_indices, &merged_shape_layouts))) |
208 | return mlir::failure(); |
209 | |
210 | target_cluster->setAttr( |
211 | kShapeOpInputLayout, |
212 | builder.getArrayAttr( |
213 | llvm::ArrayRef<mlir::Attribute>(merged_shape_layouts))); |
214 | |
215 | target_cluster->setAttr( |
216 | kShapeOpInputLayoutIndices, |
217 | builder.getI32VectorAttr(llvm::ArrayRef<int>(merged_shape_indices))); |
218 | } |
219 | |
220 | return mlir::success(); |
221 | } |
222 | |
223 | // Removes tf_device.Cluster ops if tf_device.Cluster is nested inside another |
224 | // cluster and it has same mesh specification as parent cluster. |
225 | mlir::LogicalResult InlineNestedDeviceClusters(mlir::ModuleOp module) { |
226 | auto clusters = FindAllDeviceClusters(module); |
227 | for (mlir::tf_device::ClusterOp cluster : clusters) { |
228 | auto parent_cluster = |
229 | cluster->getParentOfType<mlir::tf_device::ClusterOp>(); |
230 | if (!parent_cluster) continue; |
231 | |
232 | Mesh cluster_mesh; |
233 | if (mlir::failed(ExtractMeshFromCluster(cluster, &cluster_mesh))) |
234 | return mlir::failure(); |
235 | |
236 | Mesh parent_cluster_mesh; |
237 | if (mlir::failed( |
238 | ExtractMeshFromCluster(parent_cluster, &parent_cluster_mesh))) |
239 | return mlir::failure(); |
240 | |
241 | if (parent_cluster_mesh != cluster_mesh) continue; |
242 | |
243 | // Found a tf_device.cluster that has same mesh specification as parent |
244 | // enclosing cluster. Remove the child cluster and move all ops to parent |
245 | // cluster instead. |
246 | for (auto it : llvm::zip(cluster.GetBody().getTerminator()->getOperands(), |
247 | cluster.getResults())) { |
248 | mlir::Value new_value = std::get<0>(it); |
249 | mlir::Value value_to_replace = std::get<1>(it); |
250 | value_to_replace.replaceAllUsesWith(new_value); |
251 | } |
252 | for (mlir::Operation& op : |
253 | llvm::make_early_inc_range(cluster.GetBody().without_terminator())) { |
254 | op.moveBefore(cluster); |
255 | } |
256 | |
257 | if (mlir::failed(MergeClusterMetadata(cluster, parent_cluster))) |
258 | return mlir::failure(); |
259 | |
260 | cluster.erase(); |
261 | } |
262 | return mlir::success(); |
263 | } |
264 | |
265 | // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions |
266 | // with yield op and an empty block. |
267 | void CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region, const Mesh& mesh, |
268 | mlir::OpBuilder& builder, int* num_send_recvs, |
269 | mlir::MLIRContext* context, |
270 | mlir::TF::IfRegionOp* cloned_if_region_op) { |
271 | // Create DTensorSend just before tf.If op before creating new cluster. The |
272 | // DTensorSend op sends the predicate to `mesh` cluster with replicated |
273 | // layout. |
274 | mlir::TensorType predicate_tensor_type = |
275 | if_region.cond().getType().cast<mlir::TensorType>(); |
276 | const std::string send_recv_key = |
277 | absl::StrCat(kSendRecvKeyPrefix, *num_send_recvs); |
278 | *num_send_recvs += 1; |
279 | |
280 | const Layout target_layout = Layout::ReplicatedOnMesh(mesh, 0); |
281 | builder.create<mlir::TF::DTensorSend>( |
282 | if_region.getLoc(), if_region.cond(), |
283 | builder.getStringAttr(send_recv_key), |
284 | mlir::dtensor::LayoutAttr::get(context, target_layout)); |
285 | |
286 | // Create new cluster op that contains cloned if operation. |
287 | auto new_cluster = builder.create<mlir::tf_device::ClusterOp>( |
288 | if_region.getLoc(), llvm::SmallVector<mlir::Type, 4>{}); |
289 | new_cluster.getBody().push_back(new mlir::Block); |
290 | builder.setInsertionPointToEnd(&new_cluster.GetBody()); |
291 | auto return_op = builder.create<mlir::tf_device::ReturnOp>( |
292 | if_region.getLoc(), llvm::SmallVector<mlir::Value, 4>{}); |
293 | |
294 | // Add DTensorRecv op inside new cluster that receives the cluster. |
295 | builder.setInsertionPoint(return_op); |
296 | auto recv_op = builder.create<mlir::TF::DTensorRecv>( |
297 | if_region.getLoc(), predicate_tensor_type, |
298 | builder.getStringAttr(send_recv_key), |
299 | mlir::TF::ShapeAttr::get(context, predicate_tensor_type), |
300 | mlir::dtensor::LayoutAttr::get(context, target_layout)); |
301 | |
302 | // Clone tf.IfRegion op inside newly created cluster and make sure |
303 | // that the predicate tensor is from DTensorRecv op created above. |
304 | auto host_side_if = builder.create<mlir::TF::IfRegionOp>( |
305 | if_region.getLoc(), llvm::SmallVector<mlir::Type, 4>{}, recv_op.output(), |
306 | if_region.is_stateless(), |
307 | GetUniqueControlflowFnName("cloned_if_then" , builder), |
308 | GetUniqueControlflowFnName("cloned_if_else" , builder)); |
309 | *cloned_if_region_op = host_side_if; |
310 | |
311 | // Create empty then branch region. |
312 | auto& then_branch = host_side_if.then_branch(); |
313 | then_branch.push_back(new mlir::Block); |
314 | builder.setInsertionPointToEnd(&then_branch.front()); |
315 | builder.create<mlir::TF::YieldOp>(if_region.getLoc(), |
316 | /*operands=*/llvm::ArrayRef<mlir::Value>{}); |
317 | |
318 | // Create empty else branch region. |
319 | auto& else_branch = host_side_if.else_branch(); |
320 | else_branch.push_back(new mlir::Block); |
321 | builder.setInsertionPointToEnd(&else_branch.front()); |
322 | builder.create<mlir::TF::YieldOp>(if_region.getLoc(), |
323 | /*operands=*/llvm::ArrayRef<mlir::Value>{}); |
324 | new_cluster->setAttr(kMeshAttr, builder.getStringAttr(mesh.ToString())); |
325 | } |
326 | |
327 | // Verifies that send/recv ops are used for input output of cluster. That is, |
328 | // cluster should not have any input/output edges. |
329 | mlir::LogicalResult VerifyClusterInputOutput( |
330 | mlir::tf_device::ClusterOp cluster) { |
331 | if (cluster.getNumResults() > 0) |
332 | return cluster->emitOpError( |
333 | "found nested tf_device.Cluster op with outputs. Nested cluster must " |
334 | "use send/recv instead." ); |
335 | |
336 | mlir::LogicalResult result = mlir::success(); |
337 | mlir::visitUsedValuesDefinedAbove( |
338 | cluster.getBody(), cluster.getBody(), [&](mlir::OpOperand* input) { |
339 | if (!input->get().isa<mlir::BlockArgument>()) { |
340 | result = cluster.emitOpError( |
341 | "found nested tf_device.Cluster op with inputs. Nested cluster " |
342 | "must use send/recv instead." ); |
343 | return; |
344 | } |
345 | }); |
346 | return result; |
347 | } |
348 | |
349 | // Returns whether `cluster` is inside then branch of `if_op`. |
350 | bool IsInsideIfThenBranch(mlir::TF::IfRegionOp if_op, |
351 | mlir::tf_device::ClusterOp cluster) { |
352 | assert(if_op->isProperAncestor(cluster)); |
353 | return if_op.then_branch().isAncestor(cluster->getParentRegion()); |
354 | } |
355 | |
356 | // Decomposes multi-mesh computation nested inside tf_if operations. See |
357 | // comments for `DecomposeControlflow()` function for details. |
358 | mlir::LogicalResult DecomposeIf(mlir::TF::IfRegionOp if_op, |
359 | mlir::MLIRContext* context, |
360 | int* num_control_flow_send_recvs) { |
361 | auto nested_clusters = FindAllDeviceClusters(if_op); |
362 | if (nested_clusters.empty()) return mlir::success(); |
363 | |
364 | for (mlir::tf_device::ClusterOp nested_cluster : nested_clusters) { |
365 | if (mlir::failed(VerifyClusterInputOutput(nested_cluster))) |
366 | return mlir::failure(); |
367 | |
368 | Mesh nested_mesh; |
369 | if (mlir::failed(ExtractMeshFromCluster(nested_cluster, &nested_mesh))) |
370 | return mlir::failure(); |
371 | |
372 | mlir::OpBuilder builder(if_op); |
373 | mlir::TF::IfRegionOp cloned_if; |
374 | CloneEmptyIfWithPredicate(if_op, nested_mesh, builder, |
375 | num_control_flow_send_recvs, context, &cloned_if); |
376 | |
377 | // Find nested clusters in then/else branch of original `if_op` and |
378 | // move all inner ops inside nested cluster to `tf_cloned` in |
379 | // corresponding branch. |
380 | if (IsInsideIfThenBranch(if_op, nested_cluster)) { |
381 | mlir::Operation* then_branch_terminator = |
382 | cloned_if.then_branch().begin()->getTerminator(); |
383 | auto& nested_cluster_operations = |
384 | nested_cluster.GetBody().getOperations(); |
385 | cloned_if.then_branch().begin()->getOperations().splice( |
386 | then_branch_terminator->getIterator(), nested_cluster_operations, |
387 | nested_cluster_operations.begin(), |
388 | std::prev(nested_cluster_operations.end())); |
389 | } else { |
390 | mlir::Operation* else_branch_terminator = |
391 | cloned_if.else_branch().begin()->getTerminator(); |
392 | auto& nested_cluster_operations = |
393 | nested_cluster.GetBody().getOperations(); |
394 | cloned_if.else_branch().begin()->getOperations().splice( |
395 | else_branch_terminator->getIterator(), nested_cluster_operations, |
396 | nested_cluster_operations.begin(), |
397 | std::prev(nested_cluster_operations.end())); |
398 | } |
399 | nested_cluster.erase(); |
400 | } |
401 | return mlir::success(); |
402 | } |
403 | |
404 | // Decomposes controlflows with nested mesh computations. When multi-mesh |
405 | // computation exists inside control flow operations like tf.If, then |
406 | // the control flow operations should be replicated to ensure correct execution |
407 | // semantics. |
408 | // For example: |
409 | // |
410 | // "tf_device.cluster"() ( { |
411 | // %1 = "tf.G"() : () -> (tensor<i1>) |
412 | // "tf.IfRegion"(%1) ({ |
413 | // "tf_device.cluster"() ( { |
414 | // "tf.D"() {} : () -> () |
415 | // tf_device.return |
416 | // }) {_mesh = "TPU|x=1|0|0|TPU:0"} : () -> () |
417 | // |
418 | // "tf.Yield"() : () -> () |
419 | // }, { |
420 | // }) {is_stateless = false} : (tensor<i1>) -> () |
421 | // tf_device.return |
422 | // }) {_mesh = "CPU|x=1|0|0|CPU:0"} : () -> () |
423 | // |
424 | // Above computation includes TPU device computation that exists inside |
425 | // tf.If op in CPU mesh. In this case, tf.If op should be replicated to TPU |
426 | // device computation so that `tf.D` op is executed in sync with CPU side |
427 | // computation. After transformation in this function, above IR is changed to: |
428 | // |
429 | // "tf_device.cluster"() ( { |
430 | // %1 = "tf.DTensorRecv"() : () -> tensor<i1> |
431 | // "tf.IfRegion"(%1) ( { |
432 | // "tf.D"() : () -> () |
433 | // "tf.Yield"() : () -> () |
434 | // }, { |
435 | // "tf.Yield"() : () -> () |
436 | // }) {is_stateless = false} : (tensor<i1>) -> () |
437 | // tf_device.return |
438 | // }) {_mesh = "TPU|x=1|0|0|TPU:0"} : () -> () |
439 | // |
440 | // "tf_device.cluster"() ( { |
441 | // %1 = "tf.G"() : () -> tensor<i1> |
442 | // "tf.DTensorSend"(%1) : (tensor<i1>) -> () |
443 | // "tf.IfRegion"(%1) ( { |
444 | // "tf.Yield"() : () -> () |
445 | // }, { |
446 | // "tf.Yield"() : () -> () |
447 | // }) {is_stateless = false} : (tensor<i1>) -> () |
448 | // tf_device.return |
449 | // }) {_mesh = "CPU|x=1|0|0|CPU:0"} : () -> () |
450 | // |
451 | // Note that: |
452 | // 1) Control flow is replicated. |
453 | // 2) DTensorSend/Recv ops are added to transfer predicate tensors for |
454 | // control flow operations |
455 | mlir::LogicalResult DecomposeControlflow(mlir::MLIRContext* context, |
456 | int* num_control_flow_send_recvs, |
457 | mlir::ModuleOp module) { |
458 | llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters; |
459 | // Identify all clusters in topological order. |
460 | module.walk([&](mlir::tf_device::ClusterOp cluster) { |
461 | clusters.emplace_back(cluster); |
462 | }); |
463 | |
464 | for (mlir::tf_device::ClusterOp cluster : clusters) { |
465 | mlir::WalkResult walk_result = cluster->walk([&](mlir::Operation* op) { |
466 | if (auto if_op = mlir::dyn_cast<mlir::TF::IfRegionOp>(op)) { |
467 | if (mlir::failed( |
468 | DecomposeIf(if_op, context, num_control_flow_send_recvs))) |
469 | return mlir::WalkResult::interrupt(); |
470 | } |
471 | return mlir::WalkResult::advance(); |
472 | }); |
473 | if (walk_result.wasInterrupted()) return mlir::failure(); |
474 | } |
475 | |
476 | return mlir::success(); |
477 | } |
478 | |
479 | // Merges multiple tf_device.clusters with same mesh specification to a single |
480 | // mesh cluster. |
481 | mlir::LogicalResult MergeClusters(mlir::ModuleOp module) { |
482 | mlir::func::FuncOp main_func = |
483 | module.lookupSymbol<mlir::func::FuncOp>("main" ); |
484 | |
485 | // Create global cluster for each mesh in entire computation. |
486 | auto clusters = FindAllDeviceClusters(main_func); |
487 | mlir::Block& func_block = *main_func.getBody().begin(); |
488 | mlir::OpBuilder builder(&func_block.front()); |
489 | std::map<Mesh, llvm::SmallVector<mlir::tf_device::ClusterOp, 4>> cluster_map; |
490 | std::vector<Mesh> meshes; |
491 | for (mlir::tf_device::ClusterOp cluster : clusters) { |
492 | Mesh mesh; |
493 | if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh))) |
494 | return mlir::failure(); |
495 | |
496 | if (cluster_map.find(mesh) != cluster_map.end()) { |
497 | cluster_map[mesh].emplace_back(cluster); |
498 | } else { |
499 | cluster_map[mesh] = |
500 | llvm::SmallVector<mlir::tf_device::ClusterOp, 4>{cluster}; |
501 | meshes.push_back(std::move(mesh)); |
502 | } |
503 | } |
504 | |
505 | // Reevaluate if this sort is necessary after b/186804270 is closed. |
506 | std::sort(meshes.begin(), meshes.end(), [](const Mesh& a, const Mesh& b) { |
507 | if (a.device_type() != b.device_type()) { |
508 | return a.device_type() < b.device_type(); |
509 | } |
510 | return a < b; |
511 | }); |
512 | for (const Mesh& mesh : meshes) { |
513 | const auto& mesh_cluster_list = cluster_map[mesh]; |
514 | llvm::SmallVector<mlir::Value, 4> merged_cluster_outputs; |
515 | llvm::SmallVector<mlir::Value, 4> merged_return_values; |
516 | llvm::SmallVector<mlir::Type, 4> merged_return_types; |
517 | |
518 | for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) { |
519 | merged_cluster_outputs.insert(merged_cluster_outputs.end(), |
520 | cluster.getResults().begin(), |
521 | cluster.getResults().end()); |
522 | |
523 | auto return_values = cluster.GetBody().getTerminator()->getOperands(); |
524 | merged_return_values.insert(merged_return_values.end(), |
525 | return_values.begin(), return_values.end()); |
526 | |
527 | auto return_type = cluster->getResultTypes(); |
528 | merged_return_types.insert(merged_return_types.end(), return_type.begin(), |
529 | return_type.end()); |
530 | } |
531 | |
532 | // Create a single cluster op contains merged computations for `mesh`. |
533 | builder.setInsertionPoint(&func_block.front()); |
534 | auto new_cluster = builder.create<mlir::tf_device::ClusterOp>( |
535 | module.getLoc(), merged_return_types); |
536 | new_cluster.getBody().push_back(new mlir::Block); |
537 | new_cluster->setAttr(kMeshAttr, builder.getStringAttr(mesh.ToString())); |
538 | |
539 | // Move all ops inside clusters in cluster mesh to `new_cluster`. |
540 | for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) { |
541 | mlir::Block& cluster_body = cluster.GetBody(); |
542 | for (mlir::Operation& op_to_move : |
543 | llvm::make_early_inc_range(cluster_body.without_terminator())) { |
544 | for (mlir::OpOperand& use : op_to_move.getUses()) { |
545 | auto return_op = |
546 | llvm::dyn_cast<mlir::tf_device::ReturnOp>(use.getOwner()); |
547 | if (!return_op) continue; |
548 | |
549 | mlir::Value output = cluster.getResult(use.getOperandNumber()); |
550 | output.replaceUsesWithIf(use.get(), [](mlir::OpOperand& operand) { |
551 | return operand.getOwner() |
552 | ->getParentOfType<mlir::tf_device::ClusterOp>() != |
553 | nullptr; |
554 | }); |
555 | } |
556 | op_to_move.moveBefore(new_cluster.SingleBlock::getBody(), |
557 | new_cluster.SingleBlock::getBody()->end()); |
558 | } |
559 | } |
560 | |
561 | builder.setInsertionPointToEnd(&new_cluster.GetBody()); |
562 | builder.create<mlir::tf_device::ReturnOp>(new_cluster.getLoc(), |
563 | merged_return_values); |
564 | |
565 | // Replace return value usages. |
566 | for (auto it : |
567 | llvm::zip(merged_cluster_outputs, new_cluster.getResults())) { |
568 | mlir::Value value_to_replace = std::get<0>(it); |
569 | mlir::Value new_result_value = std::get<1>(it); |
570 | value_to_replace.replaceAllUsesWith(new_result_value); |
571 | } |
572 | |
573 | // Erase clusters in cluster_map now that all ops are moved. |
574 | for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) { |
575 | if (mlir::failed(MergeClusterMetadata(cluster, new_cluster))) |
576 | return mlir::failure(); |
577 | |
578 | cluster.erase(); |
579 | } |
580 | } |
581 | |
582 | return mlir::success(); |
583 | } |
584 | |
585 | // Pass that merges multiple tf_device.Cluster ops for multi-mesh computation |
586 | // into a single cluster. After this pass, exactly one tf_device.Cluster op |
587 | // exists for each device mesh. |
588 | struct DTensorMergeClusters |
589 | : public impl::DTensorMergeClustersBase<DTensorMergeClusters> { |
590 | void getDependentDialects(mlir::DialectRegistry& registry) const override { |
591 | registry.insert<mlir::dtensor::DTensorDialect>(); |
592 | } |
593 | |
594 | void runOnOperation() override { |
595 | mlir::MLIRContext& context = getContext(); |
596 | mlir::OpBuilder op_builder(&context); |
597 | auto module = getOperation(); |
598 | if (mlir::failed(InlineNestedDeviceClusters(module))) |
599 | return signalPassFailure(); |
600 | |
601 | int num_controlflow_send_recv = 0; |
602 | if (mlir::failed( |
603 | DecomposeControlflow(&context, &num_controlflow_send_recv, module))) |
604 | return signalPassFailure(); |
605 | |
606 | if (mlir::failed(MergeClusters(module))) return signalPassFailure(); |
607 | |
608 | llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters; |
609 | module.walk([&](mlir::tf_device::ClusterOp cluster) { |
610 | clusters.emplace_back(cluster); |
611 | }); |
612 | |
613 | for (mlir::tf_device::ClusterOp cluster : clusters) { |
614 | RemoveUnusedClusterResults(cluster); |
615 | } |
616 | }; |
617 | }; |
618 | |
619 | } // namespace |
620 | |
621 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
622 | CreateDTensorMergeClustersPass() { |
623 | return std::make_unique<DTensorMergeClusters>(); |
624 | } |
625 | |
626 | } // namespace dtensor |
627 | } // namespace tensorflow |
628 | |