1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <string>
18#include <utility>
19
20#include "absl/strings/str_cat.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/DenseMap.h"
23#include "llvm/ADT/DenseSet.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallVector.h"
26#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27#include "mlir/IR/Attributes.h" // from @llvm-project
28#include "mlir/IR/Builders.h" // from @llvm-project
29#include "mlir/IR/BuiltinOps.h" // from @llvm-project
30#include "mlir/IR/Diagnostics.h" // from @llvm-project
31#include "mlir/IR/Operation.h" // from @llvm-project
32#include "mlir/IR/Types.h" // from @llvm-project
33#include "mlir/IR/Value.h" // from @llvm-project
34#include "mlir/IR/Visitors.h" // from @llvm-project
35#include "mlir/Pass/Pass.h" // from @llvm-project
36#include "mlir/Pass/PassManager.h" // from @llvm-project
37#include "mlir/Support/LogicalResult.h" // from @llvm-project
38#include "mlir/Transforms/Passes.h" // from @llvm-project
39#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
40#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
41#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
43#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
44#include "tensorflow/dtensor/cc/constants.h"
45#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
46#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
47#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
48#include "tensorflow/dtensor/mlir/layout_parsing.h"
49#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
50
51namespace tensorflow {
52namespace dtensor {
53
54namespace {
55#define GEN_PASS_DEF_DTENSORMERGECLUSTERS
56#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
57
58constexpr char kMissingMeshErrorMsg[] =
59 "Failed to extract mesh for DTensorMergeCluster pass. "
60 "All clusters must have specified mesh.";
61
62constexpr char kSendRecvKeyPrefix[] = "SendRecvKeyForControlflow_";
63
64// Extracts mesh from `cluster`.
65mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
66 Mesh* mesh_output) {
67 auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
68 if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
69
70 const absl::optional<Mesh>& mesh_or_null = *mesh_or_status;
71 if (!mesh_or_null.has_value())
72 return cluster.emitOpError(kMissingMeshErrorMsg);
73
74 *mesh_output = mesh_or_null.value();
75 return mlir::success();
76}
77
78// Returns all tf_device.ClusterOps nested inside `op`.
79llvm::SmallVector<mlir::tf_device::ClusterOp, 4> FindAllDeviceClusters(
80 mlir::Operation* op) {
81 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> nested_clusters;
82 op->walk([&](mlir::tf_device::ClusterOp nested_cluster) {
83 nested_clusters.emplace_back(nested_cluster);
84 });
85 return nested_clusters;
86}
87
88mlir::LogicalResult MergeAttributes(
89 mlir::Operation* op, mlir::DenseIntElementsAttr indices_attr,
90 mlir::ArrayAttr layout_attr, mlir::DenseIntElementsAttr indices_attr2,
91 mlir::ArrayAttr layout_attr2, llvm::SmallVector<int, 4>* merged_indices,
92 llvm::SmallVector<mlir::Attribute, 4>* merged_layout) {
93 llvm::SmallDenseMap<llvm::APInt, mlir::Attribute> attr_map;
94 attr_map.reserve(indices_attr.size() + indices_attr2.size());
95 for (const auto& data : llvm::zip(indices_attr, layout_attr))
96 attr_map.try_emplace(std::get<0>(data), std::get<1>(data));
97
98 for (const auto& data : llvm::zip(indices_attr2, layout_attr2)) {
99 const auto& index = std::get<0>(data);
100 const auto& layout = std::get<1>(data);
101 auto result = attr_map.try_emplace(index, layout);
102
103 if (!result.second && layout != result.first->getSecond()) {
104 return op->emitOpError(
105 "Found conflicting metadata attributes while merging clusters");
106 }
107 }
108
109 merged_indices->reserve(attr_map.size());
110 merged_layout->reserve(attr_map.size());
111 for (const auto& it : attr_map) {
112 merged_indices->emplace_back(it.first.getSExtValue());
113 merged_layout->emplace_back(it.second);
114 }
115 return mlir::success();
116}
117
118// Merges metadata attribute from `src_cluster` to `target_cluster`. If metadata
119// attribute exists for both clusters, merge the attributes and verify that
120// there are no conflicing attributes.
121mlir::LogicalResult MergeClusterMetadata(
122 mlir::tf_device::ClusterOp src_cluster,
123 mlir::tf_device::ClusterOp target_cluster) {
124 if (mlir::failed(ValidateMetadataAttributes(src_cluster)) ||
125 mlir::failed(ValidateMetadataAttributes(target_cluster)))
126 return mlir::failure();
127
128 mlir::OpBuilder builder(target_cluster);
129
130 // Extract resource metadata from src/target clusters.
131 auto src_resource_handle_indices_metadata =
132 src_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
133 kNewResourceLayoutIndices);
134 auto src_inferred_resource_handle_layouts_metadata =
135 src_cluster->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
136
137 auto target_resource_handle_indices_metadata =
138 target_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
139 kNewResourceLayoutIndices);
140 auto target_inferred_resource_handle_layouts_metadata =
141 target_cluster->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
142 const bool should_merge_resource_metadata =
143 (src_inferred_resource_handle_layouts_metadata &&
144 src_resource_handle_indices_metadata &&
145 target_inferred_resource_handle_layouts_metadata &&
146 target_resource_handle_indices_metadata);
147 // If only source cluster has metadata, then simply copy the metadata to
148 // target cluster.
149 if (src_inferred_resource_handle_layouts_metadata &&
150 !target_inferred_resource_handle_layouts_metadata) {
151 target_cluster->setAttr(kNewResourceLayoutIndices,
152 src_resource_handle_indices_metadata);
153 target_cluster->setAttr(kNewResourceArgLayouts,
154 src_inferred_resource_handle_layouts_metadata);
155 } else if (should_merge_resource_metadata) {
156 // If both src cluster and target cluster has metadata, merge the metadata
157 // and check if there are no conflicts.
158 llvm::SmallVector<int, 4> merged_resource_indices;
159 llvm::SmallVector<mlir::Attribute, 4> merged_resource_layouts;
160 if (mlir::failed(MergeAttributes(
161 src_cluster, src_resource_handle_indices_metadata,
162 src_inferred_resource_handle_layouts_metadata,
163 target_resource_handle_indices_metadata,
164 target_inferred_resource_handle_layouts_metadata,
165 &merged_resource_indices, &merged_resource_layouts)))
166 return mlir::failure();
167
168 target_cluster->setAttr(
169 kNewResourceArgLayouts,
170 builder.getArrayAttr(
171 llvm::ArrayRef<mlir::Attribute>(merged_resource_layouts)));
172
173 target_cluster->setAttr(
174 kNewResourceLayoutIndices,
175 builder.getI32VectorAttr(llvm::ArrayRef<int>(merged_resource_indices)));
176 }
177
178 // Extract shape metadata from src/target clusters.
179 auto src_shape_layouts =
180 src_cluster->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
181 auto src_shape_op_indices =
182 src_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
183 kShapeOpInputLayoutIndices);
184 auto target_shape_layouts =
185 target_cluster->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
186 auto target_shape_op_indices =
187 target_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
188 kShapeOpInputLayoutIndices);
189
190 const bool should_merge_shape_metadata =
191 (src_shape_layouts && src_shape_op_indices && target_shape_layouts &&
192 target_shape_op_indices);
193
194 // If only src cluster has shape metadata, copy shape metadata to target
195 // cluster.
196 if (src_shape_layouts && !target_shape_layouts) {
197 target_cluster->setAttr(kShapeOpInputLayoutIndices, src_shape_op_indices);
198 target_cluster->setAttr(kShapeOpInputLayout, src_shape_layouts);
199 } else if (should_merge_shape_metadata) {
200 // If both src/target clusters have shape metadata, merge the shape metadata
201 // and set the merged metadata to target cluster.
202 llvm::SmallVector<int, 4> merged_shape_indices;
203 llvm::SmallVector<mlir::Attribute, 4> merged_shape_layouts;
204 if (mlir::failed(MergeAttributes(
205 src_cluster, src_shape_op_indices, src_shape_layouts,
206 target_shape_op_indices, target_shape_layouts,
207 &merged_shape_indices, &merged_shape_layouts)))
208 return mlir::failure();
209
210 target_cluster->setAttr(
211 kShapeOpInputLayout,
212 builder.getArrayAttr(
213 llvm::ArrayRef<mlir::Attribute>(merged_shape_layouts)));
214
215 target_cluster->setAttr(
216 kShapeOpInputLayoutIndices,
217 builder.getI32VectorAttr(llvm::ArrayRef<int>(merged_shape_indices)));
218 }
219
220 return mlir::success();
221}
222
223// Removes tf_device.Cluster ops if tf_device.Cluster is nested inside another
224// cluster and it has same mesh specification as parent cluster.
225mlir::LogicalResult InlineNestedDeviceClusters(mlir::ModuleOp module) {
226 auto clusters = FindAllDeviceClusters(module);
227 for (mlir::tf_device::ClusterOp cluster : clusters) {
228 auto parent_cluster =
229 cluster->getParentOfType<mlir::tf_device::ClusterOp>();
230 if (!parent_cluster) continue;
231
232 Mesh cluster_mesh;
233 if (mlir::failed(ExtractMeshFromCluster(cluster, &cluster_mesh)))
234 return mlir::failure();
235
236 Mesh parent_cluster_mesh;
237 if (mlir::failed(
238 ExtractMeshFromCluster(parent_cluster, &parent_cluster_mesh)))
239 return mlir::failure();
240
241 if (parent_cluster_mesh != cluster_mesh) continue;
242
243 // Found a tf_device.cluster that has same mesh specification as parent
244 // enclosing cluster. Remove the child cluster and move all ops to parent
245 // cluster instead.
246 for (auto it : llvm::zip(cluster.GetBody().getTerminator()->getOperands(),
247 cluster.getResults())) {
248 mlir::Value new_value = std::get<0>(it);
249 mlir::Value value_to_replace = std::get<1>(it);
250 value_to_replace.replaceAllUsesWith(new_value);
251 }
252 for (mlir::Operation& op :
253 llvm::make_early_inc_range(cluster.GetBody().without_terminator())) {
254 op.moveBefore(cluster);
255 }
256
257 if (mlir::failed(MergeClusterMetadata(cluster, parent_cluster)))
258 return mlir::failure();
259
260 cluster.erase();
261 }
262 return mlir::success();
263}
264
265// Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
266// with yield op and an empty block.
267void CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region, const Mesh& mesh,
268 mlir::OpBuilder& builder, int* num_send_recvs,
269 mlir::MLIRContext* context,
270 mlir::TF::IfRegionOp* cloned_if_region_op) {
271 // Create DTensorSend just before tf.If op before creating new cluster. The
272 // DTensorSend op sends the predicate to `mesh` cluster with replicated
273 // layout.
274 mlir::TensorType predicate_tensor_type =
275 if_region.cond().getType().cast<mlir::TensorType>();
276 const std::string send_recv_key =
277 absl::StrCat(kSendRecvKeyPrefix, *num_send_recvs);
278 *num_send_recvs += 1;
279
280 const Layout target_layout = Layout::ReplicatedOnMesh(mesh, 0);
281 builder.create<mlir::TF::DTensorSend>(
282 if_region.getLoc(), if_region.cond(),
283 builder.getStringAttr(send_recv_key),
284 mlir::dtensor::LayoutAttr::get(context, target_layout));
285
286 // Create new cluster op that contains cloned if operation.
287 auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
288 if_region.getLoc(), llvm::SmallVector<mlir::Type, 4>{});
289 new_cluster.getBody().push_back(new mlir::Block);
290 builder.setInsertionPointToEnd(&new_cluster.GetBody());
291 auto return_op = builder.create<mlir::tf_device::ReturnOp>(
292 if_region.getLoc(), llvm::SmallVector<mlir::Value, 4>{});
293
294 // Add DTensorRecv op inside new cluster that receives the cluster.
295 builder.setInsertionPoint(return_op);
296 auto recv_op = builder.create<mlir::TF::DTensorRecv>(
297 if_region.getLoc(), predicate_tensor_type,
298 builder.getStringAttr(send_recv_key),
299 mlir::TF::ShapeAttr::get(context, predicate_tensor_type),
300 mlir::dtensor::LayoutAttr::get(context, target_layout));
301
302 // Clone tf.IfRegion op inside newly created cluster and make sure
303 // that the predicate tensor is from DTensorRecv op created above.
304 auto host_side_if = builder.create<mlir::TF::IfRegionOp>(
305 if_region.getLoc(), llvm::SmallVector<mlir::Type, 4>{}, recv_op.output(),
306 if_region.is_stateless(),
307 GetUniqueControlflowFnName("cloned_if_then", builder),
308 GetUniqueControlflowFnName("cloned_if_else", builder));
309 *cloned_if_region_op = host_side_if;
310
311 // Create empty then branch region.
312 auto& then_branch = host_side_if.then_branch();
313 then_branch.push_back(new mlir::Block);
314 builder.setInsertionPointToEnd(&then_branch.front());
315 builder.create<mlir::TF::YieldOp>(if_region.getLoc(),
316 /*operands=*/llvm::ArrayRef<mlir::Value>{});
317
318 // Create empty else branch region.
319 auto& else_branch = host_side_if.else_branch();
320 else_branch.push_back(new mlir::Block);
321 builder.setInsertionPointToEnd(&else_branch.front());
322 builder.create<mlir::TF::YieldOp>(if_region.getLoc(),
323 /*operands=*/llvm::ArrayRef<mlir::Value>{});
324 new_cluster->setAttr(kMeshAttr, builder.getStringAttr(mesh.ToString()));
325}
326
327// Verifies that send/recv ops are used for input output of cluster. That is,
328// cluster should not have any input/output edges.
329mlir::LogicalResult VerifyClusterInputOutput(
330 mlir::tf_device::ClusterOp cluster) {
331 if (cluster.getNumResults() > 0)
332 return cluster->emitOpError(
333 "found nested tf_device.Cluster op with outputs. Nested cluster must "
334 "use send/recv instead.");
335
336 mlir::LogicalResult result = mlir::success();
337 mlir::visitUsedValuesDefinedAbove(
338 cluster.getBody(), cluster.getBody(), [&](mlir::OpOperand* input) {
339 if (!input->get().isa<mlir::BlockArgument>()) {
340 result = cluster.emitOpError(
341 "found nested tf_device.Cluster op with inputs. Nested cluster "
342 "must use send/recv instead.");
343 return;
344 }
345 });
346 return result;
347}
348
349// Returns whether `cluster` is inside then branch of `if_op`.
350bool IsInsideIfThenBranch(mlir::TF::IfRegionOp if_op,
351 mlir::tf_device::ClusterOp cluster) {
352 assert(if_op->isProperAncestor(cluster));
353 return if_op.then_branch().isAncestor(cluster->getParentRegion());
354}
355
356// Decomposes multi-mesh computation nested inside tf_if operations. See
357// comments for `DecomposeControlflow()` function for details.
358mlir::LogicalResult DecomposeIf(mlir::TF::IfRegionOp if_op,
359 mlir::MLIRContext* context,
360 int* num_control_flow_send_recvs) {
361 auto nested_clusters = FindAllDeviceClusters(if_op);
362 if (nested_clusters.empty()) return mlir::success();
363
364 for (mlir::tf_device::ClusterOp nested_cluster : nested_clusters) {
365 if (mlir::failed(VerifyClusterInputOutput(nested_cluster)))
366 return mlir::failure();
367
368 Mesh nested_mesh;
369 if (mlir::failed(ExtractMeshFromCluster(nested_cluster, &nested_mesh)))
370 return mlir::failure();
371
372 mlir::OpBuilder builder(if_op);
373 mlir::TF::IfRegionOp cloned_if;
374 CloneEmptyIfWithPredicate(if_op, nested_mesh, builder,
375 num_control_flow_send_recvs, context, &cloned_if);
376
377 // Find nested clusters in then/else branch of original `if_op` and
378 // move all inner ops inside nested cluster to `tf_cloned` in
379 // corresponding branch.
380 if (IsInsideIfThenBranch(if_op, nested_cluster)) {
381 mlir::Operation* then_branch_terminator =
382 cloned_if.then_branch().begin()->getTerminator();
383 auto& nested_cluster_operations =
384 nested_cluster.GetBody().getOperations();
385 cloned_if.then_branch().begin()->getOperations().splice(
386 then_branch_terminator->getIterator(), nested_cluster_operations,
387 nested_cluster_operations.begin(),
388 std::prev(nested_cluster_operations.end()));
389 } else {
390 mlir::Operation* else_branch_terminator =
391 cloned_if.else_branch().begin()->getTerminator();
392 auto& nested_cluster_operations =
393 nested_cluster.GetBody().getOperations();
394 cloned_if.else_branch().begin()->getOperations().splice(
395 else_branch_terminator->getIterator(), nested_cluster_operations,
396 nested_cluster_operations.begin(),
397 std::prev(nested_cluster_operations.end()));
398 }
399 nested_cluster.erase();
400 }
401 return mlir::success();
402}
403
404// Decomposes controlflows with nested mesh computations. When multi-mesh
405// computation exists inside control flow operations like tf.If, then
406// the control flow operations should be replicated to ensure correct execution
407// semantics.
408// For example:
409//
410// "tf_device.cluster"() ( {
411// %1 = "tf.G"() : () -> (tensor<i1>)
412// "tf.IfRegion"(%1) ({
413// "tf_device.cluster"() ( {
414// "tf.D"() {} : () -> ()
415// tf_device.return
416// }) {_mesh = "TPU|x=1|0|0|TPU:0"} : () -> ()
417//
418// "tf.Yield"() : () -> ()
419// }, {
420// }) {is_stateless = false} : (tensor<i1>) -> ()
421// tf_device.return
422// }) {_mesh = "CPU|x=1|0|0|CPU:0"} : () -> ()
423//
424// Above computation includes TPU device computation that exists inside
425// tf.If op in CPU mesh. In this case, tf.If op should be replicated to TPU
426// device computation so that `tf.D` op is executed in sync with CPU side
427// computation. After transformation in this function, above IR is changed to:
428//
429// "tf_device.cluster"() ( {
430// %1 = "tf.DTensorRecv"() : () -> tensor<i1>
431// "tf.IfRegion"(%1) ( {
432// "tf.D"() : () -> ()
433// "tf.Yield"() : () -> ()
434// }, {
435// "tf.Yield"() : () -> ()
436// }) {is_stateless = false} : (tensor<i1>) -> ()
437// tf_device.return
438// }) {_mesh = "TPU|x=1|0|0|TPU:0"} : () -> ()
439//
440// "tf_device.cluster"() ( {
441// %1 = "tf.G"() : () -> tensor<i1>
442// "tf.DTensorSend"(%1) : (tensor<i1>) -> ()
443// "tf.IfRegion"(%1) ( {
444// "tf.Yield"() : () -> ()
445// }, {
446// "tf.Yield"() : () -> ()
447// }) {is_stateless = false} : (tensor<i1>) -> ()
448// tf_device.return
449// }) {_mesh = "CPU|x=1|0|0|CPU:0"} : () -> ()
450//
451// Note that:
452// 1) Control flow is replicated.
453// 2) DTensorSend/Recv ops are added to transfer predicate tensors for
454// control flow operations
455mlir::LogicalResult DecomposeControlflow(mlir::MLIRContext* context,
456 int* num_control_flow_send_recvs,
457 mlir::ModuleOp module) {
458 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
459 // Identify all clusters in topological order.
460 module.walk([&](mlir::tf_device::ClusterOp cluster) {
461 clusters.emplace_back(cluster);
462 });
463
464 for (mlir::tf_device::ClusterOp cluster : clusters) {
465 mlir::WalkResult walk_result = cluster->walk([&](mlir::Operation* op) {
466 if (auto if_op = mlir::dyn_cast<mlir::TF::IfRegionOp>(op)) {
467 if (mlir::failed(
468 DecomposeIf(if_op, context, num_control_flow_send_recvs)))
469 return mlir::WalkResult::interrupt();
470 }
471 return mlir::WalkResult::advance();
472 });
473 if (walk_result.wasInterrupted()) return mlir::failure();
474 }
475
476 return mlir::success();
477}
478
479// Merges multiple tf_device.clusters with same mesh specification to a single
480// mesh cluster.
481mlir::LogicalResult MergeClusters(mlir::ModuleOp module) {
482 mlir::func::FuncOp main_func =
483 module.lookupSymbol<mlir::func::FuncOp>("main");
484
485 // Create global cluster for each mesh in entire computation.
486 auto clusters = FindAllDeviceClusters(main_func);
487 mlir::Block& func_block = *main_func.getBody().begin();
488 mlir::OpBuilder builder(&func_block.front());
489 std::map<Mesh, llvm::SmallVector<mlir::tf_device::ClusterOp, 4>> cluster_map;
490 std::vector<Mesh> meshes;
491 for (mlir::tf_device::ClusterOp cluster : clusters) {
492 Mesh mesh;
493 if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh)))
494 return mlir::failure();
495
496 if (cluster_map.find(mesh) != cluster_map.end()) {
497 cluster_map[mesh].emplace_back(cluster);
498 } else {
499 cluster_map[mesh] =
500 llvm::SmallVector<mlir::tf_device::ClusterOp, 4>{cluster};
501 meshes.push_back(std::move(mesh));
502 }
503 }
504
505 // Reevaluate if this sort is necessary after b/186804270 is closed.
506 std::sort(meshes.begin(), meshes.end(), [](const Mesh& a, const Mesh& b) {
507 if (a.device_type() != b.device_type()) {
508 return a.device_type() < b.device_type();
509 }
510 return a < b;
511 });
512 for (const Mesh& mesh : meshes) {
513 const auto& mesh_cluster_list = cluster_map[mesh];
514 llvm::SmallVector<mlir::Value, 4> merged_cluster_outputs;
515 llvm::SmallVector<mlir::Value, 4> merged_return_values;
516 llvm::SmallVector<mlir::Type, 4> merged_return_types;
517
518 for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
519 merged_cluster_outputs.insert(merged_cluster_outputs.end(),
520 cluster.getResults().begin(),
521 cluster.getResults().end());
522
523 auto return_values = cluster.GetBody().getTerminator()->getOperands();
524 merged_return_values.insert(merged_return_values.end(),
525 return_values.begin(), return_values.end());
526
527 auto return_type = cluster->getResultTypes();
528 merged_return_types.insert(merged_return_types.end(), return_type.begin(),
529 return_type.end());
530 }
531
532 // Create a single cluster op contains merged computations for `mesh`.
533 builder.setInsertionPoint(&func_block.front());
534 auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
535 module.getLoc(), merged_return_types);
536 new_cluster.getBody().push_back(new mlir::Block);
537 new_cluster->setAttr(kMeshAttr, builder.getStringAttr(mesh.ToString()));
538
539 // Move all ops inside clusters in cluster mesh to `new_cluster`.
540 for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
541 mlir::Block& cluster_body = cluster.GetBody();
542 for (mlir::Operation& op_to_move :
543 llvm::make_early_inc_range(cluster_body.without_terminator())) {
544 for (mlir::OpOperand& use : op_to_move.getUses()) {
545 auto return_op =
546 llvm::dyn_cast<mlir::tf_device::ReturnOp>(use.getOwner());
547 if (!return_op) continue;
548
549 mlir::Value output = cluster.getResult(use.getOperandNumber());
550 output.replaceUsesWithIf(use.get(), [](mlir::OpOperand& operand) {
551 return operand.getOwner()
552 ->getParentOfType<mlir::tf_device::ClusterOp>() !=
553 nullptr;
554 });
555 }
556 op_to_move.moveBefore(new_cluster.SingleBlock::getBody(),
557 new_cluster.SingleBlock::getBody()->end());
558 }
559 }
560
561 builder.setInsertionPointToEnd(&new_cluster.GetBody());
562 builder.create<mlir::tf_device::ReturnOp>(new_cluster.getLoc(),
563 merged_return_values);
564
565 // Replace return value usages.
566 for (auto it :
567 llvm::zip(merged_cluster_outputs, new_cluster.getResults())) {
568 mlir::Value value_to_replace = std::get<0>(it);
569 mlir::Value new_result_value = std::get<1>(it);
570 value_to_replace.replaceAllUsesWith(new_result_value);
571 }
572
573 // Erase clusters in cluster_map now that all ops are moved.
574 for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
575 if (mlir::failed(MergeClusterMetadata(cluster, new_cluster)))
576 return mlir::failure();
577
578 cluster.erase();
579 }
580 }
581
582 return mlir::success();
583}
584
585// Pass that merges multiple tf_device.Cluster ops for multi-mesh computation
586// into a single cluster. After this pass, exactly one tf_device.Cluster op
587// exists for each device mesh.
588struct DTensorMergeClusters
589 : public impl::DTensorMergeClustersBase<DTensorMergeClusters> {
590 void getDependentDialects(mlir::DialectRegistry& registry) const override {
591 registry.insert<mlir::dtensor::DTensorDialect>();
592 }
593
594 void runOnOperation() override {
595 mlir::MLIRContext& context = getContext();
596 mlir::OpBuilder op_builder(&context);
597 auto module = getOperation();
598 if (mlir::failed(InlineNestedDeviceClusters(module)))
599 return signalPassFailure();
600
601 int num_controlflow_send_recv = 0;
602 if (mlir::failed(
603 DecomposeControlflow(&context, &num_controlflow_send_recv, module)))
604 return signalPassFailure();
605
606 if (mlir::failed(MergeClusters(module))) return signalPassFailure();
607
608 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
609 module.walk([&](mlir::tf_device::ClusterOp cluster) {
610 clusters.emplace_back(cluster);
611 });
612
613 for (mlir::tf_device::ClusterOp cluster : clusters) {
614 RemoveUnusedClusterResults(cluster);
615 }
616 };
617};
618
619} // namespace
620
621std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
622CreateDTensorMergeClustersPass() {
623 return std::make_unique<DTensorMergeClusters>();
624}
625
626} // namespace dtensor
627} // namespace tensorflow
628